diff --git a/.asf.yaml b/.asf.yaml index 3973431cb9d9..ac1cf1a707d6 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -16,7 +16,7 @@ # under the License. github: - description: "Open deep learning compiler stack for cpu, gpu and specialized accelerators" + description: "Open Machine Learning Compiler Framework" homepage: https://tvm.apache.org/ labels: - tvm @@ -33,6 +33,12 @@ github: - spirv - machine-learning + features: + # Enable issue management + issues: true + # Enable projects for project management boards + projects: true + # Triage perm for collaborators(test run) # # The perm is given based on needs and not based on @@ -45,10 +51,6 @@ github: # participation, permission is given on a three month # cycle. PMC may review and recycle slots when necessary. collaborators: - - hpanda-naut - - denise-k - - janetsc - - naut-thomas - tvm-bot # For automated feedback in PR review. # See https://cwiki.apache.org/confluence/display/INFRA/Git+-+.asf.yaml+features#Git.asf.yamlfeatures-Branchprotection @@ -68,3 +70,24 @@ github: required_pull_request_reviews: required_approving_review_count: 1 + + enabled_merge_buttons: + # enable squash button: + squash: true + # default commit message when merging with a squash commit + # can either be: DEFAULT | PR_TITLE | PR_TITLE_AND_COMMIT_DETAILS | PR_TITLE_AND_DESC + squash_commit_message: PR_TITLE_AND_DESC + # enable merge button: + merge: false + # default commit message when merging with a merge commit + # can either be: DEFAULT | PR_TITLE | PR_TITLE_AND_DESC + merge_commit_message: DEFAULT + # enable rebase button for rare use. + rebase: true + +notifications: + commits: commits@tvm.apache.org + issues: discuss-archive@tvm.apache.org + pullrequests: discuss-archive@tvm.apache.org + jobs: discuss-archive@tvm.apache.org + discussions: discuss-archive@tvm.apache.org diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index cd7fd9197fae..8288c6f6418a 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -1,12 +1,12 @@ runs: using: "composite" steps: - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: CACHE_NUMBER: 2 with: path: ~/conda_pkgs_dir - key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('conda/build-environment.yaml') }} + key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('**/conda/build-environment.yaml') }} - uses: conda-incubator/setup-miniconda@v3 continue-on-error: true id: conda1 @@ -36,3 +36,7 @@ runs: mamba list mamba info --envs mamba list --name base + - name: Install tvm-ffi pip package + shell: bash -l {0} + run: | + pip install -v ./3rdparty/tvm-ffi diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d615eb9231e4..7b55dade1429 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -44,11 +44,20 @@ jobs: submodules: 'recursive' - name: Set up environment uses: ./.github/actions/setup - - name: Conda Build + - name: Install LLVM dependencies shell: bash -l {0} - run: >- - conda build --output-folder=conda/pkg conda/recipe && - conda install tvm -c ./conda/pkg + run: | + conda install -c conda-forge llvmdev cmake ninja zlib + - name: Build TVM wheel + shell: bash -l {0} + run: | + pip install scikit-build-core + export CMAKE_ARGS="-DUSE_LLVM=ON -DBUILD_TESTING=OFF" + pip wheel --no-deps -w dist . -v + - name: Install TVM from wheel + shell: bash -l {0} + run: | + pip install dist/*.whl # - name: Build iOS RPC # run: | # IOS_VERSION="14.0" @@ -98,11 +107,20 @@ jobs: submodules: 'recursive' - name: Set up environment uses: ./.github/actions/setup - - name: Conda Build + - name: Install LLVM dependencies shell: cmd /C call {0} - run: >- - conda build --output-folder=conda/pkg conda/recipe && - conda install tvm -c ./conda/pkg + run: | + conda install -c conda-forge llvmdev cmake ninja zlib libxml2-devel + - name: Install TVM + shell: cmd /C call {0} + run: | + pip install scikit-build-core + set CMAKE_ARGS=-DUSE_LLVM=ON -DBUILD_TESTING=OFF + pip install --no-deps . -v + - name: Install test dependencies + shell: cmd /C call {0} + run: | + pip install psutil cloudpickle ml_dtypes numpy packaging scipy tornado typing_extensions - name: Test shell: cmd /C call {0} run: >- diff --git a/.gitmodules b/.gitmodules index a481df243882..0513981e5886 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,15 +1,9 @@ [submodule "dmlc-core"] path = 3rdparty/dmlc-core url = https://github.com/dmlc/dmlc-core.git -[submodule "dlpack"] - path = 3rdparty/dlpack - url = https://github.com/dmlc/dlpack.git [submodule "3rdparty/rang"] path = 3rdparty/rang url = https://github.com/agauniyal/rang.git -[submodule "3rdparty/libbacktrace"] - path = 3rdparty/libbacktrace - url = https://github.com/tlc-pack/libbacktrace.git [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git @@ -28,6 +22,6 @@ [submodule "3rdparty/zlib"] path = 3rdparty/zlib url = https://github.com/madler/zlib.git -[submodule "ffi/3rdparty/dlpack"] - path = ffi/3rdparty/dlpack - url = https://github.com/dmlc/dlpack.git +[submodule "3rdparty/tvm-ffi"] + path = 3rdparty/tvm-ffi + url = https://github.com/apache/tvm-ffi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 982b78180f2a..13a06a6cb3db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,27 +32,26 @@ # default_language_version: - python: python3.6 + python: python3.9 fail_fast: True -default_stages: [push] +default_stages: [pre-push, pre-commit] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v6.0.0 hooks: - id: check-added-large-files - id: check-merge-conflict - id: check-yaml - id: end-of-file-fixer - stages: [push] - id: trailing-whitespace - stages: [push] + - repo: local hooks: - id: run-black name: Running Black... - entry: docker/lint.sh python_format + entry: docker/lint.sh python_format -i language: system - always_run: true + files: \.py$ pass_filenames: false - id: run-file-checks name: Checking File Types.... @@ -62,25 +61,25 @@ repos: pass_filenames: false - id: run-headers-check name: Checking ASF License Headers ... - entry: docker/lint.sh asf + entry: docker/lint.sh asf -i language: system always_run: true pass_filenames: false - - id: run-headers-check + - id: run-cpplint name: Linting the C++ code ... entry: docker/lint.sh cpplint language: system - always_run: true + files: \.(c|cc|cpp|h|hpp)$ pass_filenames: false - id: run-clang-format name: Checking Clang format ... - entry: docker/lint.sh clang_format + entry: docker/lint.sh clang_format -i language: system - always_run: true + files: \.(c|cc|cpp|h|hpp)$ pass_filenames: false - id: run-mypy name: Type Checking with MyPY ... entry: docker/lint.sh mypy language: system - always_run: true + files: \.py$ pass_filenames: false diff --git a/3rdparty/cutlass b/3rdparty/cutlass index ad7b2f5e84fc..b2dd65dc864e 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e +Subproject commit b2dd65dc864e09688245b316ac46c4a6cd07e15c diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index c633ae800283..72b9883c986a 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit c633ae800283627a62e69e064d05a28ff13d380a +Subproject commit 72b9883c986a2ff427ca61ac0b14ad59be1dc862 diff --git a/3rdparty/dlpack b/3rdparty/dlpack deleted file mode 160000 index 3ea601bb4130..000000000000 --- a/3rdparty/dlpack +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3ea601bb413074c49a77c4ce3218bc08f8c4703c diff --git a/3rdparty/libbacktrace b/3rdparty/libbacktrace deleted file mode 160000 index 08f7c7e69f8e..000000000000 --- a/3rdparty/libbacktrace +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 08f7c7e69f8ea61a0c4151359bc8023be8e9217b diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi new file mode 160000 index 000000000000..ae346ec92a3c --- /dev/null +++ b/3rdparty/tvm-ffi @@ -0,0 +1 @@ +Subproject commit ae346ec92a3c386f1376064ae086aae72947c329 diff --git a/CMakeLists.txt b/CMakeLists.txt index d8d23f90353d..f620c1fe5493 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,6 @@ tvm_option(USE_ALTERNATIVE_LINKER "Use 'mold' or 'lld' if found when invoking co tvm_option(USE_CCACHE "Use ccache if found when invoking compiler" AUTO) # 3rdparty libraries -tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") tvm_option(DMLC_PATH "Path to DMLC" "3rdparty/dmlc-core/include") tvm_option(RANG_PATH "Path to RANG" "3rdparty/rang/include") tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") @@ -122,10 +121,12 @@ tvm_option(USE_MSC "Enable Multi-System Compiler" OFF) tvm_option(USE_MRVL "Build with MRVL TVM support" OFF) tvm_option(USE_NVSHMEM "Build with NVSHMEM support" OFF) +# Python package options +tvm_option(TVM_BUILD_PYTHON_MODULE "Build Python module with scikit-build-core" OFF) + # include directories include_directories(${CMAKE_INCLUDE_PATH}) include_directories("include") -include_directories(SYSTEM ${DLPACK_PATH}) include_directories(SYSTEM ${DMLC_PATH}) include_directories(SYSTEM ${RANG_PATH}) include_directories(SYSTEM ${COMPILER_RT_PATH}) @@ -306,6 +307,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/analysis/*.cc src/relax/transform/*.cc src/relax/backend/vm/*.cc + src/relax/backend/adreno/*.cc src/relax/backend/task_extraction.cc src/relax/backend/pattern_registry.cc src/relax/utils.cc @@ -482,6 +484,15 @@ include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) include(cmake/modules/contrib/Mrvl.cmake) +tvm_option(USE_Z3 "Build with Z3 SMT solver support" OFF) + +if (USE_Z3) + find_package(Z3 REQUIRED) + list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc) +else() + list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc) +endif() + set(LIBINFO_FILE ${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc) add_lib_info(${LIBINFO_FILE}) list(REMOVE_ITEM COMPILER_SRCS ${LIBINFO_FILE}) @@ -501,7 +512,7 @@ if(NOT BUILD_DUMMY_LIBTVM) $ ${TVM_RUNTIME_EXT_OBJS} ) - + target_link_libraries(tvm PUBLIC tvm_ffi_shared) else() # dummy version of libtvm that can be used by downstream to specify dependencies # the real runner still need a full version of libtvm @@ -510,6 +521,7 @@ else() $ ${TVM_RUNTIME_EXT_OBJS} ) + target_link_libraries(tvm PUBLIC tvm_ffi_shared) endif() target_include_directories(tvm PUBLIC "$") @@ -519,7 +531,6 @@ if(BUILD_STATIC_RUNTIME) add_library(tvm_runtime STATIC $ $ - $ ${TVM_RUNTIME_EXT_OBJS} ) set(NOTICE_MULTILINE @@ -528,6 +539,7 @@ if(BUILD_STATIC_RUNTIME) string(CONCAT NOTICE ${NOTICE_MULTILINE}) add_custom_command(TARGET tvm_runtime POST_BUILD COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE}) + target_link_libraries(tvm_runtime PUBLIC tvm_ffi_static) else() add_library(tvm_runtime SHARED $ @@ -535,9 +547,46 @@ else() ${TVM_RUNTIME_EXT_OBJS} ) set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") + target_link_libraries(tvm_runtime PUBLIC tvm_ffi_shared) +endif() + +if(USE_Z3) + target_include_directories(tvm_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS}) + target_include_directories(tvm_runtime_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS}) + target_include_directories(tvm_libinfo_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS}) + target_link_libraries(tvm PRIVATE z3::libz3) + + if (APPLE) + # `libz3.dylib` from z3-solver on pypi have a "wrong" name `libz3.dylib`, + # so it won't be searched in rpath. We patch it to `@rpath/libz3.dylib` here. + # `POST_BUILD` command needs to be in same cmake file where the target's created. + add_custom_command(TARGET tvm POST_BUILD + COMMAND install_name_tool -change "libz3.dylib" "@rpath/libz3.dylib" $ + COMMENT "Patching libz3 reference to use @rpath" + ) + else() + find_program(PATCHELF_EXECUTABLE patchelf) + if (PATCHELF_EXECUTABLE) + execute_process( + COMMAND ${PATCHELF_EXECUTABLE} --print-soname ${Z3_LIBRARY} + OUTPUT_VARIABLE Z3_SONAME + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE Z3_SONAME_RESULT + ) + if(NOT Z3_SONAME_RESULT EQUAL "0") + message(FATAL_ERROR "Failed to get Z3 soname using patchelf") + endif() + message("-- Z3 SONAME: ${Z3_SONAME}") + add_custom_command(TARGET tvm POST_BUILD + COMMAND ${PATCHELF_EXECUTABLE} --replace-needed ${Z3_SONAME} libz3.so $ + COMMENT "Patching libz3 reference to use soname ${Z3_SONAME}" + ) + else() + message("patchelf not found, skip.") + endif() + endif() endif() - target_include_directories(tvm_runtime PUBLIC "$") set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") @@ -564,8 +613,7 @@ if(USE_IOS_RPC) add_subdirectory("apps/ios_rpc") endif() -add_subdirectory(ffi) - +add_subdirectory(3rdparty/tvm-ffi) if(TVM_DEBUG_WITH_ABI_CHANGE) message(STATUS "Building with debug code that may cause ABI changes...") @@ -602,10 +650,6 @@ endif() target_link_libraries(tvm PRIVATE ${TVM_RUNTIME_LINKER_LIBS}) target_link_libraries(tvm_runtime PRIVATE ${TVM_RUNTIME_LINKER_LIBS}) -target_link_libraries(tvm PUBLIC tvm_ffi_objs) -target_link_libraries(tvm_runtime PUBLIC tvm_ffi_objs) - - if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTEST}) include(FetchContent) FetchContent_Declare(googletest SOURCE_DIR "${USE_HEXAGON_GTEST}") @@ -633,6 +677,7 @@ if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") add_library(tvm_allvisible SHARED $ $ $) target_include_directories(tvm_allvisible PUBLIC "$") target_link_libraries(tvm_allvisible PRIVATE "$") + target_link_libraries(tvm_allvisible PUBLIC tvm_ffi_shared) set(TVM_TEST_LIBRARY_NAME tvm_allvisible) set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL") @@ -643,7 +688,6 @@ if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") target_link_libraries(tvm_runtime PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS}) target_compile_definitions(tvm_allvisible PUBLIC $) target_compile_definitions(tvm_allvisible PRIVATE $) - target_link_libraries(tvm_allvisible PUBLIC tvm_ffi_objs) endif() # Create the `cpptest` target if we can find GTest. If not, we create dummy @@ -687,19 +731,6 @@ endif() # Custom targets add_custom_target(runtime DEPENDS tvm_runtime) -# By default add cython to all build -find_package(Python) -if(NOT DEFINED ENV{CONDA_BUILD}) - message(STATUS ${CMAKE_CURRENT_BINARY_DIR}) - add_custom_target( - tvm_cython ALL - ${Python_EXECUTABLE} setup.py build_ext --inplace - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/python - ) - add_dependencies(tvm_cython tvm) - message("Add Cython build into the default build step") -endif() - # Installation rules install(TARGETS tvm DESTINATION lib${LIB_SUFFIX}) install(TARGETS tvm_runtime DESTINATION lib${LIB_SUFFIX}) @@ -713,11 +744,6 @@ if (INSTALL_DEV) FILES_MATCHING PATTERN "*.h" ) - install( - DIRECTORY "3rdparty/dlpack/include/" DESTINATION "include" - FILES_MATCHING - PATTERN "*.h" - ) install( DIRECTORY "3rdparty/dmlc-core/include/" DESTINATION "include" FILES_MATCHING @@ -779,8 +805,8 @@ if(TVM_IS_DEBUG_BUILD) endif() endif() -add_dsymutil(tvm) -add_dsymutil(tvm_runtime) +tvm_ffi_add_apple_dsymutil(tvm) +tvm_ffi_add_apple_dsymutil(tvm_runtime) if(BUILD_FOR_HEXAGON) # Wrap pthread_create to allow setting custom stack size. @@ -789,7 +815,8 @@ if(BUILD_FOR_HEXAGON) # Link tvm_runtime into the RPC skel library. Make sure it's built # as a part of the "runtime" target. if(USE_HEXAGON_RPC) - target_link_libraries(hexagon_rpc_skel -Wl,--whole-archive tvm_runtime -Wl,--no-whole-archive) + target_link_libraries( + hexagon_rpc_skel -Wl,--whole-archive tvm_runtime tvm_ffi_static -Wl,--no-whole-archive) add_dependencies(runtime hexagon_rpc_skel) endif() endif() @@ -839,3 +866,87 @@ if(USE_ROCM AND USE_RCCL) target_link_libraries(tvm PRIVATE rccl) target_link_libraries(tvm_runtime PRIVATE rccl) endif() + +# Python package installation configuration +# This section ensures that all necessary files are installed for the Python wheel +if(TVM_BUILD_PYTHON_MODULE) + message(STATUS "Configuring Python package installation") + + # Set RPATH for tvm and tvm_runtime to find other libraries relatively + if(APPLE) + # macOS uses @loader_path + set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path") + set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path") + elseif(LINUX) + # Linux uses $ORIGIN + set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN") + set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN") + endif() + + # Install compiled shared libraries + install(TARGETS tvm DESTINATION ".") + install(TARGETS tvm_runtime DESTINATION ".") + + # Install third-party compiled dependencies + if(TARGET fpA_intB_gemm) + install(TARGETS fpA_intB_gemm DESTINATION ".") + endif() + if(TARGET flash_attn) + install(TARGETS flash_attn DESTINATION ".") + endif() + + # Install minimal header files needed by Python extensions + install( + DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/include/tvm/runtime/" + DESTINATION "include/tvm/runtime/" + FILES_MATCHING + PATTERN "*.h" + ) + + # Install minimal CMake configuration + install( + DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/cmake/utils/" + DESTINATION "cmake/utils/" + FILES_MATCHING + PATTERN "*.cmake" + ) + + # Install CUTLASS headers only if available + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/cutlass/include") + install( + DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/cutlass/include/" + DESTINATION "3rdparty/cutlass/include/" + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.hpp" + ) + endif() + + # Install minimal source files + install( + DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/" + DESTINATION "src/runtime/" + FILES_MATCHING + PATTERN "*.cc" + PATTERN "*.h" + ) + + # Install web package + install(DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/web/" DESTINATION "web/") + + # Install licenses (required for distribution) + install( + DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/licenses/" + DESTINATION "licenses/" + ) + + # Install essential metadata files + install(FILES + "${CMAKE_CURRENT_SOURCE_DIR}/README.md" + "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE" + "${CMAKE_CURRENT_SOURCE_DIR}/NOTICE" + DESTINATION "." + ) + + message(STATUS "Python package installation configured") +endif() diff --git a/Makefile b/Makefile index ecc891ab7630..8ebc28412313 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ TVM_BUILD_PATH := $(abspath $(TVM_BUILD_PATH)) # Allow environment variables for 3rd-party libraries, default to # packaged version. DMLC_CORE_PATH ?= $(ROOTDIR)/3rdparty/dmlc-core -DLPACK_PATH ?= $(ROOTDIR)/3rdparty/dlpack +DLPACK_PATH ?= $(ROOTDIR)/3rdparty/tvm-ffi/3rdparty/dlpack all: $(addsuffix /all,$(TVM_BUILD_PATH)) @@ -107,16 +107,6 @@ mypy: cppdoc: doxygen docs/Doxyfile - -# Cython build -cython cython3: - cd python; python3 setup.py build_ext --inplace - -cyclean: - rm -rf python/tvm/*/*/*.so python/tvm/*/*/*.dylib python/tvm/*/*/*.cpp - - - # EMCC; Web related scripts web: $(MAKE) -C $(ROOTDIR)/web diff --git a/README.md b/README.md index 85e924e4ac80..fb9e9bc4a0d1 100644 --- a/README.md +++ b/README.md @@ -15,16 +15,18 @@ - Open Deep Learning Compiler Stack + Open Machine Learning Compiler Framework ============================================== [Documentation](https://tvm.apache.org/docs) | [Contributors](CONTRIBUTORS.md) | [Community](https://tvm.apache.org/community) | [Release Notes](NEWS.md) -Apache TVM is a compiler stack for deep learning systems. It is designed to close the gap between the -productivity-focused deep learning frameworks and the performance- and efficiency-focused hardware backends. -TVM works with deep learning frameworks to provide end-to-end compilation for different backends. +Apache TVM is an open machine learning compilation framework, +following the following principles: + +- Python-first development that enables quick customization of machine learning compiler pipelines. +- Universal deployment to bring models into minimum deployable modules. License ------- diff --git a/apps/android_rpc/app/src/main/jni/Android.mk b/apps/android_rpc/app/src/main/jni/Android.mk index 692a3390131d..d482f9429559 100644 --- a/apps/android_rpc/app/src/main/jni/Android.mk +++ b/apps/android_rpc/app/src/main/jni/Android.mk @@ -37,8 +37,8 @@ LOCAL_SRC_FILES := org_apache_tvm_native_c_api.cc LOCAL_LDFLAGS := -L$(SYSROOT)/usr/lib/ -llog LOCAL_C_INCLUDES := $(ROOT_PATH)/include \ - $(ROOT_PATH)/ffi/include \ - $(ROOT_PATH)/ffi/3rdparty/dlpack/include \ + $(ROOT_PATH)/3rdparty/tvm-ffi/include \ + $(ROOT_PATH)/3rdparty/tvm-ffi/3rdparty/dlpack/include \ $(ROOT_PATH)/3rdparty/dmlc-core/include \ $(ROOT_PATH)/3rdparty/OpenCL-Headers diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 94fc6422891f..a522f0e9968a 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -34,25 +34,24 @@ #define TVM_LOG_CUSTOMIZE 1 #define TVM_FFI_USE_LIBBACKTRACE 0 -#include "../ffi/src/ffi/container.cc" -#include "../ffi/src/ffi/dtype.cc" -#include "../ffi/src/ffi/error.cc" -#include "../ffi/src/ffi/extra/library_module.cc" -#include "../ffi/src/ffi/extra/library_module_dynamic_lib.cc" -#include "../ffi/src/ffi/extra/library_module_system_lib.cc" -#include "../ffi/src/ffi/extra/module.cc" -#include "../ffi/src/ffi/extra/testing.cc" -#include "../ffi/src/ffi/function.cc" -#include "../ffi/src/ffi/ndarray.cc" -#include "../ffi/src/ffi/object.cc" -#include "../ffi/src/ffi/traceback.cc" +#include "../3rdparty/tvm-ffi/src/ffi/backtrace.cc" +#include "../3rdparty/tvm-ffi/src/ffi/container.cc" +#include "../3rdparty/tvm-ffi/src/ffi/dtype.cc" +#include "../3rdparty/tvm-ffi/src/ffi/error.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/library_module_dynamic_lib.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/module.cc" +#include "../3rdparty/tvm-ffi/src/ffi/extra/testing.cc" +#include "../3rdparty/tvm-ffi/src/ffi/function.cc" +#include "../3rdparty/tvm-ffi/src/ffi/object.cc" +#include "../3rdparty/tvm-ffi/src/ffi/tensor.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/device_api.cc" #include "../src/runtime/file_utils.cc" #include "../src/runtime/logging.cc" #include "../src/runtime/memory/memory_manager.cc" #include "../src/runtime/minrpc/minrpc_logger.cc" -#include "../src/runtime/ndarray.cc" #include "../src/runtime/profiling.cc" #include "../src/runtime/registry.cc" #include "../src/runtime/rpc/rpc_channel.cc" @@ -63,6 +62,7 @@ #include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" +#include "../src/runtime/tensor.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" #include "../src/runtime/workspace_pool.cc" diff --git a/apps/android_rpc/tests/android_rpc_test.py b/apps/android_rpc/tests/android_rpc_test.py index b9c6995729d0..b1548df3e177 100644 --- a/apps/android_rpc/tests/android_rpc_test.py +++ b/apps/android_rpc/tests/android_rpc_test.py @@ -72,8 +72,8 @@ def test_rpc_module(): dev = remote.cl(0) remote.upload(path_dso_cl) f1 = remote.load_module("dev_lib_cl.so") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev) time_f = f1.time_evaluator(f1.entry_name, dev, number=10) cost = time_f(a, b).mean print("%g secs/op\n" % cost) diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt index e16da0ee4929..6d58308c9c47 100644 --- a/apps/cpp_rpc/CMakeLists.txt +++ b/apps/cpp_rpc/CMakeLists.txt @@ -45,10 +45,11 @@ endif() target_include_directories( tvm_rpc PUBLIC "../../include" - PUBLIC DLPACK_PATH PUBLIC DMLC_PATH ) +target_link_libraries(tvm_rpc PUBLIC tvm_ffi_header) + if (BUILD_FOR_ANDROID AND USE_HEXAGON) get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" DSPRPC_LIB DSPRPC_LIB_DIRS @@ -62,9 +63,9 @@ if (BUILD_FOR_ANDROID AND USE_HEXAGON) endif() if(BUILD_STATIC_RUNTIME) - list(APPEND TVM_RPC_LINKER_LIBS -Wl,--whole-archive tvm_runtime -Wl,--no-whole-archive) + list(APPEND TVM_RPC_LINKER_LIBS -Wl,--whole-archive tvm_runtime tvm_ffi_static -Wl,--no-whole-archive) else() list(APPEND TVM_RPC_LINKER_LIBS tvm_runtime) endif() -target_link_libraries(tvm_rpc ${TVM_RPC_LINKER_LIBS}) +target_link_libraries(tvm_rpc PRIVATE ${TVM_RPC_LINKER_LIBS}) diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 797692d0f503..fd8fc476bbec 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -399,9 +399,9 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("rpc.ServerCreate", RPCServerCreate); -}); +} } // namespace runtime } // namespace tvm diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index fa2c3d8e3300..3bf6ce23cf8d 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -163,7 +163,7 @@ tvm::runtime::Module load_module(const std::string& file_name) { return tvm::runtime::Module(); } -std::ostream& operator<<(std::ostream& os, const tvm::Array& strings) { +std::ostream& operator<<(std::ostream& os, const tvm::ffi::Array& strings) { os << '['; for (int i = 0, e = strings.size(); i != e; ++i) { if (i != 0) os << ','; @@ -191,7 +191,7 @@ tvm::runtime::Module create_graph_executor(const std::string& graph_json, tvm::runtime::Module create_aot_executor(tvm::runtime::Module factory_module, tvm::Device device) { tvm::ffi::Function list_modules = get_module_func(factory_module, "list_module_names"); - tvm::Array module_names = list_modules(); + tvm::ffi::Array module_names = list_modules(); if (module_names.size() != 1) { LOG(WARNING) << __func__ << ": expecting single module, got: " << module_names << ", using " << module_names[0]; diff --git a/apps/hexagon_launcher/launcher_core.h b/apps/hexagon_launcher/launcher_core.h index 5e62774607ba..be5a4ee94da9 100644 --- a/apps/hexagon_launcher/launcher_core.h +++ b/apps/hexagon_launcher/launcher_core.h @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include diff --git a/apps/hexagon_launcher/launcher_hexagon.cc b/apps/hexagon_launcher/launcher_hexagon.cc index bd1df4aa62ad..64b795d8f45c 100644 --- a/apps/hexagon_launcher/launcher_hexagon.cc +++ b/apps/hexagon_launcher/launcher_hexagon.cc @@ -137,7 +137,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_set_input)(remote_handle64 handle, int inpu }; DLManagedTensor managed{tensor, /*manager_ctx*/ nullptr, /*deleter*/ nullptr}; - auto input = tvm::runtime::NDArray::FromDLPack(&managed); + auto input = tvm::runtime::Tensor::FromDLPack(&managed); tvm::ffi::Function set_input = get_module_func(TheModel->model_executor, "set_input"); set_input(input_idx, input); @@ -172,17 +172,17 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out } tvm::ffi::Function get_output = get_module_func(TheModel->model_executor, "get_output"); - tvm::runtime::NDArray output = get_output(output_idx); + tvm::runtime::Tensor output = get_output(output_idx); std::vector shape_vec{output->shape, output->shape + output->ndim}; - auto* container = new tvm::runtime::NDArray::Container( - static_cast(output_value), shape_vec, output->dtype, Model::external()); + auto* container = new tvm::runtime::Tensor::Container(static_cast(output_value), shape_vec, + output->dtype, Model::external()); container->SetDeleter([](tvm::Object* container) { - delete static_cast(container); + delete static_cast(container); }); - tvm::runtime::NDArray host_output(tvm::runtime::GetObjectPtr(container)); + tvm::runtime::Tensor host_output(tvm::runtime::GetObjectPtr(container)); if (meta_size != 0) { auto* meta = reinterpret_cast(output_meta); diff --git a/apps/ios_rpc/tests/ios_rpc_test.py b/apps/ios_rpc/tests/ios_rpc_test.py index 0e563ee1b688..df850812e527 100644 --- a/apps/ios_rpc/tests/ios_rpc_test.py +++ b/apps/ios_rpc/tests/ios_rpc_test.py @@ -39,7 +39,7 @@ # override metal compiler to compile to iphone -@tvm.register_func("tvm_callback_metal_compile") +@tvm.register_global_func("tvm_callback_metal_compile") def compile_metal(src, target): return xcode.compile_metal(src, sdk=sdk) @@ -72,8 +72,8 @@ def test_rpc_module(host, port, key, mode): dev = remote.metal(0) f1 = remote.load_module("dev_lib.dylib") a_np = np.random.uniform(size=1024).astype(A.dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev) time_f = f1.time_evaluator(f1.entry_name, dev, number=10) cost = time_f(a, b).mean print("Metal: %g secs/op" % cost) diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 09ee55390959..5dfff0cd86b4 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -33,7 +33,7 @@ #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 // internal TVM header to achieve Library class -#include <../../../ffi/src/ffi/extra/library_module.h> +#include <../../../3rdparty/tvm-ffi/src/ffi/extra/library_module.h> #include #endif @@ -52,7 +52,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.rpc.server.workpath", @@ -85,7 +85,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s *rv = Module::LoadFromFile(name, fmt); LOG(INFO) << "Load module from " << name << " ..."; }); -}); +} #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 @@ -112,15 +112,15 @@ void Init(const std::string& name) { }; // Add UnsignedDSOLoader plugin in global registry -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("ffi.Module.load_from_file.dylib_custom", [](ffi::PackedArgs args, ffi::Any* rv) { - auto n = make_object(); + auto n = ffi::make_object(); n->Init(args[0]); *rv = tvm::ffi::CreateLibraryModule(n); }); -}); +} #endif diff --git a/ci/jenkins/data.py b/ci/jenkins/data.py index e52aaf32a4b2..3577a0ad008c 100644 --- a/ci/jenkins/data.py +++ b/ci/jenkins/data.py @@ -30,7 +30,12 @@ # runtime files "tvm_runtime": ["build/libtvm_runtime.so", "build/config.cmake"], # compiler files - "tvm_lib": ["build/libtvm.so", "build/libtvm_runtime.so", "build/config.cmake"], + "tvm_lib": [ + "build/libtvm.so", + "build/libtvm_runtime.so", + "build/lib/libtvm_ffi.so", + "build/config.cmake", + ], # gpu related compiler files "tvm_lib_gpu_extra": [ "build/3rdparty/libflash_attn/src/libflash_attn.so", diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy index 9e4afc8f1393..b58ec7022107 100644 --- a/ci/jenkins/generated/arm_jenkinsfile.groovy +++ b/ci/jenkins/generated/arm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-06-03T18:16:35.851073 +// Generated at 2025-08-24T16:41:22.350930 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -516,7 +516,7 @@ def run_build(node_type) { cmake_build(ci_arm, 'build') make_cpp_tests(ci_arm, 'build') sh( - script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/arm --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", + script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/arm --items build/libtvm.so build/libtvm_runtime.so build/lib/libtvm_ffi.so build/config.cmake build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", label: 'Upload artifacts to S3', ) }) @@ -532,6 +532,9 @@ def build() { stage('Build') { try { run_build('ARM-GRAVITON3-SPOT') + } catch (hudson.AbortException abortEx) { + echo "Received normal AbortException, exit now. Details:" + abortEx.toString() + throw abortEx } catch (Throwable ex) { echo 'Exception during SPOT run ' + ex.toString() if (is_last_build()) { diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index daadc16c7631..53c74d111535 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-06-03T18:16:35.861918 +// Generated at 2025-08-24T16:41:22.367054 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -516,7 +516,7 @@ def run_build(node_type) { cmake_build(ci_cpu, 'build') make_cpp_tests(ci_cpu, 'build') sh( - script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/cpu --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/libtvm_allvisible.so build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", + script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/cpu --items build/libtvm.so build/libtvm_runtime.so build/lib/libtvm_ffi.so build/config.cmake build/libtvm_allvisible.so build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", label: 'Upload artifacts to S3', ) }) @@ -532,6 +532,9 @@ def build() { stage('Build') { try { run_build('CPU-SPOT') + } catch (hudson.AbortException abortEx) { + echo "Received normal AbortException, exit now. Details:" + abortEx.toString() + throw abortEx } catch (Throwable ex) { echo 'Exception during SPOT run ' + ex.toString() if (is_last_build()) { diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index 1fc4348c6f1c..e9ade66832b1 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-06-03T18:16:35.885417 +// Generated at 2025-08-24T16:41:22.312666 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -512,7 +512,7 @@ def run_build(node_type) { sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" cmake_build("${ci_gpu} --no-gpu", 'build') sh( - script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/gpu --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/libtvm_allvisible.so build/3rdparty/libflash_attn/src/libflash_attn.so build/3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/libfpA_intB_gemm.so", + script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/gpu --items build/libtvm.so build/libtvm_runtime.so build/lib/libtvm_ffi.so build/config.cmake build/libtvm_allvisible.so build/3rdparty/libflash_attn/src/libflash_attn.so build/3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/libfpA_intB_gemm.so", label: 'Upload artifacts to S3', ) @@ -522,7 +522,7 @@ def run_build(node_type) { sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu_other.sh build" cmake_build("${ci_gpu} --no-gpu", 'build') sh( - script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/gpu2 --items build/libtvm.so build/libtvm_runtime.so build/config.cmake", + script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/gpu2 --items build/libtvm.so build/libtvm_runtime.so build/lib/libtvm_ffi.so build/config.cmake", label: 'Upload artifacts to S3', ) }) @@ -538,6 +538,9 @@ def build() { stage('Build') { try { run_build('CPU-SPOT') + } catch (hudson.AbortException abortEx) { + echo "Received normal AbortException, exit now. Details:" + abortEx.toString() + throw abortEx } catch (Throwable ex) { echo 'Exception during SPOT run ' + ex.toString() if (is_last_build()) { diff --git a/ci/jenkins/generated/hexagon_jenkinsfile.groovy b/ci/jenkins/generated/hexagon_jenkinsfile.groovy index 173506fcce7e..004798101113 100644 --- a/ci/jenkins/generated/hexagon_jenkinsfile.groovy +++ b/ci/jenkins/generated/hexagon_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-06-03T18:16:35.839798 +// Generated at 2025-08-24T16:41:22.257116 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -520,7 +520,7 @@ def run_build(node_type) { label: 'Build Hexagon API', ) sh( - script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/hexagon --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/cpptest build/build.ninja build/CMakeFiles/rules.ninja build/hexagon_api_output", + script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/hexagon --items build/libtvm.so build/libtvm_runtime.so build/lib/libtvm_ffi.so build/config.cmake build/cpptest build/build.ninja build/CMakeFiles/rules.ninja build/hexagon_api_output", label: 'Upload artifacts to S3', ) }) @@ -536,6 +536,9 @@ def build() { stage('Build') { try { run_build('CPU-SPOT') + } catch (hudson.AbortException abortEx) { + echo "Received normal AbortException, exit now. Details:" + abortEx.toString() + throw abortEx } catch (Throwable ex) { echo 'Exception during SPOT run ' + ex.toString() if (is_last_build()) { diff --git a/ci/jenkins/generated/i386_jenkinsfile.groovy b/ci/jenkins/generated/i386_jenkinsfile.groovy index 3ef2b532bae1..e54ec2c60686 100644 --- a/ci/jenkins/generated/i386_jenkinsfile.groovy +++ b/ci/jenkins/generated/i386_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-06-03T18:16:35.814567 +// Generated at 2025-08-24T16:41:22.332874 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -516,7 +516,7 @@ def run_build(node_type) { cmake_build(ci_i386, 'build') make_cpp_tests(ci_i386, 'build') sh( - script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/i386 --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", + script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/i386 --items build/libtvm.so build/libtvm_runtime.so build/lib/libtvm_ffi.so build/config.cmake build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", label: 'Upload artifacts to S3', ) }) @@ -532,6 +532,9 @@ def build() { stage('Build') { try { run_build('CPU-SPOT') + } catch (hudson.AbortException abortEx) { + echo "Received normal AbortException, exit now. Details:" + abortEx.toString() + throw abortEx } catch (Throwable ex) { echo 'Exception during SPOT run ' + ex.toString() if (is_last_build()) { diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy index d214fb3710f3..4a6ccac25f66 100644 --- a/ci/jenkins/generated/wasm_jenkinsfile.groovy +++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-06-03T18:16:35.874501 +// Generated at 2025-08-24T11:52:44.735820 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -534,6 +534,9 @@ def build() { stage('Build') { try { run_build('CPU-SPOT') + } catch (hudson.AbortException abortEx) { + echo "Received normal AbortException, exit now. Details:" + abortEx.toString() + throw abortEx } catch (Throwable ex) { echo 'Exception during SPOT run ' + ex.toString() if (is_last_build()) { diff --git a/ci/jenkins/templates/utils/macros.j2 b/ci/jenkins/templates/utils/macros.j2 index 662d9aef111c..c96432840dec 100644 --- a/ci/jenkins/templates/utils/macros.j2 +++ b/ci/jenkins/templates/utils/macros.j2 @@ -95,6 +95,9 @@ def build() { stage('Build') { try { run_build('{{ node }}-SPOT') + } catch (hudson.AbortException abortEx) { + echo "Received normal AbortException, exit now. Details:" + abortEx.toString() + throw abortEx } catch (Throwable ex) { echo 'Exception during SPOT run ' + ex.toString() if (is_last_build()) { diff --git a/ci/scripts/github/update_branch.py b/ci/scripts/github/update_branch.py index e49d9e47ab79..b3fa01413793 100755 --- a/ci/scripts/github/update_branch.py +++ b/ci/scripts/github/update_branch.py @@ -165,8 +165,6 @@ def update_branch(user: str, repo: str, sha: str, branch_name: str) -> None: remote = git(["config", "--get", f"remote.{args.remote}.url"]) user, repo = parse_remote(remote) - # TODO: Remove this before landing - user, repo = ("apache", "tvm") if args.testonly_json: r = json.loads(args.testonly_json) diff --git a/ci/scripts/jenkins/check_pr.py b/ci/scripts/jenkins/check_pr.py index 9af5ec5580a3..8be5c0ee46a8 100755 --- a/ci/scripts/jenkins/check_pr.py +++ b/ci/scripts/jenkins/check_pr.py @@ -69,19 +69,17 @@ def trailing_period(s: str): title_checks = [ Check(check=non_empty, error_fn=lambda d: "PR must have a title but title was empty"), Check(check=trailing_period, error_fn=lambda d: "PR must not end in a tailing '.'"), - # TODO(driazati): enable this check once https://github.com/apache/tvm/issues/12637 is done - # Check( - # check=usernames, - # error_fn=lambda d: f"PR title must not tag anyone but found these usernames: {d}", - # ), + Check( + check=usernames, + error_fn=lambda d: f"PR title must not tag anyone but found these usernames: {d}", + ), ] body_checks = [ Check(check=non_empty, error_fn=lambda d: "PR must have a body but body was empty"), - # TODO(driazati): enable this check once https://github.com/apache/tvm/issues/12637 is done - # Check( - # check=usernames, - # error_fn=lambda d: f"PR body must not tag anyone but found these usernames: {d}", - # ), + Check( + check=usernames, + error_fn=lambda d: f"PR body must not tag anyone but found these usernames: {d}", + ), ] diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 73d789e9fa94..f286d9f7d9fa 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -44,7 +44,6 @@ function(add_lib_info src_file) TVM_INFO_BUILD_DUMMY_LIBTVM="${BUILD_DUMMY_LIBTVM}" TVM_INFO_COMPILER_RT_PATH="${COMPILER_RT_PATH}" TVM_INFO_CUDA_VERSION="${TVM_INFO_CUDA_VERSION}" - TVM_INFO_DLPACK_PATH="${DLPACK_PATH}" TVM_INFO_DMLC_PATH="${DMLC_PATH}" TVM_INFO_GIT_COMMIT_HASH="${TVM_GIT_COMMIT_HASH}" TVM_INFO_GIT_COMMIT_TIME="${TVM_GIT_COMMIT_TIME}" @@ -108,6 +107,7 @@ function(add_lib_info src_file) TVM_INFO_USE_ROCM="${USE_ROCM}" TVM_INFO_USE_RCCL="${USE_RCCL}" TVM_INFO_USE_RPC="${USE_RPC}" + TVM_INFO_TVM_BUILD_PYTHON_MODULE="${TVM_BUILD_PYTHON_MODULE}" TVM_INFO_USE_RTTI="${USE_RTTI}" TVM_INFO_USE_RUST_EXT="${USE_RUST_EXT}" TVM_INFO_USE_SORT="${USE_SORT}" diff --git a/cmake/utils/FindCUDA.cmake b/cmake/utils/FindCUDA.cmake index 2036c7c32994..c4c18eef0f80 100644 --- a/cmake/utils/FindCUDA.cmake +++ b/cmake/utils/FindCUDA.cmake @@ -101,7 +101,7 @@ macro(find_cuda use_cuda use_cudnn) PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu NO_DEFAULT_PATH) find_library(CUDA_NVTX_LIBRARY - NAMES nvToolsExt nvTools nvtoolsext nvtools nvtx NVTX + NAMES nvToolsExt nvTools nvtoolsext nvtools nvtx NVTX nvtx3interop PATHS "${CUDA_CUDART_LIBRARY_DIR}" "${CUDA_TOOLKIT_ROOT_DIR}" ENV LD_LIBRARY_PATH PATH_SUFFIXES "lib64" "common/lib64" "common/lib" "lib" DOC "Location of the CUDA Toolkit Extension (NVTX) library" diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index 2a243b06c85d..09f4dcca7fd8 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -219,6 +219,10 @@ macro(find_llvm use_llvm) # If the library file ends in .lib try to also search the llvm_libdir message(STATUS "LLVM linker flag under LLVM libdir: ${__llvm_libdir}/${__flag}") list(APPEND LLVM_LIBS "${__llvm_libdir}/${__flag}") + elseif((__flag MATCHES ".lib$") AND (EXISTS "${__llvm_libdir}/lib${__flag}")) + # If the library file ends in .lib try to also search the llvm_libdir with lib prefix + message(STATUS "LLVM linker flag under LLVM libdir: ${__llvm_libdir}/lib${__flag}") + list(APPEND LLVM_LIBS "${__llvm_libdir}/lib${__flag}") else() message(STATUS "LLVM linker flag: ${__flag}") list(APPEND LLVM_LIBS "${__flag}") diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 716b2198faeb..28650499ea7c 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -11,12 +11,12 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. -# Build environment that can be used to build tvm. -name: tvm-build +# Build environment for TVM wheel building. +# This environment provides the necessary dependencies for building TVM wheels. +name: tvm-wheel-build # The conda channels to lookup the dependencies channels: @@ -24,15 +24,17 @@ channels: # The packages to install to the environment dependencies: - - conda < 24.9.0 - - conda-build < 24.9.0 - - git + # Core build tools + - cmake >=3.24 + - ninja + - make - llvmdev >=11 - - numpy - - pytest - - cython - - cmake + - python >=3.9 + - pip + - git - bzip2 - - make + - pytest + - numpy - scipy - - pillow + - cython + - libxml2-devel diff --git a/conda/recipe/install_tvm_python.bat b/conda/recipe/install_tvm_python.bat index 07c0465b8443..635897266cf6 100644 --- a/conda/recipe/install_tvm_python.bat +++ b/conda/recipe/install_tvm_python.bat @@ -16,5 +16,5 @@ :: under the License. echo on -cd %SRC_DIR%\python || exit /b -%PYTHON% setup.py install --single-version-externally-managed --record=%SRC_DIR%\record.txt || exit /b +cd %SRC_DIR% || exit /b +%PYTHON% -m pip install . --no-deps --no-build-isolation --record=%SRC_DIR%\record.txt || exit /b diff --git a/conda/recipe/install_tvm_python.sh b/conda/recipe/install_tvm_python.sh index 2c721c64a156..ca9f7767173f 100755 --- a/conda/recipe/install_tvm_python.sh +++ b/conda/recipe/install_tvm_python.sh @@ -19,5 +19,5 @@ set -e set -u -cd ${SRC_DIR}/python -${PYTHON} setup.py install --single-version-externally-managed --record=/tmp/record.txt +cd ${SRC_DIR} +${PYTHON} -m pip install . --no-deps --no-build-isolation --record=/tmp/record.txt diff --git a/conda/recipe/meta.yaml b/conda/recipe/meta.yaml index edf88cbca968..4a5602b4daa9 100644 --- a/conda/recipe/meta.yaml +++ b/conda/recipe/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = '0.22.dev0' %} +{% set version = '0.23.dev0' %} {% set pkg_name = 'tvm' %} {% set cuda_tag = cuda_version | replace('.', '') %} # [cuda] {% set pkg_name = pkg_name + '-cu' + cuda_tag %} # [cuda] diff --git a/configs/host/default.json b/configs/host/default.json deleted file mode 100644 index 2c29445501cc..000000000000 --- a/configs/host/default.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "targets": [ - { - "kind": "llvm" - } - ] -} diff --git a/configs/test/compile_config_test.json b/configs/test/compile_config_test.json deleted file mode 100644 index dcc6dbd27e4e..000000000000 --- a/configs/test/compile_config_test.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "targets": [ - {"kind": "cmsis-nn", "from_device": "1"}, - {"kind": "c", "mcpu": "cortex-m55"} - ], - "executor": { "kind": "aot"}, - "runtime": { "kind": "crt"}, - "pass-config": { "tir.disable_vectorize": "1"} -} diff --git a/configs/test/tune_config_test.json b/configs/test/tune_config_test.json deleted file mode 100644 index 69babc753e87..000000000000 --- a/configs/test/tune_config_test.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "targets": [ - { "kind": "llvm" } - ], - "trials": "2" -} diff --git a/docker/Dockerfile.demo_android b/docker/Dockerfile.demo_android deleted file mode 100644 index bbe8f7d82b01..000000000000 --- a/docker/Dockerfile.demo_android +++ /dev/null @@ -1,82 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Minimum docker image for demo purposes -FROM ubuntu:22.04 - -COPY utils/apt-install-and-clear.sh /usr/local/bin/apt-install-and-clear - -RUN apt-get update --fix-missing - -COPY install/ubuntu_setup_tz.sh /install/ubuntu_setup_tz.sh -RUN bash /install/ubuntu_setup_tz.sh - -COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh -RUN bash /install/ubuntu_install_core.sh - -ENV TVM_VENV /venv/apache-tvm-py3.9 -COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles -COPY install/ubuntu_install_python.sh /install/ubuntu1804_install_python.sh -RUN bash /install/ubuntu1804_install_python.sh 3.9 -ENV PATH ${TVM_VENV}/bin:$PATH -ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. - -COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh -RUN bash /install/ubuntu_install_python_package.sh - -COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh -RUN bash /install/ubuntu_install_tensorflow.sh - -COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh -RUN bash /install/ubuntu_install_java.sh - -COPY install/ubuntu2204_install_llvm.sh /install/ubuntu2204_install_llvm.sh -RUN bash /install/ubuntu2204_install_llvm.sh - -COPY install/ubuntu_install_gradle.sh /install/ubuntu_install_gradle.sh -RUN bash /install/ubuntu_install_gradle.sh - -COPY install/ubuntu_install_androidsdk.sh /install/ubuntu_install_androidsdk.sh -RUN bash /install/ubuntu_install_androidsdk.sh - -COPY install/ubuntu_install_vulkan.sh /install/ubuntu_install_vulkan.sh -RUN bash /install/ubuntu_install_vulkan.sh - -ENV VULKAN_SDK=/usr - -COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh -RUN bash /install/ubuntu_install_cmake_source.sh - -RUN git clone https://github.com/KhronosGroup/OpenCL-Headers /usr/local/OpenCL-Headers/ - -# Build TVM -RUN cd /usr && \ - git clone --depth=1 https://github.com/apache/tvm tvm --recursive && \ - cd /usr/tvm && \ - mkdir -p build && \ - cd build && \ - cmake \ - -DUSE_LLVM=llvm-config-15 \ - -DUSE_RPC=ON \ - -DUSE_SORT=ON \ - -DUSE_VULKAN=ON \ - .. && \ - make -j10 - -# Environment variables -ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/vta/python:${PYTHONPATH} -ENV ANDROID_HOME=/opt/android-sdk-linux/ diff --git a/docker/Dockerfile.demo_mrvl b/docker/Dockerfile.demo_mrvl deleted file mode 100644 index b50944d2c20e..000000000000 --- a/docker/Dockerfile.demo_mrvl +++ /dev/null @@ -1,47 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# prebuild ci-cpu image -FROM tlcpack/ci-cpu:20230604-060130-0af9ff90e - -# Cloning TVM's main repo -RUN echo "Cloning TVM source & submodules" -ENV TVM_PAR_DIR="/usr" -RUN mkdir -p TVM_PAR_DIR && \ - cd ${TVM_PAR_DIR} && \ - git clone --depth=1 https://github.com/apache/tvm tvm --recursive - -# Building TVM -RUN echo "Building TVM" -ENV TVM_HOME="/usr/tvm" -ENV TVM_BUILD_DIR="${TVM_HOME}/build" -RUN mkdir -p ${TVM_BUILD_DIR} && \ - cd ${TVM_HOME} && \ - ./tests/scripts/task_config_build_mrvl.sh build && \ - cd ${TVM_BUILD_DIR} && \ - cmake .. && \ - make -j$(nproc) - -RUN echo "Building Python package" -ENV PYTHONPATH=${TVM_HOME}/python:${PYTHONPATH} -RUN cd ${TVM_HOME}/python && python3 setup.py install --user - -# Fetching Marvell binaries -RUN cd /opt && \ - git clone https://github.com/MarvellEmbeddedProcessors/MarvellMLTools.git - -ENV PATH="/opt/MarvellMLTools/bin:$PATH" diff --git a/docker/Dockerfile.demo_opencl b/docker/Dockerfile.demo_opencl deleted file mode 100644 index 9112ccc0d8ea..000000000000 --- a/docker/Dockerfile.demo_opencl +++ /dev/null @@ -1,74 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# USAGE: sudo docker build libs/tvm -f libs/tvm/docker/Dockerfile.ocl -t l4b/tvm:ocl - -# REFERENCE: https://docs.docker.com/engine/reference/builder - -FROM ubuntu:22.04 - -COPY utils/apt-install-and-clear.sh /usr/local/bin/apt-install-and-clear - -RUN echo "Labelling this image" -LABEL Description="Docker image for TVM built with OpenCL support" - -RUN echo "Preparing to install dependencies" -RUN apt-get update -# ENV DEBIAN_FRONTEND noninteractive -RUN echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections - -RUN echo "Installing utility libraries" -RUN apt-install-and-clear -y apt-utils sudo cmake g++ llvm git libopenblas-dev - -# RUN echo "Installing gtest" -# RUN apt-install-and-clear -y libgtest-dev -# RUN cd /usr/src/gtest && cmake CMakeLists.txt && make && cp *.a /usr/lib - -RUN echo "Installing Python" -RUN apt-install-and-clear -y python3-dev python3-pip -RUN pip3 install setuptools numpy pytest cython scipy tornado psutil xgboost - -RUN echo "Installing Jupyter notebook" -RUN pip3 install matplotlib Image "Pillow<7" jupyter[notebook] - -RUN echo "Installing OpenCL libraries" -RUN apt-install-and-clear -y libviennacl-dev mesa-opencl-icd ocl-icd-opencl-dev clinfo -RUN apt-install-and-clear -y libclblas-dev libclfft-dev libclsparse-dev - -RUN echo "Upgrading dependencies" -RUN apt-get upgrade -y - -RUN echo "Cloning TVM source & submodules" -ENV TVM_PAR_DIR="/usr" -RUN mkdir -p TVM_PAR_DIR && \ - cd ${TVM_PAR_DIR} && \ - git clone --depth=1 https://github.com/apache/tvm tvm --recursive -#RUN git submodule update --init --recursive - - -RUN echo "Building TVM" -#USE_BLAS: "openblas" | "mkl" | "atlas" | "apple" | "none" -ENV TVM_HOME="/usr/tvm" -ENV TVM_BUILD_DIR="${TVM_HOME}/build" -RUN mkdir -p ${TVM_BUILD_DIR} && \ - cd ${TVM_BUILD_DIR} && \ - cmake .. -DUSE_BLAS=openblas -DUSE_LLVM=ON -DUSE_OPENCL=ON && \ - make -j6 - -RUN echo "Building Python package" -ENV PYTHONPATH=${TVM_HOME}/python:${PYTHONPATH} -RUN cd ${TVM_HOME}/python && python3 setup.py install --user diff --git a/docker/Dockerfile.demo_rocm b/docker/Dockerfile.demo_rocm deleted file mode 100644 index 4c6095ec4802..000000000000 --- a/docker/Dockerfile.demo_rocm +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Demo docker for ROCm -FROM ubuntu:22.04 - -COPY utils/apt-install-and-clear.sh /usr/local/bin/apt-install-and-clear - -COPY install/ubuntu_setup_tz.sh /install/ubuntu_setup_tz.sh -RUN bash /install/ubuntu_setup_tz.sh - -COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh -RUN bash /install/ubuntu_install_core.sh - -ENV TVM_VENV /venv/apache-tvm-py3.9 -COPY python/bootstrap/lockfiles /install/python/bootstrap/lockfiles -COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh 3.9 -ENV PATH ${TVM_VENV}/bin:$PATH -ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. - -COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh -RUN bash /install/ubuntu_install_python_package.sh - -COPY install/ubuntu2204_install_llvm.sh /install/ubuntu2204_install_llvm.sh -RUN bash /install/ubuntu2204_install_llvm.sh - -COPY install/ubuntu_install_rocm.sh /install/ubuntu_install_rocm.sh -RUN bash /install/ubuntu_install_rocm.sh - -ENV PATH "${PATH}:/opt/rocm/bin" diff --git a/docker/dev_common.sh b/docker/dev_common.sh index 763da67ef854..fd5a8f91bd1d 100755 --- a/docker/dev_common.sh +++ b/docker/dev_common.sh @@ -27,8 +27,7 @@ INVOCATION_PWD="$(pwd)" GIT_TOPLEVEL=$(cd $(dirname ${BASH_SOURCE[0]}) && git rev-parse --show-toplevel) -DOCKER_IS_ROOTLESS=$(docker info 2> /dev/null | grep 'Context: \+rootless') - +DOCKER_IS_ROOTLESS=$(docker info 2> /dev/null | grep 'Context: \+rootless' || true) function lookup_image_spec() { img_spec=$(python3 "${GIT_TOPLEVEL}/ci/jenkins/data.py" "$1") diff --git a/docker/lint.sh b/docker/lint.sh index 4f7bca445a9f..7225fa981fd9 100755 --- a/docker/lint.sh +++ b/docker/lint.sh @@ -55,10 +55,18 @@ function run_lint_step() { cmd=( tests/lint/cpplint.sh ) ;; flake8) - cmd=( tests/lint/flake8.sh ) + if [ $inplace_fix -eq 0 ]; then + cmd=( tests/lint/flake8.sh ) + else + cmd=( tests/lint/flake8.sh --rev origin/main ) + fi ;; pylint) - cmd=( tests/lint/pylint.sh ) + if [ $inplace_fix -eq 0 ]; then + cmd=( tests/lint/pylint.sh ) + else + cmd=( tests/lint/pylint.sh --rev origin/main ) + fi ;; python_format) if [ $inplace_fix -eq 0 ]; then @@ -90,7 +98,11 @@ function run_lint_step() { shift if [ $validate_only -eq 0 ]; then - run_docker -it "ci_lint" "${cmd[@]}" + if [ -t 0 ]; then + run_docker -it "ci_lint" "${cmd[@]}" + else + run_docker "ci_lint" "${cmd[@]}" + fi fi } diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst index 6015c4351076..aa7f5e67854c 100644 --- a/docs/arch/device_target_interactions.rst +++ b/docs/arch/device_target_interactions.rst @@ -153,10 +153,10 @@ then be registered with the following steps. #. Register the function to the tvm registry:: - TVM_FFI_STATIC_INIT_BLOCK({ + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("device_api.foo", FooDeviceAPI::Global); - }); + } .. _base.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h @@ -169,7 +169,7 @@ then be registered with the following steps. enum value to a string representation. This string representation should match the name given to ``GlobalDef().def``. -#. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of +#. Add entries to the ``_DEVICE_TYPE_TO_NAME`` and ``_DEVICE_NAME_TO_TYPE`` dictionaries of :py:class:`tvm.runtime.Device` for the new enum value. @@ -228,10 +228,10 @@ the same name as was used in the ``TVM_REGISTER_TARGET_KIND`` definition above. :: tvm::runtime::Module GeneratorFooCode(IRModule mod, Target target); - TVM_FFI_STATIC_INIT_BLOCK({ + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.foo", GeneratorFooCode); - }); + } The code generator takes two arguments. The first is the ``IRModule`` to compile, and the second is the ``Target`` that describes the device diff --git a/docs/arch/index.rst b/docs/arch/index.rst index 1acd38fb04c7..4985e91c0b7d 100644 --- a/docs/arch/index.rst +++ b/docs/arch/index.rst @@ -133,7 +133,7 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu import tvm # Example runtime execution program in python, with type annotated mod: tvm.runtime.Module = tvm.runtime.load_module("compiled_artifact.so") - arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.cuda(0)) + arr: tvm.runtime.Tensor = tvm.runtime.tensor([1, 2, 3], device=tvm.cuda(0)) fun: tvm.runtime.PackedFunc = mod["addone"] fun(arr) print(arr.numpy()) @@ -142,7 +142,7 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu :py:class:`tvm.runtime.Module` encapsulates the result of compilation. A runtime.Module contains a GetFunction method to obtain PackedFuncs by name. :py:class:`tvm.runtime.PackedFunc` is a type-erased function interface for both the generated functions. A runtime.PackedFunc can take arguments and return values with the -following types: POD types(int, float), string, runtime.PackedFunc, runtime.Module, runtime.NDArray, and other sub-classes of runtime.Object. +following types: POD types(int, float), string, runtime.PackedFunc, runtime.Module, runtime.Tensor, and other sub-classes of runtime.Object. :py:class:`tvm.runtime.Module` and :py:class:`tvm.runtime.PackedFunc` are powerful mechanisms to modularize the runtime. For example, to get the above `addone` function on CUDA, we can use LLVM to generate the host-side code to compute the launching parameters(e.g. size of the thread groups) and then call into another PackedFunc from a CUDAModule that is backed by the CUDA driver API. The same mechanism can be used for OpenCL kernels. @@ -155,7 +155,7 @@ The above example only deals with a simple `addone` function. The code snippet b factory: tvm.runtime.Module = tvm.runtime.load_module("resnet18.so") # Create a stateful graph execution module for resnet18 on cuda(0) gmod: tvm.runtime.Module = factory["resnet18"](tvm.cuda(0)) - data: tvm.runtime.NDArray = get_input_data() + data: tvm.runtime.Tensor = get_input_data() # set input gmod["set_input"](0, data) # execute the model diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index 4bf3abceb0ca..ef3672058c61 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -93,9 +93,9 @@ needs to be executed when running under a user-provided optimization level. The .. code:: c++ class PassInfoNode : public Object { - String name; + ffi::String name; int opt_level; - Array required; + ffi::Array required; }; PassContext @@ -125,11 +125,11 @@ Python APIs to create a compilation pipeline using pass context. class PassContextNode : public Object { public: int opt_level{2}; - tvm::Array required_pass; - tvm::Array disabled_pass; - mutable Optional diag_ctx; - Map config; - Array instruments; + tvm::ffi::Array required_pass; + tvm::ffi::Array disabled_pass; + mutable ffi::Optional diag_ctx; + ffi::Map config; + ffi::Array instruments; }; class PassContext : public NodeRef { @@ -262,7 +262,7 @@ of passes for execution. class SequentialPassNode : PassNode { PassInfo pass_info; // Passes need to be executed. - Array passes; + ffi::Array passes; bool PassEnabled(const PassInfo& info) const; Module operator()(const Module& mod, const PassContext& pass_ctx) const final; }; @@ -321,22 +321,22 @@ favorably use Python APIs to create a specific pass object. Pass CreateFunctionPass( std::function pass_func, int opt_level, - String name, - Array required); + ffi::String name, + ffi::Array required); Pass CreatePrimFuncPass( std::function pass_func, int opt_level, - String name, - Array required); + ffi::String name, + ffi::Array required); Pass CreateModulePass( std::function pass_func, int opt_level, - String name, - Array required); + ffi::String name, + ffi::Array required); - Pass Sequential(tvm::Array passes, PassInfo pass_info); + Pass Sequential(tvm::ffi::Array passes, PassInfo pass_info); Pass Registration ^^^^^^^^^^^^^^^^^ @@ -376,10 +376,10 @@ Python when needed. return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } - TVM_FFI_STATIC_INIT_BLOCK({ + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant); - }); + } } // namespace transform @@ -440,7 +440,7 @@ Multiple ``PassInstrument`` instances can be registed into a single class PassInstrumentNode : public Object { public: - String name; + ffi::String name; virtual void EnterPassContext() const = 0; virtual void ExitPassContext() const = 0; virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; @@ -451,7 +451,7 @@ Multiple ``PassInstrument`` instances can be registed into a single class PassInstrument : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PassInstrument, ObjectRef, PassInstrumentNode); }; } // namespace instrument @@ -552,7 +552,7 @@ a certain scope. .. code:: python - @tvm.ffi.register_object("transform.PassContext") + @tvm_ffi.register_object("transform.PassContext") class PassContext(tvm.runtime.Object): def __enter__(self): _transform.EnterPassContext(self) diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index 9e663b072810..99c83de8376a 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -80,10 +80,10 @@ The following example registers PackedFunc in C++ and calls from python. .. code:: c // register a global packed function in c++ - TVM_FFI_STATIC_INIT_BLOCK({ + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("myadd", MyAdd); - }); + } .. code:: python @@ -112,13 +112,13 @@ we can pass functions from python (as PackedFunc) to C++. .. code:: c - TVM_FFI_STATIC_INIT_BLOCK({ + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("callhello", [](ffi::PackedArgs args, ffi::Any* rv) { ffi::Function f = args[0].cast(); f("hello world"); }); - }); + } .. code:: python @@ -227,12 +227,10 @@ Each ``Object`` subclass will override this to register its members. Here is an namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &IntImmNode::value); } - - static constexpr const char* _type_key = "ir.IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IntImm", IntImmNode, PrimExprNode); }; // in cc file - TVM_FFI_STATIC_INIT_BLOCK({ IntImmNode::RegisterReflection(); }); + TVM_FFI_STATIC_INIT_BLOCK() { IntImmNode::RegisterReflection(); } The RegisterReflection gives us a reflection API to register each member of the object. We can use this function to visit the node and serialize any language object recursively. diff --git a/docs/conf.py b/docs/conf.py index 60ac4077e87d..42a7bf25a33d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -121,6 +121,7 @@ def split_code_and_text_blocks(source_file, return_node, real_func): # This header replaces the default sphinx-gallery one in sphinx_gallery/gen_rst.py. +# Colab button has been temporarily disabled due to prebuilt packages unavailability. COLAB_HTML_HEADER = """ .. DO NOT EDIT. THIS FILE WAS AUTOMATICALLY GENERATED BY .. TVM'S MONKEY-PATCHED VERSION OF SPHINX-GALLERY. TO MAKE @@ -132,13 +133,7 @@ def split_code_and_text_blocks(source_file, return_node, real_func): .. note:: :class: sphx-glr-download-link-note - This tutorial can be used interactively with Google Colab! You can also click - :ref:`here ` to run the Jupyter notebook locally. - - .. image:: {button_svg} - :align: center - :target: {colab_url} - :width: 300px + You can click :ref:`here ` to run the Jupyter notebook locally. .. rst-class:: sphx-glr-example-title @@ -162,7 +157,11 @@ def split_code_and_text_blocks(source_file, return_node, real_func): def save_rst_example( example_rst, example_file, time_elapsed, memory_used, gallery_conf, language, real_func ): - """Monkey-patch save_rst_example to include the "Open in Colab" button.""" + """Monkey-patch save_rst_example to customize the tutorial header. + + Note: Colab button has been temporarily disabled. The colab_url and button_svg + are still generated but not used in the header template. + """ # The url is the md5 hash of the notebook path. example_fname = os.path.relpath(example_file, gallery_conf["src_dir"]) @@ -171,6 +170,7 @@ def save_rst_example( digest = md5(notebook_path.encode()).hexdigest() # Fixed documentation versions must link to different (earlier) .ipynb notebooks. + # Note: colab_url is generated but not currently used in the header template. colab_url = f"{COLAB_URL_BASE}/{IPYTHON_GITHUB_BASE}" if "dev" not in version: colab_url += version + "/" @@ -507,9 +507,7 @@ def force_gc(gallery_conf, fname): header_links = [ ("Community", "https://tvm.apache.org/community"), ("Download", "https://tvm.apache.org/download"), - ("Blog", "https://tvm.apache.org/blog"), ("Docs", "https://tvm.apache.org/docs"), - ("Conference", "https://tvmconf.org"), ("Github", "https://github.com/apache/tvm/"), ] diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py index 3d07f6227b96..74b4406061b9 100644 --- a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py +++ b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py @@ -204,9 +204,9 @@ def mm_relu(a: T.handle, b: T.handle, c: T.handle): def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int): - A = tvm.nd.array(np.random.uniform(size=(m, k)).astype("float32")) - B = tvm.nd.array(np.random.uniform(size=(k, n)).astype("float32")) - C = tvm.nd.array(np.zeros((m, n), dtype="float32")) + A = tvm.runtime.tensor(np.random.uniform(size=(m, k)).astype("float32")) + B = tvm.runtime.tensor(np.random.uniform(size=(k, n)).astype("float32")) + C = tvm.runtime.tensor(np.zeros((m, n), dtype="float32")) lib(A, B, C) return C.numpy() diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py index 702b53011b48..eb1b2eb02029 100644 --- a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py +++ b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py @@ -72,9 +72,9 @@ def main( b_np = np.random.uniform(size=(128, 128)).astype("float32") c_np = a_np @ b_np -a_nd = tvm.nd.array(a_np) -b_nd = tvm.nd.array(b_np) -c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32")) +a_nd = tvm.runtime.tensor(a_np) +b_nd = tvm.runtime.tensor(b_np) +c_nd = tvm.runtime.tensor(np.zeros((128, 128), dtype="float32")) def evaluate(mod: tvm.IRModule): diff --git a/docs/download_3rdparty_embeds.py b/docs/download_3rdparty_embeds.py index b658d82d63f2..68dfe0662b97 100644 --- a/docs/download_3rdparty_embeds.py +++ b/docs/download_3rdparty_embeds.py @@ -310,5 +310,9 @@ def download_and_replace_urls(files: Optional[List[str]] = None, verbose: bool = if __name__ == "__main__": args = argparse.ArgumentParser() args.add_argument("-v", "--verbose", action="store_true") + args.add_argument("-p", "--path", type=str, default=None) args = args.parse_args() + + if args.path is not None: + HTML_DIR = args.path download_and_replace_urls(verbose=args.verbose) diff --git a/docs/get_started/tutorials/ir_module.py b/docs/get_started/tutorials/ir_module.py index c53d0ca5ef74..8bb8fb77a445 100644 --- a/docs/get_started/tutorials/ir_module.py +++ b/docs/get_started/tutorials/ir_module.py @@ -237,7 +237,7 @@ def main( vm = relax.VirtualMachine(exec, dev) raw_data = np.random.rand(1, 784).astype("float32") -data = tvm.nd.array(raw_data, dev) +data = tvm.runtime.tensor(raw_data, dev) cpu_out = vm["main"](data, *params_from_torch["main"]).numpy() print(cpu_out) @@ -267,8 +267,8 @@ def main( dev = tvm.device("cuda", 0) vm = relax.VirtualMachine(exec, dev) # Need to allocate data and params on GPU device -data = tvm.nd.array(raw_data, dev) -gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]] +data = tvm.runtime.tensor(raw_data, dev) +gpu_params = [tvm.runtime.tensor(p, dev) for p in params_from_torch["main"]] gpu_out = vm["main"](data, *gpu_params).numpy() print(gpu_out) diff --git a/docs/get_started/tutorials/quick_start.py b/docs/get_started/tutorials/quick_start.py index 1153108c9632..8762564c02bd 100644 --- a/docs/get_started/tutorials/quick_start.py +++ b/docs/get_started/tutorials/quick_start.py @@ -141,9 +141,9 @@ def forward(self, x): device = tvm.cpu() vm = relax.VirtualMachine(ex, device) data = np.random.rand(1, 784).astype("float32") -tvm_data = tvm.nd.array(data, device=device) +tvm_data = tvm.runtime.tensor(data, device=device) params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec] -params = [tvm.nd.array(param, device=device) for param in params] +params = [tvm.runtime.tensor(param, device=device) for param in params] print(vm["forward"](tvm_data, *params).numpy()) ################################################################################ @@ -158,15 +158,15 @@ def forward(self, x): # prefill_logits = vm["prefill"](inputs, weight, kv_cache) # decoded_logits = vm["decode"](inputs, weight, kv_cache) # -# - TVM runtime comes with native data structures, such as NDArray, can also have zero +# - TVM runtime comes with native data structures, such as Tensor, can also have zero # copy exchange with existing ecosystem (DLPack exchange with PyTorch) # # .. code-block:: Python # -# # Convert PyTorch tensor to TVM NDArray -# x_tvm = tvm.nd.from_dlpack(x_torch.to_dlpack()) -# # Convert TVM NDArray to PyTorch tensor -# x_torch = torch.from_dlpack(x_tvm.to_dlpack()) +# # Convert PyTorch tensor to TVM Tensor +# x_tvm = tvm.runtime.from_dlpack(x_torch) +# # Convert TVM Tensor to PyTorch tensor +# x_torch = torch.from_dlpack(x_tvm) # # - TVM runtime works in non-python environments, so it works on settings such as mobile # @@ -175,14 +175,14 @@ def forward(self, x): # // C++ snippet # runtime::Module vm = ex.GetFunction("load_executable")(); # vm.GetFunction("init")(...); -# NDArray out = vm.GetFunction("prefill")(data, weight, kv_cache); +# Tensor out = vm.GetFunction("prefill")(data, weight, kv_cache); # # .. code-block:: Java # # // Java snippet # Module vm = ex.getFunction("load_executable").invoke(); # vm.getFunction("init").pushArg(...).invoke; -# NDArray out = vm.getFunction("prefill").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke(); +# Tensor out = vm.getFunction("prefill").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke(); # ################################################################################ diff --git a/docs/how_to/tutorials/cross_compilation_and_rpc.py b/docs/how_to/tutorials/cross_compilation_and_rpc.py index a6b7206b3efa..ef1ca629ce4c 100644 --- a/docs/how_to/tutorials/cross_compilation_and_rpc.py +++ b/docs/how_to/tutorials/cross_compilation_and_rpc.py @@ -182,8 +182,8 @@ # create arrays on the remote device dev = remote.cpu() -a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev) -b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) +a = tvm.runtime.tensor(np.random.uniform(size=1024).astype(A.dtype), dev) +b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev) # the function will run on the remote device func(a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) @@ -249,20 +249,334 @@ def run_opencl(): # run dev = remote.cl() - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=1024).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev) func(a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) print("OpenCL test passed!") +######################################################################### +# Deploy PyTorch Models to Remote Devices with RPC +# ------------------------------------------------ +# The above examples demonstrate cross compilation and RPC using low-level +# TensorIR (via TE). For deploying complete neural network models from frameworks +# like PyTorch or ONNX, TVM's Relax provides a higher-level abstraction that is +# better suited for end-to-end model compilation. +# +# This section shows a modern workflow for deploying models to **any remote device**: +# +# 1. Import a PyTorch model and convert it to Relax +# 2. Cross-compile for the target architecture (ARM, x86, RISC-V, etc.) +# 3. Deploy via RPC to a remote device +# 4. Run inference remotely +# +# This workflow is applicable to various deployment scenarios: +# +# - **ARM devices**: Raspberry Pi, NVIDIA Jetson, mobile phones +# - **x86 servers**: Remote Linux servers, cloud instances +# - **Embedded systems**: RISC-V boards, custom hardware +# - **Accelerators**: Remote machines with GPUs, TPUs, or other accelerators +# +# .. note:: +# This example uses PyTorch for demonstration, but the workflow is identical +# for ONNX models. Simply replace ``from_exported_program()`` with +# ``from_onnx(model, keep_params_in_input=True)`` and follow the same steps. + +# First, let's check if PyTorch is available +try: + import torch + from torch.export import export + + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + + +def run_pytorch_model_via_rpc(): + """ + Demonstrates the complete workflow of deploying a PyTorch model to an ARM device via RPC. + """ + if not HAS_TORCH: + print("Skipping PyTorch example (PyTorch not installed)") + return + + from tvm import relax + from tvm.relax.frontend.torch import from_exported_program + + ###################################################################### + # Step 1: Define and Export PyTorch Model + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We use a simple MLP model for demonstration. In practice, this could be + # any PyTorch model (ResNet, BERT, etc.). + + class TorchMLP(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(28 * 28, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 10), + ) + + def forward(self, data: torch.Tensor) -> torch.Tensor: + return self.net(data) + + # Export the model using PyTorch 2.x export API + torch_model = TorchMLP().eval() + example_args = (torch.randn(1, 1, 28, 28, dtype=torch.float32),) + + with torch.no_grad(): + exported_program = export(torch_model, example_args) + + ###################################################################### + # Step 2: Convert to Relax and Prepare for Compilation + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Convert the exported PyTorch program to TVM's Relax representation + + mod = from_exported_program(exported_program, keep_params_as_input=True) + # Separate parameters from the model for flexible deployment + mod, params = relax.frontend.detach_params(mod) + + print("Converted PyTorch model to Relax:") + print(f" - Number of parameters: {len(params['main'])}") + + ###################################################################### + # Step 3: Cross-Compile for Target Device + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Compile the model for the target device architecture. The target + # configuration depends on your deployment scenario. + + if local_demo: + # For demonstration on local machine, use local target + target = tvm.target.Target("llvm") + print("Using local target for demonstration") + else: + # Choose the appropriate target for your device: + # + # ARM devices: + # - Raspberry Pi 3/4 (32-bit): "llvm -mtriple=armv7l-linux-gnueabihf" + # - Raspberry Pi 4 (64-bit) / Jetson: "llvm -mtriple=aarch64-linux-gnu" + # - Android: "llvm -mtriple=aarch64-linux-android" + # + # x86 servers: + # - Linux x86_64: "llvm -mtriple=x86_64-linux-gnu" + # - With AVX-512: "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512" + # + # RISC-V: + # - RV64: "llvm -mtriple=riscv64-unknown-linux-gnu" + # + # GPU targets: + # - CUDA: tvm.target.Target("cuda", host="llvm -mtriple=x86_64-linux-gnu") + # - OpenCL: tvm.target.Target("opencl", host="llvm -mtriple=aarch64-linux-gnu") + # + # For this example, we use ARM 64-bit + target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu") + print(f"Cross-compiling for target: {target}") + + # Apply optimization pipeline + pipeline = relax.get_pipeline() + with target: + built_mod = pipeline(mod) + + # Compile to executable + executable = tvm.compile(built_mod, target=target) + + # Export to shared library + lib_path = temp.relpath("model_deployed.so") + executable.export_library(lib_path) + print(f"Exported library to: {lib_path}") + + # Save parameters separately + import numpy as np + + params_path = temp.relpath("model_params.npz") + param_arrays = {f"p_{i}": p.numpy() for i, p in enumerate(params["main"])} + np.savez(params_path, **param_arrays) + print(f"Saved parameters to: {params_path}") + + ###################################################################### + # Step 4: Deploy to Remote Device via RPC + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Connect to the remote device, upload the compiled library and parameters, + # then run inference remotely. This works for any device with TVM RPC server. + # + # Note: The following code demonstrates the RPC workflow. In local_demo mode, + # we skip actual execution to avoid LocalSession compatibility issues. + + if local_demo: + # For demonstration, show the code structure without execution + print("\nRPC workflow (works for any remote device):") + print("=" * 50) + print("1. Start RPC server on target device:") + print(" python -m tvm.exec.rpc_server --host 0.0.0.0 --port=9090") + print("\n2. Connect from local machine:") + print(" remote = rpc.connect('DEVICE_IP', 9090)") + print("\n3. Upload compiled library:") + print(" remote.upload('model_deployed.so')") + print(" remote.upload('model_params.npz')") + print("\n4. Load and run remotely:") + print(" lib = remote.load_module('model_deployed.so')") + print(" vm = relax.VirtualMachine(lib, remote.cpu())") + print(" result = vm['main'](input, *params)") + print("\nDevice examples:") + print(" - Raspberry Pi: 192.168.1.100") + print(" - Remote server: ssh tunnel or direct IP") + print(" - NVIDIA Jetson: 10.0.0.50") + print(" - Cloud instance: public IP") + print("\nTo run actual RPC, set local_demo=False") + return # Skip actual RPC execution in demo mode + + # Actual RPC workflow for real deployment + # Connect to remote device (works for ARM, x86, RISC-V, etc.) + # Make sure the RPC server is running on the device: + # python -m tvm.exec.rpc_server --host 0.0.0.0 --port=9090 + device_host = "192.168.1.100" # Replace with your device IP + device_port = 9090 + remote = rpc.connect(device_host, device_port) + print(f"Connected to remote device at {device_host}:{device_port}") + + # Upload library and parameters to remote device + remote.upload(lib_path) + remote.upload(params_path) + print("Uploaded files to remote device") + + # Load the library on the remote device + lib = remote.load_module("model_deployed.so") + + # Choose device on remote machine + # For CPU: dev = remote.cpu() + # For CUDA GPU: dev = remote.cuda(0) + # For OpenCL: dev = remote.cl(0) + dev = remote.cpu() + + # Create VM and load parameters + vm = relax.VirtualMachine(lib, dev) + + # Load parameters from the uploaded file + # Note: In practice, you might load this from the remote filesystem + params_npz = np.load(params_path) + remote_params = [tvm.runtime.tensor(params_npz[f"p_{i}"], dev) for i in range(len(params_npz))] + + ###################################################################### + # Step 5: Run Inference on Remote Device + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Execute the model on the remote ARM device and retrieve results + + # Prepare input data + input_data = np.random.randn(1, 1, 28, 28).astype("float32") + remote_input = tvm.runtime.tensor(input_data, dev) + + # Run inference on remote device + output = vm["main"](remote_input, *remote_params) + + # Extract result (handle both tuple and single tensor outputs) + if isinstance(output, tvm.ir.Array) and len(output) > 0: + result = output[0] + else: + result = output + + # Retrieve result from remote device to local + result_np = result.numpy() + print(f"Inference completed on remote device") + print(f" Output shape: {result_np.shape}") + print(f" Predicted class: {np.argmax(result_np)}") + + ###################################################################### + # Step 6: Performance Evaluation (Optional) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Measure inference time on the remote device, excluding network overhead + + time_f = vm.time_evaluator("main", dev, number=10, repeat=3) + prof_res = time_f(remote_input, *remote_params) + print(f"Inference time on remote device: {prof_res.mean * 1000:.2f} ms") + + ###################################################################### + # Notes on Performance Optimization + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # + # For optimal performance on target devices, consider: + # + # 1. **Auto-tuning with MetaSchedule**: Use automated search to find + # optimal schedules for your specific hardware: + # + # .. code-block:: python + # + # mod = relax.get_pipeline( + # "static_shape_tuning", + # target=target, + # total_trials=2000 + # )(mod) + # + # 2. **Quick optimization with DLight**: Apply pre-defined performant schedules: + # + # .. code-block:: python + # + # from tvm import dlight as dl + # with target: + # mod = dl.ApplyDefaultSchedule()(mod) + # + # 3. **Architecture-specific optimizations**: + # + # - ARM NEON SIMD: ``-mattr=+neon`` + # - x86 AVX-512: ``-mcpu=skylake-avx512`` + # - RISC-V Vector: ``-mattr=+v`` + # + # .. code-block:: python + # + # # Example: ARM with NEON + # target = tvm.target.Target( + # "llvm -mtriple=aarch64-linux-gnu -mattr=+neon" + # ) + # + # # Example: x86 with AVX-512 + # target = tvm.target.Target( + # "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512" + # ) + # + # See :doc:`e2e_opt_model ` for detailed + # tuning examples. + + +# Run the PyTorch RPC example if PyTorch is available +if HAS_TORCH and local_demo: + try: + run_pytorch_model_via_rpc() + except Exception: + pass # Silently skip if execution fails + + ###################################################################### # Summary # ------- # This tutorial provides a walk through of cross compilation and RPC # features in TVM. # -# - Set up an RPC server on the remote device. -# - Set up the target device configuration to cross compile the kernels on the -# local machine. -# - Upload and run the kernels remotely via the RPC API. +# We demonstrated two approaches: +# +# **Low-level TensorIR (TE) approach** - for understanding fundamentals: +# +# - Define computations using Tensor Expression +# - Cross-compile for ARM targets +# - Deploy and run via RPC +# +# **High-level Relax approach** - for deploying complete models: +# +# - Import models from PyTorch (or ONNX) +# - Convert to Relax representation +# - Cross-compile for ARM Linux devices +# - Deploy to remote devices via RPC +# - Run inference and evaluate performance +# +# Key takeaways: +# +# - Set up an RPC server on the remote device +# - Cross-compile on a powerful local machine for resource-constrained targets +# - Upload and execute compiled modules remotely via the RPC API +# - Measure performance excluding network overhead +# +# For complete model deployment workflows, see also: +# +# - :doc:`export_and_load_executable ` - Export and load compiled models +# - :doc:`e2e_opt_model ` - End-to-end optimization with auto-tuning diff --git a/docs/how_to/tutorials/customize_opt.py b/docs/how_to/tutorials/customize_opt.py index d215654019f0..2e2747d61fc5 100644 --- a/docs/how_to/tutorials/customize_opt.py +++ b/docs/how_to/tutorials/customize_opt.py @@ -209,8 +209,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR dev = tvm.device("cuda", 0) vm = relax.VirtualMachine(ex, dev) # Need to allocate data and params on GPU device -data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev) -gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params] +data = tvm.runtime.tensor(np.random.rand(*input_shape).astype("float32"), dev) +gpu_params = [tvm.runtime.tensor(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params] gpu_out = vm["forward"](data, *gpu_params).numpy() print(gpu_out) diff --git a/docs/how_to/tutorials/e2e_opt_model.py b/docs/how_to/tutorials/e2e_opt_model.py index 88cc86bfa800..507864160d9f 100644 --- a/docs/how_to/tutorials/e2e_opt_model.py +++ b/docs/how_to/tutorials/e2e_opt_model.py @@ -95,13 +95,38 @@ # leverage MetaSchedule to tune the model and store the tuning logs to the database. We also # apply the database to the model to get the best performance. # +# The ResNet18 model will be divided into 20 independent tuning tasks during compilation. +# To ensure each task receives adequate tuning resources in one iteration while providing +# early feedback: +# +# - To quickly observe tuning progress, each task is allocated a maximum of 16 trials per +# iteration (controlled by ``MAX_TRIALS_PER_TASK=16``). We should set ``TOTAL_TRIALS`` +# to at least ``320 (20 tasks * 16 trials)`` ensures every task receives one full iteration +# of tuning. We set it to 512 in our configuration to allow for several more iterations, +# aiming to explore a wider parameter space and potentially achieve better performance. +# - If ``MAX_TRIALS_PER_TASK == None``, the system defaults to ``TOTAL_TRIALS`` trials per +# task per iteration. An insufficient ``TOTAL_TRIALS`` setting may lead to undersubscribed +# tuning, potentially skipping some tasks entirely. Explicitly setting both parameters +# avoids this issue and provides deterministic resource allocation across all tasks. +# +# Note: These parameter settings are optimized for quick tutorial demonstration. For production +# deployments requiring higher performance, we recommend adjusting both ``MAX_TRIALS_PER_TASK`` +# and ``TOTAL_TRIALS`` to larger values. This allows more extensive search space exploration +# and typically yields better performance outcomes. -TOTAL_TRIALS = 8000 # Change to 20000 for better performance if needed +TOTAL_TRIALS = 512 # Change to 20000 for better performance if needed +MAX_TRIALS_PER_TASK = 16 # Change to more trials per task for better performance if needed target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") # Change to your target device work_dir = "tuning_logs" if not IS_IN_CI: - mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod) + mod = relax.get_pipeline( + "static_shape_tuning", + target=target, + work_dir=work_dir, + total_trials=TOTAL_TRIALS, + max_trials_per_task=MAX_TRIALS_PER_TASK, + )(mod) # Only show the main function mod["main"].show() @@ -113,12 +138,14 @@ # We skip this step in the CI environment. if not IS_IN_CI: - ex = tvm.compile(mod, target="cuda") + with target: + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + ex = tvm.compile(mod, target=target) dev = tvm.device("cuda", 0) vm = relax.VirtualMachine(ex, dev) # Need to allocate data and params on GPU device - gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev) - gpu_params = [tvm.nd.array(p, dev) for p in params["main"]] - gpu_out = vm["main"](gpu_data, *gpu_params).numpy() + gpu_data = tvm.runtime.tensor(np.random.rand(1, 3, 224, 224).astype("float32"), dev) + gpu_params = [tvm.runtime.tensor(p, dev) for p in params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params)[0].numpy() print(gpu_out.shape) diff --git a/docs/how_to/tutorials/export_and_load_executable.py b/docs/how_to/tutorials/export_and_load_executable.py new file mode 100644 index 000000000000..9665db48cb5b --- /dev/null +++ b/docs/how_to/tutorials/export_and_load_executable.py @@ -0,0 +1,378 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _deploy_export_and_load_executable: + +Export and Load Relax Executables +================================= + +This tutorial walks through exporting a compiled Relax module to a shared +object, loading it back into the TVM runtime, and running the result either +interactively or from a standalone script. This tutorial demonstrates how +to turn Relax (or imported PyTorch / ONNX) programs into deployable artifacts +using ``tvm.relax`` APIs. + +.. note:: + This tutorial uses PyTorch as the source format, but the export/load workflow + is the same for ONNX models. For ONNX, use ``from_onnx(model, keep_params_in_input=True)`` + instead of ``from_exported_program()``, then follow the same steps for building, + exporting, and loading. +""" + +###################################################################### +# Introduction +# ------------ +# TVM builds Relax programs into ``tvm.runtime.Executable`` objects. These +# contain VM bytecode, compiled kernels, and constants. By exporting the +# executable with :py:meth:`export_library`, you obtain a shared library (for +# example ``.so`` on Linux) that can be shipped to another machine, uploaded +# via RPC, or loaded back later with the TVM runtime. This tutorial shows the +# exact steps end-to-end and explains what files are produced along the way. + +import os +from pathlib import Path + +try: + import torch + from torch.export import export +except ImportError: # pragma: no cover + torch = None # type: ignore + + +###################################################################### +# Prepare a Torch MLP and Convert to Relax +# ---------------------------------------- +# We start with a small PyTorch MLP so the example remains lightweight. The +# model is exported to a :py:class:`torch.export.ExportedProgram` and then +# translated into a Relax ``IRModule``. + +import tvm +from tvm import relax +from tvm.relax.frontend.torch import from_exported_program + +# Check dependencies first +IS_IN_CI = os.getenv("CI", "").lower() == "true" +HAS_TORCH = torch is not None +RUN_EXAMPLE = HAS_TORCH and not IS_IN_CI + + +if HAS_TORCH: + + class TorchMLP(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(28 * 28, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 10), + ) + + def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return self.net(data) + +else: # pragma: no cover + TorchMLP = None # type: ignore[misc, assignment] + +if RUN_EXAMPLE: + torch_model = TorchMLP().eval() + example_args = (torch.randn(1, 1, 28, 28, dtype=torch.float32),) + + with torch.no_grad(): + exported_program = export(torch_model, example_args) + + mod = from_exported_program(exported_program, keep_params_as_input=True) + + # Separate model parameters so they can be bound later (or stored on disk). + mod, params = relax.frontend.detach_params(mod) + + print("Imported Relax module:") + mod.show() + + +###################################################################### +# Build and Export with ``export_library`` +# ------------------------------------------- +# We build for ``llvm`` to generate CPU code and then export the resulting +# executable. Passing ``workspace_dir`` keeps the intermediate packaging files, +# which is useful to inspect what was produced. + +TARGET = tvm.target.Target("llvm") +ARTIFACT_DIR = Path("relax_export_artifacts") +ARTIFACT_DIR.mkdir(exist_ok=True) + +if RUN_EXAMPLE: + # Apply the default Relax compilation pipeline before building. + pipeline = relax.get_pipeline() + with TARGET: + built_mod = pipeline(mod) + + # Build without params - we'll pass them at runtime + executable = tvm.compile(built_mod, target=TARGET) + + library_path = ARTIFACT_DIR / "mlp_cpu.so" + executable.export_library(str(library_path), workspace_dir=str(ARTIFACT_DIR)) + + print(f"Exported runtime library to: {library_path}") + + # The workspace directory now contains the shared object and supporting files. + produced_files = sorted(p.name for p in ARTIFACT_DIR.iterdir()) + print("Artifacts saved:") + for name in produced_files: + print(f" - {name}") + + # Generated files: + # - ``mlp_cpu.so``: The main deployable shared library containing VM bytecode, + # compiled kernels, and constants. Note: Since parameters are passed at runtime, + # you will also need to save a separate parameters file (see next section). + # - Intermediate object files (``devc.o``, ``lib0.o``, etc.) are kept in the + # workspace for inspection but are not required for deployment. + # + # Note: Additional files like ``*.params``, ``*.metadata.json``, or ``*.imports`` + # may appear in specific configurations but are typically embedded into the + # shared library or only generated when needed. + + +###################################################################### +# Load the Exported Library and Run It +# ------------------------------------ +# Once the shared object is produced, we can reload it back into the TVM runtime +# on any machine with a compatible instruction set. The Relax VM consumes the +# runtime module directly. + +if RUN_EXAMPLE: + loaded_rt_mod = tvm.runtime.load_module(str(library_path)) + dev = tvm.cpu(0) + vm = relax.VirtualMachine(loaded_rt_mod, dev) + + # Prepare input data + input_tensor = torch.randn(1, 1, 28, 28, dtype=torch.float32) + vm_input = tvm.runtime.tensor(input_tensor.numpy(), dev) + + # Prepare parameters (allocate on target device) + vm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]] + + # Run inference: pass input data followed by all parameters + tvm_output = vm["main"](vm_input, *vm_params) + + # TVM returns Array objects for tuple outputs, access via indexing. + # For models imported from PyTorch, outputs are typically tuples (even for single outputs). + # For ONNX models, outputs may be a single Tensor directly. + if isinstance(tvm_output, tvm.ir.Array) and len(tvm_output) > 0: + result_tensor = tvm_output[0] + else: + result_tensor = tvm_output + + print("VM output shape:", result_tensor.shape) + print("VM output type:", type(tvm_output), "->", type(result_tensor)) + + # You can still inspect the executable after reloading. + print("Executable stats:\n", loaded_rt_mod["stats"]()) + + +###################################################################### +# Save Parameters for Deployment +# ------------------------------- +# Since parameters are passed at runtime (not embedded in the ``.so``), we must +# save them separately for deployment. This is a required step to use the model +# on other machines or in standalone scripts. + +import numpy as np + +if RUN_EXAMPLE: + # Save parameters to disk + params_path = ARTIFACT_DIR / "model_params.npz" + param_arrays = {f"p_{i}": p.numpy() for i, p in enumerate(params["main"])} + np.savez(str(params_path), **param_arrays) + print(f"Saved parameters to: {params_path}") + +# Note: Alternatively, you can embed parameters directly into the ``.so`` to +# create a single-file deployment. Use ``keep_params_as_input=False`` when +# importing from PyTorch: +# +# .. code-block:: python +# +# mod = from_exported_program(exported_program, keep_params_as_input=False) +# # Parameters are now embedded as constants in the module +# executable = tvm.compile(built_mod, target=TARGET) +# # Runtime: vm["main"](input) # No need to pass params! +# +# This creates a single-file deployment (only the ``.so`` is needed), but you +# lose the flexibility to swap parameters without recompiling. For most +# production workflows, separating code and parameters (as shown above) is +# preferred for flexibility. + + +###################################################################### +# Loading and Running the Exported Model +# ----------------------------------------------------------- +# To use the exported model on another machine or in a standalone script, you need +# to load both the ``.so`` library and the parameters file. Here's a complete example +# of how to reload and run the model. Save this as ``run_mlp.py``: +# +# To make it executable from the command line: +# +# .. code-block:: bash +# +# chmod +x run_mlp.py +# ./run_mlp.py # Run it like a regular program +# +# Complete script: +# +# .. code-block:: python +# +# #!/usr/bin/env python3 +# import numpy as np +# import tvm +# from tvm import relax +# +# # Step 1: Load the compiled library +# lib = tvm.runtime.load_module("relax_export_artifacts/mlp_cpu.so") +# +# # Step 2: Create Virtual Machine +# device = tvm.cpu(0) +# vm = relax.VirtualMachine(lib, device) +# +# # Step 3: Load parameters from the .npz file +# params_npz = np.load("relax_export_artifacts/model_params.npz") +# params = [tvm.runtime.tensor(params_npz[f"p_{i}"], device) +# for i in range(len(params_npz))] +# +# # Step 4: Prepare input data +# data = np.random.randn(1, 1, 28, 28).astype("float32") +# input_tensor = tvm.runtime.tensor(data, device) +# +# # Step 5: Run inference (pass input followed by all parameters) +# output = vm["main"](input_tensor, *params) +# +# # Step 6: Extract result (output may be tuple or single Tensor) +# # PyTorch models typically return tuples, ONNX models may return a single Tensor +# if isinstance(tvm_output, tvm.ir.Array) and len(tvm_output) > 0: +# result_tensor = tvm_output[0] +# else: +# result_tensor = tvm_output +# +# print("Prediction shape:", result.shape) +# print("Predicted class:", np.argmax(result.numpy())) +# +# **Running on GPU:** +# To run on GPU instead of CPU, make the following changes: +# +# 1. **Compile for GPU** (earlier in the tutorial, around line 112): +# .. code-block:: python +# +# TARGET = tvm.target.Target("cuda") # Change from "llvm" to "cuda" +# +# 2. **Use GPU device in the script**: +# .. code-block:: python +# +# device = tvm.cuda(0) # Use CUDA device instead of CPU +# vm = relax.VirtualMachine(lib, device) +# +# # Load parameters to GPU +# params = [tvm.runtime.tensor(params_npz[f"p_{i}"], device) # Note: device parameter +# for i in range(len(params_npz))] +# +# # Prepare input on GPU +# input_tensor = tvm.runtime.tensor(data, device) # Note: device parameter +# +# The rest of the script remains the same. All tensors (parameters and inputs) +# must be allocated on the same device (GPU) as the compiled model. +# +# **Deployment Checklist:** +# When moving to another host (via RPC or SCP), you must copy **both** files: +# 1. ``mlp_cpu.so`` (or ``mlp_cuda.so`` for GPU) - The compiled model code +# 2. ``model_params.npz`` - The model parameters (serialized as NumPy arrays) +# +# The remote machine needs both files in the same directory. The script above +# assumes they are in ``relax_export_artifacts/`` relative to the script location. +# Adjust the paths as needed for your deployment. For GPU deployment, ensure the +# target machine has compatible CUDA drivers and the model was compiled for the +# same GPU architecture. + + +###################################################################### +# Deploying to Remote Devices +# --------------------------- +# To deploy the exported model to a remote ARM Linux device (e.g., Raspberry Pi), +# you can use TVM's RPC mechanism to cross-compile, upload, and run the model +# remotely. This workflow is useful when: +# +# - The target device has limited resources for compilation +# - You want to fine-tune performance by running on the actual hardware +# - You need to deploy to embedded devices +# +# See :doc:`cross_compilation_and_rpc ` +# for a comprehensive guide on: +# +# - Setting up TVM runtime on the remote device +# - Starting an RPC server on the device +# - Cross-compiling for ARM targets (e.g., ``llvm -mtriple=aarch64-linux-gnu``) +# - Uploading exported libraries via RPC +# - Running inference remotely +# +# Quick example for ARM deployment workflow: +# +# .. code-block:: python +# +# import tvm.rpc as rpc +# from tvm import relax +# +# # Step 1: Cross-compile for ARM target (on local machine) +# TARGET = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu") +# executable = tvm.compile(built_mod, target=TARGET) +# executable.export_library("mlp_arm.so") +# +# # Step 2: Connect to remote device RPC server +# remote = rpc.connect("192.168.1.100", 9090) # Device IP and RPC port +# +# # Step 3: Upload the compiled library and parameters +# remote.upload("mlp_arm.so") +# remote.upload("model_params.npz") +# +# # Step 4: Load and run on remote device +# lib = remote.load_module("mlp_arm.so") +# vm = relax.VirtualMachine(lib, remote.cpu()) +# # ... prepare input and params, then run inference +# +# The key difference is using an ARM target triple during compilation and +# uploading files via RPC instead of copying them directly. + + +###################################################################### +# FAQ +# --- +# **Can I run the ``.so`` as a standalone executable (like ``./mlp_cpu.so``)?** +# No. The ``.so`` file is a shared library, not a standalone executable binary. +# You cannot run it directly from the terminal. It must be loaded through a TVM +# runtime program (as shown in the "Loading and Running" section above). The +# ``.so`` bundles VM bytecode and compiled kernels, but still requires the TVM +# runtime to execute. +# +# **Which devices can run the exported library?** +# The target must match the ISA you compiled for (``llvm`` in this example). +# As long as the target triple, runtime ABI, and available devices line up, +# you can move the artifact between machines. For heterogeneous builds (CPU +# plus GPU), ship the extra device libraries as well. +# +# **What about the ``.params`` and ``metadata.json`` files?** +# These auxiliary files are only generated in specific configurations. In this +# tutorial, since we pass parameters at runtime, they are not generated. When +# they do appear, they may be kept alongside the ``.so`` for inspection, but +# the essential content is typically embedded in the shared object itself, so +# deploying the ``.so`` alone is usually sufficient. diff --git a/docs/how_to/tutorials/optimize_llm.py b/docs/how_to/tutorials/optimize_llm.py index 8cc674920da1..0e82b055592f 100644 --- a/docs/how_to/tutorials/optimize_llm.py +++ b/docs/how_to/tutorials/optimize_llm.py @@ -489,7 +489,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Convert params into ndarray params = [ - tvm.nd.array(param_dict[k].astype("float16"), device=dev) for k in named_params.keys() + tvm.runtime.tensor(param_dict[k].astype("float16"), device=dev) for k in named_params.keys() ] @@ -523,7 +523,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I input_len = len(prompt) # Load prompt tokens into TVM ndarray on the target device - tokens = tvm.nd.array(np.array(prompt).astype("int32"), device=dev) + tokens = tvm.runtime.tensor(np.array(prompt).astype("int32"), device=dev) ###################################################################### # Create the KVCache @@ -609,7 +609,7 @@ def sample_token(logits): print("The generated token:") while last_token != tokenizer.eos_token_id: - tokens = tvm.nd.array(np.array([last_token]).astype("int32"), device=dev) + tokens = tvm.runtime.tensor(np.array([last_token]).astype("int32"), device=dev) hidden_states = embed(tokens, params) begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([1])) logits, kv_cache = vm["decode"](hidden_states, kv_cache, params) diff --git a/docs/index.rst b/docs/index.rst index 05ca8c952bc3..2b5ef6464636 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,6 +45,7 @@ driving its costs down. how_to/tutorials/customize_opt how_to/tutorials/optimize_llm how_to/tutorials/cross_compilation_and_rpc + how_to/tutorials/export_and_load_executable how_to/dev/index .. The Deep Dive content is comprehensive diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index ba2190958991..ee81f8477835 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -130,6 +130,14 @@ Once ``config.cmake`` is edited accordingly, kick off build with the commands be A success build should produce ``libtvm`` and ``libtvm_runtime`` under ``build/`` directory. +Apache TVM relies on the tvm-ffi package to support its python bindings. +Therefore, after we finish the build, we need to install the tvm-ffi package. + +.. code-block:: bash + + cd 3rdparty/tvm-ffi; pip install .; cd .. + + Leaving the build environment ``tvm-build-venv``, there are two ways to install the successful build into your environment: - Install via environment variable @@ -137,7 +145,7 @@ Leaving the build environment ``tvm-build-venv``, there are two ways to install .. code-block:: bash export TVM_HOME=/path-to-tvm - export PYTHONPATH=$TVM_HOME/python:$PYTHONPATH + export PYTHONPATH=$TVM_HOME/python:$TVM_HOME/ffi/python:$PYTHONPATH - Install via pip local project diff --git a/docs/install/index.rst b/docs/install/index.rst index b09ddb35dd45..8e4af2821edc 100644 --- a/docs/install/index.rst +++ b/docs/install/index.rst @@ -32,12 +32,4 @@ If you are interested in deploying to mobile or embedded devices, you do not nee install the entire TVM stack on your device. Instead, you only need the runtime. If you would like to quickly try out TVM or run some demo and tutorials, you -can :ref:`install from Docker `. You can also use TVM locally through ``pip``. - -.. code-block:: - - # Linux/MacOS CPU build only! - # See tlcpack.ai for other pre-built binaries including CUDA - pip install apache-tvm - -For more details on installation of pre-built binaries, visit `tlcpack.ai `_. +can :ref:`install from Docker `. diff --git a/docs/reference/api/python/index.rst b/docs/reference/api/python/index.rst index a233c69a0173..c63784781cb9 100644 --- a/docs/reference/api/python/index.rst +++ b/docs/reference/api/python/index.rst @@ -34,7 +34,6 @@ Python API :caption: tvm.runtime runtime/runtime - runtime/ndarray runtime/vm runtime/disco runtime/profiling diff --git a/docs/reference/api/python/runtime/ndarray.rst b/docs/reference/api/python/runtime/ndarray.rst deleted file mode 100644 index 8c794f04b193..000000000000 --- a/docs/reference/api/python/runtime/ndarray.rst +++ /dev/null @@ -1,21 +0,0 @@ -.. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - -.. http://www.apache.org/licenses/LICENSE-2.0 - -.. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -tvm.runtime.ndarray -------------------- -.. automodule:: tvm.runtime.ndarray - :members: diff --git a/docs/reference/api/python/runtime/runtime.rst b/docs/reference/api/python/runtime/runtime.rst index 4dd9d9653369..ae373080aeac 100644 --- a/docs/reference/api/python/runtime/runtime.rst +++ b/docs/reference/api/python/runtime/runtime.rst @@ -19,4 +19,3 @@ tvm.runtime ----------- .. automodule:: tvm.runtime :members: - :exclude-members: NDArray diff --git a/ffi/3rdparty/dlpack b/ffi/3rdparty/dlpack deleted file mode 160000 index 3ea601bb4130..000000000000 --- a/ffi/3rdparty/dlpack +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3ea601bb413074c49a77c4ce3218bc08f8c4703c diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt deleted file mode 100644 index 466571c2889f..000000000000 --- a/ffi/CMakeLists.txt +++ /dev/null @@ -1,145 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.14) - -project( - tvm_ffi - VERSION 1.0 - DESCRIPTION "TVM's FFI system" - LANGUAGES CXX C -) - -option(TVM_FFI_BUILD_TESTS "Adding test targets." OFF) -option(TVM_FFI_USE_LIBBACKTRACE "Enable libbacktrace" ON) -option(TVM_FFI_USE_EXTRA_CXX_API "Enable extra CXX API in shared lib" ON) -option(TVM_FFI_BACKTRACE_ON_SEGFAULT "Set signal handler to print traceback on segfault" ON) - -include(cmake/Utils/CxxWarning.cmake) -include(cmake/Utils/Sanitizer.cmake) -include(cmake/Utils/Library.cmake) -if (TVM_FFI_USE_LIBBACKTRACE) - include(cmake/Utils/AddLibbacktrace.cmake) -endif() - -########## Target: `dlpack_header` ########## - -add_library(dlpack_header INTERFACE) -target_include_directories(dlpack_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include") - -########## Target: `tvm_ffi_header` ########## - -add_library(tvm_ffi_header INTERFACE) -target_include_directories(tvm_ffi_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/include") -target_link_libraries(tvm_ffi_header INTERFACE dlpack_header) - -########## Target: `tvm_ffi` ########## - -set(tvm_ffi_objs_sources - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback_win.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/error.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/function.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/ndarray.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" -) - -if (TVM_FFI_USE_EXTRA_CXX_API) - list(APPEND tvm_ffi_objs_sources - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_equal.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/module.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_c_api.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/testing.cc" - ) -endif() - -add_library(tvm_ffi_objs OBJECT ${tvm_ffi_objs_sources}) - -set_target_properties( - tvm_ffi_objs PROPERTIES - POSITION_INDEPENDENT_CODE ON - CXX_STANDARD 17 - CXX_EXTENSIONS OFF - CXX_STANDARD_REQUIRED ON - CXX_VISIBILITY_PRESET hidden - VISIBILITY_INLINES_HIDDEN ON - PREFIX "lib" -) -add_cxx_warning(tvm_ffi_objs) -target_link_libraries(tvm_ffi_objs PRIVATE dlpack_header) -target_include_directories(tvm_ffi_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include") - -if (TVM_FFI_USE_LIBBACKTRACE) - message(STATUS "Setting C++ macro TVM_FFI_USE_LIBBACKTRACE - 1") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_USE_LIBBACKTRACE=1) -else() - message(STATUS "Setting C++ macro TVM_FFI_USE_LIBBACKTRACE - 0") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_USE_LIBBACKTRACE=0) -endif() - -if (TVM_FFI_BACKTRACE_ON_SEGFAULT) - message(STATUS "Setting C++ macro TVM_FFI_BACKTRACE_ON_SEGFAULT - 1") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_BACKTRACE_ON_SEGFAULT=1) -else() - message(STATUS "Setting C++ macro TVM_FFI_BACKTRACE_ON_SEGFAULT - 0") - target_compile_definitions(tvm_ffi_objs PRIVATE TVM_FFI_BACKTRACE_ON_SEGFAULT=0) -endif() - -add_target_from_obj(tvm_ffi tvm_ffi_objs) - -if (TARGET libbacktrace) - target_link_libraries(tvm_ffi_objs PRIVATE libbacktrace) - target_link_libraries(tvm_ffi_shared PRIVATE libbacktrace) - target_link_libraries(tvm_ffi_static PRIVATE libbacktrace) -endif () - -if (MSVC) - target_link_libraries(tvm_ffi_objs PRIVATE DbgHelp.lib) - target_link_libraries(tvm_ffi_shared PRIVATE DbgHelp.lib) - target_link_libraries(tvm_ffi_static PRIVATE DbgHelp.lib) -endif () - -target_link_libraries(tvm_ffi_objs PUBLIC tvm_ffi_header) -target_link_libraries(tvm_ffi_shared PUBLIC tvm_ffi_header) -target_link_libraries(tvm_ffi_static PUBLIC tvm_ffi_header) - -install(TARGETS tvm_ffi_static DESTINATION lib${LIB_SUFFIX}) -install(TARGETS tvm_ffi_shared DESTINATION lib${LIB_SUFFIX}) - -add_msvc_flags(tvm_ffi_objs) - -########## Adding tests ########## - -if (${PROJECT_NAME} STREQUAL ${CMAKE_PROJECT_NAME}) - if (TVM_FFI_BUILD_TESTS) - enable_testing() - message(STATUS "Enable Testing") - include(cmake/Utils/AddGoogleTest.cmake) - add_subdirectory(tests/cpp/) - endif() -endif () diff --git a/ffi/cmake/Utils/AddGoogleTest.cmake b/ffi/cmake/Utils/AddGoogleTest.cmake deleted file mode 100644 index 10e59386128b..000000000000 --- a/ffi/cmake/Utils/AddGoogleTest.cmake +++ /dev/null @@ -1,59 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -include(FetchContent) -set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE) -set(BUILD_GMOCK ON CACHE BOOL "" FORCE) -set(BUILD_GTEST ON CACHE BOOL "" FORCE) -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG v1.14.0 -) -FetchContent_GetProperties(googletest) -if (NOT googletest_POPULATED) - FetchContent_Populate(googletest) - message(STATUS "Found googletest_SOURCE_DIR - ${googletest_SOURCE_DIR}") - message(STATUS "Found googletest_BINARY_DIR - ${googletest_BINARY_DIR}") - add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR}) - include(GoogleTest) - set_target_properties(gtest PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - set_target_properties(gtest_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - set_target_properties(gmock PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - set_target_properties(gmock_main PROPERTIES EXPORT_COMPILE_COMMANDS OFF EXCLUDE_FROM_ALL ON FOLDER 3rdparty) - mark_as_advanced( - BUILD_GMOCK BUILD_GTEST BUILD_SHARED_LIBS - gmock_build_tests gtest_build_samples gtest_build_tests - gtest_disable_pthreads gtest_force_shared_crt gtest_hide_internal_symbols - ) -endif() - -macro(add_googletest target_name) - add_test( - NAME ${target_name} - COMMAND ${target_name} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - ) - target_link_libraries(${target_name} PRIVATE gtest_main) - gtest_discover_tests(${target_name} - WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} - DISCOVERY_MODE PRE_TEST - PROPERTIES - VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" - ) - set_target_properties(${target_name} PROPERTIES FOLDER tests) -endmacro() diff --git a/ffi/cmake/Utils/AddLibbacktrace.cmake b/ffi/cmake/Utils/AddLibbacktrace.cmake deleted file mode 100644 index 844a8816a6d8..000000000000 --- a/ffi/cmake/Utils/AddLibbacktrace.cmake +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -include(ExternalProject) - -function(_libbacktrace_compile) - set(_libbacktrace_source ${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/libbacktrace) - set(_libbacktrace_prefix ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace) - if(CMAKE_SYSTEM_NAME MATCHES "Darwin" AND (CMAKE_C_COMPILER MATCHES "^/Library" OR CMAKE_C_COMPILER MATCHES "^/Applications")) - set(_cmake_c_compiler "/usr/bin/cc") - else() - set(_cmake_c_compiler "${CMAKE_C_COMPILER}") - endif() - - message(STATUS CMAKC_C_COMPILER="${CMAKE_C_COMPILER}") - - file(MAKE_DIRECTORY ${_libbacktrace_prefix}/include) - file(MAKE_DIRECTORY ${_libbacktrace_prefix}/lib) - - ExternalProject_Add(project_libbacktrace - PREFIX libbacktrace - SOURCE_DIR ${_libbacktrace_source} - BINARY_DIR ${_libbacktrace_prefix} - CONFIGURE_COMMAND - "${_libbacktrace_source}/configure" - "--prefix=${_libbacktrace_prefix}" - --with-pic - "CC=${_cmake_c_compiler}" - "CPP=${_cmake_c_compiler} -E" - "CFLAGS=${CMAKE_C_FLAGS}" - "LDFLAGS=${CMAKE_EXE_LINKER_FLAGS}" - "NM=${CMAKE_NM}" - "STRIP=${CMAKE_STRIP}" - "--host=${MACHINE_NAME}" - INSTALL_DIR ${_libbacktrace_prefix} - BUILD_COMMAND make - INSTALL_COMMAND make install - BUILD_BYPRODUCTS "${_libbacktrace_prefix}/lib/libbacktrace.a" - "${_libbacktrace_prefix}/include/backtrace.h" - ) - ExternalProject_Add_Step(project_libbacktrace checkout DEPENDERS configure DEPENDEES download) - set_target_properties(project_libbacktrace PROPERTIES EXCLUDE_FROM_ALL TRUE) - add_library(libbacktrace STATIC IMPORTED) - add_dependencies(libbacktrace project_libbacktrace) - set_target_properties(libbacktrace PROPERTIES - IMPORTED_LOCATION ${_libbacktrace_prefix}/lib/libbacktrace.a - INTERFACE_INCLUDE_DIRECTORIES ${_libbacktrace_prefix}/include - ) -endfunction() - -if(NOT MSVC) - _libbacktrace_compile() -endif() diff --git a/ffi/cmake/Utils/Library.cmake b/ffi/cmake/Utils/Library.cmake deleted file mode 100644 index cff7ca35a28f..000000000000 --- a/ffi/cmake/Utils/Library.cmake +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -function(add_dsymutil target_name) - # running dsymutil on macos to generate debugging symbols for backtraces - if(APPLE AND TVM_FFI_USE_LIBBACKTRACE) - find_program(DSYMUTIL dsymutil) - mark_as_advanced(DSYMUTIL) - add_custom_command(TARGET ${target_name} - POST_BUILD - COMMAND ${DSYMUTIL} ARGS $ - COMMENT "[COMMAND] dsymutil $" - VERBATIM - ) - endif() -endfunction() - -function(add_msvc_flags target_name) - # running if we are under msvc - if(MSVC) - target_compile_definitions(${target_name} PUBLIC -DWIN32_LEAN_AND_MEAN) - target_compile_definitions(${target_name} PUBLIC -D_CRT_SECURE_NO_WARNINGS) - target_compile_definitions(${target_name} PUBLIC -D_SCL_SECURE_NO_WARNINGS) - target_compile_definitions(${target_name} PUBLIC -D_ENABLE_EXTENDED_ALIGNED_STORAGE) - target_compile_definitions(${target_name} PUBLIC -DNOMINMAX) - target_compile_options(${target_name} PRIVATE "/Z7") - endif() -endfunction() - -function(add_target_from_obj target_name obj_target_name) - add_library(${target_name}_static STATIC $) - set_target_properties( - ${target_name}_static PROPERTIES - OUTPUT_NAME "${target_name}_static" - PREFIX "lib" - ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - ) - add_library(${target_name}_shared SHARED $) - set_target_properties( - ${target_name}_shared PROPERTIES - OUTPUT_NAME "${target_name}" - PREFIX "lib" - ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - ) - add_custom_target(${target_name}) - add_dependencies(${target_name} ${target_name}_static ${target_name}_shared) - if (MSVC) - target_compile_definitions(${obj_target_name} PRIVATE TVM_FFI_EXPORTS) - endif() - add_dsymutil(${target_name}_shared) - add_msvc_flags(${target_name}_shared) -endfunction() diff --git a/ffi/cmake/Utils/Sanitizer.cmake b/ffi/cmake/Utils/Sanitizer.cmake deleted file mode 100644 index a20eead0c869..000000000000 --- a/ffi/cmake/Utils/Sanitizer.cmake +++ /dev/null @@ -1,35 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -function(add_sanitizer_address target_name) - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") - include(CheckCXXCompilerFlag) - set (_saved_CRF ${CMAKE_REQUIRED_FLAGS}) - set(CMAKE_REQUIRED_FLAGS "-fsanitize=address") - check_cxx_source_compiles("int main() { return 0; }" COMPILER_SUPPORTS_ASAN) - set (CMAKE_REQUIRED_FLAGS ${_saved_CRF}) - get_target_property(_saved_type ${target_name} TYPE) - if (${_saved_type} STREQUAL "INTERFACE_LIBRARY") - set(_saved_type INTERFACE) - else() - set(_saved_type PRIVATE) - endif() - target_link_options(${target_name} ${_saved_type} "-fsanitize=address") - target_compile_options(${target_name} ${_saved_type} "-fsanitize=address" "-fno-omit-frame-pointer" "-g") - return() - endif() -endfunction() diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h deleted file mode 100644 index ed34328d1e67..000000000000 --- a/ffi/include/tvm/ffi/any.h +++ /dev/null @@ -1,646 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/any.h - * \brief Any value support. - */ -#ifndef TVM_FFI_ANY_H_ -#define TVM_FFI_ANY_H_ - -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -class Any; - -namespace details { -// Helper to perform -// unsafe operations related to object -struct AnyUnsafe; -} // namespace details - -/*! - * \brief AnyView allows us to take un-managed reference view of any value. - */ -class AnyView { - protected: - /*! \brief The underlying backing data of the any object */ - TVMFFIAny data_; - // Any can see AnyView - friend class Any; - - public: - // NOTE: the following two functions uses styl style - // since they are common functions appearing in FFI. - /*! - * \brief Reset any view to None - */ - void reset() { - data_.type_index = TypeIndex::kTVMFFINone; - // invariance: always set the union padding part to 0 - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - TVM_FFI_INLINE void swap(AnyView& other) noexcept { std::swap(data_, other.data_); } - /*! \return the internal type index */ - TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } - // default constructors - AnyView() { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - ~AnyView() = default; - // constructors from any view - AnyView(const AnyView&) = default; - AnyView& operator=(const AnyView&) = default; - AnyView(AnyView&& other) : data_(other.data_) { - other.data_.type_index = TypeIndex::kTVMFFINone; - other.data_.zero_padding = 0; - other.data_.v_int64 = 0; - } - TVM_FFI_INLINE AnyView& operator=(AnyView&& other) { - // copy-and-swap idiom - AnyView(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - // constructor from general types - template ::convert_enabled>> - AnyView(const T& other) { // NOLINT(*) - TypeTraits::CopyToAnyView(other, &data_); - } - template ::convert_enabled>> - TVM_FFI_INLINE AnyView& operator=(const T& other) { // NOLINT(*) - // copy-and-swap idiom - AnyView(other).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief Try to see if we can reinterpret the AnyView to as T object. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try run type conversion (use try_cast for that purpose). - */ - template ::convert_enabled>> - TVM_FFI_INLINE std::optional as() const { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::CopyFromAnyViewAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - /* - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T* as() const { - return this->as().value_or(nullptr); - } - - /** - * \brief Cast to a type T. - * - * \tparam T The type to cast to. - * \return The casted value, or throws an exception if the cast is not possible. - */ - template ::convert_enabled>> - TVM_FFI_INLINE T cast() const { - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /*! - * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - */ - template ::convert_enabled>> - TVM_FFI_INLINE std::optional try_cast() const { - return TypeTraits::TryCastFromAnyView(&data_); - } - - // comparison with nullptr - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { - return data_.type_index == TypeIndex::kTVMFFINone; - } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { - return data_.type_index != TypeIndex::kTVMFFINone; - } - /*! - * \brief Get the type key of the Any - * \return The type key of the Any - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } - // The following functions are only used for testing purposes - /*! - * \return The underlying supporting data of any view - * \note This function is used only for testing purposes. - */ - TVM_FFI_INLINE TVMFFIAny CopyToTVMFFIAny() const { return data_; } - /*! - * \return Create an AnyView from TVMFFIAny - * \param data the underlying ffi data. - */ - TVM_FFI_INLINE static AnyView CopyFromTVMFFIAny(TVMFFIAny data) { - AnyView view; - view.data_ = data; - return view; - } -}; - -namespace details { -/*! - * \brief Helper function to inplace convert any view to any. - * \param data The pointer that represents the format as any view. - * \param extra_any_bytes Indicate that the data may contain extra bytes following - * the TVMFFIAny data structure. This is reserved for future possible optimizations - * of small-string and extended any object. - */ -TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, - [[maybe_unused]] size_t extra_any_bytes = 0) { - if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data->v_obj); - } else if (data->type_index >= TypeIndex::kTVMFFIRawStr) { - if (data->type_index == TypeIndex::kTVMFFIRawStr) { - // convert raw string to owned string object - String temp(data->v_c_str); - TypeTraits::MoveToAny(std::move(temp), data); - } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - // convert byte array to owned bytes object - Bytes temp(*static_cast(data->v_ptr)); - TypeTraits::MoveToAny(std::move(temp), data); - } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - // convert rvalue ref to owned object - Object** obj_addr = static_cast(data->v_ptr); - TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved"; - ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned(obj_addr[0])); - // set the rvalue ref to nullptr to avoid double move - obj_addr[0] = nullptr; - TypeTraits::MoveToAny(std::move(temp), data); - } - } -} -} // namespace details - -/*! - * \brief Managed Any that takes strong reference to a value. - * - * \note Develooper invariance: the TVMFFIAny data_ - * in the Any can be safely used in AnyView. - */ -class Any { - protected: - /*! \brief The underlying backing data of the any object */ - TVMFFIAny data_; - - public: - /*! - * \brief Reset any to None - */ - TVM_FFI_INLINE void reset() { - if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); - } - data_.type_index = TVMFFITypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - TVM_FFI_INLINE void swap(Any& other) noexcept { std::swap(data_, other.data_); } - /*! \return the internal type index */ - TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } - // default constructors - Any() { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - ~Any() { this->reset(); } - // constructors from Any - Any(const Any& other) : data_(other.data_) { - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); - } - } - Any(Any&& other) : data_(other.data_) { - other.data_.type_index = TypeIndex::kTVMFFINone; - other.data_.zero_padding = 0; - other.data_.v_int64 = 0; - } - TVM_FFI_INLINE Any& operator=(const Any& other) { - // copy-and-swap idiom - Any(other).swap(*this); // NOLINT(*) - return *this; - } - TVM_FFI_INLINE Any& operator=(Any&& other) { - // copy-and-swap idiom - Any(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - // convert from/to AnyView - Any(const AnyView& other) : data_(other.data_) { // NOLINT(*) - details::InplaceConvertAnyViewToAny(&data_); - } - TVM_FFI_INLINE Any& operator=(const AnyView& other) { - // copy-and-swap idiom - Any(other).swap(*this); // NOLINT(*) - return *this; - } - /*! \brief Any can be converted to AnyView in zero cost. */ - operator AnyView() const { return AnyView::CopyFromTVMFFIAny(data_); } - // constructor from general types - template ::convert_enabled>> - Any(T other) { // NOLINT(*) - TypeTraits::MoveToAny(std::move(other), &data_); - } - template ::convert_enabled>> - TVM_FFI_INLINE Any& operator=(T other) { // NOLINT(*) - // copy-and-swap idiom - Any(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - /** - * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try to run type conversion (use try_cast for that purpose). - */ - template ::storage_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional as() && { - if constexpr (std::is_same_v) { - return std::move(*this); - } else { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::MoveFromAnyAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - } - - /** - * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try to run type conversion (use try_cast for that purpose). - */ - template ::convert_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional as() const& { - if constexpr (std::is_same_v) { - return *this; - } else { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::CopyFromAnyViewAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - } - - /* - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T* as() const& { - return this->as().value_or(nullptr); - } - - /** - * \brief Cast to a type T, throw an exception if the cast is not possible. - * - * \tparam T The type to cast to. - */ - template ::convert_enabled>> - TVM_FFI_INLINE T cast() const& { - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /** - * \brief Cast to a type T, throw an exception if the cast is not possible. - * - * \tparam T The type to cast to. - */ - template ::storage_enabled>> - TVM_FFI_INLINE T cast() && { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::MoveFromAnyAfterCheck(&data_); - } - // slow path, try to do fallback convert - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /** - * \brief Try to cast to a type T. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note use STL name since it to be more consistent with cast API. - */ - template ::convert_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional try_cast() const { - if constexpr (std::is_same_v) { - return *this; - } else { - return TypeTraits::TryCastFromAnyView(&data_); - } - } - /* - * \brief Check if the two Any are same type and value in shallow comparison. - * \param other The other Any - * \return True if the two Any are same type and value, false otherwise. - */ - TVM_FFI_INLINE bool same_as(const Any& other) const noexcept { - return data_.type_index == other.data_.type_index && - data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64; - } - - /* - * \brief Check if any and ObjectRef are same type and value in shallow comparison. - * \param other The other ObjectRef - * \return True if the two Any are same type and value, false otherwise. - */ - TVM_FFI_INLINE bool same_as(const ObjectRef& other) const noexcept { - if (other.get() != nullptr) { - return (data_.type_index == other->type_index() && - reinterpret_cast(data_.v_obj) == other.get()); - } else { - return data_.type_index == TypeIndex::kTVMFFINone; - } - } - - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { - return data_.type_index == TypeIndex::kTVMFFINone; - } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { - return data_.type_index != TypeIndex::kTVMFFINone; - } - - /*! - * \brief Get the type key of the Any - * \return The type key of the Any - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } - - friend struct details::AnyUnsafe; - friend struct AnyHash; - friend struct AnyEqual; -}; - -// layout assert to ensure we can freely cast between the two types -static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); -static_assert(sizeof(Any) == sizeof(TVMFFIAny)); - -namespace details { - -template -struct Type2Str { - static std::string v() { return TypeTraitsNoCR::TypeStr(); } -}; - -template <> -struct Type2Str { - static std::string v() { return "Any"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "Any"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "AnyView"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "AnyView"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "void"; } -}; - -// Extra unsafe method to help any manipulation -struct AnyUnsafe : public ObjectUnsafe { - // FFI related operations - TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) { - TVMFFIAny result = any.data_; - any.data_.type_index = TypeIndex::kTVMFFINone; - any.data_.zero_padding = 0; - any.data_.v_int64 = 0; - return result; - } - - TVM_FFI_INLINE static Any MoveTVMFFIAnyToAny(TVMFFIAny&& data) { - Any any; - any.data_ = data; - data.type_index = TypeIndex::kTVMFFINone; - data.zero_padding = 0; - data.v_int64 = 0; - return any; - } - - template - TVM_FFI_INLINE static bool CheckAnyStrict(const Any& ref) { - return TypeTraits::CheckAnyStrict(&(ref.data_)); - } - - template - TVM_FFI_INLINE static T CopyFromAnyViewAfterCheck(const Any& ref) { - if constexpr (!std::is_same_v) { - return TypeTraits::CopyFromAnyViewAfterCheck(&(ref.data_)); - } else { - return ref; - } - } - - template - TVM_FFI_INLINE static T MoveFromAnyAfterCheck(Any&& ref) { - if constexpr (!std::is_same_v) { - return TypeTraits::MoveFromAnyAfterCheck(&(ref.data_)); - } else { - return std::move(ref); - } - } - - TVM_FFI_INLINE static Object* ObjectPtrFromAnyAfterCheck(const Any& ref) { - return reinterpret_cast(ref.data_.v_obj); - } - - TVM_FFI_INLINE static const TVMFFIAny* TVMFFIAnyPtrFromAny(const Any& ref) { - return &(ref.data_); - } - - template - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const Any& ref) { - return TypeTraits::GetMismatchTypeInfo(&(ref.data_)); - } -}; -} // namespace details - -/*! \brief String-aware Any equal functor */ -struct AnyHash { - /*! - * \brief Calculate the hash code of an Any - * \param a The given Any - * \return Hash code of a, string hash for strings and pointer address otherwise. - */ - uint64_t operator()(const Any& src) const { - if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) { - // for small string, we use the same type key hash as normal string - // so heap allocated string and on stack string will have the same hash - return details::StableHashCombine(TypeIndex::kTVMFFIStr, - details::StableHashSmallStrBytes(&src.data_)); - } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) { - // use byte the same type key as bytes - return details::StableHashCombine(TypeIndex::kTVMFFIBytes, - details::StableHashSmallStrBytes(&src.data_)); - } else if (src.data_.type_index == TypeIndex::kTVMFFIStr || - src.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* src_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(src); - return details::StableHashCombine(src.data_.type_index, - details::StableHashBytes(src_str->data, src_str->size)); - } else { - return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64); - } - } -}; - -/*! \brief String-aware Any hash functor */ -struct AnyEqual { - /*! - * \brief Check if the two Any are equal - * \param lhs left operand. - * \param rhs right operand - * \return String equality if both are strings, pointer address equality otherwise. - */ - bool operator()(const Any& lhs, const Any& rhs) const { - // header with type index - const int64_t* lhs_as_int64 = reinterpret_cast(&lhs.data_); - const int64_t* rhs_as_int64 = reinterpret_cast(&rhs.data_); - static_assert(sizeof(TVMFFIAny) == 16); - // fast path, check byte equality - if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) { - return true; - } - // common false case type index match, in this case we only need to pay attention to string - // equality - if (lhs.data_.type_index == rhs.data_.type_index) { - // specialy handle string hash - if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || - lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - const details::BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); - } - return false; - } else { - // type_index mismatch, if index is not string, return false - if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr && - lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) { - return false; - } - // small string and normal string comparison - if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) { - const details::BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size, - rhs.data_.small_str_len); - } - if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) { - const details::BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len, - rhs_str->size); - } - if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) { - const details::BytesObjBase* lhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size, - rhs.data_.small_str_len); - } - if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) { - const details::BytesObjBase* rhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len, - rhs_bytes->size); - } - return false; - } - } -}; -} // namespace ffi - -// Expose to the tvm namespace for usability -// Rationale: no ambiguity even in root -using tvm::ffi::Any; -using tvm::ffi::AnyView; - -} // namespace tvm -#endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h deleted file mode 100644 index 7c96b091d761..000000000000 --- a/ffi/include/tvm/ffi/base_details.h +++ /dev/null @@ -1,271 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/base_details.h - * \brief Internal detail utils that can be used by files in tvm/ffi. - * \note details header are for internal use only - * and not to be directly used by user. - */ -#ifndef TVM_FFI_BASE_DETAILS_H_ -#define TVM_FFI_BASE_DETAILS_H_ - -#include -#include - -#include -#include - -#if defined(_MSC_VER) -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif - -#ifndef NOMINMAX -#define NOMINMAX -#endif - -#include - -#ifdef ERROR -#undef ERROR -#endif - -#endif - -#if defined(_MSC_VER) -#define TVM_FFI_INLINE [[msvc::forceinline]] inline -#else -#define TVM_FFI_INLINE [[gnu::always_inline]] inline -#endif - -/*! - * \brief Macro helper to force a function not to be inlined. - * It is only used in places that we know not inlining is good, - * e.g. some logging functions. - */ -#if defined(_MSC_VER) -#define TVM_FFI_NO_INLINE [[msvc::noinline]] -#else -#define TVM_FFI_NO_INLINE [[gnu::noinline]] -#endif - -#if defined(_MSC_VER) -#define TVM_FFI_UNREACHABLE() __assume(false) -#else -#define TVM_FFI_UNREACHABLE() __builtin_unreachable() -#endif - -/*! \brief helper macro to suppress unused warning */ -#define TVM_FFI_ATTRIBUTE_UNUSED [[maybe_unused]] - -#define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y -#define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y) - -#if defined(__GNUC__) || defined(__clang__) -#define TVM_FFI_FUNC_SIG __PRETTY_FUNCTION__ -#elif defined(_MSC_VER) -#define TVM_FFI_FUNC_SIG __FUNCSIG__ -#else -#define TVM_FFI_FUNC_SIG __func__ -#endif - -#define TVM_FFI_STATIC_INIT_BLOCK_VAR_DEF \ - TVM_FFI_ATTRIBUTE_UNUSED static inline int __##TVMFFIStaticInitReg - -/*! \brief helper macro to run code once during initialization */ -#define TVM_FFI_STATIC_INIT_BLOCK(Body) \ - TVM_FFI_STR_CONCAT(TVM_FFI_STATIC_INIT_BLOCK_VAR_DEF, __COUNTER__) = []() { Body return 0; }() - -/* - * \brief Define the default copy/move constructor and assign operator - * \param TypeName The class typename. - */ -#define TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - TypeName(const TypeName& other) = default; \ - TypeName(TypeName&& other) = default; \ - TypeName& operator=(const TypeName& other) = default; \ - TypeName& operator=(TypeName&& other) = default; - -/** - * \brief marks the begining of a C call that logs exception - */ -#define TVM_FFI_LOG_EXCEPTION_CALL_BEGIN() \ - try { \ - (void)0 - -/*! - * \brief Marks the end of a C call that logs exception - */ -#define TVM_FFI_LOG_EXCEPTION_CALL_END(Name) \ - } \ - catch (const std::exception& err) { \ - std::cerr << "Exception caught during " << #Name << ":\n" << err.what() << std::endl; \ - exit(-1); \ - } - -/*! - * \brief Clear the padding parts so we can safely use v_int64 for hash - * and equality check even when the value stored is a pointer. - * - * This macro is used to clear the padding parts for hash and equality check - * in 32bit platform. - */ -#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ - if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \ - (result)->v_int64 = 0; \ - } - -namespace tvm { -namespace ffi { -namespace details { - -// for each iterator -struct for_each_dispatcher { - template - static void run(std::index_sequence, const F& f, Args&&... args) { // NOLINT(*) - (f(I, std::forward(args)), ...); - } -}; - -template -void for_each(const F& f, Args&&... args) { // NOLINT(*) - for_each_dispatcher::run(std::index_sequence_for{}, f, std::forward(args)...); -} - -/*! - * \brief hash an object and combines uint64_t key with previous keys - * - * This hash function is stable across platforms. - * - * \param key The left operand. - * \param value The right operand. - * \return the combined result. - */ -template ::value, bool> = true> -TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T& value) { - // XXX: do not use std::hash in this function. This hash must be stable - // across different platforms and std::hash is implementation dependent. - return key ^ (uint64_t(value) + 0x9e3779b9 + (key << 6) + (key >> 2)); -} - -/*! - * \brief Hash the binary bytes - * \param data The data pointer - * \param size The size of the bytes. - * \return the hash value. - */ -TVM_FFI_INLINE uint64_t StableHashBytes(const void* data_ptr, size_t size) { - const char* data = reinterpret_cast(data_ptr); - const constexpr uint64_t kMultiplier = 1099511628211ULL; - const constexpr uint64_t kMod = 2147483647ULL; - union Union { - uint8_t a[8]; - uint64_t b; - } u; - static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); - const char* it = data; - const char* end = it + size; - uint64_t result = 0; - if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { - // if alignment requirement is met, directly use load - if (reinterpret_cast(it) % 8 == 0) { - for (; it + 8 <= end; it += 8) { - u.b = *reinterpret_cast(it); - result = (result * kMultiplier + u.b) % kMod; - } - } else { - // unaligned version - for (; it + 8 <= end; it += 8) { - u.a[0] = it[0]; - u.a[1] = it[1]; - u.a[2] = it[2]; - u.a[3] = it[3]; - u.a[4] = it[4]; - u.a[5] = it[5]; - u.a[6] = it[6]; - u.a[7] = it[7]; - result = (result * kMultiplier + u.b) % kMod; - } - } - } else { - // need endian swap - for (; it + 8 <= end; it += 8) { - u.a[0] = it[7]; - u.a[1] = it[6]; - u.a[2] = it[5]; - u.a[3] = it[4]; - u.a[4] = it[3]; - u.a[5] = it[2]; - u.a[6] = it[1]; - u.a[7] = it[0]; - result = (result * kMultiplier + u.b) % kMod; - } - } - - if (it < end) { - u.b = 0; - uint8_t* a = u.a; - if (it + 4 <= end) { - a[0] = it[0]; - a[1] = it[1]; - a[2] = it[2]; - a[3] = it[3]; - it += 4; - a += 4; - } - if (it + 2 <= end) { - a[0] = it[0]; - a[1] = it[1]; - it += 2; - a += 2; - } - if (it + 1 <= end) { - a[0] = it[0]; - it += 1; - a += 1; - } - if constexpr (!TVM_FFI_IO_NO_ENDIAN_SWAP) { - std::swap(u.a[0], u.a[7]); - std::swap(u.a[1], u.a[6]); - std::swap(u.a[2], u.a[5]); - std::swap(u.a[3], u.a[4]); - } - result = (result * kMultiplier + u.b) % kMod; - } - return result; -} - -/*! - * \brief Same as StableHashBytes, but for small string data. - * \param data The data pointer - * \return the hash value. - */ -TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const TVMFFIAny* data) { - if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { - // fast path, no endian swap, simply hash as uint64_t - const constexpr uint64_t kMod = 2147483647ULL; - return data->v_uint64 % kMod; - } - return StableHashBytes(reinterpret_cast(data), sizeof(data->v_uint64)); -} - -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_BASE_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h deleted file mode 100644 index 39b7de69fa75..000000000000 --- a/ffi/include/tvm/ffi/c_api.h +++ /dev/null @@ -1,968 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/c_api.h - * \brief This file defines the C convention of the FFI convention - */ -#ifndef TVM_FFI_C_API_H_ -#define TVM_FFI_C_API_H_ - -#include -#include - -// Macros to do weak linking -#ifdef _MSC_VER -#define TVM_FFI_WEAK __declspec(selectany) -#else -#define TVM_FFI_WEAK __attribute__((weak)) -#endif - -// Defines two macros -// TVM_FFI_DLL: marks the function as a DLL export/import -// depending on whether TVM_FFI_EXPORTS is defined -// TVM_FFI_DLL_EXPORT: always marks the function as a DLL export -#if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) -#include -#define TVM_FFI_DLL EMSCRIPTEN_KEEPALIVE -#define TVM_FFI_DLL_EXPORT EMSCRIPTEN_KEEPALIVE -#endif -#if !defined(TVM_FFI_DLL) && defined(_MSC_VER) -#ifdef TVM_FFI_EXPORTS -#define TVM_FFI_DLL __declspec(dllexport) -#else -#define TVM_FFI_DLL __declspec(dllimport) -#endif -#define TVM_FFI_DLL_EXPORT __declspec(dllexport) -#endif -#ifndef TVM_FFI_DLL -#define TVM_FFI_DLL __attribute__((visibility("default"))) -#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default"))) -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __cplusplus -enum TVMFFITypeIndex : int32_t { -#else -typedef enum { -#endif - - /* - * \brief The root type of all FFI objects. - * - * We include it so TypeIndex captures all possible runtime values. - * `kTVMFFIAny` code will never appear in Any::type_index. - * However, it may appear in field annotations during reflection. - */ - kTVMFFIAny = -1, - // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) - // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, - // which is not owned by TVMFFIAny. It is required that the following - // invariant holds: - // - `Any::type_index` is never `kTVMFFIRawStr` - // - `AnyView::type_index` can be `kTVMFFIRawStr` - // - /*! \brief None/nullptr value */ - kTVMFFINone = 0, - /*! \brief POD int value */ - kTVMFFIInt = 1, - /*! \brief POD bool value */ - kTVMFFIBool = 2, - /*! \brief POD float value */ - kTVMFFIFloat = 3, - /*! \brief Opaque pointer object */ - kTVMFFIOpaquePtr = 4, - /*! \brief DLDataType */ - kTVMFFIDataType = 5, - /*! \brief DLDevice */ - kTVMFFIDevice = 6, - /*! \brief DLTensor* */ - kTVMFFIDLTensorPtr = 7, - /*! \brief const char* */ - kTVMFFIRawStr = 8, - /*! \brief TVMFFIByteArray* */ - kTVMFFIByteArrayPtr = 9, - /*! \brief R-value reference to ObjectRef */ - kTVMFFIObjectRValueRef = 10, - /*! \brief Small string on stack */ - kTVMFFISmallStr = 11, - /*! \brief Small bytes on stack */ - kTVMFFISmallBytes = 12, - /*! \brief Start of statically defined objects. */ - kTVMFFIStaticObjectBegin = 64, - /*! - * \brief Object, all objects starts with TVMFFIObject as its header. - * \note We will also add other fields - */ - kTVMFFIObject = 64, - /*! - * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } - */ - kTVMFFIStr = 65, - /*! - * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } - */ - kTVMFFIBytes = 66, - /*! \brief Error object. */ - kTVMFFIError = 67, - /*! \brief Function object. */ - kTVMFFIFunction = 68, - /*! \brief Array object. */ - kTVMFFIArray = 69, - /*! \brief Map object. */ - kTVMFFIMap = 70, - /*! - * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } - */ - kTVMFFIShape = 71, - /*! - * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... } - */ - kTVMFFINDArray = 72, - /*! \brief Runtime module object. */ - kTVMFFIModule = 73, - kTVMFFIStaticObjectEnd, - // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) - /*! \brief Start of type indices that are allocated at runtime. */ - kTVMFFIDynObjectBegin = 128 -#ifdef __cplusplus -}; -#else -} TVMFFITypeIndex; -#endif - -/*! \brief Handle to Object from C API's pov */ -typedef void* TVMFFIObjectHandle; - -/*! - * \brief C-based type of all FFI object header that allocates on heap. - * \note TVMFFIObject and TVMFFIAny share the common type_index header - */ -typedef struct TVMFFIObject { - /*! - * \brief type index of the object. - * \note The type index of Object and Any are shared in FFI. - */ - int32_t type_index; - /*! \brief Reference counter of the object. */ - int32_t ref_counter; - union { - /*! \brief Deleter to be invoked when reference counter goes to zero. */ - void (*deleter)(struct TVMFFIObject* self); - /*! - * \brief auxilary field to TVMFFIObject is always 8 bytes aligned. - * \note This helps us to ensure cross platform compatibility. - */ - int64_t __ensure_align; - }; -} TVMFFIObject; - -/*! - * \brief C-based type of all on stack Any value. - * - * Any value can hold on stack values like int, - * as well as reference counted pointers to object. - */ -typedef struct TVMFFIAny { - /*! - * \brief type index of the object. - * \note The type index of Object and Any are shared in FFI. - */ - int32_t type_index; - union { // 4 bytes - /*! \brief padding, must set to zero for values other than small string. */ - uint32_t zero_padding; - /*! - * \brief Length of small string, with a max value of 7. - * - * We keep small str to start at next 4 bytes to ensure alignment - * when accessing the small str content. - */ - uint32_t small_str_len; - }; - union { // 8 bytes - int64_t v_int64; // integers - double v_float64; // floating-point numbers - void* v_ptr; // typeless pointers - const char* v_c_str; // raw C-string - TVMFFIObject* v_obj; // ref counted objects - DLDataType v_dtype; // data type - DLDevice v_device; // device - char v_bytes[8]; // small string - char32_t v_char32[2]; // small UCS4 string and Unicode - uint64_t v_uint64; // uint64 repr mainly used for hashing - }; -} TVMFFIAny; - -/*! - * \brief Byte array data structure used by String and Bytes. - * - * String and Bytes object layout = { TVMFFIObject, TVMFFIByteArray, ... } - * - * \note This byte array data structure layout differs in 32/64 bit platforms. - * as size_t equals to the size of the pointer, use this convetion to - * be consistent with std::string and also avoid need to calculate padding - * for the size field on 32-bit platforms. - * The FFI binding should be careful when treating this ABI. - */ -typedef struct { - const char* data; - size_t size; -} TVMFFIByteArray; - -/*! - * \brief Shape cell used in shape object following header. - */ -typedef struct { - const int64_t* data; - size_t size; -} TVMFFIShapeCell; - -/*! - * \brief Error cell used in error object following header. - */ -typedef struct { - /*! \brief The kind of the error. */ - TVMFFIByteArray kind; - /*! \brief The message of the error. */ - TVMFFIByteArray message; - /*! - * \brief The traceback of the error. - */ - TVMFFIByteArray traceback; - /*! - * \brief Function handle to update the traceback of the error. - * \param self The self object handle. - * \param traceback The traceback to update. - */ - void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback); -} TVMFFIErrorCell; - -/*! - * \brief Type that defines C-style safe call convention - * - * Safe call explicitly catches exception on function boundary. - * - * \param handle The function handle - * \param num_args Number of input arguments - * \param args The input arguments to the call. - * \param result Store output result. - * - * IMPORTANT: caller must initialize result->type_index to be kTVMFFINone, - * or any other value smaller than kTVMFFIStaticObjectBegin. - * - * \return The call returns 0 if call is successful. - * It returns non-zero value if there is an error. - * - * Possible return error of the API functions: - * * 0: success - * * -1: error happens, can be retrieved by TVMFFIErrorMoveFromRaised - * * -2: a frontend error occurred and recorded in the frontend. - * - * \note We decided to leverage TVMFFIErrorMoveFromRaised and TVMFFIErrorSetRaised - * for C function error propagation. This design choice, while - * introducing a dependency for TLS runtime, simplifies error - * propgation in chains of calls in compiler codegen. - * As we do not need to propagate error through argument but simply - * set them in the runtime environment. - * - * \sa TVMFFIErrorMoveFromRaised - * \sa TVMFFIErrorSetRaised - * \sa TVMFFIErrorSetRaisedFromCStr - */ -typedef int (*TVMFFISafeCallType)(void* handle, const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result); - -/*! - * \brief Object cell for function object following header. - */ -typedef struct { - /*! \brief A C API compatible call with exception catching. */ - TVMFFISafeCallType safe_call; -} TVMFFIFunctionCell; - -//------------------------------------------------------------ -// Section: Basic object API -//------------------------------------------------------------ -/*! - * \brief Free an object handle by decreasing reference - * \param obj The object handle. - * \note Internally we decrease the reference counter of the object. - * The object will be freed when every reference to the object are removed. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_tindex the corresponding type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); - -//----------------------------------------------------------------------- -// Section: Basic function calling API for function implementation -//----------------------------------------------------------------------- -/*! - * \brief Create a FFIFunc by passing in callbacks from C callback. - * - * The registered function then can be pulled by the backend by the name. - * - * \param self The resource handle of the C callback. - * \param safe_call The C callback implementation - * \param deleter deleter to recycle - * \param out The output of the function. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void* self), TVMFFIObjectHandle* out); - -/*! - * \brief Get a global function registered in system. - * - * \param name The name of the function. - * \param out the result function pointer, NULL if it does not exist. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); - -/*! - * \brief Convert a AnyView to an owned Any. - * \param any The AnyView to convert. - * \param out The output Any, must be an empty object - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); - -/*! - * \brief Call a FFIFunc by passing in arguments. - * - * \param func The resource handle of the C callback. - * \param args The input arguments to the call. - * \param num_args The number of input arguments. - * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result); - -/*! - * \brief Move the last error from the environment to result. - * - * \param result The result error. - * - * \note This function clears the error stored in the TLS. - */ -TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result); - -/*! - * \brief Set raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. - * - * \param error The error object handle - */ -TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); - -/*! - * \brief Set raised error in TLS, which can be fetched by TVMFFIMoveFromRaised. - * - * \param kind The kind of the error. - * \param message The error message. - * \note This is a convenient method for C API side to set error directly from string. - */ -TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message); - -/*! - * \brief Create an initial error object. - * - * \param kind The kind of the error. - * \param message The error message. - * \param traceback The traceback of the error. - * \return The created error object handle. - * \note This function is different from other functions as it is used in error handling loop. - * So we do not follow normal error handling patterns via returning error code. - */ -TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, - const TVMFFIByteArray* message, - const TVMFFIByteArray* traceback); - -//------------------------------------------------------------ -// Section: DLPack support APIs -//------------------------------------------------------------ -/*! - * \brief Produce a managed NDArray from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output NDArray handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t require_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out); - -/*! - * \brief Produce a DLMangedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out); - -/*! - * \brief Produce a managed NDArray from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output NDArray handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, - int32_t require_alignment, - int32_t require_contiguous, - TVMFFIObjectHandle* out); - -/*! - * \brief Produce a DLMangedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, - DLManagedTensorVersioned** out); - -//--------------------------------------------------------------- -// Section: dtype string support APIs. -// These APIs are used to simplify the dtype printings during FFI -//--------------------------------------------------------------- - -/*! - * \brief Convert a string to a DLDataType. - * \param str The string to convert. - * \param out The output DLDataType. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); - -/*! -* \brief Convert a DLDataType to a string. -* \param dtype The DLDataType to convert. -* \param out The output string. -* \return 0 when success, nonzero when failure happens -* \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. -The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. - -* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. -*/ -TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out); - -//------------------------------------------------------------ -// Section: Type reflection support APIs -// -// The reflec -//------------------------------------------------------------ -/*! - * \brief Getter that can take address of a field and set the result. - * \param field The raw address of the field. - * \param result Stores the result. - * \return 0 when success, nonzero when failure happens - */ -typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result); - -/*! - * \brief Getter that can take address of a field and set to value. - * \param field The raw address of the field. - * \param value The value to set. - * \return 0 when success, nonzero when failure happens - */ -typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value); - -/*! - * \brief Function that create a new instance of the type. - * \param result The new object handle - * \return 0 when success, nonzero when failure happens - */ -typedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result); - -/*! - * \brief bitmask of the field. - */ -#ifdef __cplusplus -enum TVMFFIFieldFlagBitMask : int32_t { -#else -typedef enum { -#endif - /*! \brief The field is writable. */ - kTVMFFIFieldFlagBitMaskWritable = 1 << 0, - /*! \brief The field has default value. */ - kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1, - /*! \brief The field is a static method. */ - kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2, - /*! - * \brief The field should be ignored when performing structural eq/hash - * - * This is an optional meta-data for structural eq/hash. - */ - kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3, - /*! - * \brief The field enters a def region where var can be defined/matched. - * - * This is an optional meta-data for structural eq/hash. - */ - kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4, -#ifdef __cplusplus -}; -#else -} TVMFFIFieldFlagBitMask; -#endif - -/*! - * \brief Optional meta-data for structural eq/hash. - * - * This meta-data is only useful when we want to leverage the information - * to perform richer semantics aware structural comparison and hash. - * It can be safely ignored if such information is not needed. - * - * The meta-data record comparison method in tree node and DAG node. - * - * \code - * x = VarNode() - * v0 = AddNode(x, 1) - * v1 = AddNode(x, 1) - * v2 = AddNode(v0, v0) - * v3 = AddNode(v1, v0) - * \endcode - * - * Consider the construct sequence of AddNode below, - * if AddNode is treated as a tree node, then v2 and v3 - * structural equals to each other, but if AddNode is - * treated as a DAG node, then v2 and v3 does not - * structural equals to each other. - */ -#ifdef __cplusplus -enum TVMFFISEqHashKind : int32_t { -#else -typedef enum { -#endif - /*! \brief Do not support structural eq/hash. */ - kTVMFFISEqHashKindUnsupported = 0, - /*! - * \brief The object be compared as a tree node. - */ - kTVMFFISEqHashKindTreeNode = 1, - /*! - * \brief The object is treated as a free variable that can be mapped - * to another free variable in the definition region. - */ - kTVMFFISEqHashKindFreeVar = 2, - /*! - * \brief The field should be compared as a DAG node. - */ - kTVMFFISEqHashKindDAGNode = 3, - /*! - * \brief The object is treated as a constant tree node. - * - * Same as tree node, but the object does not contain free var - * as any of its nested children. - * - * That means we can use pointer equality for equality. - */ - kTVMFFISEqHashKindConstTreeNode = 4, - /*! - * \brief One can simply use pointer equality for equality. - * - * This is useful for "singleton"-style object that can - * is only an unique copy of each value. - */ - kTVMFFISEqHashKindUniqueInstance = 5, -#ifdef __cplusplus -}; -#else -} TVMFFISEqHashKind; -#endif - -/*! - * \brief Information support for optional object reflection. - */ -typedef struct { - /*! \brief The name of the field. */ - TVMFFIByteArray name; - /*! \brief The docstring about the field. */ - TVMFFIByteArray doc; - /*! \brief The type schema of the field in JSON string. */ - TVMFFIByteArray type_schema; - /*! - * \brief bitmask flags of the field. - */ - int64_t flags; - /*! \brief The size of the field. */ - int64_t size; - /*! \brief The alignment of the field. */ - int64_t alignment; - /*! \brief The offset of the field. */ - int64_t offset; - /*! \brief The getter to access the field. */ - TVMFFIFieldGetter getter; - /*! - * \brief The setter to access the field. - * \note The setter is set even if the field is readonly for serialization. - */ - TVMFFIFieldSetter setter; - /*! - * \brief The default value of the field, this field hold AnyView, - * valid when flags set kTVMFFIFieldFlagBitMaskHasDefault - */ - TVMFFIAny default_value; - /*! - * \brief Records the static type kind of the field. - * - * Possible values: - * - * - TVMFFITypeIndex::kTVMFFIObject for general objects - * - The value is nullable when kTVMFFIObject is chosen - * - static object type kinds such as Map, Dict, String - * - POD type index, note it does not give information about storage size of the field. - * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info - * about the field. - * - * When the value is a type index of Object type, the field is storaged as an ObjectRef. - * - * \note This information maybe helpful in designing serializer. - * As it helps to narrow down the field type so we don't have to - * print type_key for cases like POD types. - * It also helps to provide opportunities to enable short-cut getter to ObjectRef fields. - */ - int32_t field_static_type_index; -} TVMFFIFieldInfo; - -/*! - * \brief Method information that can appear in reflection table. - */ -typedef struct { - /*! \brief The name of the field. */ - TVMFFIByteArray name; - /*! \brief The docstring about the method. */ - TVMFFIByteArray doc; - /*! \brief Optional type schema of the method in JSON string. */ - TVMFFIByteArray type_schema; - /*! \brief bitmask flags of the method. */ - int64_t flags; - /*! - * \brief The method wrapped as ffi::Function, stored as AnyView. - * \note The first argument to the method is always the self for instance methods. - */ - TVMFFIAny method; -} TVMFFIMethodInfo; - -/*! - * \brief Extra information of object type that can be used for reflection. - * - * \note This information is optional and can be used to enable reflection based - * creation of the object. - */ -typedef struct { - /*! \brief The docstring about the object. */ - TVMFFIByteArray doc; - /*! - * \brief An optional function that can create a new empty instance of the type. - * - * When known_fixed_size is non-zero, creator can be called - * with nullptr passed to optional_bytes. - * - * \note Caller must call setter for each field to initialize the object for - * the final object to be in valid state. - * - * \note This field is optional to enable reflection based creation. - */ - TVMFFIObjectCreator creator; - /*! - * \brief Total size of the object struct, if it is fixed and known. - * - * This field is set optional and set to 0 if not registered. - */ - int32_t total_size; - /*! - * \brief Optional meta-data for structural eq/hash. - */ - TVMFFISEqHashKind structural_eq_hash_kind; -} TVMFFITypeMetadata; - -/* - * \brief Column array that stores extra attributes about types - * - * The attributes stored in a column array that can be looked up by type index. - * Note that the TypeAttr behaves like type_traits so column[T] so not contain - * attributes from base classes. - * - * \note - * \sa TVMFFIRegisterTypeAttr - */ -typedef struct { - /*! \brief The data of the column. */ - const TVMFFIAny* data; - /*! \brief The size of the column. */ - size_t size; -} TVMFFITypeAttrColumn; - -/*! - * \brief Runtime type information for object type checking. - */ -typedef struct TVMFFITypeInfo { - /*! - *\brief The runtime type index, - * It can be allocated during runtime if the type is dynamic. - */ - int32_t type_index; - /*! \brief number of parent types in the type hierachy. */ - int32_t type_depth; - /*! \brief the unique type key to identify the type. */ - TVMFFIByteArray type_key; - /*! - * \brief type_acenstors[depth] stores the type_index of the acenstors at depth level - * \note To keep things simple, we do not allow multiple inheritance so the - * hieracy stays as a tree - */ - const struct TVMFFITypeInfo** type_acenstors; - // The following fields are used for reflection - /*! \brief Cached hash value of the type key, used for consistent structural hashing. */ - uint64_t type_key_hash; - /*! \brief number of reflection accessible fields. */ - int32_t num_fields; - /*! \brief number of reflection acccesible methods. */ - int32_t num_methods; - /*! \brief The reflection field information. */ - const TVMFFIFieldInfo* fields; - /*! \brief The reflection method. */ - const TVMFFIMethodInfo* methods; - /*! \brief The extra information of the type. */ - const TVMFFITypeMetadata* metadata; -} TVMFFITypeInfo; - -/*! - * \brief Register the function to runtime's global table. - * - * The registered function then can be pulled by the backend by the name. - * - * \param name The name of the function. - * \param f The function to be registered. - * \param override Whether allow override already registered function. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, - int override); - -/*! - * \brief Register the function to runtime's global table with method info. - * - * This is same as TVMFFIFunctionSetGlobal but with method info that can provide extra - * metadata used in the runtime. - * - * \param method_info The method info to be registered. - * \param override Whether allow override already registered function. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo* method_info, - int override); - -/*! - * \brief Register type field information for runtime reflection. - * \param type_index The type index - * \param info The field info to be registered. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info); - -/*! - * \brief Register type method information for runtime reflection. - * \param type_index The type index - * \param info The method info to be registered. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info); - -/*! - * \brief Register type creator information for runtime reflection. - * \param type_index The type index - * \param metadata The extra information to be registered. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata); - -/*! - * \brief Register extra type attributes that can be looked up during runtime. - * \param type_index The type index - * \param attr_value The attribute value to be registered. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* attr_name, - const TVMFFIAny* attr_value); - -/*! - * \brief Get the type attribute column by name. - * \param attr_name The name of the attribute. - * \return The pointer to the type attribute column. - * \return NULL if the attribute was not registered in the system - */ -TVM_FFI_DLL const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* attr_name); - -//------------------------------------------------------------ -// Section: Backend noexcept functions for internal use -// -// These functions are used internally and do not throw error -// instead the error will be logged and abort the process -// These are function are being called in startup or exit time -// so exception handling do not apply -//------------------------------------------------------------ -/*! - * \brief Get stack traceback in a string. - * \param filename The current file name. - * \param lineno The current line number - * \param func The current function - * \return The traceback string - * - * \note filename func and lino are only used as a backup info, most cases they are not needed. - * The return value is set to const char* to be more compatible across dll boundaries. - */ -TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, - const char* func); - -/*! - * \brief Initialize the type info during runtime. - * - * When the function is first time called for a type, - * it will register the type to the type table in the runtime. - * - * If the static_tindex is non-negative, the function will - * allocate a runtime type index. - * Otherwise, we will populate the type table and return the static index. - * - * \param type_key The type key. - * \param static_type_index Static type index if any, can be -1, which means this is a dynamic index - * \param num_child_slots Number of slots reserved for its children. - * \param child_slots_can_overflow Whether to allow child to overflow the slots. - * \param parent_type_index Parent type index, pass in -1 if it is root. - * \param result The output type index - * - * \return 0 if success, -1 if error occured - */ -TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, - int32_t static_type_index, int32_t type_depth, - int32_t num_child_slots, - int32_t child_slots_can_overflow, - int32_t parent_type_index); - -/*! - * \brief Get dynamic type info by type index. - * - * \param type_index The type index - * \param result The output type information - * \return The type info - */ -TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); - -#ifdef __cplusplus -} // TVM_FFI_EXTERN_C -#endif - -//--------------------------------------------------------------- -// The following API defines static object attribute accessors -// for language bindings. -// -// They are defined in C++ inline functions for cleaner code. -// Note that they only have to do with address offset computation. -// So they can always be reimplemented in bindings when c++ is -// not available or when binding only wants to refer to the dll. -//---------------------------------------------------------------- -#ifdef __cplusplus -/*! - * \brief Get the type index of an object. - * \param obj The object handle. - * \return The type index. - */ -inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) { - return static_cast(obj)->type_index; -} - -/*! - * \brief Get the content of a small string in bytearray format. - * \param obj The object handle. - * \return The content of the small string in bytearray format. - */ -inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) { - return TVMFFIByteArray{value->v_bytes, static_cast(value->small_str_len)}; -} - -/*! - * \brief Get the data pointer of a bytearray from a string or bytes object. - * \param obj The object handle. - * \return The data pointer. - */ -inline TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a ErrorInfo from an Error object. - * \param obj The object handle. - * \return The data pointer. - */ -inline TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a function cell from a function object. - * \param obj The object handle. - * \return The data pointer. - */ -inline TVMFFIFunctionCell* TVMFFIFunctionGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a shape array from a shape object. - * \param obj The object handle. - * \return The data pointer. - */ -inline TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the DLTensor pointer from an NDArray object. - * \param obj The object handle. - * \return The DLTensor pointer. - */ -inline DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Create a DLDevice from a device type and device id. - * \param device_type The device type. - * \param device_id The device id. - * \return The DLDevice. - */ -inline DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) { - return DLDevice{static_cast(device_type), device_id}; -} -#endif // __cplusplus -#endif // TVM_FFI_C_API_H_ diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h deleted file mode 100644 index c75d4a075f97..000000000000 --- a/ffi/include/tvm/ffi/cast.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/cast.h - * \brief Extra value casting helpers - */ -#ifndef TVM_FFI_CAST_H_ -#define TVM_FFI_CAST_H_ - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Get a reference type from a raw object ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the object alive beyond the scope of the function. - * - * \param ptr The object pointer - * \tparam RefType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline RefType GetRef(const ObjectType* ptr) { - static_assert(std::is_base_of_v, - "Can only cast to the ref of same container type"); - - if constexpr (is_optional_type_v || RefType::_type_is_nullable) { - if (ptr == nullptr) { - return RefType(ObjectPtr(nullptr)); - } - } else { - TVM_FFI_ICHECK_NOTNULL(ptr); - } - return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned( - const_cast(static_cast(ptr)))); -} - -/*! - * \brief Get an object ptr type from a raw object ptr. - * - * \param ptr The object pointer - * \tparam BaseType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline ObjectPtr GetObjectPtr(ObjectType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); -} -} // namespace ffi - -using ffi::GetObjectPtr; -using ffi::GetRef; -} // namespace tvm -#endif // TVM_FFI_CAST_H_ diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h deleted file mode 100644 index 180c870ccbb6..000000000000 --- a/ffi/include/tvm/ffi/container/array.h +++ /dev/null @@ -1,1086 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/array.h - * \brief Array type. - * - * tvm::ffi::Array is an erased type that contains list of content - */ -#ifndef TVM_FFI_CONTAINER_ARRAY_H_ -#define TVM_FFI_CONTAINER_ARRAY_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! \brief array node content in array */ -class ArrayObj : public Object, public details::InplaceArrayBase { - public: - ~ArrayObj() { - Any* begin = MutableBegin(); - for (int64_t i = 0; i < size_; ++i) { - (begin + i)->Any::~Any(); - } - if (data_deleter_ != nullptr) { - data_deleter_(data_); - } - } - - /*! \return The size of the array */ - size_t size() const { return this->size_; } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const Any& at(int64_t i) const { return this->operator[](i); } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const Any& operator[](int64_t i) const { - if (i >= size_) { - TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; - } - return static_cast(data_)[i]; - } - - /*! \return begin constant iterator */ - const Any* begin() const { return static_cast(data_); } - - /*! \return end constant iterator */ - const Any* end() const { return begin() + size_; } - - /*! \brief Release reference to all the elements */ - void clear() { ShrinkBy(size_); } - - /*! - * \brief Set i-th element of the array in-place - * \param i The index - * \param item The value to be set - */ - void SetItem(int64_t i, Any item) { - if (i >= size_) { - TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; - } - static_cast(data_)[i] = std::move(item); - } - - /*! - * \brief Constructs a container and copy from another - * \param cap The capacity of the container - * \param from Source of the copy - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr CopyFrom(int64_t cap, ArrayObj* from) { - int64_t size = from->size_; - if (size > cap) { - TVM_FFI_THROW(ValueError) << "not enough capacity"; - } - ObjectPtr p = ArrayObj::Empty(cap); - Any* write = p->MutableBegin(); - Any* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) Any(*read++); - } - return p; - } - - /*! - * \brief Constructs a container and move from another - * \param cap The capacity of the container - * \param from Source of the move - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr MoveFrom(int64_t cap, ArrayObj* from) { - int64_t size = from->size_; - if (size > cap) { - TVM_FFI_THROW(RuntimeError) << "not enough capacity"; - } - ObjectPtr p = ArrayObj::Empty(cap); - Any* write = p->MutableBegin(); - Any* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) Any(std::move(*read++)); - } - from->size_ = 0; - return p; - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr CreateRepeated(int64_t n, const Any& val) { - ObjectPtr p = ArrayObj::Empty(n); - Any* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < n; ++i) { - new (itr++) Any(val); - } - return p; - } - - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIArray; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ArrayObj, Object); - - private: - /*! \return Size of initialized memory, used by InplaceArrayBase. */ - size_t GetSize() const { return this->size_; } - - /*! \return begin mutable iterator */ - Any* MutableBegin() const { return static_cast(this->data_); } - - /*! \return end mutable iterator */ - Any* MutableEnd() const { return MutableBegin() + size_; } - - /*! - * \brief Emplace a new element at the back of the array - * \param args The arguments to construct the new element - */ - template - void EmplaceInit(size_t idx, Args&&... args) { - Any* itr = MutableBegin() + idx; - new (itr) Any(std::forward(args)...); - } - - /*! - * \brief Create an ArrayObj with the given capacity. - * \param n Required capacity - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr Empty(int64_t n = kInitSize) { - ObjectPtr p = make_inplace_array_object(n); - p->capacity_ = n; - p->size_ = 0; - p->data_ = p->AddressOf(0); - return p; - } - - /*! - * \brief Inplace-initialize the elements starting idx from [first, last) - * \param idx The starting point - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return Self - */ - template - ArrayObj* InitRange(int64_t idx, IterType first, IterType last) { - Any* itr = MutableBegin() + idx; - for (; first != last; ++first) { - Any ref = *first; - new (itr++) Any(std::move(ref)); - } - return this; - } - - /*! - * \brief Move elements from right to left, requires src_begin > dst - * \param dst Destination - * \param src_begin The start point of copy (inclusive) - * \param src_end The end point of copy (exclusive) - * \return Self - */ - ArrayObj* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { - Any* from = MutableBegin() + src_begin; - Any* to = MutableBegin() + dst; - while (src_begin++ != src_end) { - *to++ = std::move(*from++); - } - return this; - } - - /*! - * \brief Move elements from left to right, requires src_begin < dst - * \param dst Destination - * \param src_begin The start point of move (inclusive) - * \param src_end The end point of move (exclusive) - * \return Self - */ - ArrayObj* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { - Any* from = MutableBegin() + src_end; - Any* to = MutableBegin() + (src_end - src_begin + dst); - while (src_begin++ != src_end) { - *--to = std::move(*--from); - } - return this; - } - - /*! - * \brief Enlarges the size of the array - * \param delta Size enlarged, should be positive - * \param val Default value - * \return Self - */ - ArrayObj* EnlargeBy(int64_t delta, const Any& val = Any()) { - Any* itr = MutableEnd(); - while (delta-- > 0) { - new (itr++) Any(val); - ++size_; - } - return this; - } - - /*! - * \brief Shrinks the size of the array - * \param delta Size shrinked, should be positive - * \return Self - */ - ArrayObj* ShrinkBy(int64_t delta) { - Any* itr = MutableEnd(); - while (delta-- > 0) { - (--itr)->Any::~Any(); - --size_; - } - return this; - } - - /*! \brief Data pointer to the first element of the array */ - void* data_; - /*! \brief Number of elements used */ - int64_t size_; - /*! \brief Number of elements allocated */ - int64_t capacity_; - /*! - * \brief Optional data deleter when data is allocated separately - * and its deletion is not managed by ArrayObj::deleter_. - */ - void (*data_deleter_)(void*) = nullptr; - - /*! \brief Initial size of ArrayObj */ - static constexpr int64_t kInitSize = 4; - - /*! \brief Expansion factor of the Array */ - static constexpr int64_t kIncFactor = 2; - - // CRTP parent class - friend InplaceArrayBase; - - // Reference class - template - friend class Array; - - template - friend class Tuple; - - template - friend struct TypeTraits; - - // To specialize make_object - friend ObjectPtr make_object<>(); -}; - -/*! \brief Helper struct for type-checking - * - * is_valid_iterator::value will be true if IterType can - * be dereferenced into a type that can be stored in an Array, and - * false otherwise. - */ -template -struct is_valid_iterator - : std::bool_constant< - std::is_same_v< - T, std::remove_cv_t())>>> || - std::is_base_of_v< - T, std::remove_cv_t())>>>> { -}; - -template -struct is_valid_iterator, IterType> : is_valid_iterator {}; - -template -struct is_valid_iterator : std::true_type {}; - -template -inline constexpr bool is_valid_iterator_v = is_valid_iterator::value; - -/*! - * \brief Array, container representing a contiguous sequence of ObjectRefs. - * - * Array implements in-place copy-on-write semantics. - * - * As in typical copy-on-write, a method which would typically mutate the array - * instead opaquely copies the underlying container, and then acts on its copy. - * - * If the array has reference count equal to one, we directly update the - * container in place without copying. This is optimization is sound because - * when the reference count is equal to one this reference is guranteed to be - * the sole pointer to the container. - * - * - * operator[] only provides const access, use Set to mutate the content. - * \tparam T The content Value type, must be compatible with tvm::ffi::Any - */ -template >> -class Array : public ObjectRef { - public: - using value_type = T; - // constructors - /*! - * \brief default constructor - */ - Array() { data_ = ArrayObj::Empty(); } - Array(Array&& other) : ObjectRef(std::move(other.data_)) {} - Array(const Array& other) : ObjectRef(other.data_) {} - template >> - Array(Array&& other) : ObjectRef(std::move(other.data_)) {} - template >> - Array(const Array& other) : ObjectRef(other.data_) {} - - TVM_FFI_INLINE Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - TVM_FFI_INLINE Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - template >> - TVM_FFI_INLINE Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - template >> - TVM_FFI_INLINE Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief Constructor from iterator - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType first, IterType last) { - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - Assign(first, last); - } - - /*! - * \brief constructor from initializer list - * \param init The initializer list - */ - Array(std::initializer_list init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(const size_t n, const T& val) { data_ = ArrayObj::CreateRepeated(n, val); } - - public: - // iterators - struct ValueConverter { - using ResultType = T; - static T convert(const Any& n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck(n); } - }; - - using iterator = details::IterAdapter; - using reverse_iterator = details::ReverseIterAdapter; - - /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayObj()->begin()); } - - /*! \return end iterator */ - iterator end() const { return iterator(GetArrayObj()->end()); } - - /*! \return rbegin iterator */ - reverse_iterator rbegin() const { - // ArrayObj::end() is never nullptr - return reverse_iterator(GetArrayObj()->end() - 1); - } - - /*! \return rend iterator */ - reverse_iterator rend() const { - // ArrayObj::begin() is never nullptr - return reverse_iterator(GetArrayObj()->begin() - 1); - } - - public: - // const methods in std::vector - /*! - * \brief Immutably read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const T operator[](int64_t i) const { - ArrayObj* p = GetArrayObj(); - if (p == nullptr) { - TVM_FFI_THROW(IndexError) << "cannot index a null array"; - } - if (i < 0 || i >= p->size_) { - TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin() + i)); - } - - /*! \return The size of the array */ - size_t size() const { - ArrayObj* p = GetArrayObj(); - return p == nullptr ? 0 : GetArrayObj()->size_; - } - - /*! \return The capacity of the array */ - size_t capacity() const { - ArrayObj* p = GetArrayObj(); - return p == nullptr ? 0 : GetArrayObj()->capacity_; - } - - /*! \return Whether array is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the array */ - const T front() const { - ArrayObj* p = GetArrayObj(); - if (p == nullptr || p->size_ == 0) { - TVM_FFI_THROW(IndexError) << "cannot index a empty array"; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin())); - } - - /*! \return The last element of the array */ - const T back() const { - ArrayObj* p = GetArrayObj(); - if (p == nullptr || p->size_ == 0) { - TVM_FFI_THROW(IndexError) << "cannot index a empty array"; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->end() - 1)); - } - - public: - // mutation in std::vector, implements copy-on-write - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - void push_back(const T& item) { - ArrayObj* p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, item); - } - - template - void emplace_back(Args&&... args) { - ArrayObj* p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, std::forward(args)...); - } - - /*! - * \brief Insert an element into the given position - * \param position An iterator pointing to the insertion point - * \param val The element to insert - */ - void insert(iterator position, const T& val) { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; - } - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - auto addr = CopyOnWrite(1) // - ->EnlargeBy(1) // - ->MoveElementsRight(idx + 1, idx, size) // - ->MutableBegin(); - new (addr + idx) Any(val); - } - - /*! - * \brief Insert a range of elements into the given position - * \param position An iterator pointing to the insertion point - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - template - void insert(iterator position, IterType first, IterType last) { - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - - if (first == last) { - return; - } - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; - } - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - int64_t numel = std::distance(first, last); - CopyOnWrite(numel) - ->EnlargeBy(numel) - ->MoveElementsRight(idx + numel, idx, size) - ->InitRange(idx, first, last); - } - - /*! \brief Remove the last item of the list */ - void pop_back() { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array"; - } - int64_t size = GetArrayObj()->size_; - if (size == 0) { - TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty array"; - } - CopyOnWrite()->ShrinkBy(1); - } - - /*! - * \brief Erase an element on the given position - * \param position An iterator pointing to the element to be erased - */ - void erase(iterator position) { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; - } - int64_t st = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - if (st < 0 || st >= size) { - TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ", because Array size is " - << size; - } - CopyOnWrite() // - ->MoveElementsLeft(st, st + 1, size) // - ->ShrinkBy(1); - } - - /*! - * \brief Erase a given range of elements - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - void erase(iterator first, iterator last) { - if (first == last) { - return; - } - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; - } - int64_t size = GetArrayObj()->size_; - int64_t st = std::distance(begin(), first); - int64_t ed = std::distance(begin(), last); - if (st >= ed) { - TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")"; - } - if (st < 0 || st > size || ed < 0 || ed > size) { - TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")" - << ", because array size is " << size; - } - CopyOnWrite() // - ->MoveElementsLeft(st, ed, size) // - ->ShrinkBy(ed - st); - } - - /*! - * \brief Resize the array. - * \param n The new size. - */ - void resize(int64_t n) { - if (n < 0) { - TVM_FFI_THROW(ValueError) << "cannot resize an Array to negative size"; - } - if (data_ == nullptr) { - SwitchContainer(n); - return; - } - int64_t size = GetArrayObj()->size_; - if (size < n) { - CopyOnWrite(n - size)->EnlargeBy(n - size); - } else if (size > n) { - CopyOnWrite()->ShrinkBy(size - n); - } - } - - /*! - * \brief Make sure the list has the capacity of at least n - * \param n lower bound of the capacity - */ - void reserve(int64_t n) { - if (data_ == nullptr || n > GetArrayObj()->capacity_) { - SwitchContainer(n); - } - } - - /*! \brief Release reference to all the elements */ - void clear() { - if (data_ != nullptr) { - ArrayObj* p = CopyOnWrite(); - p->clear(); - } - } - - template - static size_t CalcCapacityImpl() { - return 0; - } - - template - static size_t CalcCapacityImpl(Array value, Args... args) { - return value.size() + CalcCapacityImpl(args...); - } - - template - static size_t CalcCapacityImpl(T value, Args... args) { - return 1 + CalcCapacityImpl(args...); - } - - template - static void AgregateImpl(Array& dest) {} // NOLINT(*) - - template - static void AgregateImpl(Array& dest, Array value, Args... args) { // NOLINT(*) - dest.insert(dest.end(), value.begin(), value.end()); - AgregateImpl(dest, args...); - } - - template - static void AgregateImpl(Array& dest, T value, Args... args) { // NOLINT(*) - dest.push_back(value); - AgregateImpl(dest, args...); - } - - public: - // Array's own methods - - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - void Set(int64_t i, T value) { - ArrayObj* p = this->CopyOnWrite(); - if (i < 0 || i >= p->size_) { - TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; - } - *(p->MutableBegin() + i) = std::move(value); - } - - /*! \return The underlying ArrayObj */ - ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } - - /*! - * \brief Helper function to apply a map function onto the array. - * - * \param fmap The transformation function T -> U. - * - * \tparam F The type of the mutation function. - * - * \tparam U The type of the returned array, inferred from the - * return type of F. If overridden by the user, must be something - * that is convertible from the return type of F. - * - * \note This function performs copy on write optimization. If - * `fmap` returns an object of type `T`, and all elements of the - * array are mapped to themselves, then the returned array will be - * the same as the original, and reference counts of the elements in - * the array will not be incremented. - * - * \return The transformed array. - */ - template > - Array Map(F fmap) const { - return Array(MapHelper(data_, fmap)); - } - - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template >>> - void MutateByApply(F fmutate) { - data_ = MapHelper(std::move(data_), fmutate); - } - - /*! - * \brief reset the array to content from iterator. - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - void Assign(IterType first, IterType last) { - int64_t cap = std::distance(first, last); - if (cap < 0) { - TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative size"; - } - ArrayObj* p = GetArrayObj(); - if (p != nullptr && data_.unique() && p->capacity_ >= cap) { - // do not have to make new space - p->clear(); - } else { - // create new space - data_ = ArrayObj::Empty(cap); - p = GetArrayObj(); - } - // To ensure exception safety, size is only incremented after the initialization succeeds - Any* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { - new (itr) Any(*first); - } - } - - /*! - * \brief Copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - ArrayObj* CopyOnWrite() { - if (data_ == nullptr) { - return SwitchContainer(ArrayObj::kInitSize); - } - if (!data_.unique()) { - return SwitchContainer(capacity()); - } - return static_cast(data_.get()); - } - - /*! \brief specify container node */ - using ContainerType = ArrayObj; - - /*! - * \brief Agregate arguments into a single Array - * \param args sequence of T or Array elements - * \return Agregated Array - */ - template - static Array Agregate(Args... args) { - Array result; - result.reserve(CalcCapacityImpl(args...)); - AgregateImpl(result, args...); - return result; - } - - private: - /*! - * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. - * \param reserve_extra Number of extra slots needed - * \return ArrayObj pointer to the unique copy - */ - ArrayObj* CopyOnWrite(int64_t reserve_extra) { - ArrayObj* p = GetArrayObj(); - if (p == nullptr) { - // necessary to get around the constexpr address issue before c++17 - const int64_t kInitSize = ArrayObj::kInitSize; - return SwitchContainer(std::max(kInitSize, reserve_extra)); - } - if (p->capacity_ >= p->size_ + reserve_extra) { - return CopyOnWrite(); - } - int64_t cap = p->capacity_ * ArrayObj::kIncFactor; - cap = std::max(cap, p->size_ + reserve_extra); - return SwitchContainer(cap); - } - - /*! - * \brief Move or copy the ArrayObj to new address with the given capacity - * \param capacity The capacity requirement of the new address - */ - ArrayObj* SwitchContainer(int64_t capacity) { - if (data_ == nullptr) { - data_ = ArrayObj::Empty(capacity); - } else if (data_.unique()) { - data_ = ArrayObj::MoveFrom(capacity, GetArrayObj()); - } else { - data_ = ArrayObj::CopyFrom(capacity, GetArrayObj()); - } - return static_cast(data_.get()); - } - - /*! \brief Helper method for mutate/map - * - * A helper function used internally by both `Array::Map` and - * `Array::MutateInPlace`. Given an array of data, apply the - * mapping function to each element, returning the collected array. - * Applies both mutate-in-place and copy-on-write optimizations, if - * possible. - * - * \param data A pointer to the ArrayObj containing input data. - * Passed by value to allow for mutate-in-place optimizations. - * - * \param fmap The mapping function - * - * \tparam F The type of the mutation function. - * - * \tparam U The output type of the mutation function. Inferred - * from the callable type given. Must inherit from ObjectRef. - * - * \return The mapped array. Depending on whether mutate-in-place - * or copy-on-write optimizations were applicable, may be the same - * underlying array as the `data` parameter. - */ - template > - static ObjectPtr MapHelper(ObjectPtr data, F fmap) { - if (data == nullptr) { - return nullptr; - } - - TVM_FFI_ICHECK(data->IsInstance()); - - constexpr bool is_same_output_type = std::is_same_v; - - if constexpr (is_same_output_type) { - if (data.unique()) { - // Mutate-in-place path. Only allowed if the output type U is - // the same as type T, we have a mutable this*, and there are - // no other shared copies of the array. - auto arr = static_cast(data.get()); - for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { - T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it); - // reset the original value to nullptr, to ensure unique ownership - it->reset(); - T mapped = fmap(std::move(value)); - *it = std::move(mapped); - } - return data; - } - } - - constexpr bool compatible_types = is_valid_iterator_v || is_valid_iterator_v; - - ObjectPtr output = nullptr; - auto arr = static_cast(data.get()); - - auto it = arr->begin(); - if constexpr (compatible_types) { - // Copy-on-write path, if the output Array might be - // represented by the same underlying array as the existing - // Array. Typically, this is for functions that map `T` to - // `T`, but can also apply to functions that map `T` to - // `Optional`, or that map `T` to a subclass or superclass of - // `T`. - bool all_identical = true; - for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); - if (!(*it).same_as(mapped)) { - // At least one mapped element is different than the - // original. Therefore, prepare the output array, - // consisting of any previous elements that had mapped to - // themselves (if any), and the element that didn't map to - // itself. - // - // We cannot use `U()` as the default object, as `U` may be - // a non-nullable type. Since the default `Any()` - // will be overwritten before returning, all objects will be - // of type `U` for the calling scope. - all_identical = false; - output = ArrayObj::CreateRepeated(arr->size(), Any()); - output->InitRange(0, arr->begin(), it); - output->SetItem(it - arr->begin(), std::move(mapped)); - it++; - break; - } - } - if (all_identical) { - return data; - } - } else { - // Path for incompatible types. The constexpr check for - // compatible types isn't strictly necessary, as the first - // (*it).same_as(mapped) would return false, but we might as well - // avoid it altogether. - // - // We cannot use `U()` as the default object, as `U` may be a - // non-nullable type. Since the default `Any()` will be - // overwritten before returning, all objects will be of type `U` - // for the calling scope. - output = ArrayObj::CreateRepeated(arr->size(), Any()); - } - - // Normal path for incompatible types, or post-copy path for - // copy-on-write instances. - // - // If the types are incompatible, then at this point `output` is - // empty, and `it` points to the first element of the input. - // - // If the types were compatible, then at this point `output` - // contains zero or more elements that mapped to themselves - // followed by the first element that does not map to itself, and - // `it` points to the element just after the first element that - // does not map to itself. Because at least one element has been - // changed, we no longer have the opportunity to avoid a copy, so - // we don't need to check the result. - // - // In both cases, `it` points to the next element to be processed, - // so we can either start or resume the iteration from that point, - // with no further checks on the result. - for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); - output->SetItem(it - arr->begin(), std::move(mapped)); - } - - return output; - } - template - friend class Array; -}; - -/*! - * \brief Concat two Arrays. - * \param lhs first Array to be concatenated. - * \param rhs second Array to be concatenated. - * \return The concatenated Array. Original Arrays are kept unchanged. - */ -template || - TypeTraits::convert_enabled>> -inline Array Concat(Array lhs, const Array& rhs) { - for (const auto& x : rhs) { - lhs.push_back(x); - } - return std::move(lhs); -} - -// Specialize make_object to make sure it is correct. -template <> -inline ObjectPtr make_object() { - return ArrayObj::Empty(); -} - -// Traits for Array -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray; - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - if constexpr (!std::is_same_v) { - const ArrayObj* n = reinterpret_cast(src->v_obj); - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - // CheckAnyStrict is cheaper than try_cast - if (details::AnyUnsafe::CheckAnyStrict(any_v)) continue; - // try see if p is convertible to T - if (any_v.try_cast()) continue; - // now report the accurate mismatch information - return "Array[index " + std::to_string(i) + ": " + - details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; - } - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) return false; - if constexpr (std::is_same_v) { - return true; - } else { - const ArrayObj* n = reinterpret_cast(src->v_obj); - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v)) return false; - } - return true; - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // try to run conversion. - if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; - if constexpr (!std::is_same_v) { - const ArrayObj* n = reinterpret_cast(src->v_obj); - bool storage_check = [&]() { - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v)) return false; - } - return true; - }(); - // fast path, if storage check passes, we can return the array directly. - if (storage_check) { - return CopyFromAnyViewAfterCheck(src); - } - // slow path, try to run a conversion to Array - Array result; - result.reserve(n->size()); - for (size_t i = 0; i < n->size(); i++) { - const Any& any_v = (*n)[i]; - if (auto opt_v = any_v.try_cast()) { - result.push_back(*std::move(opt_v)); - } else { - return std::nullopt; - } - } - return result; - } else { - return CopyFromAnyViewAfterCheck(src); - } - } - - TVM_FFI_INLINE static std::string TypeStr() { return "Array<" + details::Type2Str::v() + ">"; } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Array> = type_contains_v; -} // namespace details - -} // namespace ffi - -// Expose to the tvm namespace -// Rationale: convinience and no ambiguity -using ffi::Array; -} // namespace tvm -#endif // TVM_FFI_CONTAINER_ARRAY_H_ diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h deleted file mode 100644 index bb29a14f7cb8..000000000000 --- a/ffi/include/tvm/ffi/container/container_details.h +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/container_details.h - * \brief Common utilities for typed container types. - */ -#ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ -#define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ - -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base template for classes with array like memory layout. - * - * It provides general methods to access the memory. The memory - * layout is ArrayType + [ElemType]. The alignment of ArrayType - * and ElemType is handled by the memory allocator. - * - * \tparam ArrayType The array header type, contains object specific metadata. - * \tparam ElemType The type of objects stored in the array right after - * ArrayType. - * - * \code - * // Example usage of the template to define a simple array wrapper - * class ArrayObj : public tvm::ffi::details::InplaceArrayBase { - * public: - * // Wrap EmplaceInit to initialize the elements - * template - * void Init(Iterator begin, Iterator end) { - * size_t num_elems = std::distance(begin, end); - * auto it = begin; - * this->size = 0; - * for (size_t i = 0; i < num_elems; ++i) { - * InplaceArrayBase::EmplaceInit(i, *it++); - * this->size++; - * } - * } - * } - * - * void test_function() { - * vector fields; - * auto ptr = make_inplace_array_object(fields.size()); - * ptr->Init(fields.begin(), fields.end()); - * - * // Access the 0th element in the array. - * assert(ptr->operator[](0) == fields[0]); - * } - * - * \endcode - */ -template -class InplaceArrayBase { - public: - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Const reference to ElemType at the index. - */ - const ElemType& operator[](size_t idx) const { - size_t size = Self()->GetSize(); - if (idx > size) { - TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; - } - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Reference to ElemType at the index. - */ - ElemType& operator[](size_t idx) { - size_t size = Self()->GetSize(); - if (idx > size) { - TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; - } - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Destroy the Inplace Array Base object - */ - ~InplaceArrayBase() { - if constexpr (!(std::is_standard_layout::value && std::is_trivial::value)) { - size_t size = Self()->GetSize(); - for (size_t i = 0; i < size; ++i) { - ElemType* fp = reinterpret_cast(AddressOf(i)); - fp->ElemType::~ElemType(); - } - } - } - - protected: - /*! - * \brief Construct a value in place with the arguments. - * - * \tparam Args Type parameters of the arguments. - * \param idx Index of the element. - * \param args Arguments to construct the new value. - * - * \note Please make sure ArrayType::GetSize returns 0 before first call of - * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. - */ - template - void EmplaceInit(size_t idx, Args&&... args) { - void* field_ptr = AddressOf(idx); - new (field_ptr) ElemType(std::forward(args)...); - } - - /*! - * \brief Return the self object for the array. - * - * \return Pointer to ArrayType. - */ - inline ArrayType* Self() const { - return static_cast(const_cast(this)); - } - - /*! - * \brief Return the raw pointer to the element at idx. - * - * \param idx The index of the element. - * \return Raw pointer to the element. - */ - void* AddressOf(size_t idx) const { - static_assert( - alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); - - size_t kDataStart = sizeof(ArrayType); - ArrayType* self = Self(); - char* data_start = reinterpret_cast(self) + kDataStart; - return data_start + idx * sizeof(ElemType); - } -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - IterAdapter& operator++() { - ++iter_; - return *this; - } - IterAdapter& operator--() { - --iter_; - return *this; - } - IterAdapter operator++(int) { - IterAdapter copy = *this; - ++iter_; - return copy; - } - IterAdapter operator--(int) { - IterAdapter copy = *this; - --iter_; - return copy; - } - - IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - - IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } - - IterAdapter& operator+=(difference_type offset) { - iter_ += offset; - return *this; - } - - IterAdapter& operator-=(difference_type offset) { - iter_ -= offset; - return *this; - } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const IterAdapter& rhs) const { - return iter_ - rhs.iter_; - } - - bool operator==(IterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(IterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class ReverseIterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} - ReverseIterAdapter& operator++() { - --iter_; - return *this; - } - ReverseIterAdapter& operator--() { - ++iter_; - return *this; - } - ReverseIterAdapter operator++(int) { - ReverseIterAdapter copy = *this; - --iter_; - return copy; - } - ReverseIterAdapter operator--(int) { - ReverseIterAdapter copy = *this; - ++iter_; - return copy; - } - ReverseIterAdapter operator+(difference_type offset) const { - return ReverseIterAdapter(iter_ - offset); - } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const ReverseIterAdapter& rhs) const { - return rhs.iter_ - iter_; - } - - bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief Check if T is compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool storage_enabled_v = std::is_same_v || TypeTraits::storage_enabled; - -/*! - * \brief Check if all T are compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool all_storage_enabled_v = (storage_enabled_v && ...); - -/*! - * \brief Check if all T are compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool all_object_ref_v = (std::is_base_of_v && ...); -/** - * \brief Check if Any storage of Derived can always be directly used as Base. - * - * \tparam Base The base type. - * \tparam Derived The derived type. - * \return True if Derived's storage can be used as Base's storage, false otherwise. - */ -template -inline constexpr bool type_contains_v = - std::is_base_of_v || std::is_same_v; -// special case for Any -template -inline constexpr bool type_contains_v = true; - -/*! - * \brief Create a string of the container type. - * \tparam V The types of the elements in the container. - * \param name The name of the container type. - * \return A string of the container type. - */ -template -std::string ContainerTypeStr(const char* name) { - std::stringstream ss; - // helper to construct concated string of TypeStr - class TypeStrHelper { - public: - TypeStrHelper(std::stringstream& stream) : stream_(stream) {} // NOLINT(*) - - TypeStrHelper& operator<<(const std::string& str) { - if (counter_ > 0) { - stream_ << ", "; - } - stream_ << str; - counter_++; - return *this; - } - - private: - std::stringstream& stream_; // NOLINT(*) - int counter_ = 0; - }; - TypeStrHelper helper(ss); - ss << name << '<'; - (helper << ... << Type2Str::v()); - ss << '>'; - return ss.str(); -} - -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h deleted file mode 100644 index b1ca4f805edd..000000000000 --- a/ffi/include/tvm/ffi/container/map.h +++ /dev/null @@ -1,1709 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/map.h - * \brief Runtime Map container types. - */ -#ifndef TVM_FFI_CONTAINER_MAP_H_ -#define TVM_FFI_CONTAINER_MAP_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE -#define TVM_FFI_MAP_FAIL_IF_CHANGED() \ - TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; -#else -#define TVM_FFI_MAP_FAIL_IF_CHANGED() -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - -/*! \brief Shared content of all specializations of hash map */ -class MapObj : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = Any; - /*! \brief Type of the values in the hash map */ - using mapped_type = Any; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /*! \brief Type of raw storage of the key-value pair in the hash map */ - struct KVRawStorageType { - TVMFFIAny first; - TVMFFIAny second; - }; - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect"); - - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIMap; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(MapObj, Object); - - /*! - * \brief Number of elements in the SmallMapObj - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { erase(find(key)); } - - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType*; - using reference = KVType&; -/*! \brief Default constructor */ -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - iterator() : state_marker(0), index(0), self(nullptr) {} -#else - iterator() : index(0), self(nullptr) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return *((*this).operator->()); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - --(*this); - return copy; - } - - protected: -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; - /*! \brief Construct by value */ - iterator(uint64_t index, const MapObj* self) - : state_marker(self->state_marker), index(index), self(self) {} - -#else - iterator(uint64_t index, const MapObj* self) : index(index), self(self) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapObj* self; - - friend class DenseMapObj; - friend class SmallMapObj; - }; - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - - protected: -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(KVType&& kv, ObjectPtr* map); - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapObj* from); - /*! - * \brief data pointer to the data region of the map. - * \note For immutable inplace small map we do not need data_, - * but we keep it here for future compact with mutable container. - */ - void* data_; - /*! \brief number of entries in the container */ - uint64_t size_; - /*! \brief number of slots */ - uint64_t slots_; - /*! - * \brief Small layout tag mask - * \note The most significant bit is used to indicate the small map layout. - */ - static constexpr uint64_t kSmallTagMask = static_cast(1) << 63; - /*! - * \brief Check if the map is a small map - * \return True if the map is a small map - */ - bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; } - /*! - * \brief Optional data deleter when data is allocated separately - * and its deletion is not managed by MapObj::deleter_. - */ - void (*data_deleter_)(void*) = nullptr; - // Reference class - template - friend class Map; -}; - -/*! \brief A specialization of small-sized hash map */ -class SmallMapObj : public MapObj, - public details::InplaceArrayBase { - private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - - public: - using MapObj::iterator; - using MapObj::KVType; - - // Return the number of usable slots for Small layout (mask off tag). - /*! - * \brief Return the number of usable slots for Small layout (mask off tag). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; } - - ~SmallMapObj() { - KVType* begin = static_cast(data_); - for (uint64_t index = 0; index < size_; ++index) { - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - } - if (data_deleter_ != nullptr) { - data_deleter_(data_); - } - } - /*! - * \brief Count the number of times a key exists in the SmallMapObj - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; - } - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; - } - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - KVType* ptr = static_cast(data_); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (AnyEqual()(ptr->first, key)) { - return iterator(i, this); - } - } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { Erase(position.index); } - - private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; } - /*! - * \brief Remove a position in SmallMapObj - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; - } - KVType* begin = static_cast(data_); - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - // IMPORTANT: We do direct raw memmove to bring later items to the current position - // to preserve the order of insertion. - // This works because direct memory copy preserves the Any's move semantics. - if (index + 1 < size_) { - std::memmove(reinterpret_cast(begin + index), - reinterpret_cast(begin + index + 1), - (size_ - index - 1) * sizeof(KVType)); - } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::ffi::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->data_ = p->AddressOf(0); - p->size_ = 0; - p->SetSlotsAndSmallLayoutTag(n); - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType* ptr = static_cast(p->data_); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapObj* from) { - KVType* first = static_cast(from->data_); - KVType* last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - SmallMapObj* map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; - } - if (map_node->size_ < map_node->NumSlots()) { - KVType* ptr = static_cast(map_node->data_) + map_node->size_; - new (ptr) KVType(std::move(kv)); - ++map_node->size_; - return; - } - uint64_t next_size = std::max(map_node->NumSlots() * 2, uint64_t(kInitSize)); - next_size = std::min(next_size, uint64_t(kMaxSize)); - TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots()); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return static_cast(data_) + index; } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - - protected: - friend class MapObj; - friend class DenseMapObj; - friend class details::InplaceArrayBase; -}; - -/*! \brief A specialization of hash map that implements the idea of array-based hash map. - * Another reference implementation can be found [1]. - * - * A. Overview - * - * DenseMapObj did several improvements over traditional separate chaining hash, - * in terms of cache locality, memory footprints and data organization. - * - * A1. Implicit linked list. For better cache locality, instead of using linked list - * explicitly for each bucket, we store list data into a single array that spans contiguously - * in memory, and then carefully design access patterns to make sure most of them fall into - * a single cache line. - * - * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and - * traversal. This can be divided in 3 parts. - * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, - * which means the slot is empty but not allowed to be written. - * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is - * head of a linked list. - * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit - * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when - * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are - * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to - * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, - * then x must be one of the 126 pre-defined values. - * - * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. - * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. - * 16 key-value pairs. - * - * B. Implementation details - * - * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid - * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, - * we use the Fibonacci Hashing [2] trick. - * - * B2. Traverse a linked list in the array. - * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i - * indicates that it is list head, then we found the head; otherwise the list is empty. No probing - * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we - * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of - * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). - * - * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this - * element is in the linked list, and if not, we put it at the end by probing the next empty - * position in one of the 126 candidate positions. If the linked list does not even exist, but the - * slot for list head has been occupied by another linked list, we should find this intruder another - * place. - * - * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing - * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the - * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list - * head. - * - * [1] https://github.com/skarupke/flat_hash_map - * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ - * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - */ -class DenseMapObj : public MapObj { - private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Index indicator to indicate an invalid index */ - static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief item type of the dense map, including a kv data and prev/next pointer */ - struct ItemType { - KVType data; - uint64_t prev = kInvalidIndex; - uint64_t next = kInvalidIndex; - - explicit ItemType(KVType&& data) : data(std::move(data)) {} - explicit ItemType(key_type key, mapped_type value) : data(key, value) {} - }; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(ItemType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(ItemType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout::value, "Block is not standard layout"); - - /*! - * \brief Deleter for the Block - * \param data The pointer to the Block - */ - static void BlockDeleter(void* data) { delete[] static_cast(data); } - - public: - using MapObj::iterator; - - /*! - * \brief Return the number of usable slots for Dense layout (MSB clear => identity). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_; } - - /*! - * \brief Destroy the DenseMapObj - */ - ~DenseMapObj() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type& key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->NumSlots()) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { return iterator(iter_list_head_, this); } - /*! \return end iterator */ - iterator end() const { return iterator(kInvalidIndex, this); } - - private: - Block* GetBlock(size_t index) const { return static_cast(data_) + index; } - /*! - * \brief Unlink the entry from iterator list - * \param node The node to be unlinked - * \note This function is usually used before deletion, - * and it does not change data content of the node. - */ - void IterListUnlink(ListNode node) { - // update head and tail of iterator list if needed - if (node.Item().prev == kInvalidIndex) { - iter_list_head_ = node.Item().next; - } else { - ListNode prev_node(node.Item().prev, this); - prev_node.Item().next = node.Item().next; - } - if (node.Item().next == kInvalidIndex) { - iter_list_tail_ = node.Item().prev; - } else { - ListNode next_node(node.Item().next, this); - next_node.Item().prev = node.Item().prev; - } - } - /*! - * \brief Insert the entry into tail of iterator list - * \param node The node to be inserted - * \note this function does not change data content of the node. - */ - void IterListPushBack(ListNode node) { - node.Item().prev = iter_list_tail_; - node.Item().next = kInvalidIndex; - if (iter_list_tail_ != kInvalidIndex) { - ListNode prev_node(iter_list_tail_, this); - prev_node.Item().next = node.index; - } - if (iter_list_head_ == kInvalidIndex) { - iter_list_head_ = node.index; - } - iter_list_tail_ = node.index; - } - /*! - * \brief Replace node src by dst in the iter list - * \param src The source node - * \param dst The destination node, must be empty - * \note This function does not change data content of the nodes, - * which needs to be updated by the caller. - */ - void IterListReplaceNodeBy(ListNode src, ListNode dst) { - // set link correctly on the dst - dst.Item().prev = src.Item().prev; - dst.Item().next = src.Item().next; - // update prev and next of dst - if (dst.Item().prev == kInvalidIndex) { - iter_list_head_ = dst.index; - } else { - ListNode prev_node(dst.Item().prev, this); - prev_node.Item().next = dst.index; - } - if (dst.Item().next == kInvalidIndex) { - iter_list_tail_ = dst.index; - } else { - ListNode next_node(dst.Item().next, this); - next_node.Item().prev = dst.index; - } - } - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type& key) const { - if (this->size_ == 0) { - return ListNode(); - } - for (ListNode iter = GetListHead(AnyHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (AnyEqual()(key, iter.Key())) { - return iter; - } - } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type& At(const key_type& key) const { - ListNode iter = Search(key); - if (iter.IsNone()) { - TVM_FFI_THROW(IndexError) << "key is not in Map"; - } - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type& key, ListNode* result) { - if (slots_ == 0) { - return false; - } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(AnyHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = iter; - return true; - } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); - } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (AnyEqual()(key, next.Key())) { - // we plan to take next, so we need to unlink it from iterator list - IterListUnlink(next); - *result = next; - return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; - } - result->NewTail(ItemType(key, Any(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - // first move the data over - empty.NewTail(ItemType(std::move(r.Data()))); - // then move link list chain of r to empty - // this needs to happen after NewTail so empty's prev/next get updated - IterListReplaceNodeBy(r, empty); - // explicit call destructor to destroy the item in `r` - r.DestructData(); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode& iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - // unlink the node from iterator list - IterListUnlink(iter); - // IMPORTANT: must explicit call destructor `iter` to avoid memory leak - // This is because we need to recycle iter's data - iter.DestructData(); - // set the meta data to be empty - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - // needs to first unlink iter from the list - IterListUnlink(iter); - // move data from last to iter - iter.Data() = std::move(last.Data()); - // Move link chain of iter to last as we stores last node to the new iter loc. - IterListReplaceNodeBy(last, iter); - // IMPORTANT: must explicit call destructor `last` to avoid memory leak - // likely we don't need this in this particular case because Any move behavior - // keep it here to be safe so code do not depend on specific move behavior of KVType - last.DestructData(); - // set the meta data to be empty - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->NumSlots()); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = GetBlock(bi)->bytes; - ItemType* data_ptr = reinterpret_cast(GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - data_ptr->ItemType::~ItemType(); - } - } - } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - if (data_ != nullptr) { - TVM_FFI_ICHECK(data_deleter_ != nullptr); - data_deleter_(data_); - } - data_ = nullptr; - data_deleter_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); - // Ensure even slot count (power-of-two expected by callers; this guard - // makes the method robust if a non-even value slips through). - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots); - Block* block = new Block[n_blocks]; - p->data_ = block; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(n_slots); - p->size_ = 0; - p->fib_shift_ = fib_shift; - p->iter_list_head_ = kInvalidIndex; - p->iter_list_tail_ = kInvalidIndex; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapObj* from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->NumSlots()); - p->data_ = new Block[n_blocks]; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(from->NumSlots()); - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - p->iter_list_head_ = from->iter_list_head_; - p->iter_list_tail_ = from->iter_list_tail_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr_from = from->GetBlock(bi)->bytes; - ItemType* data_ptr_from = reinterpret_cast(from->GetBlock(bi)->bytes + kBlockCap); - uint8_t* meta_ptr_to = p->GetBlock(bi)->bytes; - ItemType* data_ptr_to = reinterpret_cast(p->GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t& meta = *meta_ptr_to = *meta_ptr_from; - TVM_FFI_ICHECK(meta != kProtectedSlot); - if (meta != uint8_t(kEmptySlot)) { - new (data_ptr_to) ItemType(*data_ptr_from); - } - } - } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - DenseMapObj* map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = std::move(kv.second); - // update the iter list relation - map_node->IterListPushBack(iter); - return; - } - TVM_FFI_ICHECK(!map_node->IsSmallMap()); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2); - - // need to insert in the same order as the original map - for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { - ListNode node(index, map_node); - // now try move src_data into the new map, note that src may still not - // be fully consumed into the call, but destructor will be called. - InsertMaybeReHash(std::move(node.Data()), &p); - // Important, needs to explicit call destructor in case move did remove - // node's internal item - index = node.Item().next; - // IMPORTANT: must explicit call destructor `node` to avoid memory leak - // We must call node.DestructData() here. - // This is because std::move() arguments in IterMaybeReHash may or may not - // explicitly move out the node.Data() - // Remove this call will cause memory leak very likely. - node.DestructData(); - } - InsertMaybeReHash(std::move(kv), &p); - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { return size_ + 1 > NumSlots() * kMaxLoadFactor; } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - // keep at the end of iterator - if (index == kInvalidIndex) { - return index; - } - ListNode node(index, this); - return node.Item().next; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - // this is the end iterator, we need to return tail. - if (index == kInvalidIndex) { - return iter_list_tail_; - } - // circle around the iterator list, which is OK - ListNode node(index, this); - return node.Item().prev; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots) { return (n_slots + kBlockCap - 1) / kBlockCap; } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; - } - TVM_FFI_ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; - } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapObj* self) - : index(index), block(self->GetBlock(index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - ItemType& Item() const { - return *(reinterpret_cast(block->bytes + kBlockCap + - (index % kBlockCap) * sizeof(ItemType))); - } - /*! \brief Data on the entry */ - KVType& Data() const { return Item().data; } - /*! \brief Key on the entry */ - key_type& Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type& Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } - /*! \brief Destruct the item in the entry */ - void DestructData() const { - // explicit call destructor to destroy the item - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (&Data())->first.Any::~Any(); - (&Data())->second.Any::~Any(); - } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(ItemType v) const { - Meta() = 0b00000000; - new (&Item()) ItemType(std::move(v)); - } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(ItemType v) const { - Meta() = 0b10000000; - new (&Item()) ItemType(std::move(v)); - } - - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj* self, uint8_t meta) { - uint64_t offset = NextProbeLocation(meta & 0b01111111); - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - // the probing will go to next position and round back to stay within the - // correct range of the slots - index = (index + offset) % self->NumSlots(); - block = self->GetBlock(index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj* self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapObj* self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(AnyHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapObj* self, uint8_t* jump, ListNode* result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - // the probing will go to next position and round back to stay within the - // correct range of the slots - ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block* block; - }; - - protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief the head of iterator list */ - uint64_t iter_list_head_ = kInvalidIndex; - /*! \brief the tail of iterator list */ - uint64_t iter_list_tail_ = kInvalidIndex; - - static uint64_t NextProbeLocation(size_t index) { - /* clang-format off */ - /*! \brief Candidates of probing distance */ - static const uint64_t kNextProbeLocation[kNumJumpDists] { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - // Quadratic probing with triangle numbers. See also: - // 1) https://en.wikipedia.org/wiki/Quadratic_probing - // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - // 3) https://github.com/skarupke/flat_hash_map - 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, - 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, - 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, - 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, - 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, - 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, - 2211, 2278, 2346, 2415, 2485, 2556, 2628, - // larger triangle numbers - 8515, 19110, 42778, 96141, 216153, - 486591, 1092981, 2458653, 5532801, 12442566, - 27993903, 62983476, 141717030, 318844378, 717352503, - 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, - 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, - 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, - 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, - 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, - 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435, - }; - /* clang-format on */ - return kNextProbeLocation[index]; - } - friend class MapObj; - - private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndDenseLayoutTag(uint64_t n) { - TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear"; - slots_ = n; - } -}; - -#define TVM_FFI_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapObj*; \ - using TDense = DenseMapObj*; \ - if (base->IsSmallMap()) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -#define TVM_FFI_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapObj*; \ - using TDense = const DenseMapObj*; \ - if (base->IsSmallMap()) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -inline MapObj::iterator::pointer MapObj::iterator::operator->() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); -} - -inline MapObj::iterator& MapObj::iterator::operator++() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); -} - -inline MapObj::iterator& MapObj::iterator::operator--() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); -} - -inline size_t MapObj::count(const key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); -} - -inline const MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); -} - -inline MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->at(key); }); -} - -inline MapObj::iterator MapObj::begin() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); -} - -inline MapObj::iterator MapObj::end() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->end(); }); -} - -inline MapObj::iterator MapObj::find(const MapObj::key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); -} - -inline void MapObj::erase(const MapObj::iterator& position) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->erase(position); }); -} - -#undef TVM_FFI_DISPATCH_MAP -#undef TVM_FFI_DISPATCH_MAP_CONST - -inline ObjectPtr MapObj::Empty() { return SmallMapObj::Empty(); } - -inline ObjectPtr MapObj::CopyFrom(MapObj* from) { - if (from->IsSmallMap()) { - return SmallMapObj::CopyFrom(static_cast(from)); - } else { - return DenseMapObj::CopyFrom(static_cast(from)); - } -} - -template -inline ObjectPtr MapObj::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapObj::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapObj::kMaxSize) { - if (cap < 2) { - return SmallMapObj::CreateFromRange(cap, first, last); - } - // need to insert to avoid duplicate keys - ObjectPtr obj = SmallMapObj::Empty(cap); - for (; first != last; ++first) { - KVType kv(*first); - SmallMapObj::InsertMaybeReHash(std::move(kv), &obj); - } - return obj; - } else { - uint32_t fib_shift; - uint64_t n_slots; - DenseMapObj::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapObj::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapObj::InsertMaybeReHash(std::move(kv), &obj); - } - return obj; - } -} - -inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - MapObj* base = static_cast(map->get()); -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - base->state_marker++; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - if (base->IsSmallMap()) { - SmallMapObj* sm = static_cast(base); - if (sm->NumSlots() < SmallMapObj::kMaxSize) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else if (sm->NumSlots() == SmallMapObj::kMaxSize) { - if (base->size_ < sm->NumSlots()) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else { - ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); - DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } - } - } else { - DenseMapObj::InsertMaybeReHash(std::move(kv), map); - } -} - -template <> -inline ObjectPtr make_object<>() = delete; - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template && - details::storage_enabled_v>> -class Map : public ObjectRef { - public: - using key_type = K; - using mapped_type = V; - class iterator; - /*! - * \brief default constructor - */ - Map() { data_ = MapObj::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map&& other) : ObjectRef(std::move(other.data_)) {} - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map& other) : ObjectRef(other.data_) {} - - template && - details::type_contains_v>> - Map(Map&& other) : ObjectRef(std::move(other.data_)) {} - - template && - details::type_contains_v>> - Map(const Map& other) : ObjectRef(other.data_) {} - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - - template && - details::type_contains_v>> - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - - template && - details::type_contains_v>> - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapObj::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K& key) const { - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(GetMapObj()->at(key)); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K& key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapObj* n = GetMapObj(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K& key) const { - MapObj* n = GetMapObj(); - return n == nullptr ? 0 : GetMapObj()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! \brief Release reference to all the elements */ - void clear() { - MapObj* n = GetMapObj(); - if (n != nullptr) { - data_ = MapObj::Empty(); - } - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K& key, const V& value) { - CopyOnWrite(); - MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapObj()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapObj()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); } - /*! \return The value associated with the key, std::nullopt if not found */ - std::optional Get(const K& key) const { - MapObj::iterator iter = GetMapObj()->find(key); - if (iter == GetMapObj()->end()) { - return std::nullopt; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(iter->second); - } - void erase(const K& key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which guarantees to be unique) - */ - MapObj* CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapObj::Empty(); - } else if (!data_.unique()) { - data_ = MapObj::CopyFrom(GetMapObj()); - } - return GetMapObj(); - } - /*! \brief specify container node */ - using ContainerType = MapObj; - - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto& kv = *itr; - return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.first), - details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.second)); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++() { - ++itr; - return *this; - } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--() { - --itr; - return *this; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - iterator copy = *this; - --(*this); - return copy; - } - - private: - iterator(const MapObj::iterator& itr) // NOLINT(*) - : itr(itr) {} - - template - friend class Map; - - MapObj::iterator itr; - }; - - private: - /*! \brief Return data_ as type of pointer of MapObj */ - MapObj* GetMapObj() const { return static_cast(data_.get()); } - - template - friend class Map; -}; - -/*! - * \brief Merge two Maps. - * \param lhs the first Map to merge. - * \param rhs the second Map to merge. - * @return The merged Array. Original Maps are kept unchanged. - */ -template && - details::storage_enabled_v>> -inline Map Merge(Map lhs, const Map& rhs) { - for (const auto& p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); -} - -// Traits for Map -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj* n = reinterpret_cast(src->v_obj); - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first) && - !kv.first.try_cast().has_value()) { - return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + - ", V]"; - } - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second) && - !kv.second.try_cast().has_value()) { - return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + - "]"; - } - } - } - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) return false; - if constexpr (std::is_same_v && std::is_same_v) { - return true; - } else { - const MapObj* n = reinterpret_cast(src->v_obj); - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; - } - } - return true; - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt; - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj* n = reinterpret_cast(src->v_obj); - bool storage_check = [&]() { - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; - } - } - return true; - }(); - // fast path, if storage check passes, we can return the array directly. - if (storage_check) return CopyFromAnyViewAfterCheck(src); - // slow path, we need to create a new map and convert to the target type. - Map ret; - for (const auto& kv : *n) { - auto k = kv.first.try_cast(); - auto v = kv.second.try_cast(); - if (!k.has_value() || !v.has_value()) return std::nullopt; - ret.Set(*std::move(k), *std::move(v)); - } - return ret; - } else { - return CopyFromAnyViewAfterCheck(src); - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "Map<" + details::Type2Str::v() + ", " + details::Type2Str::v() + ">"; - } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Map> = - type_contains_v && type_contains_v; -} // namespace details - -} // namespace ffi - -// Expose to the tvm namespace -// Rationale: convinience and no ambiguity -using ffi::Map; -} // namespace tvm -#endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h deleted file mode 100644 index 6acdbc3a2692..000000000000 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ /dev/null @@ -1,337 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/ndarray.h - * \brief Container to store an NDArray. - */ -#ifndef TVM_FFI_CONTAINER_NDARRAY_H_ -#define TVM_FFI_CONTAINER_NDARRAY_H_ - -#include -#include -#include -#include - -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief check if a DLTensor is contiguous. - * \param arr The input DLTensor. - * \return The check result. - */ -inline bool IsContiguous(const DLTensor& arr) { - if (arr.strides == nullptr) return true; - int64_t expected_stride = 1; - for (int32_t i = arr.ndim; i != 0; --i) { - int32_t k = i - 1; - if (arr.shape[k] == 1) { - // Skip stride check if shape[k] is 1, where the dimension is contiguous - // regardless of the value of stride. - // - // For example, PyTorch will normalize stride to 1 if shape is 1 when exporting - // to DLPack. - // More context: https://github.com/pytorch/pytorch/pull/83158 - continue; - } - if (arr.strides[k] != expected_stride) return false; - expected_stride *= arr.shape[k]; - } - return true; -} - -/** - * \brief Check if the data in the DLTensor is aligned to the given alignment. - * \param arr The input DLTensor. - * \param alignment The alignment to check. - * \return True if the data is aligned to the given alignment, false otherwise. - */ -inline bool IsAligned(const DLTensor& arr, size_t alignment) { - // whether the device uses direct address mapping instead of indirect buffer - bool direct_address = arr.device.device_type <= kDLCUDAHost || - arr.device.device_type == kDLCUDAManaged || - arr.device.device_type == kDLROCM || arr.device.device_type == kDLROCMHost; - if (direct_address) { - return (reinterpret_cast(static_cast(arr.data) + arr.byte_offset) % alignment == - 0); - } else { - return arr.byte_offset % alignment == 0; - } -} - -/*! - * \brief return the total number bytes needs to store packed data - * - * \param numel the number of elements in the array - * \param dtype the data type of the array - * \return the total number bytes needs to store packed data - */ -inline size_t GetDataSize(int64_t numel, DLDataType dtype) { - // compatible handling sub-byte uint1(bool), which usually stored as uint8_t - // TODO(tqchen): revisit and switch to kDLBool - if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) { - return numel; - } - // for other sub-byte types, packing is preferred - return (numel * dtype.bits * dtype.lanes + 7) / 8; -} - -/*! - * \brief return the size of data the DLTensor hold, in term of number of bytes - * - * \param arr the input DLTensor - * \return number of bytes of data in the DLTensor. - */ -inline size_t GetDataSize(const DLTensor& arr) { - size_t size = 1; - for (int i = 0; i < arr.ndim; ++i) { - size *= static_cast(arr.shape[i]); - } - return GetDataSize(size, arr.dtype); -} - -/*! \brief An object representing an NDArray. */ -class NDArrayObj : public Object, public DLTensor { - public: - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFINDArray; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFINDArray; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(NDArrayObj, Object); - - /*! - * \brief Move NDArray to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensor* ToDLPack() const { - DLManagedTensor* ret = new DLManagedTensor(); - NDArrayObj* from = const_cast(this); - ret->dl_tensor = *static_cast(from); - ret->manager_ctx = from; - ret->deleter = DLManagedTensorDeleter; - details::ObjectUnsafe::IncRefObjectHandle(from); - return ret; - } - - /*! - * \brief Move NDArray to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensorVersioned* ToDLPackVersioned() const { - DLManagedTensorVersioned* ret = new DLManagedTensorVersioned(); - NDArrayObj* from = const_cast(this); - ret->version.major = DLPACK_MAJOR_VERSION; - ret->version.minor = DLPACK_MINOR_VERSION; - ret->dl_tensor = *static_cast(from); - ret->manager_ctx = from; - ret->deleter = DLManagedTensorVersionedDeleter; - ret->flags = 0; - details::ObjectUnsafe::IncRefObjectHandle(from); - return ret; - } - - protected: - // backs up the shape of the NDArray - Optional shape_data_; - - static void DLManagedTensorDeleter(DLManagedTensor* tensor) { - NDArrayObj* obj = static_cast(tensor->manager_ctx); - details::ObjectUnsafe::DecRefObjectHandle(obj); - delete tensor; - } - - static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) { - NDArrayObj* obj = static_cast(tensor->manager_ctx); - details::ObjectUnsafe::DecRefObjectHandle(obj); - delete tensor; - } - - friend class NDArray; -}; - -namespace details { -/*! - *\brief Helper class to create an NDArrayObj from an NDAllocator - * - * The underlying allocator needs to be implemented by user. - */ -template -class NDArrayObjFromNDAlloc : public NDArrayObj { - public: - template - NDArrayObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, - ExtraArgs&&... extra_args) - : alloc_(alloc) { - this->device = device; - this->ndim = static_cast(shape.size()); - this->dtype = dtype; - this->shape = const_cast(shape.data()); - this->strides = nullptr; - this->byte_offset = 0; - this->shape_data_ = std::move(shape); - alloc_.AllocData(static_cast(this), std::forward(extra_args)...); - } - - ~NDArrayObjFromNDAlloc() { alloc_.FreeData(static_cast(this)); } - - private: - TNDAlloc alloc_; -}; - -/*! \brief helper class to import from DLPack legacy DLManagedTensor */ -template -class NDArrayObjFromDLPack : public NDArrayObj { - public: - explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { - *static_cast(this) = tensor_->dl_tensor; - // set strides to nullptr if the tensor is contiguous. - if (IsContiguous(tensor->dl_tensor)) { - this->strides = nullptr; - } - } - - ~NDArrayObjFromDLPack() { - // run DLPack deleter if needed. - if (tensor_->deleter != nullptr) { - (*tensor_->deleter)(tensor_); - } - } - - private: - TDLPackManagedTensor* tensor_; -}; -} // namespace details - -/*! - * \brief Managed NDArray. - * The array is backed by reference counted blocks. - * - * \note This class can be subclassed to implement downstream customized - * NDArray types that are backed by the same NDArrayObj storage type. - */ -class NDArray : public ObjectRef { - public: - /*! - * \brief Get the shape of the NDArray. - * \return The shape of the NDArray. - */ - tvm::ffi::Shape shape() const { - NDArrayObj* obj = get_mutable(); - if (!obj->shape_data_.has_value()) { - obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim); - } - return *(obj->shape_data_); - } - /*! - * \brief Get the data type of the NDArray. - * \return The data type of the NDArray. - */ - DLDataType dtype() const { return (*this)->dtype; } - /*! - * \brief Check if the NDArray is contiguous. - * \return True if the NDArray is contiguous, false otherwise. - */ - bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } - /*! - * \brief Create a NDArray from a NDAllocator. - * \param alloc The NDAllocator. - * \param shape The shape of the NDArray. - * \param dtype The data type of the NDArray. - * \param device The device of the NDArray. - * \return The created NDArray. - * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free. - * \tparam ExtraArgs Extra arguments to be passed to Alloc. - */ - template - static NDArray FromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, - ExtraArgs&&... extra_args) { - return NDArray(make_object>( - alloc, shape, dtype, device, std::forward(extra_args)...)); - } - - /*! - * \brief Create a NDArray from a DLPack managed tensor, pre v1.0 API. - * \param tensor The input DLPack managed tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \note This function will not run any checks on flags. - * \return The created NDArray. - */ - static NDArray FromDLPack(DLManagedTensor* tensor, size_t require_alignment = 0, - bool require_contiguous = false) { - if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment - << " bytes."; - } - if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; - } - return NDArray(make_object>(tensor)); - } - - /*! - * \brief Create a NDArray from a DLPack managed tensor, post v1.0 API. - * \param tensor The input DLPack managed tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \return The created NDArray. - */ - static NDArray FromDLPackVersioned(DLManagedTensorVersioned* tensor, size_t require_alignment = 0, - bool require_contiguous = false) { - if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment - << " bytes."; - } - if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; - } - if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) { - TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet supported"; - } - return NDArray(make_object>(tensor)); - } - - /*! - * \brief Convert the NDArray to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensor* ToDLPack() const { return get_mutable()->ToDLPack(); } - - /*! - * \brief Convert the NDArray to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensorVersioned* ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS(NDArray, ObjectRef, NDArrayObj); - - protected: - /*! - * \brief Get mutable internal container pointer. - * \return a mutable container pointer. - */ - NDArrayObj* get_mutable() const { return const_cast(get()); } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_CONTAINER_NDARRAY_H_ diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h deleted file mode 100644 index 2fccc028a5b3..000000000000 --- a/ffi/include/tvm/ffi/container/shape.h +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/shape.h - * \brief Container to store shape of an NDArray. - */ -#ifndef TVM_FFI_CONTAINER_SHAPE_H_ -#define TVM_FFI_CONTAINER_SHAPE_H_ - -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! \brief An object representing a shape tuple. */ -class ShapeObj : public Object, public TVMFFIShapeCell { - public: - using index_type = int64_t; - - /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ - int64_t Product() const { - int64_t product = 1; - for (size_t i = 0; i < this->size; ++i) { - product *= this->data[i]; - } - return product; - } - - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIShape; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ShapeObj, Object); -}; - -namespace details { - -class ShapeObjStdImpl : public ShapeObj { - public: - explicit ShapeObjStdImpl(std::vector other) : data_{other} { - this->data = data_.data(); - this->size = static_cast(data_.size()); - } - - private: - std::vector data_; -}; - -TVM_FFI_INLINE ObjectPtr MakeEmptyShape(size_t length, int64_t** mutable_data) { - ObjectPtr p = make_inplace_array_object(length); - static_assert(alignof(ShapeObj) % alignof(int64_t) == 0); - static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0); - int64_t* data = reinterpret_cast(reinterpret_cast(p.get()) + sizeof(ShapeObj)); - if (mutable_data) { - *mutable_data = data; - } - p->data = data; - p->size = length; - return p; -} - -// inplace shape allocation -template -TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end) { - size_t length = std::distance(begin, end); - int64_t* mutable_data; - ObjectPtr p = MakeEmptyShape(length, &mutable_data); - std::copy(begin, end, mutable_data); - return p; -} - -} // namespace details - -/*! - * \brief Reference to shape object. - */ -class Shape : public ObjectRef { - public: - /*! \brief The type of shape index element. */ - using index_type = ShapeObj::index_type; - - /*! \brief Default constructor */ - Shape() : ObjectRef(details::MakeEmptyShape(0, nullptr)) {} - - /*! - * \brief Constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {} - - /** - * \brief Constructor from Array - * \param shape The Array - * - * \note This constructor will copy the data content. - */ - Shape(Array shape) // NOLINT(*) - : Shape(shape.begin(), shape.end()) {} - - /*! - * \brief constructor from initializer list - * \param shape The initializer list - */ - Shape(std::initializer_list shape) : Shape(shape.begin(), shape.end()) {} - - /*! - * \brief constructor from int64_t [N] - * - * \param other a int64_t array. - */ - Shape(std::vector other) // NOLINT(*) - : ObjectRef(make_object(std::move(other))) {} - - /*! - * \brief Return the data pointer - * - * \return const index_type* data pointer - */ - const int64_t* data() const { return get()->data; } - - /*! - * \brief Return the size of the shape tuple - * - * \return size_t shape tuple size - */ - size_t size() const { return get()->size; } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - int64_t operator[](size_t idx) const { - if (idx >= this->size()) { - TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); - } - return this->data()[idx]; - } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - int64_t at(size_t idx) const { return this->operator[](idx); } - - /*! \return Whether shape tuple is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the shape tuple */ - int64_t front() const { return this->at(0); } - - /*! \return The last element of the shape tuple */ - int64_t back() const { return this->at(this->size() - 1); } - - /*! \return begin iterator */ - const int64_t* begin() const { return get()->data; } - - /*! \return end iterator */ - const int64_t* end() const { return (get()->data + size()); } - - /*! \return The product of the shape tuple */ - int64_t Product() const { return get()->Product(); } - - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj); -}; - -inline std::ostream& operator<<(std::ostream& os, const Shape& shape) { - os << '['; - for (size_t i = 0; i < shape.size(); ++i) { - if (i != 0) { - os << ", "; - } - os << shape[i]; - } - os << ']'; - return os; -} - -// Shape -template <> -inline constexpr bool use_default_type_traits_v = false; - -// Allow auto conversion from Array to Shape, but not from Shape to Array -template <> -struct TypeTraits : public ObjectRefWithFallbackTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape; - TVM_FFI_INLINE static Shape ConvertFallbackValue(Array src) { return Shape(src); } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_CONTAINER_SHAPE_H_ diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h deleted file mode 100644 index 332f78a2fe78..000000000000 --- a/ffi/include/tvm/ffi/container/tuple.h +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/tuple.h - * \brief Typed tuple like std::tuple backed by ArrayObj container. - */ -#ifndef TVM_FFI_CONTAINER_TUPLE_H_ -#define TVM_FFI_CONTAINER_TUPLE_H_ - -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Typed tuple like std::tuple backed by ArrayObj container. - * - * Tuple implements in-place copy-on-write semantics. - * - * \tparam Types The types of the tuple elements - */ -template -class Tuple : public ObjectRef { - public: - static_assert(details::all_storage_enabled_v, - "All types used in Tuple<...> must be compatible with Any"); - - Tuple() : ObjectRef(MakeDefaultTupleNode()) {} - Tuple(const Tuple& other) : ObjectRef(other) {} - Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} - template && ...), int>> - Tuple(const Tuple& other) : ObjectRef(other) {} - template && ...), int>> - Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} - - template , Tuple> && - ...))>> - explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} - - TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { - data_ = std::move(other.data_); - return *this; - } - - template && ...)>> - TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { - data_ = other.data_; - return *this; - } - - template && ...)>> - TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { - data_ = std::move(other.data_); - return *this; - } - - explicit Tuple(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief Get I-th element of the tuple - * - * \tparam I The index of the element to get - * \return The I-th element of the tuple - * \note We use stl style since get usually is like a getter. - */ - template - auto get() const { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using ReturnType = std::tuple_element_t>; - const Any* ptr = GetArrayObj()->begin() + I; - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*ptr); - } - - /*! - * \brief Set I-th element of the tuple - * - * \param item The item to set - * \tparam I The index of the element to set - * \tparam U The type of the item - * - * \note This function will perform copy on write if underlying - * container is not uniquely owned. - * We use CamelCase since Set can cause copy on write - * and is more complicated than simple field setter. - */ - template - void Set(U&& item) { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using T = std::tuple_element_t>; - this->CopyIfNotUnique(); - Any* ptr = GetArrayObj()->MutableBegin() + I; - *ptr = T(std::forward(item)); - } - - /*! \brief specify container node */ - using ContainerType = ArrayObj; - - private: - static ObjectPtr MakeDefaultTupleNode() { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types()), p->size_++), ...); - return p; - } - - template - static ObjectPtr MakeTupleNode(UTypes&&... args) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types(std::forward(args))), p->size_++), ...); - return p; - } - - /*! \brief Copy on write */ - void CopyIfNotUnique() { - if (!data_.unique()) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - const Any* read = GetArrayObj()->begin(); - // increase size after each new to ensure exception safety - for (size_t i = 0; i < sizeof...(Types); ++i) { - new (itr++) Any(*read++); - p->size_++; - } - data_ = std::move(p); - } - } - - /*! \return The underlying ArrayObj */ - ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } - - template - friend class Tuple; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) { - return "Array[size=" + std::to_string(n->size()) + "]"; - } - return GetMismatchTypeInfoHelper<0, Types...>(n->begin()); - } - - template - TVM_FFI_INLINE static std::string GetMismatchTypeInfoHelper(const Any* arr) { - if constexpr (!std::is_same_v) { - const Any& any_v = arr[I]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v) && !(any_v.try_cast().has_value())) { - // now report the accurate mismatch information - return "Array[index " + std::to_string(I) + ": " + - details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; - } - } - if constexpr (sizeof...(Rest) > 0) { - return GetMismatchTypeInfoHelper(arr); - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) return false; - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) return false; - const TVMFFIAny* ffi_any_arr = reinterpret_cast(n->begin()); - return CheckAnyStrictHelper<0, Types...>(ffi_any_arr); - } - - template - TVM_FFI_INLINE static bool CheckAnyStrictHelper(const TVMFFIAny* src_arr) { - if constexpr (!std::is_same_v) { - if (!TypeTraits::CheckAnyStrict(src_arr + I)) { - return false; - } - } - if constexpr (sizeof...(Rest) > 0) { - return CheckAnyStrictHelper(src_arr); - } - return true; - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src // - ) { - if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) return std::nullopt; - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); - } - // slow path, try to convert to each type to match the tuple storage need. - Array arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); - Any* ptr = arr.CopyOnWrite()->MutableBegin(); - if (TryConvertElements<0, Types...>(ptr)) { - return Tuple(details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); - } - return std::nullopt; - } - - template - TVM_FFI_INLINE static bool TryConvertElements(Any* arr) { - if constexpr (!std::is_same_v) { - if (auto opt_convert = arr[I].try_cast()) { - arr[I] = *std::move(opt_convert); - } else { - return false; - } - } - if constexpr (sizeof...(Rest) > 0) { - return TryConvertElements(std::move(arr)); - } else { - return true; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return details::ContainerTypeStr("Tuple"); - } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Tuple> = (type_contains_v && ...); -} // namespace details - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h deleted file mode 100644 index ee1f8316d80c..000000000000 --- a/ffi/include/tvm/ffi/container/variant.h +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/variant.h - * \brief Runtime variant container types. - */ -#ifndef TVM_FFI_CONTAINER_VARIANT_H_ -#define TVM_FFI_CONTAINER_VARIANT_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base class for Variant. - * - * \tparam all_storage_object Whether all types are derived from ObjectRef. - */ -template -class VariantBase { - public: - TVM_FFI_INLINE bool same_as(const VariantBase& other) const { - return data_.same_as(other.data_); - } - - protected: - template - explicit VariantBase(T other) : data_(std::move(other)) {} - - TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } - - TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } - - TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } - - Any data_; -}; - -// Specialization for all object ref case, backed by ObjectRef. -template <> -class VariantBase : public ObjectRef { - protected: - template - explicit VariantBase(const T& other) : ObjectRef(other) {} - template - explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {} - explicit VariantBase(ObjectPtr ptr) : ObjectRef(ptr) {} - explicit VariantBase(Any other) - : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} - - TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } - - TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } - - TVM_FFI_INLINE AnyView ToAnyView() const { - TVMFFIAny any_data; - if (data_ == nullptr) { - any_data.type_index = TypeIndex::kTVMFFINone; - any_data.zero_padding = 0; - any_data.v_int64 = 0; - } else { - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); - any_data.type_index = data_->type_index(); - any_data.zero_padding = 0; - any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); - } - return AnyView::CopyFromTVMFFIAny(any_data); - } -}; -} // namespace details - -/*! - * \brief A typed variant container. - * - * When all values are ObjectRef, Variant is backed by ObjectRef, - * otherwise it is backed by Any. - */ -template -class Variant : public details::VariantBase> { - public: - using TParent = details::VariantBase>; - static_assert(details::all_storage_enabled_v, - "All types used in Variant<...> must be compatible with Any"); - /* - * \brief Helper utility to check if the type can be contained in the variant - */ - template - static constexpr bool variant_contains_v = (details::type_contains_v || ...); - /* \brief Helper utility for SFINAE if the type is part of the variant */ - template - using enable_if_variant_contains_t = std::enable_if_t>; - - Variant(const Variant& other) : TParent(other.data_) {} - Variant(Variant&& other) : TParent(std::move(other.data_)) {} - - TVM_FFI_INLINE Variant& operator=(const Variant& other) { - this->SetData(other.data_); - return *this; - } - - TVM_FFI_INLINE Variant& operator=(Variant&& other) { - this->SetData(std::move(other.data_)); - return *this; - } - - template > - Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) - - template > - TVM_FFI_INLINE Variant& operator=(T other) { - return operator=(Variant(std::move(other))); - } - - template > - TVM_FFI_INLINE std::optional as() const { - return this->TParent::ToAnyView().template as(); - } - - /* - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T* as() const { - return this->TParent::ToAnyView().template as().value_or(nullptr); - } - - template > - TVM_FFI_INLINE T get() const& { - return this->TParent::ToAnyView().template cast(); - } - - template > - TVM_FFI_INLINE T get() && { - return std::move(*this).TParent::MoveToAny().template cast(); - } - - TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } - - private: - friend struct TypeTraits>; - friend struct ObjectPtrHash; - friend struct ObjectPtrEqual; - // constructor from any - explicit Variant(Any data) : TParent(std::move(data)) {} - /*! - * \brief Get the object pointer from the variant - * \note This function is only available if all types used in Variant<...> are derived from - * ObjectRef - */ - TVM_FFI_INLINE Object* GetObjectPtrForHashEqual() const { - constexpr bool all_object_v = (std::is_base_of_v && ...); - static_assert(all_object_v, - "All types used in Variant<...> must be derived from ObjectRef " - "to enable ObjectPtrHash/ObjectPtrEqual"); - return this->data_.get(); - } - // rexpose to friend class - using TParent::MoveToAny; - using TParent::ToAnyView; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const Variant& src, TVMFFIAny* result) { - *result = src.ToAnyView().CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(Variant src, TVMFFIAny* result) { - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return (TypeTraits::CheckAnyStrict(src) || ...); - } - - TVM_FFI_INLINE static Variant CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); - } - - TVM_FFI_INLINE static Variant MoveFromAnyAfterCheck(TVMFFIAny* src) { - return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src))); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); - } - // More expensive path, try to convert to each type, in order of declaration - return TryVariantTypes(src); - } - - template - TVM_FFI_INLINE static std::optional> TryVariantTypes(const TVMFFIAny* src) { - if (auto opt_convert = TypeTraits::TryCastFromAnyView(src)) { - return Variant(*std::move(opt_convert)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryVariantTypes(src); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return details::ContainerTypeStr("Variant"); } -}; - -template -TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant& a) const { - return std::hash()(a.GetObjectPtrForHashEqual()); -} - -template -TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant& a, - const Variant& b) const { - return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual(); -} - -namespace details { -template -inline constexpr bool type_contains_v, T> = (type_contains_v || ...); -} // namespace details -} // namespace ffi - -// Expose to the tvm namespace -// Rationale: convinience and no ambiguity -using ffi::Variant; -} // namespace tvm -#endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h deleted file mode 100644 index c153d71cb70a..000000000000 --- a/ffi/include/tvm/ffi/dtype.h +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/dtype.h - * \brief Data type handling. - */ -#ifndef TVM_FFI_DTYPE_H_ -#define TVM_FFI_DTYPE_H_ - -#include -#include -#include -#include -#include - -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Extension code beyond the DLDataType. - * - * This class is always consistent with the DLPack. - * - * TOTO(tvm-team): update to latest DLPack types. - */ -enum DLExtDataTypeCode { kDLExtCustomBegin = 129 }; - -namespace details { - -/* - * \brief Convert a DLDataTypeCode to a string. - * \param os The output stream. - * \param type_code The DLDataTypeCode to convert. - */ -inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(*) - switch (static_cast(type_code)) { - case kDLInt: { - return "int"; - } - case kDLUInt: { - return "uint"; - } - case kDLFloat: { - return "float"; - } - case kDLOpaqueHandle: { - return "handle"; - } - case kDLBfloat: { - return "bfloat"; - } - case kDLFloat8_e3m4: { - return "float8_e3m4"; - } - case kDLFloat8_e4m3: { - return "float8_e4m3"; - } - case kDLFloat8_e4m3b11fnuz: { - return "float8_e4m3b11fnuz"; - } - case kDLFloat8_e4m3fn: { - return "float8_e4m3fn"; - } - case kDLFloat8_e4m3fnuz: { - return "float8_e4m3fnuz"; - } - case kDLFloat8_e5m2: { - return "float8_e5m2"; - } - case kDLFloat8_e5m2fnuz: { - return "float8_e5m2fnuz"; - } - case kDLFloat8_e8m0fnu: { - return "float8_e8m0fnu"; - } - case kDLFloat6_e2m3fn: { - return "float6_e2m3fn"; - } - case kDLFloat6_e3m2fn: { - return "float6_e3m2fn"; - } - case kDLFloat4_e2m1fn: { - return "float4_e2m1fn"; - } - default: { - if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { - return "custom"; - } else { - TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" - << static_cast(type_code); - } - TVM_FFI_UNREACHABLE(); - } - } -} -} // namespace details - -inline DLDataType StringToDLDataType(const String& str) { - DLDataType out; - TVMFFIByteArray data{str.data(), str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out)); - return out; -} - -inline String DLDataTypeToString(DLDataType dtype) { - TVMFFIAny out; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out)); - return TypeTraits::MoveFromAnyAfterCheck(&out); -} - -// DLDataType -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; - - TVM_FFI_INLINE static void CopyToAnyView(const DLDataType& src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->zero_padding = 0; - result->v_dtype = src; - } - - TVM_FFI_INLINE static void MoveToAny(DLDataType src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->zero_padding = 0; - result->v_dtype = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIDataType; - } - - TVM_FFI_INLINE static DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return src->v_dtype; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIDataType) { - return src->v_dtype; - } - // enable string to dtype auto conversion - if (auto opt_str = TypeTraits::TryCastFromAnyView(src)) { - return StringToDLDataType(*opt_str); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } -}; -} // namespace ffi -} // namespace tvm - -// define DLDataType comparison and printing in root namespace -inline std::ostream& operator<<(std::ostream& os, DLDataType dtype) { // NOLINT(*) - return os << tvm::ffi::DLDataTypeToString(dtype); -} - -inline bool operator==(const DLDataType& lhs, const DLDataType& rhs) { - return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; -} - -inline bool operator!=(const DLDataType& lhs, const DLDataType& rhs) { return !(lhs == rhs); } -#endif // TVM_FFI_DTYPE_H_ diff --git a/ffi/include/tvm/ffi/endian.h b/ffi/include/tvm/ffi/endian.h deleted file mode 100644 index 4a73b82e6c30..000000000000 --- a/ffi/include/tvm/ffi/endian.h +++ /dev/null @@ -1,89 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/endian.h - * \brief Endian detection and handling - */ -#ifndef TVM_FFI_ENDIAN_H_ -#define TVM_FFI_ENDIAN_H_ - -#include -#include - -#ifndef TVM_FFI_IO_USE_LITTLE_ENDIAN -#define TVM_FFI_IO_USE_LITTLE_ENDIAN 1 -#endif - -#ifdef TVM_FFI_CMAKE_LITTLE_ENDIAN -// If compiled with CMake, use CMake's endian detection logic -#define TVM_FFI_LITTLE_ENDIAN TVM_FFI_CMAKE_LITTLE_ENDIAN -#else -#if defined(__APPLE__) || defined(_WIN32) -#define TVM_FFI_LITTLE_ENDIAN 1 -#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || defined(__RISCV__) -#include -#define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) -#elif defined(__FreeBSD__) || defined(__OpenBSD__) -#include -#define TVM_FFI_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN) -#elif defined(__QNX__) -#include -#define TVM_FFI_LITTLE_ENDIAN (BYTE_ORDER == LITTLE_ENDIAN) -#elif defined(__EMSCRIPTEN__) || defined(__hexagon__) -#define TVM_FFI_LITTLE_ENDIAN 1 -#elif defined(__sun) || defined(sun) -#include -#if defined(_LITTLE_ENDIAN) -#define TVM_FFI_LITTLE_ENDIAN 1 -#else -#define TVM_FFI_LITTLE_ENDIAN 0 -#endif -#else -#error "Unable to determine endianness of your machine; use CMake to compile" -#endif -#endif - -/*! \brief whether serialize using little endian */ -#define TVM_FFI_IO_NO_ENDIAN_SWAP (TVM_FFI_LITTLE_ENDIAN == TVM_FFI_IO_USE_LITTLE_ENDIAN) - -namespace tvm { -namespace ffi { -/*! - * \brief A generic inplace byte swapping function. - * \param data The data pointer. - * \param elem_bytes The number of bytes of the data elements - * \param num_elems Number of elements in the data. - * \note Always try pass in constant elem_bytes to enable - * compiler optimization - */ -inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { - for (size_t i = 0; i < num_elems; ++i) { - uint8_t* bptr = reinterpret_cast(data) + elem_bytes * i; - for (size_t j = 0; j < elem_bytes / 2; ++j) { - uint8_t v = bptr[elem_bytes - 1 - j]; - bptr[elem_bytes - 1 - j] = bptr[j]; - bptr[j] = v; - } - } -} -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_ENDIAN_H_ diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h deleted file mode 100644 index 77a7fe9c2e68..000000000000 --- a/ffi/include/tvm/ffi/error.h +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/error.h - * \brief Error handling component. - */ -#ifndef TVM_FFI_ERROR_H_ -#define TVM_FFI_ERROR_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -/*! - * \brief Macro defines whether we enable libbacktrace - */ -#ifndef TVM_FFI_USE_LIBBACKTRACE -#define TVM_FFI_USE_LIBBACKTRACE 1 -#endif - -/*! - * \brief Macro defines whether to install signal handler - * and print backtrace during segfault - */ -#ifndef TVM_FFI_BACKTRACE_ON_SEGFAULT -#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1 -#endif - -#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW -#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0 -#endif - -namespace tvm { -namespace ffi { - -/*! - * \brief Error already set in frontend env. - * - * This error can be thrown by EnvCheckSignals to indicate - * that there is an error set in the frontend environment(e.g. - * python interpreter). The TVM FFI should catch this error - * and return a proper code tell the frontend caller about - * this fact. - * - * \code - * - * void ExampleLongRunningFunction() { - * if (TVMFFIEnvCheckSignals() != 0) { - * throw ::tvm::ffi::EnvErrorAlreadySet(); - * } - * // do work here - * } - * - * \endcode - */ -struct EnvErrorAlreadySet : public std::exception {}; - -/*! - * \brief Error object class. - */ -class ErrorObj : public Object, public TVMFFIErrorCell { - public: - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError; - static constexpr const char* _type_key = "ffi.Error"; - - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ErrorObj, Object); -}; - -namespace details { -class ErrorObjFromStd : public ErrorObj { - public: - ErrorObjFromStd(std::string kind, std::string message, std::string traceback) - : kind_data_(kind), message_data_(message), traceback_data_(traceback) { - this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()}; - this->message = TVMFFIByteArray{message_data_.data(), message_data_.length()}; - this->traceback = TVMFFIByteArray{traceback_data_.data(), traceback_data_.length()}; - this->update_traceback = UpdateTraceback; - } - - private: - /*! - * \brief Update the traceback of the error object. - * \param traceback The traceback to update. - */ - static void UpdateTraceback(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback_str) { - ErrorObjFromStd* obj = static_cast(self); - obj->traceback_data_ = std::string(traceback_str->data, traceback_str->size); - obj->traceback = TVMFFIByteArray{obj->traceback_data_.data(), obj->traceback_data_.length()}; - } - - std::string kind_data_; - std::string message_data_; - std::string traceback_data_; -}; -} // namespace details - -/*! - * \brief Managed reference to ErrorObj - * \sa Error Object - */ -class Error : public ObjectRef, public std::exception { - public: - Error(std::string kind, std::string message, std::string traceback) { - data_ = make_object(kind, message, traceback); - } - - Error(std::string kind, std::string message, const TVMFFIByteArray* traceback) - : Error(kind, message, std::string(traceback->data, traceback->size)) {} - - std::string kind() const { - ErrorObj* obj = static_cast(data_.get()); - return std::string(obj->kind.data, obj->kind.size); - } - - std::string message() const { - ErrorObj* obj = static_cast(data_.get()); - return std::string(obj->message.data, obj->message.size); - } - - std::string traceback() const { - ErrorObj* obj = static_cast(data_.get()); - return std::string(obj->traceback.data, obj->traceback.size); - } - - void UpdateTraceback(const TVMFFIByteArray* traceback_str) { - ErrorObj* obj = static_cast(data_.get()); - obj->update_traceback(obj, traceback_str); - } - - const char* what() const noexcept(true) override { - thread_local std::string what_data; - ErrorObj* obj = static_cast(data_.get()); - what_data = (std::string("Traceback (most recent call last):\n") + - std::string(obj->traceback.data, obj->traceback.size) + - std::string(obj->kind.data, obj->kind.size) + std::string(": ") + - std::string(obj->message.data, obj->message.size) + '\n'); - return what_data.c_str(); - } - - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj); -}; - -namespace details { - -class ErrorBuilder { - public: - explicit ErrorBuilder(std::string kind, std::string traceback, bool log_before_throw) - : kind_(kind), traceback_(traceback), log_before_throw_(log_before_throw) {} - - explicit ErrorBuilder(std::string kind, const TVMFFIByteArray* traceback, bool log_before_throw) - : ErrorBuilder(kind, std::string(traceback->data, traceback->size), log_before_throw) {} - -// MSVC disable warning in error builder as it is exepected -#ifdef _MSC_VER -#pragma disagnostic push -#pragma warning(disable : 4722) -#endif - // avoid inline to reduce binary size, error throw path do not need to be fast - [[noreturn]] ~ErrorBuilder() noexcept(false) { - ::tvm::ffi::Error error(std::move(kind_), stream_.str(), std::move(traceback_)); - if (log_before_throw_) { - std::cerr << error.what(); - } - throw error; - } -#ifdef _MSC_VER -#pragma disagnostic pop -#endif - - std::ostringstream& stream() { return stream_; } - - protected: - std::string kind_; - std::ostringstream stream_; - std::string traceback_; - bool log_before_throw_; -}; - -// define traceback here as call into traceback function -#define TVM_FFI_TRACEBACK_HERE TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG) -} // namespace details - -/*! - * \brief Helper macro to throw an error with traceback and message - * - * \code - * - * void ThrowError() { - * TVM_FFI_THROW(RuntimeError) << "error message"; - * } - * - * \endcode - */ -#define TVM_FFI_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, \ - TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \ - .stream() - -/*! - * \brief Explicitly log error in stderr and then throw the error. - * - * \note This is only necessary on startup functions where we know error - * cannot be caught, and it is better to have a clear log message. - * In most cases, we should use use TVM_FFI_THROW. - */ -#define TVM_FFI_LOG_AND_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, true).stream() - -// Glog style checks with TVM_FFI prefix -// NOTE: we explicitly avoid glog style generic macros (LOG/CHECK) in tvm ffi -// to avoid potential conflict of downstream users who might have their own GLOG style macros -namespace details { - -template -TVM_FFI_INLINE std::unique_ptr LogCheckFormat(const X& x, const Y& y) { - std::ostringstream os; - os << " (" << x << " vs. " << y << ") "; // CHECK_XX(x, y) requires x and y can be serialized to - // string. Use CHECK(x OP y) otherwise. - return std::make_unique(os.str()); -} - -#define TVM_FFI_CHECK_FUNC(name, op) \ - template \ - TVM_FFI_INLINE std::unique_ptr LogCheck##name(const X& x, const Y& y) { \ - if (x op y) return nullptr; \ - return LogCheckFormat(x, y); \ - } \ - TVM_FFI_INLINE std::unique_ptr LogCheck##name(int x, int y) { \ - return LogCheck##name(x, y); \ - } - -// Inline _Pragma in macros does not work reliably on old version of MSVC and -// GCC. We wrap all comparisons in a function so that we can use #pragma to -// silence bad comparison warnings. -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsign-compare" -#elif defined(_MSC_VER) // MSVC -#pragma warning(push) -#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch -#endif - -TVM_FFI_CHECK_FUNC(_LT, <) -TVM_FFI_CHECK_FUNC(_GT, >) -TVM_FFI_CHECK_FUNC(_LE, <=) -TVM_FFI_CHECK_FUNC(_GE, >=) -TVM_FFI_CHECK_FUNC(_EQ, ==) -TVM_FFI_CHECK_FUNC(_NE, !=) - -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) // MSVC -#pragma warning(pop) -#endif -} // namespace details - -#define TVM_FFI_ICHECK_BINARY_OP(name, op, x, y) \ - if (auto __tvm__log__err = ::tvm::ffi::details::LogCheck##name(x, y)) \ - TVM_FFI_THROW(InternalError) << "Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": " - -#define TVM_FFI_ICHECK(x) \ - if (!(x)) TVM_FFI_THROW(InternalError) << "Check failed: (" #x << ") is false: " - -#define TVM_FFI_ICHECK_LT(x, y) TVM_FFI_ICHECK_BINARY_OP(_LT, <, x, y) -#define TVM_FFI_ICHECK_GT(x, y) TVM_FFI_ICHECK_BINARY_OP(_GT, >, x, y) -#define TVM_FFI_ICHECK_LE(x, y) TVM_FFI_ICHECK_BINARY_OP(_LE, <=, x, y) -#define TVM_FFI_ICHECK_GE(x, y) TVM_FFI_ICHECK_BINARY_OP(_GE, >=, x, y) -#define TVM_FFI_ICHECK_EQ(x, y) TVM_FFI_ICHECK_BINARY_OP(_EQ, ==, x, y) -#define TVM_FFI_ICHECK_NE(x, y) TVM_FFI_ICHECK_BINARY_OP(_NE, !=, x, y) -#define TVM_FFI_ICHECK_NOTNULL(x) \ - ((x) == nullptr ? TVM_FFI_THROW(InternalError) << "Check not null: " #x << ' ', \ - (x) : (x)) // NOLINT(*) -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_ERROR_H_ diff --git a/ffi/include/tvm/ffi/extra/base.h b/ffi/include/tvm/ffi/extra/base.h deleted file mode 100644 index b09b3540a83e..000000000000 --- a/ffi/include/tvm/ffi/extra/base.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/base.h - * \brief Base header for Extra API. - * - * The extra APIs contains a minmal set of extra APIs that are not - * required to support essential core functionality. - */ -#ifndef TVM_FFI_EXTRA_BASE_H_ -#define TVM_FFI_EXTRA_BASE_H_ - -#include - -/*! - * \brief Marks the API as extra c++ api that is defined in cc files. - * - * They are implemented in cc files to reduce compile-time overhead. - * The input/output only uses POD/Any/ObjectRef for ABI stability. - * However, these extra APIs may have an issue across MSVC/Itanium ABI, - * - * Related features are also available through reflection based function - * that is fully based on C API - * - * The project aims to minimize the number of extra C++ APIs to keep things - * lightweight and restrict the use to non-core functionalities. - */ -#ifndef TVM_FFI_EXTRA_CXX_API -#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL -#endif - -#endif // TVM_FFI_EXTRA_BASE_H_ diff --git a/ffi/include/tvm/ffi/extra/base64.h b/ffi/include/tvm/ffi/extra/base64.h deleted file mode 100644 index 136fec2e7f84..000000000000 --- a/ffi/include/tvm/ffi/extra/base64.h +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * - * \file tvm/ffi/extra/base64.h - * \brief Base64 encoding and decoding utilities - */ -#ifndef TVM_FFI_EXTRA_BASE64_H_ -#define TVM_FFI_EXTRA_BASE64_H_ - -#include - -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Encode a byte array into a base64 string - * \param bytes The byte array to encode - * \return The base64 encoded string - */ -inline String Base64Encode(TVMFFIByteArray bytes) { - // encoding every 3 bytes into 4 characters - constexpr const char kEncodeTable[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string encoded; - encoded.reserve(4 * (bytes.size + 2) / 3); - - for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { - int32_t buf[3]; - buf[0] = static_cast(bytes.data[i]); - buf[1] = static_cast(bytes.data[i + 1]); - buf[2] = static_cast(bytes.data[i + 2]); - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); - encoded.push_back(kEncodeTable[buf[2] & 0x3F]); - } - if (bytes.size % 3 == 1) { - int32_t buf[1] = {static_cast(bytes.data[bytes.size - 1])}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); - encoded.push_back('='); - encoded.push_back('='); - } else if (bytes.size % 3 == 2) { - int32_t buf[2] = {static_cast(bytes.data[bytes.size - 2]), - static_cast(bytes.data[bytes.size - 1])}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); - encoded.push_back('='); - } - return String(encoded); -} - -/*! - * \brief Encode a bytes object into a base64 string - * \param data The bytes object to encode - * \return The base64 encoded string - */ -inline String Base64Encode(const Bytes& data) { - return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); -} - -/*! - * \brief Decode a base64 string into a byte array - * \param data The base64 encoded string to decode - * \return The decoded byte array - */ -inline Bytes Base64Decode(TVMFFIByteArray bytes) { - constexpr const char kDecodeTable[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' - }; - std::string decoded; - decoded.reserve(bytes.size * 3 / 4); - if (bytes.size == 0) return Bytes(); - TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; - // leverage this property to simplify decoding - static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); - // base64 is always multiple of 4 bytes - for (size_t i = 0; i < bytes.size; i += 4) { - // decode every 4 characters into 24bits, each character contains 6 bits - // note that = is also decoded as 0, which is safe to skip - int32_t buf[4] = { - static_cast(bytes.data[i]), - static_cast(bytes.data[i + 1]), - static_cast(bytes.data[i + 2]), - static_cast(bytes.data[i + 3]), - }; - int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | - (static_cast(kDecodeTable[buf[1]]) << 12) | - (static_cast(kDecodeTable[buf[2]]) << 6) | - static_cast(kDecodeTable[buf[3]]); - // unpack 24bits into 3 bytes, each contains 8 bits - decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); - if (buf[2] != '=') { - decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); - } - if (buf[3] != '=') { - decoded.push_back(static_cast(value_i24 & 0xFF)); - } - } - return Bytes(decoded); -} - -/*! - * \brief Decode a base64 string into a byte array - * \param data The base64 encoded string to decode - * \return The decoded byte array - */ -inline Bytes Base64Decode(const String& data) { - return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); -} - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_BASE64_H_ diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h deleted file mode 100644 index 17cb3af6d0eb..000000000000 --- a/ffi/include/tvm/ffi/extra/c_env_api.h +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/c_env_api.h - * \brief Extra environment API. - */ -#ifndef TVM_FFI_EXTRA_C_ENV_API_H_ -#define TVM_FFI_EXTRA_C_ENV_API_H_ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// ---------------------------------------------------------------------------- -// Stream context -// Focusing on minimalistic thread-local context recording stream being used. -// We explicitly not handle allocation/de-allocation of stream here. -// ---------------------------------------------------------------------------- -typedef void* TVMFFIStreamHandle; - -/*! - * \brief FFI function to set the current stream for a device - * - * \param device_type The type of the device. - * \param device_id The id of the device. - * \param stream The stream to set. - * \param opt_out_original_stream Output original stream if the address is not nullptr. - * \note The stream is a weak reference that is cached/owned by the module. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream); - -/*! - * \brief FFI function to get the current stream for a device - * - * \param device_type The type of the device. - * \param device_id The id of the device. - * \return The current stream of the device. - */ -TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id); - -/*! - * \brief Check if there are any signals raised in the surrounding env. - * \return 0 when success, nonzero when failure happens - * \note Under python this function redirects to PyErr_CheckSignals - */ -TVM_FFI_DLL int TVMFFIEnvCheckSignals(); - -/*! - * \brief Register a symbol into the from the surrounding env such as python - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const char* name, void* symbol); - -// ---------------------------------------------------------------------------- -// Module symbol management in callee side -// ---------------------------------------------------------------------------- -/*! - * \brief FFI function to lookup a function from a module's imports. - * - * This is a helper function that is used by generated code. - * - * \param library_ctx The library context module handle. - * \param func_name The name of the function. - * \param out The result function. - * \note The returned function is a weak reference that is cached/owned by the module. - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, - TVMFFIObjectHandle* out); - -/* - * \brief Register a symbol value that will be initialized when a library with the symbol is loaded. - * - * This function can be used to make context functions to be available in the library - * module that wants to avoid an explicit link dependency - * - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol); - -/*! - * \brief Register a symbol that will be initialized when a system library is loaded. - * - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* symbol); - -#ifdef __cplusplus -} // extern "C" -#endif -#endif // TVM_FFI_EXTRA_C_ENV_API_H_ diff --git a/ffi/include/tvm/ffi/extra/json.h b/ffi/include/tvm/ffi/extra/json.h deleted file mode 100644 index 409f7aa52560..000000000000 --- a/ffi/include/tvm/ffi/extra/json.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/json.h - * \brief Minimal lightweight JSON parsing and serialization utilities - */ -#ifndef TVM_FFI_EXTRA_JSON_H_ -#define TVM_FFI_EXTRA_JSON_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace json { - -/*! - * \brief alias Any as json Value. - * - * To keep things lightweight, we simply reuse the ffi::Any system. - */ -using Value = Any; - -/*! - * \brief alias Map as json Object. - * \note We use Map instead of Map to avoid - * the overhead of key checking when doing as conversion, - * the check will be performed at runtime when we read each key - */ -using Object = ffi::Map; - -/*! \brief alias Array as json Array. */ -using Array = ffi::Array; - -/*! - * \brief Parse a JSON string into an Any value. - * - * Besides the standard JSON syntax, this function also supports: - * - Infinity/NaN as javascript syntax - * - int64 integer value - * - * If error_msg is not nullptr, the error message will be written to it - * and no exception will be thrown when parsing fails. - * - * \param json_str The JSON string to parse. - * \param error_msg The output error message, can be nullptr. - * - * \return The parsed Any value. - */ -TVM_FFI_EXTRA_CXX_API json::Value Parse(const String& json_str, String* error_msg = nullptr); - -/*! - * \brief Serialize an Any value into a JSON string. - * - * \param value The Any value to serialize. - * \param indent The number of spaces to indent the output. - * If not specified, the output will be compact. - * \return The output JSON string. - */ -TVM_FFI_EXTRA_CXX_API String Stringify(const json::Value& value, - Optional indent = std::nullopt); - -} // namespace json -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_JSON_H_ diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h deleted file mode 100644 index f220c582a91f..000000000000 --- a/ffi/include/tvm/ffi/extra/module.h +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/module.h - * \brief A managed dynamic module in the TVM FFI. - */ -#ifndef TVM_FFI_EXTRA_MODULE_H_ -#define TVM_FFI_EXTRA_MODULE_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -// forward declare Module -class Module; - -/*! - * \brief A module that can dynamically load ffi::Functions or exportable source code. - */ -class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { - public: - /*! - * \return The per module type key. - * \note This key is used to for serializing custom modules. - */ - virtual const char* kind() const = 0; - /*! - * \brief Get the property mask of the module. - * \return The property mask of the module. - * - * \sa Module::ModulePropertyMask - */ - virtual int GetPropertyMask() const { return 0b000; } - /*! - * \brief Get a ffi::Function from the module. - * \param name The name of the function. - * \return The function. - */ - virtual Optional GetFunction(const String& name) = 0; - /*! - * \brief Returns true if this module has a definition for a function of \p name. - * - * Note that even if this function returns true the corresponding \p GetFunction result - * may be nullptr if the function is not yet callable without further compilation. - * - * The default implementation just checks if \p GetFunction is non-null. - * \param name The name of the function. - * \return True if the module implements the function, false otherwise. - */ - virtual bool ImplementsFunction(const String& name) { return GetFunction(name).defined(); } - /*! - * \brief Write the current module to file with given format (for further compilation). - * - * \param file_name The file to be saved to. - * \param format The format of the file. - * - * \note This function is mainly used by modules that - */ - virtual void WriteToFile(const String& file_name, const String& format) const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; - } - /*! - * \brief Get the possible write formats of the module, when available. - * \return Possible write formats when available. - */ - virtual Array GetWriteFormats() const { return Array(); } - /*! - * \brief Serialize the the module to bytes. - * \return The serialized module. - */ - virtual Bytes SaveToBytes() const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; - TVM_FFI_UNREACHABLE(); - } - /*! - * \brief Get the source code of module, when available. - * \param format Format of the source code, can be empty by default. - * \return Possible source code when available, or empty string if not available. - */ - virtual String InspectSource(const String& format = "") const { return String(); } - /*! - * \brief Import another module. - * \param other The module to import. - */ - virtual void ImportModule(const Module& other); - /*! - * \brief Clear all imported modules. - */ - virtual void ClearImports(); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return The function. - */ - Optional GetFunction(const String& name, bool query_imports); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return True if the module implements the function, false otherwise. - */ - bool ImplementsFunction(const String& name, bool query_imports); - /*! - * \brief Get the imports of the module. - * \return The imports of the module. - * \note Note the signature is not part of the public API. - */ - const Array& imports() const { return this->imports_; } - - struct InternalUnsafe; - - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIModule; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ModuleObj, Object); - - protected: - friend struct InternalUnsafe; - - /*! - * \brief The modules that this module depends on. - * \note Use ObjectRef to avoid circular dep on Module. - */ - Array imports_; - - private: - /*! - * \brief cache used by TVMFFIModuleLookupFromImports - */ - Map import_lookup_cache_; -}; - -/*! - * \brief Reference to module object. - */ -class Module : public ObjectRef { - public: - /*! - * \brief Property of ffi::Module - */ - enum ModulePropertyMask : int { - /*! - * \brief The module can be serialized to bytes. - * - * This prooperty indicates that module implements SaveToBytes. - * The system also registers a GlobalDef function - * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. - */ - kBinarySerializable = 0b001, - /*! - * \brief The module can directly get runnable functions. - * - * This property indicates that module implements GetFunction that returns - * runnable ffi::Functions. - */ - kRunnable = 0b010, - /*! - * \brief The module can be exported to a object file or source file that then be compiled. - * - * This property indicates that module implements WriteToFile with a given format - * that can be queried by GetLibExportFormat. - * - * Examples include modules that can be exported to .o, .cc, .cu files. - * - * Such modules can be exported, compiled and loaded back as a dynamic library module. - */ - kCompilationExportable = 0b100 - }; - - /*! - * \brief Load a module from file. - * \param file_name The name of the host function module. - * \param format The format of the file. - * \note This function won't load the import relationship. - * Re-create import relationship by calling Import. - */ - TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String& file_name); - /* - * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. - * \param callback The callback to be called with the symbol name and address. - * \note This helper can be used to implement custom Module that needs to access context symbols. - */ - TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( - const ffi::TypedFunction& callback); - - TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Module, ObjectRef, ModuleObj); -}; - -/* - * \brief Symbols for library module. - */ -namespace symbol { -/*! \brief Global variable to store context pointer for a library module. */ -constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; -/*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; -/*! \brief Default entry function of a library module. */ -constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; -} // namespace symbol -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_MODULE_H_ diff --git a/ffi/include/tvm/ffi/extra/serialization.h b/ffi/include/tvm/ffi/extra/serialization.h deleted file mode 100644 index c08ad81cc363..000000000000 --- a/ffi/include/tvm/ffi/extra/serialization.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/serialization.h - * \brief Reflection-based serialization utilities - */ -#ifndef TVM_FFI_EXTRA_SERIALIZATION_H_ -#define TVM_FFI_EXTRA_SERIALIZATION_H_ - -#include -#include - -namespace tvm { -namespace ffi { - -/** - * \brief Serialize ffi::Any to a JSON that stores the object graph. - * - * The JSON graph structure is stored as follows: - * - * ```json - * { - * "root_index": , // Index of root node in nodes array - * "nodes": [, ...], // Array of serialized nodes - * "metadata": // Optional metadata - * } - * ``` - * - * Each node has the format: `{"type": "", "data": }` - * For object types and strings, the data may contain indices to other nodes. - * For object fields whose static type is known as a primitive type, it is stored directly, - * otherwise, it is stored as a reference to the nodes array by an index. - * - * This function preserves the type and multiple references to the same object, - * which is useful for debugging and serialization. - * - * \param value The ffi::Any value to serialize. - * \param metadata Extra metadata attached to "metadata" field of the JSON object. - * \return The serialized JSON value. - */ -TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any& metadata = Any(nullptr)); - -/** - * \brief Deserialize a JSON that stores the object graph to an ffi::Any value. - * - * This function can be used to implement deserialization - * and debugging. - * - * \param value The JSON value to deserialize. - * \return The deserialized object graph. - */ -TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value& value); - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ diff --git a/ffi/include/tvm/ffi/extra/structural_equal.h b/ffi/include/tvm/ffi/extra/structural_equal.h deleted file mode 100644 index 9727940297ed..000000000000 --- a/ffi/include/tvm/ffi/extra/structural_equal.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/structural_equal.h - * \brief Structural equal implementation - */ -#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ -#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -/* - * \brief Structural equality comparators - */ -class StructuralEqual { - public: - /** - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_ndarray_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_EXTRA_CXX_API static bool Equal(const Any& lhs, const Any& rhs, - bool map_free_vars = false, - bool skip_ndarray_content = false); - /** - * \brief Get the first mismatch AccessPath pair when running - * structural equal comparison between two Any values. - * - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_ndarray_content Whether to skip comparing ndarray data content, - * useful for cases where we don't care about parameters content - * \return If comparison fails, return the first mismatch AccessPath pair, - * otherwise return std::nullopt. - */ - TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( - const Any& lhs, const Any& rhs, bool map_free_vars = false, - bool skip_ndarray_content = false); - - /* - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_INLINE bool operator()(const Any& lhs, const Any& rhs) const { - return Equal(lhs, rhs, false, true); - } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ diff --git a/ffi/include/tvm/ffi/extra/structural_hash.h b/ffi/include/tvm/ffi/extra/structural_hash.h deleted file mode 100644 index 9cb08a1c0fc8..000000000000 --- a/ffi/include/tvm/ffi/extra/structural_hash.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/structural_hash.h - * \brief Structural hash - */ -#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ -#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ - -#include -#include - -namespace tvm { -namespace ffi { - -/* - * \brief Structural hash - */ -class StructuralHash { - public: - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \param map_free_vars Whether to map free variables. - * \param skip_ndarray_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content. - * \return The hash value. - */ - TVM_FFI_EXTRA_CXX_API static uint64_t Hash(const Any& value, bool map_free_vars = false, - bool skip_ndarray_content = false); - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \return The hash value. - */ - TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h deleted file mode 100644 index 5a30f25a7b5b..000000000000 --- a/ffi/include/tvm/ffi/function.h +++ /dev/null @@ -1,819 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/function.h - * \brief A managed function in the TVM FFI. - */ -#ifndef TVM_FFI_FUNCTION_H_ -#define TVM_FFI_FUNCTION_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/** - * Helper macro to construct a safe call - * - * \brief Marks the begining of the safe call that catches exception explicitly - * - */ -#define TVM_FFI_SAFE_CALL_BEGIN() \ - try { \ - (void)0 - -/*! - * \brief Marks the end of safe call. - */ -#define TVM_FFI_SAFE_CALL_END() \ - return 0; \ - } \ - catch (const ::tvm::ffi::Error& err) { \ - ::tvm::ffi::details::SetSafeCallRaised(err); \ - return -1; \ - } \ - catch (const ::tvm::ffi::EnvErrorAlreadySet&) { \ - return -2; \ - } \ - catch (const std::exception& ex) { \ - ::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \ - return -1; \ - } \ - TVM_FFI_UNREACHABLE() - -#define TVM_FFI_CHECK_SAFE_CALL(func) \ - { \ - int ret_code = (func); \ - if (ret_code != 0) { \ - if (ret_code == -2) { \ - throw ::tvm::ffi::EnvErrorAlreadySet(); \ - } \ - throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \ - } \ - } - -/*! - * \brief Object container class that backs ffi::Function - * \note Do not use this function directly, use ffi::Function - */ -class FunctionObj : public Object, public TVMFFIFunctionCell { - public: - typedef void (*FCall)(const FunctionObj*, const AnyView*, int32_t, Any*); - using TVMFFIFunctionCell::safe_call; - /*! \brief A C++ style call implementation, with exception propagation in c++ style. */ - FCall call; - - TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { - this->call(this, args, num_args, result); - } - - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIFunction; - - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(FunctionObj, Object); - - protected: - /*! \brief Make default constructor protected. */ - FunctionObj() {} - - // Implementing safe call style - static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin); - FunctionObj* self = static_cast(func); - self->call(self, reinterpret_cast(args), num_args, - reinterpret_cast(result)); - TVM_FFI_SAFE_CALL_END(); - } - - friend class Function; -}; - -namespace details { -/*! - * \brief Derived object class for constructing FunctionObj backed by a TCallable - * - * This is a helper class that - */ -template -class FunctionObjImpl : public FunctionObj { - public: - using TStorage = typename std::remove_cv::type>::type; - /*! \brief The type of derived object class */ - using TSelf = FunctionObjImpl; - /*! - * \brief Derived object class for constructing ffi::FunctionObj. - * \param callable The type-erased callable object. - */ - explicit FunctionObjImpl(TCallable callable) : callable_(callable) { - this->safe_call = SafeCall; - this->call = Call; - } - - private: - // implementation of call - static void Call(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) { - (static_cast(func))->callable_(args, num_args, result); - } - - /*! \brief Type-erased filed for storing callable object*/ - mutable TStorage callable_; -}; - -/*! - * \brief Base class to provide a common implementation to redirect call to safecall - * \tparam Derived The derived class in CRTP-idiom - */ -template -struct RedirectCallToSafeCall { - static void Call(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* rv) { - Derived* self = static_cast(const_cast(func)); - TVM_FFI_CHECK_SAFE_CALL(self->RedirectSafeCall(reinterpret_cast(args), - num_args, reinterpret_cast(rv))); - } - - static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* rv) { - Derived* self = reinterpret_cast(func); - return self->RedirectSafeCall(args, num_args, rv); - } -}; - -/*! - * \brief FunctionObj specialization that leverages C-style callback definitions. - */ -class ExternCFunctionObjImpl : public FunctionObj, - public RedirectCallToSafeCall { - public: - using RedirectCallToSafeCall::SafeCall; - - ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) - : self_(self), safe_call_(safe_call), deleter_(deleter) { - this->call = RedirectCallToSafeCall::Call; - this->safe_call = RedirectCallToSafeCall::SafeCall; - } - - ~ExternCFunctionObjImpl() { deleter_(self_); } - - TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* rv) const { - return safe_call_(self_, args, num_args, rv); - } - - private: - void* self_; - TVMFFISafeCallType safe_call_; - void (*deleter_)(void* self); -}; - -/*! - * \brief FunctionObj specialization that wraps an external function. - */ -class ImportedFunctionObjImpl : public FunctionObj, - public RedirectCallToSafeCall { - public: - using RedirectCallToSafeCall::SafeCall; - - explicit ImportedFunctionObjImpl(ObjectPtr data) : data_(data) { - this->call = RedirectCallToSafeCall::Call; - this->safe_call = RedirectCallToSafeCall::SafeCall; - } - - TVM_FFI_INLINE int32_t RedirectSafeCall(const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* rv) const { - FunctionObj* func = const_cast(static_cast(data_.get())); - return func->safe_call(func, args, num_args, rv); - } - - private: - ObjectPtr data_; -}; - -// Helper class to set packed arguments -class PackedArgsSetter { - public: - explicit PackedArgsSetter(AnyView* args) : args_(args) {} - - // NOTE: setter needs to be very carefully designed - // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) - // that is why we need T&& and std::forward here - template - TVM_FFI_INLINE void operator()(size_t i, T&& value) const { - args_[i].operator=(std::forward(value)); - } - - private: - AnyView* args_; -}; -} // namespace details - -/*! - * \brief Represents arguments packed in AnyView array - * \note This class represent packed arguments to ffi::Function - */ -class PackedArgs { - public: - /*! - * \brief Constructor - * \param data The arguments - * \param size The number of arguments - */ - PackedArgs(const AnyView* data, int32_t size) : data_(data), size_(size) {} - - /*! \return size of the arguments */ - int size() const { return size_; } - - /*! \return The arguments */ - const AnyView* data() const { return data_; } - - /*! - * \brief Slice the arguments - * \param begin The begin index - * \param end The end index - * \return The sliced arguments - */ - PackedArgs Slice(int begin, int end = -1) const { - if (end == -1) { - end = size_; - } - return PackedArgs(data_ + begin, end - begin); - } - - /*! - * \brief Get i-th argument - * \param i the index. - * \return the ith argument. - */ - AnyView operator[](int i) const { return data_[i]; } - - /*! - * \brief Fill the arguments into the AnyView array - * \param data The AnyView array to store the packed arguments - * \param args The arguments to be packed - * \note Caller must ensure all args are alive during lifetime of data. - * A common pitfall is to pass in local variables that are immediately - * destroyed after calling Fill. - */ - template - TVM_FFI_INLINE static void Fill(AnyView* data, Args&&... args) { - details::for_each(details::PackedArgsSetter(data), std::forward(args)...); - } - - private: - /*! \brief The arguments */ - const AnyView* data_; - /*! \brief The number of arguments */ - int32_t size_; -}; - -/*! - * \brief ffi::Function is a type-erased function. - * The arguments are passed by "packed format" via AnyView - */ -class Function : public ObjectRef { - public: - /*! \brief Constructor from null */ - Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - * \note legacy purpose, should change to Function::FromPacked for mostfuture use. - */ - template - explicit Function(TCallable packed_call) { - *this = FromPacked(packed_call); - } - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - */ - template - static Function FromPacked(TCallable packed_call) { - static_assert( - std::is_convertible_v> || - std::is_convertible_v>, - "tvm::ffi::Function::FromPacked requires input function signature to match packed func " - "format"); - if constexpr (std::is_convertible_v>) { - auto wrapped_call = [packed_call](const AnyView* args, int32_t num_args, - Any* rv) mutable -> void { - PackedArgs args_pack(args, num_args); - packed_call(args_pack, rv); - }; - return FromPackedInternal(wrapped_call); - } else { - return FromPackedInternal(packed_call); - } - } - /*! - * \brief Import a possibly externally defined function to this dll - * \param other Function defined in another dynamic library. - * - * \note This function will redirect the call to safe_call in other. - * It will try to detect if the function is already from the same DLL - * and directly return the original function if so. - * - * \return The imported function. - */ - static Function ImportFromExternDLL(Function other) { - const FunctionObj* other_func = static_cast(other.get()); - // the other function comes from the same dll, no action needed - if (other_func->safe_call == &(FunctionObj::SafeCall) || - other_func->safe_call == &(details::ImportedFunctionObjImpl::SafeCall) || - other_func->safe_call == &(details::ExternCFunctionObjImpl::SafeCall)) { - return other; - } - // the other function coems from a different library - Function func; - func.data_ = make_object(std::move(other.data_)); - return func; - } - /*! - * \brief Create ffi::Function from a C style callbacks. - * \param self Resource handle to the function - * \param safe_call The safe_call definition in C. - * \param deleter The deleter to release the resource of self. - * \return The created function. - */ - static Function FromExternC(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void* self)) { - // the other function coems from a different library - Function func; - func.data_ = make_object(self, safe_call, deleter); - return func; - } - /*! - * \brief Get global function by name - * \param name The function name - * \return The global function. - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(std::string_view name) { - TVMFFIObjectHandle handle; - TVMFFIByteArray name_arr{name.data(), name.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); - if (handle != nullptr) { - return Function( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); - } else { - return std::nullopt; - } - } - - static std::optional GetGlobal(const std::string& name) { - return GetGlobal(std::string_view(name.data(), name.length())); - } - - static std::optional GetGlobal(const String& name) { - return GetGlobal(std::string_view(name.data(), name.length())); - } - - static std::optional GetGlobal(const char* name) { - return GetGlobal(std::string_view(name)); - } - /*! - * \brief Get global function by name and throw an error if it is not found. - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(std::string_view name) { - std::optional res = GetGlobal(name); - if (!res.has_value()) { - TVM_FFI_THROW(ValueError) << "Function " << name << " not found"; - } - return *res; - } - - static Function GetGlobalRequired(const std::string& name) { - return GetGlobalRequired(std::string_view(name.data(), name.length())); - } - - static Function GetGlobalRequired(const String& name) { - return GetGlobalRequired(std::string_view(name.data(), name.length())); - } - - static Function GetGlobalRequired(const char* name) { - return GetGlobalRequired(std::string_view(name)); - } - /*! - * \brief Set global function by name - * \param name The name of the function - * \param func The function - * \param override Whether to override when there is duplication. - */ - static void SetGlobal(std::string_view name, Function func, bool override = false) { - TVMFFIByteArray name_arr{name.data(), name.size()}; - TVM_FFI_CHECK_SAFE_CALL( - TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override)); - } - /*! - * \brief List all global names - * \return A vector of all global names - * \note This function do not depend on Array so core do not have container dep. - */ - static std::vector ListGlobalNames() { - Function fname_functor = - GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast(); - std::vector names; - int len = fname_functor(-1).cast(); - for (int i = 0; i < len; ++i) { - names.push_back(fname_functor(i).cast()); - } - return names; - } - /** - * \brief Remove a global function by name - * \param name The name of the function - */ - static void RemoveGlobal(const String& name) { - static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal"); - fremove(name); - } - /*! - * \brief Constructing a packed function from a normal function. - * - * \param callable the internal container of packed function. - */ - template - static Function FromTyped(TCallable callable) { - using FuncInfo = details::FunctionInfo; - auto call_packed = [callable](const AnyView* args, int32_t num_args, Any* rv) mutable -> void { - details::unpack_call( - std::make_index_sequence{}, nullptr, callable, args, num_args, rv); - }; - return FromPackedInternal(call_packed); - } - /*! - * \brief Constructing a packed function from a normal function. - * - * \param callable the internal container of packed function. - * \param name optional name attacked to the function. - */ - template - static Function FromTyped(TCallable callable, std::string name) { - using FuncInfo = details::FunctionInfo; - auto call_packed = [callable, name](const AnyView* args, int32_t num_args, - Any* rv) mutable -> void { - details::unpack_call( - std::make_index_sequence{}, &name, callable, args, num_args, rv); - }; - return FromPackedInternal(call_packed); - } - /*! - * \brief Call function by directly passing in unpacked arguments. - * - * \param args Arguments to be passed. - * \tparam Args arguments to be passed. - * - * \code - * // Example code on how to call packed function - * void CallFFIFunction(tvm::ffi::Function f) { - * // call like normal functions by pass in arguments - * // return value is automatically converted back - * int rvalue = f(1, 2.0); - * } - * \endcode - */ - template - TVM_FFI_INLINE Any operator()(Args&&... args) const { - const int kNumArgs = sizeof...(Args); - const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; - AnyView args_pack[kArraySize]; - PackedArgs::Fill(args_pack, std::forward(args)...); - Any result; - static_cast(data_.get())->CallPacked(args_pack, kNumArgs, &result); - return result; - } - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param rv The return value. - */ - TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { - static_cast(data_.get())->CallPacked(args, num_args, result); - } - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param result The return value. - */ - TVM_FFI_INLINE void CallPacked(PackedArgs args, Any* result) const { - static_cast(data_.get())->CallPacked(args.data(), args.size(), result); - } - - /*! \return Whether the packed function is nullptr */ - TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS(Function, ObjectRef, FunctionObj); - - class Registry; - - private: - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - */ - template - static Function FromPackedInternal(TCallable packed_call) { - using ObjType = typename details::FunctionObjImpl; - Function func; - func.data_ = make_object(std::forward(packed_call)); - return func; - } -}; - -/*! - * \brief Please refer to \ref TypedFunctionAnchor "TypedFunction" - */ -template -class TypedFunction; - -/*! - * \anchor TypedFunctionAnchor - * \brief A ffi::Function wrapper to provide typed function signature. - * It is backed by a ffi::Function internally. - * - * TypedFunction enables compile time type checking. - * TypedFunction works with the runtime system: - * - It can be passed as an argument of ffi::Function. - * - It can be assigned to ffi::Any. - * - It can be directly converted to a type-erased ffi::Function. - * - * Developers should prefer TypedFunction over ffi::Function in C++ code - * as it enables compile time checking. - * We can construct a TypedFunction from a lambda function - * with the same signature. - * - * \code - * // user defined lambda function. - * auto addone = [](int x)->int { - * return x + 1; - * }; - * // We can directly convert - * // lambda function to TypedFunction - * TypedFunction ftyped(addone); - * // invoke the function. - * int y = ftyped(1); - * // Can be directly converted to ffi::Function - * ffi::Function packed = ftype; - * \endcode - * \tparam R The return value of the function. - * \tparam Args The argument signature of the function. - */ -template -class TypedFunction { - public: - /*! \brief short hand for this function type */ - using TSelf = TypedFunction; - /*! \brief default constructor */ - TypedFunction() {} - /*! \brief constructor from null */ - TypedFunction(std::nullptr_t null) {} // NOLINT(*) - /*! - * \brief constructor from a function - * \param packed The function - */ - TypedFunction(Function packed) : packed_(packed) {} // NOLINT(*) - /*! - * \brief construct from a lambda function with the same signature. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedFunction ftyped(typed_lambda, "add_one"); - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \param name the name of the lambda function. - * \tparam FLambda the type of the lambda function. - */ - template >::value>::type> - TypedFunction(FLambda typed_lambda, std::string name) { // NOLINT(*) - packed_ = Function::FromTyped(typed_lambda, name); - } - /*! - * \brief construct from a lambda function with the same signature. - * - * This version does not take a name. It is highly recommend you use the - * version that takes a name for the lambda. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedFunction ftyped(typed_lambda); - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - */ - template >::value>::type> - TypedFunction(const FLambda& typed_lambda) { // NOLINT(*) - packed_ = Function::FromTyped(typed_lambda); - } - /*! - * \brief copy assignment operator from typed lambda - * - * Example usage: - * \code - * // construct from packed function - * TypedFunction ftyped; - * ftyped = [](int x) { return x + 1; } - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - * \returns reference to self. - */ - template >::value>::type> - TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) - packed_ = Function::FromTyped(typed_lambda); - return *this; - } - /*! - * \brief copy assignment operator from ffi::Function. - * \param packed The packed function. - * \returns reference to self. - */ - TSelf& operator=(Function packed) { - packed_ = std::move(packed); - return *this; - } - /*! - * \brief Invoke the operator. - * \param args The arguments - * \returns The return value. - */ - TVM_FFI_INLINE R operator()(Args... args) const { - if constexpr (std::is_same_v) { - packed_(std::forward(args)...); - } else { - Any res = packed_(std::forward(args)...); - if constexpr (std::is_same_v) { - return res; - } else { - return std::move(res).cast(); - } - } - } - /*! - * \brief convert to ffi::Function - * \return the internal ffi::Function - */ - operator Function() const { return packed(); } - /*! - * \return reference the internal ffi::Function - */ - const Function& packed() const& { return packed_; } - /*! - * \return r-value reference the internal ffi::Function - */ - constexpr Function&& packed() && { return std::move(packed_); } - /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } - - private: - /*! \brief The internal packed function */ - Function packed_; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction; - - TVM_FFI_INLINE static void CopyToAnyView(const TypedFunction& src, TVMFFIAny* result) { - TypeTraits::CopyToAnyView(src.packed(), result); - } - - TVM_FFI_INLINE static void MoveToAny(TypedFunction src, TVMFFIAny* result) { - TypeTraits::MoveToAny(std::move(src.packed()), result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIFunction; - } - - TVM_FFI_INLINE static TypedFunction CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return TypedFunction(TypeTraits::CopyFromAnyViewAfterCheck(src)); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView( - const TVMFFIAny* src) { - std::optional opt = TypeTraits::TryCastFromAnyView(src); - if (opt.has_value()) { - return TypedFunction(*std::move(opt)); - } else { - return std::nullopt; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo::Sig(); } -}; - -/*! - * \brief helper function to get type index from key - */ -inline int32_t TypeKeyToIndex(std::string_view type_key) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - return type_index; -} - -/*! - * \brief Export typed function as a SafeCallType symbol. - * - * \param ExportName The symbol name to be exported. - * \param Function The typed function. - * \note ExportName and Function must be different, - * see code examples below. - * - * \sa ffi::TypedFunction - * - * \code - * - * int AddOne_(int x) { - * return x + 1; - * } - * - * // Expose the function as "AddOne" - * TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); - * - * // Expose the function as "SubOne" - * TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { - * return x - 1; - * }); - * - * // The following code will cause compilation error. - * // Because the same Function and ExportName - * // TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); - * - * // The following code is OK, assuming the macro - * // is in a different namespace from xyz - * // TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); - * - * \endcode - */ -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_FFI_DLL_EXPORT int ExportName(void* self, TVMFFIAny* args, int32_t num_args, \ - TVMFFIAny* result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ - static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call( \ - std::make_index_sequence{}, &name, Function, \ - reinterpret_cast(args), num_args, \ - reinterpret_cast<::tvm::ffi::Any*>(result)); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ - } -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_FUNCTION_H_ diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h deleted file mode 100644 index d029c19dd107..000000000000 --- a/ffi/include/tvm/ffi/function_details.h +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/function_details.h - * \brief Implements the funciton signature reflection - */ -#ifndef TVM_FFI_FUNCTION_DETAILS_H_ -#define TVM_FFI_FUNCTION_DETAILS_H_ - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { - -template -struct Arg2Str { - template - TVM_FFI_INLINE static void Apply(std::ostream& os) { - using Arg = std::tuple_element_t; - if constexpr (i != 0) { - os << ", "; - } - os << i << ": " << Type2Str::v(); - } - template - TVM_FFI_INLINE static void Run(std::ostream& os, std::index_sequence) { - using TExpander = int[]; - (void)TExpander{0, (Apply(os), 0)...}; - } -}; - -template -static constexpr bool ArgSupported = - (std::is_same_v>, Any> || - std::is_same_v>, AnyView> || - TypeTraitsNoCR::convert_enabled); - -// NOTE: return type can only support non-reference managed returns -template -static constexpr bool RetSupported = - (std::is_same_v || std::is_void_v || TypeTraits::convert_enabled); - -template -struct FuncFunctorImpl { - using FType = R(Args...); - using ArgType = std::tuple; - using RetType = R; - /*! \brief total number of arguments*/ - static constexpr size_t num_args = sizeof...(Args); - // MSVC is not that friendly to in-template nested bool evaluation -#ifndef _MSC_VER - /*! \brief Whether this function can be converted to ffi::Function via FromTyped */ - static constexpr bool unpacked_supported = (ArgSupported && ...) && (RetSupported); -#endif - - TVM_FFI_INLINE static std::string Sig() { - using IdxSeq = std::make_index_sequence; - std::ostringstream ss; - ss << "("; - Arg2Str>::Run(ss, IdxSeq{}); - ss << ") -> " << Type2Str::v(); - return ss.str(); - } -}; - -template -struct FunctionInfoHelper; - -template -struct FunctionInfoHelper : FuncFunctorImpl {}; -template -struct FunctionInfoHelper : FuncFunctorImpl {}; - -/*! - * \brief Template class to get function signature of a function or functor. - * \tparam T The function/functor type. - * \note We need a decltype redirection because this helps lambda types. - */ -template -struct FunctionInfo : FunctionInfoHelper {}; - -template -struct FunctionInfo : FuncFunctorImpl {}; -template -struct FunctionInfo : FuncFunctorImpl {}; - -/*! \brief Using static function to output typed function signature */ -typedef std::string (*FGetFuncSignature)(); - -/*! - * \brief Auxilary argument value with context for error reporting - */ -class ArgValueWithContext { - public: - /*! - * \brief move constructor from another return value. - * \param args The argument list - * \param arg_index In a function call, this argument is at index arg_index (0-indexed). - * \param optional_name Name of the function being called. Can be nullptr if the function is not. - * \param f_sig Pointer to static function outputting signature of the function being called. - * named. - */ - TVM_FFI_INLINE ArgValueWithContext(const AnyView* args, int32_t arg_index, - const std::string* optional_name, FGetFuncSignature f_sig) - : args_(args), arg_index_(arg_index), optional_name_(optional_name), f_sig_(f_sig) {} - - template - TVM_FFI_INLINE operator Type() { - using TypeWithoutCR = std::remove_const_t>; - - if constexpr (std::is_same_v) { - return args_[arg_index_]; - } else if constexpr (std::is_same_v) { - return Any(args_[arg_index_]); - } else { - std::optional opt = args_[arg_index_].try_cast(); - if (!opt.has_value()) { - TVMFFIAny any_data = args_[arg_index_].CopyToTVMFFIAny(); - TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" << arg_index_ - << " when calling: `" - << (optional_name_ == nullptr ? "" : *optional_name_) - << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" - << Type2Str::v() << "` but got `" - << TypeTraits::GetMismatchTypeInfo(&any_data) - << '`'; - } - return *std::move(opt); - } - } - - private: - const AnyView* args_; - int32_t arg_index_; - const std::string* optional_name_; - FGetFuncSignature f_sig_; -}; - -template -TVM_FFI_INLINE void unpack_call(std::index_sequence, const std::string* optional_name, - const F& f, [[maybe_unused]] const AnyView* args, - [[maybe_unused]] int32_t num_args, [[maybe_unused]] Any* rv) { - using FuncInfo = FunctionInfo; - FGetFuncSignature f_sig = FuncInfo::Sig; - - // somehow MSVC does not support the static constexpr member in this case, function is fine -#ifndef _MSC_VER - static_assert(FuncInfo::unpacked_supported, "The function signature do not support unpacked"); -#endif - constexpr size_t nargs = sizeof...(Is); - if (nargs != num_args) { - TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: `" - << (optional_name == nullptr ? "" : *optional_name) - << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << nargs - << " but got " << num_args << " arguments"; - } - // use index sequence to do recursive-less unpacking - if constexpr (std::is_same_v) { - f(ArgValueWithContext(args, Is, optional_name, f_sig)...); - } else { - *rv = R(f(ArgValueWithContext(args, Is, optional_name, f_sig)...)); - } -} - -/*! - * \brief Move the safe call raised error to the caller - * \return The error - */ -TVM_FFI_INLINE static Error MoveFromSafeCallRaised() { - TVMFFIObjectHandle handle; - TVMFFIErrorMoveFromRaised(&handle); - // handle is owned by caller - return Error( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); -} - -/*! - * \brief Set the safe call raised error - * \param error The error - */ -TVM_FFI_INLINE static void SetSafeCallRaised(const Error& error) { - TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error)); -} -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_FUNCTION_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h deleted file mode 100644 index 02537df79cb4..000000000000 --- a/ffi/include/tvm/ffi/memory.h +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/memory.h - * \brief Runtime memory management to allocate on heap object. - */ -#ifndef TVM_FFI_MEMORY_H_ -#define TVM_FFI_MEMORY_H_ - -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! \brief Deleter function for obeject */ -typedef void (*FObjectDeleter)(TVMFFIObject* obj); - -/*! - * \brief Allocate an object using default allocator. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The ObjectPtr to the allocated object. - */ -template -inline ObjectPtr make_object(Args&&... args); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter = nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. - -/*! - * \brief Base class of object allocators that implements make. - * Use curiously recurring template pattern. - * - * \tparam Derived The derived class. - */ -template -class ObjAllocatorBase { - public: - /*! - * \brief Make a new object using the allocator. - * \tparam T The type to be allocated. - * \tparam Args The constructor signature. - * \param args The arguments. - */ - template - ObjectPtr make_object(Args&&... args) { - using Handler = typename Derived::template Handler; - static_assert(std::is_base_of::value, "make can only be used to create Object"); - T* ptr = Handler::New(static_cast(this), std::forward(args)...); - TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->ref_counter = 1; - ffi_ptr->type_index = T::RuntimeTypeIndex(); - ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); - } - - /*! - * \tparam ArrayType The type to be allocated. - * \tparam ElemType The type of array element. - * \tparam Args The constructor signature. - * \param num_elems The number of array elements. - * \param args The arguments. - */ - template - ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { - using Handler = typename Derived::template ArrayHandler; - static_assert(std::is_base_of::value, - "make_inplace_array can only be used to create Object"); - ArrayType* ptr = - Handler::New(static_cast(this), num_elems, std::forward(args)...); - TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->ref_counter = 1; - ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); - ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); - } -}; - -// Simple allocator that uses new/delete. -class SimpleObjAllocator : public ObjAllocatorBase { - public: - template - class Handler { - public: - struct alignas(T) StorageType { - char data[sizeof(T)]; - }; - - template - static T* New(SimpleObjAllocator*, Args&&... args) { - // NOTE: the first argument is not needed for SimpleObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - StorageType* data = new StorageType(); - new (data) T(std::forward(args)...); - return reinterpret_cast(data); - } - - static FObjectDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(TVMFFIObject* objptr) { - T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); - // It is important to do tptr->T::~T(), - // so that we explicitly call the specific destructor - // instead of tptr->~T(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->T::~T(); - delete reinterpret_cast(tptr); - } - }; - - // Array handler that uses new/delete. - template - class ArrayHandler { - public: - using StorageType = typename std::aligned_storage::type; - // for now only support elements that aligns with array header. - static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, - "element alignment constraint"); - - template - static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { - // NOTE: the first argument is not needed for ArrayObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - size_t unit = sizeof(StorageType); - size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType); - size_t num_storage_slots = (requested_size + unit - 1) / unit; - StorageType* data = new StorageType[num_storage_slots]; - new (data) ArrayType(std::forward(args)...); - return reinterpret_cast(data); - } - - static FObjectDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(TVMFFIObject* objptr) { - ArrayType* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); - // It is important to do tptr->ArrayType::~ArrayType(), - // so that we explicitly call the specific destructor - // instead of tptr->~ArrayType(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->ArrayType::~ArrayType(); - StorageType* p = reinterpret_cast(tptr); - delete[] p; - } - }; -}; - -template -inline ObjectPtr make_object(Args&&... args) { - return SimpleObjAllocator().make_object(std::forward(args)...); -} - -template -inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { - return SimpleObjAllocator().make_inplace_array(num_elems, - std::forward(args)...); -} - -} // namespace ffi - -// Export the make_object function -// rationale: ease of use, and no ambiguity -using ffi::make_object; -} // namespace tvm -#endif // TVM_FFI_MEMORY_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h deleted file mode 100644 index abf7f489038b..000000000000 --- a/ffi/include/tvm/ffi/object.h +++ /dev/null @@ -1,837 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/object.h - * \brief A managed object in the TVM FFI. - */ -#ifndef TVM_FFI_OBJECT_H_ -#define TVM_FFI_OBJECT_H_ - -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -using TypeIndex = TVMFFITypeIndex; -using TypeInfo = TVMFFITypeInfo; - -/*! - * \brief Known type keys for pre-defined types. - */ -struct StaticTypeKey { - static constexpr const char* kTVMFFIAny = "Any"; - static constexpr const char* kTVMFFINone = "None"; - static constexpr const char* kTVMFFIBool = "bool"; - static constexpr const char* kTVMFFIInt = "int"; - static constexpr const char* kTVMFFIFloat = "float"; - static constexpr const char* kTVMFFIOpaquePtr = "void*"; - static constexpr const char* kTVMFFIDataType = "DataType"; - static constexpr const char* kTVMFFIDevice = "Device"; - static constexpr const char* kTVMFFIRawStr = "const char*"; - static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; - static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef"; - static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr"; - static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes"; - static constexpr const char* kTVMFFIBytes = "ffi.Bytes"; - static constexpr const char* kTVMFFIStr = "ffi.String"; - static constexpr const char* kTVMFFIShape = "ffi.Shape"; - static constexpr const char* kTVMFFINDArray = "ffi.NDArray"; - static constexpr const char* kTVMFFIObject = "ffi.Object"; - static constexpr const char* kTVMFFIFunction = "ffi.Function"; - static constexpr const char* kTVMFFIArray = "ffi.Array"; - static constexpr const char* kTVMFFIMap = "ffi.Map"; - static constexpr const char* kTVMFFIModule = "ffi.Module"; -}; - -/*! - * \brief Get type key from type index - * \param type_index The input type index - * \return the type key - */ -inline std::string TypeIndexToTypeKey(int32_t type_index) { - const TypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - return std::string(type_info->type_key.data, type_info->type_key.size); -} - -namespace details { -// Helper to perform -// unsafe operations related to object -struct ObjectUnsafe; - -/*! - * Check if the type_index is an instance of TargetObjectType. - * - * \tparam TargetType The target object type to be checked. - * - * \param object_type_index The type index to be checked, caller - * ensures that the index is already within the object index range. - * - * \return Whether the target type is true. - */ -template -TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index); -} // namespace details - -/*! - * \brief base class of all object containers. - * - * Sub-class of objects should declare the following static constexpr fields: - * - * - _type_index: - * Static type index of the object, if assigned to TypeIndex::kTVMFFIDynObject - * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::TypeIndex(); - * - _type_key: - * The unique string identifier of the type. - * - _type_final: - * Whether the type is terminal type(there is no subclass of the type in the object system). - * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO - * It is still OK to sub-class a terminal object type T and construct it using make_object. - * But IsInstance check will only show that the object type is T(instead of the sub-class). - * - _type_mutable: - * Whether we would like to expose cast to non-constant pointer - * ObjectType* from Any/AnyView. By default, we set to false so it is not exposed. - * - * The following two fields are necessary for base classes that can be sub-classed. - * - * - _type_child_slots: - * Number of reserved type index slots for child classes. - * Used for runtime optimization for type checking in IsInstance. - * If an object's type_index is within range of [type_index, type_index + _type_child_slots] - * Then the object can be quickly decided as sub-class of the current object class. - * If not, a fallback mechanism is used to check the global type table. - * Recommendation: set to estimate number of children needed. - * - * - _type_child_slots_can_overflow: - * Whether we can add additional child classes even if the number of child classes - * exceeds the _type_child_slots. A fallback mechanism to check type table will be used. - * Recommendation: set to false for optimal runtime speed if we know exact number of children. - * - * Two macros are used to declare helper functions in the object: - * - Use TVM_FFI_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. - * - Use TVM_FFI_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. - * - * New objects can be created using make_object function. - * Which will automatically populate the type_index and deleter of the object. - */ -class Object { - protected: - /*! \brief header field that is the common prefix of all objects */ - TVMFFIObject header_; - - public: - Object() { - header_.ref_counter = 0; - header_.deleter = nullptr; - } - /*! - * Check if the object is an instance of TargetType. - * \tparam TargetType The target type to be checked. - * \return Whether the target type is true. - */ - template - bool IsInstance() const { - return details::IsObjectInstance(header_.type_index); - } - - /*! \return The internal runtime type index of the object. */ - int32_t type_index() const { return header_.type_index; } - - /*! - * \return the type key of the object. - * \note this operation is expensive, can be used for error reporting. - */ - std::string GetTypeKey() const { - // the function checks that the info exists - const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index); - return std::string(type_info->type_key.data, type_info->type_key.size); - } - - /*! - * \return A hash value of the return of GetTypeKey. - */ - uint64_t GetTypeKeyHash() const { - // the function checks that the info exists - const TypeInfo* type_info = TVMFFIGetTypeInfo(header_.type_index); - return type_info->type_key_hash; - } - - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - * \return the result. - */ - static std::string TypeIndex2Key(int32_t tindex) { - const TypeInfo* type_info = TVMFFIGetTypeInfo(tindex); - return std::string(type_info->type_key.data, type_info->type_key.size); - } - - bool unique() const { return use_count() == 1; } - - /*! - * \return The usage count of the cell. - * \note We use stl style naming to be consistent with known API in shared_ptr. - */ - int32_t use_count() const { - // only need relaxed load of counters -#ifdef _MSC_VER - return (reinterpret_cast(&header_.ref_counter))[0]; // NOLINT(*) -#else - return __atomic_load_n(&(header_.ref_counter), __ATOMIC_RELAXED); -#endif - } - - // Information about the object - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIObject; - - // Default object type properties for sub-classes - static constexpr bool _type_final = false; - static constexpr bool _type_mutable = false; - static constexpr uint32_t _type_child_slots = 0; - static constexpr bool _type_child_slots_can_overflow = true; - // NOTE: static type index field of the class - static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject; - // the static type depth of the class - static constexpr int32_t _type_depth = 0; - // the structural equality and hash kind of the type - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported; - // The following functions are provided by macro - // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO - /*! - * \brief Get the runtime allocated type index of the type - * \note Getting this information may need dynamic calls into a global table. - */ - static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } - /*! - * \brief Internal function to get or allocate a runtime index. - */ - static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } - - private: - /*! \brief increase reference count */ - void IncRef() { -#ifdef _MSC_VER - _InterlockedIncrement(reinterpret_cast(&header_.ref_counter)); // NOLINT(*) -#else - __atomic_fetch_add(&(header_.ref_counter), 1, __ATOMIC_RELAXED); -#endif - } - - /*! \brief decrease reference count and delete the object */ - void DecRef() { -#ifdef _MSC_VER - if (_InterlockedDecrement( // - reinterpret_cast(&header_.ref_counter)) == 0) { // NOLINT(*) - // full barrrier is implicit in InterlockedDecrement - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_)); - } - } -#else - // first do a release, note we only need to acquire for deleter - if (__atomic_fetch_sub(&(header_.ref_counter), 1, __ATOMIC_RELEASE) == 1) { - // only acquire when we need to call deleter - // in this case we need to ensure all previous writes are visible - __atomic_thread_fence(__ATOMIC_ACQUIRE); - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_)); - } - } -#endif - } - - // friend classes - template - friend class ObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief A custom smart pointer for Object. - * \tparam T the content data type. - * \sa make_object - */ -template -class ObjectPtr { - public: - /*! \brief default constructor */ - ObjectPtr() {} - /*! \brief default constructor */ - ObjectPtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - ObjectPtr(const ObjectPtr& other) // NOLINT(*) - : ObjectPtr(other.data_) {} - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - ObjectPtr(const ObjectPtr& other) // NOLINT(*) - : ObjectPtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - ObjectPtr(ObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - ObjectPtr(ObjectPtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class ObjectPtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~ObjectPtr() { this->reset(); } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - void swap(ObjectPtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T* get() const { return static_cast(data_); } - /*! - * \return The pointer - */ - T* operator->() const { return get(); } - /*! - * \return The reference - */ - T& operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr& operator=(const ObjectPtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - ObjectPtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr& operator=(ObjectPtr&& other) { // NOLINT(*) - // copy-and-swap idiom - ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief nullptr check - * \return result of comparison of internal pointer with nullptr. - */ - explicit operator bool() const { return get() != nullptr; } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } - /*! \return whether the reference is unique */ - bool unique() const { return data_ != nullptr && data_->use_count() == 1; } - /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } - /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t) const { return data_ == nullptr; } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - - private: - /*! \brief internal pointer field */ - Object* data_{nullptr}; - /*! - * \brief constructor from Object - * \param data The data pointer - */ - explicit ObjectPtr(Object* data) : data_(data) { - if (data_ != nullptr) { - data_->IncRef(); - } - } - // friend classes - friend class Object; - friend class ObjectRef; - friend struct ObjectPtrHash; - template - friend class ObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief Optional data type in FFI. - * \tparam T The underlying type of the optional. - * - * \note Compared to std::optional, Optional - * akes less storage as it used nullptr to represent nullopt. - */ -template -class Optional; - -/*! \brief Base class of all object reference */ -class ObjectRef { - public: - /*! \brief default constructor */ - ObjectRef() = default; - /*! \brief copy constructor */ - ObjectRef(const ObjectRef& other) = default; - /*! \brief move constructor */ - ObjectRef(ObjectRef&& other) = default; - /*! \brief copy assignment */ - ObjectRef& operator=(const ObjectRef& other) = default; - /*! \brief move assignment */ - ObjectRef& operator=(ObjectRef&& other) = default; - /*! \brief Constructor from existing object ptr */ - explicit ObjectRef(ObjectPtr data) : data_(data) {} - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool same_as(const ObjectRef& other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator==(const ObjectRef& other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } - /*! - * \brief Comparator - * \param other Another object ref by address. - * \return the compare result. - */ - bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } - /*! - * \return whether the object is defined. - */ - bool defined() const { return data_ != nullptr; } - /*! \return the internal object pointer */ - const Object* get() const { return data_.get(); } - /*! \return the internal object pointer */ - const Object* operator->() const { return get(); } - /*! \return whether the reference is unique */ - bool unique() const { return data_.unique(); } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_.use_count(); } - - /*! - * \brief Try to downcast the internal Object to a - * raw pointer of a corresponding type. - * - * The function will return a nullptr if the cast failed. - * - * if (const AddNode *ptr = node_ref.as()) { - * // This is an add node - * } - * - * \tparam ObjectType the target type, must be a subtype of Object - * \return The pointer to the requested type. - */ - template >> - const ObjectType* as() const { - if (data_ != nullptr && data_->IsInstance()) { - return static_cast(data_.get()); - } else { - return nullptr; - } - } - - /*! - * \brief Try to downcast the ObjectRef to Optional of the requested type. - * - * The function will return a std::nullopt if the cast or if the pointer is nullptr. - * - * \tparam ObjectRefType the target type, must be a subtype of ObjectRef' - * \return The optional value of the requested type. - */ - template >> - TVM_FFI_INLINE std::optional as() const { - if (data_ != nullptr) { - if (data_->IsInstance()) { - return ObjectRefType(data_); - } else { - return std::nullopt; - } - } else { - return std::nullopt; - } - } - /*! - * \brief Get the type index of the ObjectRef - * \return The type index of the ObjectRef - */ - int32_t type_index() const { - return data_ != nullptr ? data_->type_index() : TypeIndex::kTVMFFINone; - } - - /*! - * \brief Get the type key of the ObjectRef - * \return The type key of the ObjectRef - */ - std::string GetTypeKey() const { - return data_ != nullptr ? data_->GetTypeKey() : StaticTypeKey::kTVMFFINone; - } - - /*! \brief type indicate the container type. */ - using ContainerType = Object; - // Default type properties for the reference class. - static constexpr bool _type_is_nullable = true; - - protected: - /*! \brief Internal pointer that backs the reference. */ - ObjectPtr data_; - /*! \return return a mutable internal ptr, can be used by sub-classes. */ - Object* get_mutable() const { return data_.get(); } - // friend classes. - friend struct ObjectPtrHash; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -// forward delcare variant -template -class Variant; - -/*! \brief ObjectRef hash functor */ -struct ObjectPtrHash { - size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } - - template - size_t operator()(const ObjectPtr& a) const { - return std::hash()(a.get()); - } - - template - TVM_FFI_INLINE size_t operator()(const Variant& a) const; -}; - -/*! \brief ObjectRef equal functor */ -struct ObjectPtrEqual { - bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } - - template - bool operator()(const ObjectPtr& a, const ObjectPtr& b) const { - return a == b; - } - - template - TVM_FFI_INLINE bool operator()(const Variant& a, const Variant& b) const; -}; - -// If dynamic type is enabled, we still need to register the runtime type of parent -#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) \ - static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - static_assert(!ParentType::_type_final, "ParentType marked as final"); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - TVMFFIByteArray type_key{TypeName::_type_key, \ - std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ - &type_key, TypeName::_type_index, TypeName::_type_depth, TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ - return tindex; \ - } \ - static inline int32_t _register_type_index = _GetOrAllocRuntimeTypeIndex() - -/*! - * \brief Helper macro to declare a object that comes with static type index. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \ - static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ - TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) - -/*! - * \brief helper macro to declare a base object type that can be inherited. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - static_assert(!ParentType::_type_final, "ParentType marked as final"); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - TVMFFIByteArray type_key{TypeName::_type_key, \ - std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ - &type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ - return tindex; \ - } \ - static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } \ - static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex() - -/*! - * \brief helper macro to declare type information in a final class. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ - static const constexpr int _type_child_slots = 0; \ - static const constexpr bool _type_final = true; \ - TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) - -/* - * \brief Define object reference methods. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - * - * \note This macro also defines the default constructor that puts the ObjectRef - * in undefined state initially. - */ -#define TVM_FFI_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - const ObjectName* operator->() const { return static_cast(data_.get()); } \ - const ObjectName* get() const { return operator->(); } \ - using ContainerType = ObjectName - -/* - * \brief Define object reference methods do not have undefined state. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - const ObjectName* operator->() const { return static_cast(data_.get()); } \ - const ObjectName* get() const { return operator->(); } \ - static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName - -/* - * \brief Define object reference methods of whose content is mutable. - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - * \note We recommend making objects immutable when possible. - * This macro is only reserved for objects that stores runtime states. - */ -#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ - using ContainerType = ObjectName - -/* - * \brief Define object reference methods that is both not nullable and mutable. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ - ObjectName* get() const { return operator->(); } \ - static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName - -namespace details { -template -TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { - static_assert(std::is_base_of_v); - // Everything is a subclass of object. - if constexpr (std::is_same::value) { - return true; - } else if constexpr (TargetType::_type_final) { - // if the target type is a final type - // then we only need to check the equivalence. - return object_type_index == TargetType::RuntimeTypeIndex(); - } else { - // Explicitly enclose in else to eliminate this branch early in compilation. - // if target type is a non-leaf type - // Check if type index falls into the range of reserved slots. - int32_t target_type_index = TargetType::RuntimeTypeIndex(); - int32_t begin = target_type_index; - // The condition will be optimized by constant-folding. - if constexpr (TargetType::_type_child_slots != 0) { - // total_slots = child_slots + 1 (including self) - int32_t end = begin + TargetType::_type_child_slots + 1; - if (object_type_index >= begin && object_type_index < end) return true; - } else { - if (object_type_index == begin) return true; - } - if constexpr (TargetType::_type_child_slots_can_overflow) { - // Invariance: parent index is always smaller than the child. - if (object_type_index < target_type_index) return false; - // Do a runtime lookup of type information - // the function checks that the info exists - const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index); - return (type_info->type_depth > TargetType::_type_depth && - type_info->type_acenstors[TargetType::_type_depth]->type_index == target_type_index); - } else { - return false; - } - } -} - -/*! - * \brief Namespace to internally manipulate object class. - * \note These functions are only supposed to be used by internal - * implementations and not external users of the tvm::ffi - */ -struct ObjectUnsafe { - // NOTE: get ffi header from an object - TVM_FFI_INLINE static TVMFFIObject* GetHeader(const Object* src) { - return const_cast(&(src->header_)); - } - - template - TVM_FFI_INLINE static int64_t GetObjectOffsetToSubclass() { - return (reinterpret_cast(&(static_cast(nullptr)->header_)) - - reinterpret_cast(&(static_cast(nullptr)->header_))); - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(const ObjectRef& ref) { - if constexpr (std::is_same_v) { - return ref.data_; - } else { - return tvm::ffi::ObjectPtr(ref.data_.data_); - } - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(ObjectRef&& ref) { - if constexpr (std::is_same_v) { - return std::move(ref.data_); - } else { - return tvm::ffi::ObjectPtr(std::move(ref.data_.data_)); - } - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(Object* raw_ptr) { - tvm::ffi::ObjectPtr ptr; - ptr.data_ = raw_ptr; - return ptr; - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(TVMFFIObject* obj_ptr) { - return ObjectPtrFromOwned(reinterpret_cast(obj_ptr)); - } - - template - TVM_FFI_INLINE static T* RawObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { - // NOTE: this is important to first cast to Object* - // then cast back to T* because objptr and tptr may not be the same - // depending on how sub-class allocates the space. - return static_cast(reinterpret_cast(obj_ptr)); - } - - // Create ObjectPtr from unowned ptr - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(Object* raw_ptr) { - return tvm::ffi::ObjectPtr(raw_ptr); - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { - return tvm::ffi::ObjectPtr(reinterpret_cast(obj_ptr)); - } - - TVM_FFI_INLINE static void DecRefObjectHandle(TVMFFIObjectHandle handle) { - reinterpret_cast(handle)->DecRef(); - } - - TVM_FFI_INLINE static void IncRefObjectHandle(TVMFFIObjectHandle handle) { - reinterpret_cast(handle)->IncRef(); - } - - TVM_FFI_INLINE static Object* RawObjectPtrFromObjectRef(const ObjectRef& src) { - return src.data_.data_; - } - - TVM_FFI_INLINE static TVMFFIObject* TVMFFIObjectPtrFromObjectRef(const ObjectRef& src) { - return GetHeader(src.data_.data_); - } - - template - TVM_FFI_INLINE static TVMFFIObject* TVMFFIObjectPtrFromObjectPtr(const ObjectPtr& src) { - return GetHeader(src.data_); - } - - template - TVM_FFI_INLINE static TVMFFIObject* MoveObjectPtrToTVMFFIObjectPtr(ObjectPtr&& src) { - Object* obj_ptr = src.data_; - src.data_ = nullptr; - return GetHeader(obj_ptr); - } - - TVM_FFI_INLINE static TVMFFIObject* MoveObjectRefToTVMFFIObjectPtr(ObjectRef&& src) { - Object* obj_ptr = src.data_.data_; - src.data_.data_ = nullptr; - return GetHeader(obj_ptr); - } -}; -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_OBJECT_H_ diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h deleted file mode 100644 index a52f64e483dc..000000000000 --- a/ffi/include/tvm/ffi/optional.h +++ /dev/null @@ -1,416 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/optional.h - * \brief Runtime Optional container types. - * \note Optional specializes for T is ObjectRef and used nullptr to indicate nullopt. - */ -#ifndef TVM_FFI_OPTIONAL_H_ -#define TVM_FFI_OPTIONAL_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -// Note: We place optional in tvm/ffi instead of tvm/ffi/container -// because optional itself is an inherent core component of the FFI system. - -template -inline constexpr bool is_optional_type_v = false; - -template -inline constexpr bool is_optional_type_v> = true; - -// we can safely used ptr based optional for ObjectRef types -// that do not have additional data members and virtual functions. -template -inline constexpr bool use_ptr_based_optional_v = - (std::is_base_of_v && !is_optional_type_v); - -// Specialization for non-ObjectRef types. -// simply fallback to std::optional -template -class Optional && !std::is_same_v && - !std::is_same_v>> { - public: - // default constructors. - Optional() = default; - Optional(const Optional& other) : data_(other.data_) {} - Optional(Optional&& other) : data_(std::move(other.data_)) {} - Optional(std::optional other) : data_(std::move(other)) {} // NOLINT(*) - Optional(std::nullopt_t) {} // NOLINT(*) - // normal value handling. - Optional(T other) // NOLINT(*) - : data_(std::move(other)) {} - - TVM_FFI_INLINE Optional& operator=(const Optional& other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(Optional&& other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(T other) { - data_ = std::move(other); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(std::nullopt_t) { - data_ = std::nullopt; - return *this; - } - - TVM_FFI_INLINE const T& value() const& { - if (!data_.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return *data_; - } - - TVM_FFI_INLINE T&& value() && { - if (!data_.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return *std::move(data_); - } - - template > - TVM_FFI_INLINE T value_or(U&& default_value) const { - return data_.value_or(std::forward(default_value)); - } - - TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.has_value(); } - - TVM_FFI_INLINE bool has_value() const noexcept { return data_.has_value(); } - - TVM_FFI_INLINE bool operator==(const Optional& other) const { return data_ == other.data_; } - - TVM_FFI_INLINE bool operator!=(const Optional& other) const { return data_ != other.data_; } - - template - TVM_FFI_INLINE bool operator==(const U& other) const { - return data_ == other; - } - template - TVM_FFI_INLINE bool operator!=(const U& other) const { - return data_ != other; - } - - /*! - * \brief Direct access to the value. - * \return the xvalue reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T&& operator*() && noexcept { return *std::move(data_); } - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE const T& operator*() const& noexcept { return *data_; } - - private: - std::optional data_; -}; - -// Specialization for String type, use nullptr to indicate nullopt -template -class Optional || std::is_same_v>> { - public: - // default constructors. - Optional() = default; - Optional(const Optional& other) : data_(other.data_) {} - Optional(Optional&& other) : data_(std::move(other.data_)) {} - Optional(std::nullopt_t) {} // NOLINT(*) - // normal value handling. - Optional(T other) // NOLINT(*) - : data_(std::move(other)) {} - - TVM_FFI_INLINE Optional& operator=(const Optional& other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(Optional&& other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(T other) { - data_ = std::move(other); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(std::nullopt_t) { - T(details::BytesBaseCell(std::nullopt)).swap(data_); - return *this; - } - - TVM_FFI_INLINE const T& value() const& { - if (data_.data_ == std::nullopt) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return data_; - } - - TVM_FFI_INLINE String&& value() && { - if (data_.data_ == std::nullopt) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return std::move(data_); - } - - template - TVM_FFI_INLINE T value_or(U&& default_value) const { - if (data_.data_ == std::nullopt) { - return std::forward(default_value); - } - return data_; - } - - TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.data_ != std::nullopt; } - - TVM_FFI_INLINE bool has_value() const noexcept { return data_.data_ != std::nullopt; } - - TVM_FFI_INLINE bool operator==(const Optional& other) const { - if (data_.data_ == std::nullopt) { - return other.data_.data_ == std::nullopt; - } - if (other.data_.data_ == std::nullopt) { - return false; - } - return data_ == other.data_; - } - - TVM_FFI_INLINE bool operator!=(const Optional& other) const { return !(*this == other); } - - template - TVM_FFI_INLINE bool operator==(const U& other) const { - if constexpr (std::is_same_v) { - return data_.data_ == std::nullopt; - } else { - if (data_.data_ == std::nullopt) { - return false; - } - return data_ == other; - } - } - template - TVM_FFI_INLINE bool operator!=(const U& other) const { - if constexpr (std::is_same_v) { - return data_.data_ != std::nullopt; - } else { - if (data_.data_ == std::nullopt) { - return true; - } - return data_ != other; - } - } - - /*! - * \brief Direct access to the value. - * \return the xvalue reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T&& operator*() && noexcept { return std::move(data_); } - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE const T& operator*() const& noexcept { return data_; } - - private: - // this is a private initializer - T data_{details::BytesBaseCell(std::nullopt)}; -}; - -// Specialization for ObjectRef types. -// nullptr is treated as std::nullopt. -template -class Optional>> : public ObjectRef { - public: - using ContainerType = typename T::ContainerType; - Optional() = default; - Optional(const Optional& other) : ObjectRef(other.data_) {} - Optional(Optional&& other) : ObjectRef(std::move(other.data_)) {} - explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} - // nullopt hanlding - Optional(std::nullopt_t) {} // NOLINT(*) - - // handle conversion from std::optional - Optional(std::optional other) { // NOLINT(*) - if (other.has_value()) { - *this = *std::move(other); - } - } - // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) {} - - TVM_FFI_INLINE Optional& operator=(T other) { - ObjectRef::operator=(std::move(other)); - return *this; - } - - TVM_FFI_INLINE Optional& operator=(const Optional& other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(std::nullptr_t) { - data_ = nullptr; - return *this; - } - - TVM_FFI_INLINE Optional& operator=(Optional&& other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE T value() const& { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return T(data_); - } - - TVM_FFI_INLINE T value() && { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return T(std::move(data_)); - } - - template > - TVM_FFI_INLINE T value_or(U&& default_value) const { - return data_ != nullptr ? T(data_) : T(std::forward(default_value)); - } - - TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; } - - TVM_FFI_INLINE bool has_value() const { return data_ != nullptr; } - - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T operator*() const& noexcept { return T(data_); } - - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T operator*() && noexcept { return T(std::move(data_)); } - - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); } - - // operator overloadings - TVM_FFI_INLINE auto operator==(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - return EQToOptional(other); - } - TVM_FFI_INLINE auto operator!=(const Optional& other) const { return NEToOptional(other); } - - TVM_FFI_INLINE auto operator==(const std::optional& other) const { - // support case where sub-class returns a symbolic ref type. - return EQToOptional(other); - } - TVM_FFI_INLINE auto operator!=(const std::optional& other) const { - return NEToOptional(other); - } - - TVM_FFI_INLINE auto operator==(const T& other) const { - using RetType = decltype(value() == other); - if (same_as(other)) return RetType(true); - if (has_value()) return operator*() == other; - return RetType(false); - } - - TVM_FFI_INLINE auto operator!=(const T& other) const { return !(*this == other); } - - template - TVM_FFI_INLINE auto operator==(const U& other) const { - using RetType = decltype(value() == other); - if (!has_value()) return RetType(false); - return operator*() == other; - } - - template - TVM_FFI_INLINE auto operator!=(const U& other) const { - using RetType = decltype(value() != other); - if (!has_value()) return RetType(true); - return operator*() != other; - } - - /*! - * \return The internal object pointer with container type of T. - * \note This function do not perform not-null checking. - */ - TVM_FFI_INLINE const ContainerType* get() const { - return static_cast(data_.get()); - } - - private: - template - TVM_FFI_INLINE auto EQToOptional(const U& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(operator*() == *other); - if (same_as(other)) return RetType(true); - if (has_value() && other.has_value()) { - return operator*() == *other; - } else { - // one of them is nullptr. - return RetType(false); - } - } - - template - TVM_FFI_INLINE auto NEToOptional(const U& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(operator*() != *other); - if (same_as(other)) return RetType(false); - if (has_value() && other.has_value()) { - return operator*() != *other; - } else { - // one of them is nullptr. - return RetType(true); - } - } -}; -} // namespace ffi - -// Expose to the tvm namespace -using ffi::Optional; -} // namespace tvm -#endif // TVM_FFI_OPTIONAL_H_ diff --git a/ffi/include/tvm/ffi/reflection/access_path.h b/ffi/include/tvm/ffi/reflection/access_path.h deleted file mode 100644 index 267cb76fc1fe..000000000000 --- a/ffi/include/tvm/ffi/reflection/access_path.h +++ /dev/null @@ -1,377 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/registry.h - * \brief Registry of reflection metadata. - */ -#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_ -#define TVM_FFI_REFLECTION_ACCESS_PATH_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -enum class AccessKind : int32_t { - kAttr = 0, - kArrayItem = 1, - kMapItem = 2, - // the following two are used for error reporting when - // the supposed access field is not available - kAttrMissing = 3, - kArrayItemMissing = 4, - kMapItemMissing = 5, -}; - -class AccessStep; - -/*! - * \brief Represent a single step in object field, map key, array index access. - */ -class AccessStepObj : public Object { - public: - /*! - * \brief The kind of the access pattern. - */ - AccessKind kind; - /*! - * \brief The access key - * \note for array access, it will always be integer - * for field access, it will be string - */ - Any key; - - // default constructor to enable auto-serialization - AccessStepObj() = default; - AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {} - - /*! - * \brief Deep check if two steps are equal. - * \param other The other step to compare with. - * \return True if the two steps are equal, false otherwise. - */ - inline bool StepEqual(const AccessStep& other) const; - - static constexpr const char* _type_key = "ffi.reflection.AccessStep"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object); -}; - -/*! - * \brief ObjectRef class of AccessStepObj. - * - * \sa AccessStepObj - */ -class AccessStep : public ObjectRef { - public: - AccessStep(AccessKind kind, Any key) : ObjectRef(make_object(kind, key)) {} - - static AccessStep Attr(String field_name) { return AccessStep(AccessKind::kAttr, field_name); } - - static AccessStep AttrMissing(String field_name) { - return AccessStep(AccessKind::kAttrMissing, field_name); - } - - static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } - - static AccessStep ArrayItemMissing(int64_t index) { - return AccessStep(AccessKind::kArrayItemMissing, index); - } - - static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); } - - static AccessStep MapItemMissing(Any key = nullptr) { - return AccessStep(AccessKind::kMapItemMissing, key); - } - - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj); -}; - -inline bool AccessStepObj::StepEqual(const AccessStep& other) const { - return this->kind == other->kind && AnyEqual()(this->key, other->key); -} - -// forward declaration -class AccessPath; - -/*! - * \brief ObjectRef class of AccessPathObj. - * - * \sa AccessPathObj - */ -class AccessPathObj : public Object { - public: - /*! - * \brief The parent of the access path. - * - * This parent-pointing tree structure is more space efficient when - * representing multiple paths that share a common prefix. - * - * \note Empty for root. - */ - Optional parent; - /*! - * \brief The current of the access path. - * \note Empty for root. - */ - Optional step; - /*! - * \brief The current depth of the access path, 0 for root - */ - int32_t depth; - - // default constructor to enable auto-serialization - AccessPathObj() = default; - /*! - * \brief Constructor for the access path. - * \param parent The parent of the access path. - * \param step The current step of the access path. - * \param depth The current depth of the access path. - */ - AccessPathObj(Optional parent, Optional step, int32_t depth) - : parent(parent), step(step), depth(depth) {} - - /*! - * \brief Get the parent of the access path. - * \return The parent of the access path. - */ - inline Optional GetParent() const; - - /*! - * \brief Extend the access path with a new step. - * \param step The step to extend the access path with. - * \return The extended access path. - */ - inline AccessPath Extend(AccessStep step) const; - - /*! - * \brief Extend the access path with an object attribute access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath Attr(String field_name) const; - - /*! - * \brief Extend the access path with an object attribute missing access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath AttrMissing(String field_name) const; - - /*! - * \brief Extend the access path with an array item access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItem(int64_t index) const; - - /*! - * \brief Extend the access path with an array item missing access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItemMissing(int64_t index) const; - - /*! - * \brief Extend the access path with a map item access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItem(Any key) const; - - /*! - * \brief Extend the access path with a map item missing access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItemMissing(Any key) const; - - /*! - * \brief Get the array of steps that corresponds to the access path. - * \return The array of steps that corresponds to the access path. - */ - inline Array ToSteps() const; - - /*! - * \brief Check if two paths are equal by deep comparing the steps. - * \param other The other path to compare with. - * \return True if the two paths are equal, false otherwise. - */ - inline bool PathEqual(const AccessPath& other) const; - - /*! - * \brief Check if this path is a prefix of another path. - * \param other The other path to compare with. - * \return True if this path is a prefix of the other path, false otherwise. - */ - inline bool IsPrefixOf(const AccessPath& other) const; - - static constexpr const char* _type_key = "ffi.reflection.AccessPath"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessPathObj, Object); - - private: - static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) { - // fast path for same pointer - if (lhs == rhs) return true; - if (lhs->depth != rhs->depth) return false; - // do deep equality checks - while (lhs->parent.has_value()) { - TVM_FFI_ICHECK(rhs->parent.has_value()); - TVM_FFI_ICHECK(lhs->step.has_value()); - TVM_FFI_ICHECK(rhs->step.has_value()); - if (!(*lhs->step)->StepEqual(*(rhs->step))) { - return false; - } - lhs = static_cast(lhs->parent.get()); - rhs = static_cast(rhs->parent.get()); - // fast path for same pointer - if (lhs == rhs) return true; - TVM_FFI_ICHECK(lhs != nullptr); - TVM_FFI_ICHECK(rhs != nullptr); - } - return true; - } -}; - -/*! - * \brief ObjectRef class of AccessPath. - * - * \sa AccessPathObj - */ -class AccessPath : public ObjectRef { - public: - /*! - * \brief Create an access path from an iterator range of steps. - * \param begin The beginning of the iterator range. - * \param end The end of the iterator range. - * \return The access path. - */ - template - static AccessPath FromSteps(Iter begin, Iter end) { - AccessPath path = AccessPath::Root(); - for (Iter it = begin; it != end; ++it) { - path = path->Extend(*it); - } - return path; - } - /*! - * \brief Create an access path from an array of steps. - * \param steps The array of steps. - * \return The access path. - */ - static AccessPath FromSteps(Array steps) { - AccessPath path = AccessPath::Root(); - for (AccessStep step : steps) { - path = path->Extend(step); - } - return path; - } - - /*! - * \brief Create a root access path. - * \return The root access path. - */ - static AccessPath Root() { - return AccessPath(make_object(std::nullopt, std::nullopt, 0)); - } - - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef, AccessPathObj); -}; - -using AccessPathPair = Tuple; - -inline Optional AccessPathObj::GetParent() const { - if (auto opt_parent = this->parent.as()) { - return opt_parent; - } - return std::nullopt; -} - -inline AccessPath AccessPathObj::Extend(AccessStep step) const { - return AccessPath(make_object(GetRef(this), step, this->depth + 1)); -} - -inline AccessPath AccessPathObj::Attr(String field_name) const { - return this->Extend(AccessStep::Attr(field_name)); -} - -inline AccessPath AccessPathObj::AttrMissing(String field_name) const { - return this->Extend(AccessStep::AttrMissing(field_name)); -} - -inline AccessPath AccessPathObj::ArrayItem(int64_t index) const { - return this->Extend(AccessStep::ArrayItem(index)); -} - -inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const { - return this->Extend(AccessStep::ArrayItemMissing(index)); -} - -inline AccessPath AccessPathObj::MapItem(Any key) const { - return this->Extend(AccessStep::MapItem(key)); -} - -inline AccessPath AccessPathObj::MapItemMissing(Any key) const { - return this->Extend(AccessStep::MapItemMissing(key)); -} - -inline Array AccessPathObj::ToSteps() const { - std::vector reverse_steps; - reverse_steps.reserve(this->depth); - const AccessPathObj* current = this; - while (current->parent.has_value()) { - TVM_FFI_ICHECK(current->step.has_value()); - reverse_steps.push_back(*(current->step)); - current = static_cast(current->parent.get()); - TVM_FFI_ICHECK(current != nullptr); - } - return Array(reverse_steps.rbegin(), reverse_steps.rend()); -} - -inline bool AccessPathObj::PathEqual(const AccessPath& other) const { - return PathEqual(this, other.get()); -} - -inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const { - if (this->depth > other->depth) { - return false; - } - const AccessPathObj* rhs_path = other.get(); - while (rhs_path->depth > this->depth) { - TVM_FFI_ICHECK(rhs_path->parent.has_value()); - rhs_path = static_cast(rhs_path->parent.get()); - } - return PathEqual(this, rhs_path); -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ diff --git a/ffi/include/tvm/ffi/reflection/accessor.h b/ffi/include/tvm/ffi/reflection/accessor.h deleted file mode 100644 index 5215444052f8..000000000000 --- a/ffi/include/tvm/ffi/reflection/accessor.h +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/accessor.h - * \brief Reflection-based accessor for object fields and methods. - */ -#ifndef TVM_FFI_REFLECTION_ACCESSOR_H_ -#define TVM_FFI_REFLECTION_ACCESSOR_H_ - -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -/*! - * \brief helper function to get reflection field info by type key and field name - */ -inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char* field_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo* info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_fields; ++i) { - if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) { - return &(info->fields[i]); - } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in " << type_key; - TVM_FFI_UNREACHABLE(); -} - -/*! - * \brief helper wrapper class to obtain a getter. - */ -class FieldGetter { - public: - explicit FieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} - - explicit FieldGetter(std::string_view type_key, const char* field_name) - : FieldGetter(GetFieldInfo(type_key, field_name)) {} - - Any operator()(const Object* obj_ptr) const { - Any result; - const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->getter(const_cast(addr), reinterpret_cast(&result))); - return result; - } - - Any operator()(const ObjectPtr& obj_ptr) const { return operator()(obj_ptr.get()); } - - Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); } - - private: - const TVMFFIFieldInfo* field_info_; -}; - -/*! - * \brief helper wrapper class to obtain a setter. - */ -class FieldSetter { - public: - explicit FieldSetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} - - explicit FieldSetter(std::string_view type_key, const char* field_name) - : FieldSetter(GetFieldInfo(type_key, field_name)) {} - - void operator()(const Object* obj_ptr, AnyView value) const { - const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->setter(const_cast(addr), reinterpret_cast(&value))); - } - - void operator()(const ObjectPtr& obj_ptr, AnyView value) const { - operator()(obj_ptr.get(), value); - } - - void operator()(const ObjectRef& obj, AnyView value) const { operator()(obj.get(), value); } - - private: - const TVMFFIFieldInfo* field_info_; -}; - -class TypeAttrColumn { - public: - explicit TypeAttrColumn(std::string_view attr_name) { - TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()}; - column_ = TVMFFIGetTypeAttrColumn(&attr_name_array); - if (column_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name; - } - } - - AnyView operator[](int32_t type_index) const { - size_t tindex = static_cast(type_index); - if (tindex >= column_->size) { - return AnyView(); - } - const AnyView* any_view_data = reinterpret_cast(column_->data); - return any_view_data[tindex]; - } - - private: - const TVMFFITypeAttrColumn* column_; -}; - -/*! - * \brief helper function to get reflection method info by type key and method name - * - * \param type_key The type key. - * \param method_name The name of the method. - * \return The method info. - */ -inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const char* method_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo* info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_methods; ++i) { - if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) { - return &(info->methods[i]); - } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key; - TVM_FFI_UNREACHABLE(); -} - -/*! - * \brief helper function to get reflection method function by method info - * - * \param type_key The type key. - * \param method_name The name of the method. - * \return The method function. - */ -inline Function GetMethod(std::string_view type_key, const char* method_name) { - const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name); - return AnyView::CopyFromTVMFFIAny(info->method).cast(); -} - -/*! - * \brief Visit each field info of the type info and run callback. - * - * \tparam Callback The callback function type. - * - * \param type_info The type info. - * \param callback The callback function. - * - * \note This function calls both the child and parent type info. - */ -template -inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) { - using ResultType = decltype(callback(type_info->fields)); - static_assert(std::is_same_v, "Callback must return void"); - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - callback(parent_info->fields + j); - } - } - for (int i = 0; i < type_info->num_fields; ++i) { - callback(type_info->fields + i); - } -} - -/*! - * \brief Visit each field info of the type info and run callback which returns bool for early stop. - * - * \tparam Callback The callback function type, which returns bool for early stop. - * - * \param type_info The type info. - * \param callback_with_early_stop The callback function. - * \return true if any of early stop is triggered. - * - * \note This function calls both the child and parent type info and can be used for searching. - */ -template -inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info, - Callback callback_with_early_stop) { - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - if (callback_with_early_stop(parent_info->fields + j)) return true; - } - } - for (int i = 0; i < type_info->num_fields; ++i) { - if (callback_with_early_stop(type_info->fields + i)) return true; - } - return false; -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_ACCESSOR_H_ diff --git a/ffi/include/tvm/ffi/reflection/creator.h b/ffi/include/tvm/ffi/reflection/creator.h deleted file mode 100644 index 983b8034a3b1..000000000000 --- a/ffi/include/tvm/ffi/reflection/creator.h +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/creator.h - * \brief Reflection-based creator to create objects from type key and fields. - */ -#ifndef TVM_FFI_REFLECTION_CREATOR_H_ -#define TVM_FFI_REFLECTION_CREATOR_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { -/*! - * \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection. - */ -class ObjectCreator { - public: - explicit ObjectCreator(std::string_view type_key) - : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {} - - explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) { - int32_t type_index = type_info->type_index; - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have reflection registered"; - } - if (type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor, " - << "as a result cannot be created via reflection"; - } - } - - /** - * \brief Create an object from a map of fields. - * \param fields The fields of the object. - * \return The created object. - */ - Any operator()(const Map& fields) const { - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - size_t match_field_count = 0; - ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) { - String field_name(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (fields.count(field_name) != 0) { - Any field_value = fields[field_name]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); - ++match_field_count; - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "`"; - } - }); - if (match_field_count == fields.size()) return ObjectRef(ptr); - // report error that checks if contains extra fields that are not in the type - auto check_field_name = [&](const String& field_name) { - bool found = false; - ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) { - if (field_name.compare(field_info->name) == 0) { - found = true; - return true; - } - return false; - }); - return found; - }; - for (const auto& [field_name, _] : fields) { - if (!check_field_name(field_name)) { - TVM_FFI_THROW(TypeError) << "Type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "` does not have field `" << field_name << "`"; - } - } - TVM_FFI_UNREACHABLE(); - } - - private: - const TVMFFITypeInfo* type_info_; -}; -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_CREATOR_H_ diff --git a/ffi/include/tvm/ffi/reflection/registry.h b/ffi/include/tvm/ffi/reflection/registry.h deleted file mode 100644 index 107a6e77592b..000000000000 --- a/ffi/include/tvm/ffi/reflection/registry.h +++ /dev/null @@ -1,498 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/registry.h - * \brief Registry of reflection metadata. - */ -#ifndef TVM_FFI_REFLECTION_REGISTRY_H_ -#define TVM_FFI_REFLECTION_REGISTRY_H_ - -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -/*! \brief Reflection namespace */ -namespace reflection { - -/*! \brief Trait that can be used to set field info */ -struct FieldInfoTrait {}; - -/*! - * \brief Trait that can be used to set field default value - */ -class DefaultValue : public FieldInfoTrait { - public: - explicit DefaultValue(Any value) : value_(value) {} - - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { - info->default_value = AnyView(value_).CopyToTVMFFIAny(); - info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; - } - - private: - Any value_; -}; - -/* - * \brief Trait that can be used to attach field flag - */ -class AttachFieldFlag : public FieldInfoTrait { - public: - /*! - * \brief Attach a field flag to the field - * - * \param flag The flag to be set - * - * \return The trait object. - */ - explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} - - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); - } - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore); - } - - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; } - - private: - int32_t flag_; -}; - -/*! - * \brief Get the byte offset of a class member field. - * - * \tparam The original class. - * \tparam T the field type. - * - * \param field_ptr A class member pointer - * \returns The byteoffset - */ -template -TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { - int64_t field_offset_to_class = - reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); - return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); -} - -class ReflectionDefBase { - protected: - template - static int FieldGetter(void* field, TVMFFIAny* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int FieldSetter(void* field, const TVMFFIAny* value) { - TVM_FFI_SAFE_CALL_BEGIN(); - if constexpr (std::is_same_v) { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value); - } else { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); - } - TVM_FFI_SAFE_CALL_END(); - } - - template - static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - ObjectPtr obj = make_object(); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); - TVM_FFI_SAFE_CALL_END(); - } - - template - TVM_FFI_INLINE static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); - } - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) { - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) { - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](Class target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { - // call method pointer - return (const_cast(target)->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { - // call method pointer - return (target->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) { - return ffi::Function::FromTyped(std::forward(func), name); - } -}; - -class GlobalDef : public ReflectionDefBase { - public: - /* - * \brief Define a global function. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - GlobalDef& def(const char* name, Func&& func, Extra&&... extra) { - RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), - std::forward(extra)...); - return *this; - } - - /* - * \brief Define a global function in ffi::PackedArgs format. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) { - RegisterFunc(name, ffi::Function::FromPacked(func), std::forward(extra)...); - return *this; - } - - /* - * \brief Expose a class method as a global function. - * - * An argument will be added to the first position if the function is not static. - * - * \tparam Class The class type. - * \tparam Func The function type. - * - * \param name The name of the method. - * \param func The function to be registered. - * - * \return The reflection definition. - */ - template - GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { - RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), - std::forward(extra)...); - return *this; - } - - private: - template - void RegisterFunc(const char* name, ffi::Function func, Extra&&... extra) { - TVMFFIMethodInfo info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.type_schema = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - // obtain the method function - info.method = AnyView(func).CopyToTVMFFIAny(); - // apply method info traits - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0)); - } -}; - -template -class ObjectDef : public ReflectionDefBase { - public: - template - explicit ObjectDef(ExtraArgs&&... extra_args) - : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { - RegisterExtraInfo(std::forward(extra_args)...); - } - - /*! - * \brief Define a readonly field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { - RegisterField(name, field_ptr, false, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a read-write field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { - static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); - RegisterField(name, field_ptr, true, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { - RegisterMethod(name, false, std::forward(func), std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a static method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { - RegisterMethod(name, true, std::forward(func), std::forward(extra)...); - return *this; - } - - private: - template - void RegisterExtraInfo(ExtraArgs&&... extra_args) { - TVMFFITypeMetadata info; - info.total_size = sizeof(Class); - info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; - info.creator = nullptr; - info.doc = TVMFFIByteArray{nullptr, 0}; - if constexpr (std::is_default_constructible_v) { - info.creator = ObjectCreatorDefault; - } - // apply extra info traits - ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info)); - } - - template - void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable, - ExtraArgs&&... extra_args) { - static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); - TVMFFIFieldInfo info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.field_static_type_index = TypeToFieldStaticTypeIndex::value; - // store byte offset and setter, getter - // so the same setter can be reused for all the same type - info.offset = GetFieldByteOffsetToObject(field_ptr); - info.size = sizeof(T); - info.alignment = alignof(T); - info.flags = 0; - if (writable) { - info.flags |= kTVMFFIFieldFlagBitMaskWritable; - } - info.getter = FieldGetter; - info.setter = FieldSetter; - // initialize default value to nullptr - info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); - info.doc = TVMFFIByteArray{nullptr, 0}; - info.type_schema = TVMFFIByteArray{nullptr, 0}; - // apply field info traits - ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); - // call register - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); - } - - // register a method - template - void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { - TVMFFIMethodInfo info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.type_schema = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - if (is_static) { - info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; - } - // obtain the method function - Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - info.method = AnyView(method).CopyToTVMFFIAny(); - // apply method info traits - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); - } - - int32_t type_index_; - const char* type_key_; -}; - -template >> -class TypeAttrDef : public ReflectionDefBase { - public: - template - explicit TypeAttrDef(ExtraArgs&&... extra_args) - : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {} - - /* - * \brief Define a function-valued type attribute. - * - * \tparam Func The function type. - * - * \param name The name of the function. - * \param func The function to be registered. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef& def(const char* name, Func&& func) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - ffi::Function ffi_func = - GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - - /* - * \brief Define a constant-valued type attribute. - * - * \tparam T The type of the value. - * - * \param name The name of the attribute. - * \param value The value of the attribute. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef& attr(const char* name, T value) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - - private: - int32_t type_index_; - const char* type_key_; -}; - -/*! - * \brief Ensure the type attribute column is presented in the system. - * - * \param name The name of the type attribute. - */ -inline void EnsureTypeAttrColumn(std::string_view name) { - TVMFFIByteArray name_array = {name.data(), name.size()}; - AnyView any_view(nullptr); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array, - reinterpret_cast(&any_view))); -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_REGISTRY_H_ diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h deleted file mode 100644 index 7c89038cc24e..000000000000 --- a/ffi/include/tvm/ffi/rvalue_ref.h +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/rvalue_ref.h - * \brief Helper class to define rvalue reference type. - */ -#ifndef TVM_FFI_RVALUE_REF_H_ -#define TVM_FFI_RVALUE_REF_H_ - -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Helper class to define rvalue reference type. - * - * By default, FFI pass all values by lvalue reference. - * - * However, we do allow users to intentionally mark a function parameter - * as RValueRef. In such cases, the caller can choose to pass parameter - * wrapped by RValueRef to the function. In which case the parameter - * can be directly moved by the callee. The caller can also choose to pass - * a normal lvalue to the function, in such case a copy will be triggered. - * - * To keep FFI checking overhead minimal, we do not handle case when rvalue - * is passed, but the callee did not declare the parameter as RValueRef. - * - * This design allows us to still leverage move semantics for parameters that - * need copy on write scenarios (and requires an unique copy). - * - * \code - * - * void Example() { - * auto append = Function::FromTyped([](RValueRef> ref, int val) -> Array { - * Array arr = *std::move(ref); - * assert(arr.unique()); - * arr.push_back(val); - * return arr; - * }); - * Array a = Array({1, 2}); - * // as we use rvalue ref to move a into append - * // we keep a single copy of the Array without creating new copies during copy-on-write - * a = append(RvalueRef(std::move(a)), 3); - * assert(a.size() == 3); - * } - * - * \endcode - */ -template >> -class RValueRef { - public: - /*! \brief only allow move constructor from rvalue of T */ - explicit RValueRef(TObjRef&& data) - : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} - - /*! \brief return the data as rvalue */ - TObjRef operator*() && { return TObjRef(std::move(data_)); } - - private: - mutable ObjectPtr data_; - - template - friend struct TypeTraits; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const RValueRef& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIObjectRValueRef; - result->zero_padding = 0; - // store the address of the ObjectPtr, which allows us to move the value - // and set the original ObjectPtr to nullptr - result->v_ptr = &(src.data_); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - return "RValueRef<" + TypeTraits::GetMismatchTypeInfo(&tmp_any) + ">"; - } else { - return TypeTraits::GetMismatchTypeInfo(src); - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // first try rvalue conversion - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - // fast path, storage type matches, direct move the rvalue ref - if (TypeTraits::CheckAnyStrict(&tmp_any)) { - return RValueRef(TObjRef(std::move(*rvalue_ref))); - } - if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - return RValueRef(*std::move(opt)); - } - return std::nullopt; - } - // try lvalue conversion - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { - return RValueRef(*std::move(opt)); - } else { - return std::nullopt; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "RValueRef<" + TypeTraits::TypeStr() + ">"; - } -}; -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_RVALUE_REF_H_ diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h deleted file mode 100644 index fe84b6154706..000000000000 --- a/ffi/include/tvm/ffi/string.h +++ /dev/null @@ -1,987 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/string.h - * \brief Runtime Bytes and String types. - */ -#ifndef TVM_FFI_STRING_H_ -#define TVM_FFI_STRING_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -// Note: We place string in tvm/ffi instead of tvm/ffi/container -// because string itself needs special handling and is an inherent -// core component for return string handling. -// The following dependency relation holds -// any -> string -> object - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base class for bytes and string objects. - */ -class BytesObjBase : public Object, public TVMFFIByteArray {}; - -/*! - * \brief An object representing bytes. - * \note We use separate object for bytes to follow python convention - * and indicate passing of raw bytes. - * Bytes can be converted from/to string. - */ -class BytesObj : public BytesObjBase { - public: - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIBytes; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIBytes; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(BytesObj, Object); -}; - -/*! \brief An object representing string. It's POD type. */ -class StringObj : public BytesObjBase { - public: - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr; - static constexpr const char* _type_key = StaticTypeKey::kTVMFFIStr; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object); -}; - -// String moved from std::string -// without having to trigger a copy -template -class BytesObjStdImpl : public Base { - public: - explicit BytesObjStdImpl(std::string other) : data_{other} { - this->data = data_.data(); - this->size = data_.size(); - } - - private: - std::string data_; -}; - -/*! - * \brief Helper cell class that can be used to back small string - * \note Do not use directly, use String or Bytes instead - */ -class BytesBaseCell { - public: - BytesBaseCell() { - // initialize to none - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - explicit BytesBaseCell(std::nullopt_t) { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - BytesBaseCell(const BytesBaseCell& other) : data_(other.data_) { // NOLINT(*) - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); - } - } - - BytesBaseCell(BytesBaseCell&& other) : data_(other.data_) { // NOLINT(*) - other.data_.type_index = TypeIndex::kTVMFFINone; - } - - BytesBaseCell& operator=(const BytesBaseCell& other) { - BytesBaseCell(other).swap(*this); // NOLINT(*) - return *this; - } - - BytesBaseCell& operator=(BytesBaseCell&& other) { - BytesBaseCell(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - ~BytesBaseCell() { - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); - } - } - - /*! - * \brief Check if the cell is null - * \return true if the cell is null, false otherwise - */ - bool operator==(std::nullopt_t) const { return data_.type_index == TypeIndex::kTVMFFINone; } - - /*! - * \brief Check if the cell is not null - * \return true if the cell is not null, false otherwise - */ - bool operator!=(std::nullopt_t) const { return data_.type_index != TypeIndex::kTVMFFINone; } - - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(BytesBaseCell& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - const char* data() const noexcept { - if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - return data_.v_bytes; - } else { - return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->data; - } - } - - size_t size() const noexcept { - if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - return data_.small_str_len; - } else { - return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->size; - } - } - - template - void InitFromStd(std::string&& other, int32_t large_type_index) { - // needs to be reset to none first for exception safety - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); - ObjectPtr ptr = make_object>(std::move(other)); - data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); - data_.type_index = large_type_index; - } - - /*! - * \brief Create a new empty space for a string - * \param size The size of the string - * \param small_type_index The type index for the small string - * \param large_type_index The type index for the large string - * \note always reserve one byte for \0 compactibility - * \return A pointer to the empty space - */ - template - char* InitSpaceForSize(size_t size, int32_t small_type_index, int32_t large_type_index) { - size_t kMaxSmallBytesLen = sizeof(int64_t) - 1; - // first zero the content, this is important for exception safety - data_.type_index = small_type_index; - data_.zero_padding = 0; - if (size <= kMaxSmallBytesLen) { - // set up the size accordingly - data_.small_str_len = static_cast(size); - return data_.v_bytes; - } else { - // allocate from heap - ObjectPtr ptr = make_inplace_array_object(size + 1); - char* dest_data = reinterpret_cast(ptr.get()) + sizeof(LargeObj); - ptr->data = dest_data; - ptr->size = size; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); - data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); - // now reset the type index to str - data_.type_index = large_type_index; - return dest_data; - } - } - - void InitTypeIndex(int32_t type_index) { data_.type_index = type_index; } - - void MoveToAny(TVMFFIAny* result) { - *result = data_; - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - TVMFFIAny CopyToTVMFFIAny() const { return data_; } - - static BytesBaseCell CopyFromAnyView(const TVMFFIAny* src) { - BytesBaseCell result(*src); - if (result.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(result.data_.v_obj); - } - return result; - } - - static BytesBaseCell MoveFromAny(TVMFFIAny* src) { - BytesBaseCell result(*src); - src->type_index = TypeIndex::kTVMFFINone; - src->zero_padding = 0; - src->v_int64 = 0; - return result; - } - - private: - explicit BytesBaseCell(TVMFFIAny data) : data_(data) {} - /*! \brief internal backing data */ - TVMFFIAny data_; -}; -} // namespace details - -/*! - * \brief Managed reference of byte array. - */ -class Bytes { - public: - /*! \brief default constructor */ - Bytes() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallBytes); } - /*! - * \brief constructor from size - * - * \param other a char array. - */ - Bytes(const char* data, size_t size) { this->InitData(data, size); } - /*! - * \brief constructor from TVMFFIByteArray - * - * \param other a char array. - */ - Bytes(TVMFFIByteArray bytes) { // NOLINT(*) - this->InitData(bytes.data, bytes.size); - } - /*! - * \brief constructor from std::string - * - * \param other a char array. - */ - Bytes(const std::string& other) { // NOLINT(*) - this->InitData(other.data(), other.size()); - } - /*! - * \brief constructor from std::string - * - * \param other a char array. - */ - Bytes(std::string&& other) { // NOLINT(*) - data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIBytes); - } - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(Bytes& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - template - Bytes& operator=(T&& other) { - // copy-and-swap idiom - Bytes(std::forward(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const { return data_.size(); } - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const { return data_.data(); } - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { return std::string{data(), size()}; } - - /*! - * \brief Compare two char sequence - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * \return int zero if both char sequences compare equal. negative if this - * appear before other, positive otherwise. - */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { - if (lhs == rhs && lhs_count == rhs_count) return 0; - - for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { - if (lhs[i] < rhs[i]) return -1; - if (lhs[i] > rhs[i]) return 1; - } - if (lhs_count < rhs_count) { - return -1; - } else if (lhs_count > rhs_count) { - return 1; - } else { - return 0; - } - } - /*! - * \brief Compare two char sequence for equality - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * - * \return true if the two char sequences are equal, false otherwise. - */ - static bool memequal(const void* lhs, const void* rhs, size_t lhs_count, size_t rhs_count) { - return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0); - } - - private: - template - friend struct TypeTraits; - template - friend class Optional; - // internal backing cell - details::BytesBaseCell data_; - // create a new String from TVMFFIAny, must keep private - explicit Bytes(details::BytesBaseCell data) : data_(data) {} - char* InitSpaceForSize(size_t size) { - return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallBytes, - TypeIndex::kTVMFFIBytes); - } - void InitData(const char* data, size_t size) { - char* dest_data = InitSpaceForSize(size); - std::memcpy(dest_data, data, size); - // mainly to be compat with string - dest_data[size] = '\0'; - } -}; - -/*! - * \brief String container class. - */ -class String { - public: - /*! - * \brief avoid misuse of nullptr - */ - String(std::nullptr_t) = delete; // NOLINT(*) - /*! - * \brief constructor - */ - String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); } - // constructors from Any - String(const String& other) = default; // NOLINT(*) - String(String&& other) = default; // NOLINT(*) - String& operator=(const String& other) = default; // NOLINT(*) - String& operator=(String&& other) = default; // NOLINT(*) - - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(String& other) noexcept { // NOLINT(*) - std::swap(data_, other.data_); - } - - String& operator=(const std::string& other) { - String(other).swap(*this); // NOLINT(*) - return *this; - } - String& operator=(std::string&& other) { - String(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - String& operator=(const char* other) { - String(other).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief constructor from raw string - * - * \param other a char array. - */ - String(const char* other, size_t size) { this->InitData(other, size); } - - /*! - * \brief constructor from raw string - * - * \param other a char array. - * \note This constructor is marked as explicit to avoid implicit conversion - * of nullptr value here to string, which then was used in comparison - */ - String(const char* other) { // NOLINT(*) - this->InitData(other, std::char_traits::length(other)); - } - /*! - * \brief Construct a new string object - * \param other The std::string object to be copied - */ - String(const std::string& other) { // NOLINT(*) - this->InitData(other.data(), other.size()); - } - - /*! - * \brief Construct a new string object - * \param other The std::string object to be moved - */ - String(std::string&& other) { // NOLINT(*) - // exception safety, first set to none so if exception is thrown - // destructor works correctly - data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIStr); - } - - /*! - * \brief constructor from TVMFFIByteArray - * - * \param other a TVMFFIByteArray. - */ - explicit String(TVMFFIByteArray other) { this->InitData(other.data, other.size); } - - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const noexcept { return data_.data(); } - - /*! - * \brief Returns a pointer to the char array in the string. - * - * \return const char* - */ - const char* c_str() const noexcept { return data(); } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const noexcept { return data_.size(); } - - /*! - * \brief Compares this String object to other - * - * \param other The String to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const String& other) const { - return Bytes::memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this String object to other - * - * \param other The string to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const std::string& other) const { - return Bytes::memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this to other - * - * \param other The character array to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const char* other) const { - const char* this_data = data(); - size_t this_size = size(); - for (size_t i = 0; i < this_size; ++i) { - // other is shorter than this - if (other[i] == '\0') return 1; - if (this_data[i] < other[i]) return -1; - if (this_data[i] > other[i]) return 1; - } - // other equals this - if (other[this_size] == '\0') return 0; - // other longer than this - return -1; - } - - /*! - * \brief Compares this to other - * - * \param other The TVMFFIByteArray to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const TVMFFIByteArray& other) const { - return Bytes::memncmp(data(), other.data, size(), other.size); - } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t length() const { return size(); } - - /*! - * \brief Retun if the string is empty - * - * \return true if empty, false otherwise. - */ - bool empty() const { return size() == 0; } - - /*! - * \brief Read an element. - * \param pos The position at which to read the character. - * - * \return The char at position - */ - char at(size_t pos) const { - if (pos < size()) { - return data()[pos]; - } else { - throw std::out_of_range("tvm::String index out of bounds"); - } - } - - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { return std::string{data(), size()}; } - - private: - template - friend struct TypeTraits; - template - friend class Optional; - // internal backing cell - details::BytesBaseCell data_; - // create a new String from TVMFFIAny, must keep private - explicit String(details::BytesBaseCell data) : data_(data) {} - /*! - * \brief Create a new empty space for a string - * \param size The size of the string - * \return A pointer to the empty space - */ - char* InitSpaceForSize(size_t size) { - return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallStr, - TypeIndex::kTVMFFIStr); - } - void InitData(const char* data, size_t size) { - char* dest_data = InitSpaceForSize(size); - std::memcpy(dest_data, data, size); - dest_data[size] = '\0'; - } - /*! - * \brief Concatenate two char sequences - * - * \param lhs Pointers to the lhs char array - * \param lhs_size The size of the lhs char array - * \param rhs Pointers to the rhs char array - * \param rhs_size The size of the rhs char array - * - * \return The concatenated char sequence - */ - static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { - String ret; - // disable stringop-overflow and restrict warnings - // gcc may produce false positive when we enable dest_data returned from small string path - // Because compiler is not able to detect the condition that the path is only triggered via - // size < kMaxSmallStrLen and can report it as a overflow case. -#if (__GNUC__) && !(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstringop-overflow" -#pragma GCC diagnostic ignored "-Wrestrict" -#endif - char* dest_data = ret.InitSpaceForSize(lhs_size + rhs_size); - std::memcpy(dest_data, lhs, lhs_size); - std::memcpy(dest_data + lhs_size, rhs, rhs_size); - dest_data[lhs_size + rhs_size] = '\0'; -#if (__GNUC__) && !(__clang__) -#pragma GCC diagnostic pop -#endif - return ret; - } - // Overload + operator - friend String operator+(const String& lhs, const String& rhs); - friend String operator+(const String& lhs, const std::string& rhs); - friend String operator+(const std::string& lhs, const String& rhs); - friend String operator+(const String& lhs, const char* rhs); - friend String operator+(const char* lhs, const String& rhs); -}; - -/*! \brief Convert TVMFFIByteArray to std::string_view */ -TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { - return std::string_view(str.data, str.size); -} - -template <> -inline constexpr bool use_default_type_traits_v = false; - -// specialize to enable implicit conversion from TVMFFIByteArray* -template <> -struct TypeTraits : public TypeTraitsBase { - // bytes can be union type of small bytes and object, so keep it as any - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - - TVM_FFI_INLINE static void CopyToAnyView(const Bytes& src, TVMFFIAny* result) { - *result = src.data_.CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(Bytes src, TVMFFIAny* result) { - src.data_.MoveToAny(result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFISmallBytes || - src->type_index == TypeIndex::kTVMFFIBytes; - } - - TVM_FFI_INLINE static Bytes CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); - } - - TVM_FFI_INLINE static Bytes MoveFromAnyAfterCheck(TVMFFIAny* src) { - return Bytes(details::BytesBaseCell::MoveFromAny(src)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - return Bytes(*static_cast(src->v_ptr)); - } - if (src->type_index == TypeIndex::kTVMFFISmallBytes || - src->type_index == TypeIndex::kTVMFFIBytes) { - return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "bytes"; } -}; - -template <> -inline constexpr bool use_default_type_traits_v = false; - -// specialize to enable implicit conversion from const char* -template <> -struct TypeTraits : public TypeTraitsBase { - // string can be union type of small string and object, so keep it as any - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - - TVM_FFI_INLINE static void CopyToAnyView(const String& src, TVMFFIAny* result) { - *result = src.data_.CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(String src, TVMFFIAny* result) { - src.data_.MoveToAny(result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFISmallStr || - src->type_index == TypeIndex::kTVMFFIStr; - } - - TVM_FFI_INLINE static String CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return String(details::BytesBaseCell::CopyFromAnyView(src)); - } - - TVM_FFI_INLINE static String MoveFromAnyAfterCheck(TVMFFIAny* src) { - return String(details::BytesBaseCell::MoveFromAny(src)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIRawStr) { - return String(src->v_c_str); - } - if (src->type_index == TypeIndex::kTVMFFISmallStr || src->type_index == TypeIndex::kTVMFFIStr) { - return String(details::BytesBaseCell::CopyFromAnyView(src)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "str"; } -}; - -// const char*, requirement: not nullable, do not retain ownership -template -struct TypeTraits : public TypeTraitsBase { - // NOTE: only enable implicit conversion into AnyView - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const char src[N], TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src; - } - - TVM_FFI_INLINE static void MoveToAny(const char src[N], TVMFFIAny* result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(src), result); - } -}; - -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const char* src, TVMFFIAny* result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src; - } - - TVM_FFI_INLINE static void MoveToAny(const char* src, TVMFFIAny* result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(src), result); - } - // Do not allow const char* in a container, so we do not need CheckAnyStrict - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIRawStr) { - return static_cast(src->v_c_str); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "const char*"; } -}; - -// TVMFFIByteArray, requirement: not nullable, do not retain ownership -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIByteArrayPtr; - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(TVMFFIByteArray* src, TVMFFIAny* result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIByteArrayPtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static void MoveToAny(TVMFFIByteArray* src, TVMFFIAny* result) { - TypeTraits::MoveToAny(Bytes(*src), result); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - return static_cast(src->v_ptr); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIByteArrayPtr; } -}; - -template <> -inline constexpr bool use_default_type_traits_v = false; - -template <> -struct TypeTraits - : public FallbackOnlyTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const std::string& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src.c_str(); - } - - TVM_FFI_INLINE static void MoveToAny(std::string src, TVMFFIAny* result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(std::move(src)), result); - } - - TVM_FFI_INLINE static std::string TypeStr() { return "std::string"; } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(const char* src) { - return std::string(src); - } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(TVMFFIByteArray* src) { - return std::string(src->data, src->size); - } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(Bytes src) { - return src.operator std::string(); - } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(String src) { - return src.operator std::string(); - } -}; - -inline String operator+(const String& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const std::string& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const std::string& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const char* lhs, const String& rhs) { - size_t lhs_size = std::strlen(lhs); - size_t rhs_size = rhs.size(); - return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const char* rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = std::strlen(rhs); - return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); -} - -// Overload < operator -inline bool operator<(std::nullptr_t, const String& rhs) = delete; -inline bool operator<(const String& lhs, std::nullptr_t) = delete; - -inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -// Overload > operator -inline bool operator>(std::nullptr_t, const String& rhs) = delete; -inline bool operator>(const String& lhs, std::nullptr_t) = delete; - -inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -// Overload <= operator -inline bool operator<=(std::nullptr_t, const String& rhs) = delete; -inline bool operator<=(const String& lhs, std::nullptr_t) = delete; - -inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -// Overload >= operator -inline bool operator>=(std::nullptr_t, const String& rhs) = delete; -inline bool operator>=(const String& lhs, std::nullptr_t) = delete; - -inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } - -inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } - -// delete Overload == operator for nullptr -inline bool operator==(const String& lhs, std::nullptr_t) = delete; -inline bool operator==(std::nullptr_t, const String& rhs) = delete; - -inline bool operator==(const String& lhs, const std::string& rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const std::string& lhs, const String& rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const String& lhs, const String& rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -// Overload != operator -inline bool operator!=(const String& lhs, std::nullptr_t) = delete; -inline bool operator!=(std::nullptr_t, const String& rhs) = delete; - -inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline std::ostream& operator<<(std::ostream& out, const String& input) { - out.write(input.data(), input.size()); - return out; -} -} // namespace ffi - -// Expose to the tvm namespace for usability -// Rationale: no ambiguity even in root -using ffi::Bytes; -using ffi::String; -} // namespace tvm - -namespace std { - -template <> -struct hash<::tvm::ffi::Bytes> { - std::size_t operator()(const ::tvm::ffi::Bytes& bytes) const { - return std::hash()(std::string_view(bytes.data(), bytes.size())); - } -}; - -template <> -struct hash<::tvm::ffi::String> { - std::size_t operator()(const ::tvm::ffi::String& str) const { - return std::hash()(std::string_view(str.data(), str.size())); - } -}; -} // namespace std -#endif // TVM_FFI_STRING_H_ diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h deleted file mode 100644 index b019935a6cc8..000000000000 --- a/ffi/include/tvm/ffi/type_traits.h +++ /dev/null @@ -1,751 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/object.h - * \brief A managed object in the TVM FFI. - */ -#ifndef TVM_FFI_TYPE_TRAITS_H_ -#define TVM_FFI_TYPE_TRAITS_H_ - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. - * - * The function specifications of TypeTraits - * - * - CopyToAnyView: Convert a value T to AnyView - * - MoveToAny: Move a value to Any - * - CheckAnyStrict: Check if a Any stores a result of CopyToAnyView of current T. - * - CopyFromAnyViewAfterCheck: Copy a value T from Any view after we pass CheckAnyStrict. - * - MoveFromAnyAfterCheck: Move a value T from Any storage after we pass CheckAnyStrict. - * - TryCastFromAnyView: Convert a AnyView to a T, we may apply type conversion. - * - GetMismatchTypeInfo: Get the type key of a type when TryCastFromAnyView fails. - * - TypeStr: Get the type key of a type - * - * It is possible that CheckAnyStrict is false but TryCastFromAnyView still works. - * - * For example, when Any x stores int, TypeTraits::CheckAnyStrict(x) will be false, - * but TypeTraits::TryCastFromAnyView(x) will return a corresponding float value - * via type conversion. - * - * CheckAnyStrict is mainly used in recursive container such as Array to - * decide if a new Array needed to be created via recursive conversion, - * or we can use the current container as is when converting to Array. - * - * A container array: Array satisfies the following invariant: - * - `all(TypeTraits::CheckAnyStrict(x) for x in the array)`. - */ -template -struct TypeTraits { - /*! \brief Whether the type is enabled in FFI. */ - static constexpr bool convert_enabled = false; - /*! \brief Whether the type can appear as a storage type in Container */ - static constexpr bool storage_enabled = false; -}; - -/*! - * \brief TypeTraits that removes const and reference keywords. - * \tparam T the original type - */ -template -using TypeTraitsNoCR = TypeTraits>>; - -template -inline constexpr bool use_default_type_traits_v = true; - -struct TypeTraitsBase { - static constexpr bool convert_enabled = true; - static constexpr bool storage_enabled = true; - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - // get mismatched type when result mismatches the trait. - // this function is called after TryCastFromAnyView fails - // to get more detailed type information in runtime - // especially when the error involves nested container type - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* source) { - return TypeIndexToTypeKey(source->type_index); - } -}; - -template -struct TypeToFieldStaticTypeIndex { - static constexpr int32_t value = TypeIndex::kTVMFFIAny; -}; - -template -struct TypeToFieldStaticTypeIndex::convert_enabled>> { - static constexpr int32_t value = TypeTraits::field_static_type_index; -}; - -template -struct TypeToRuntimeTypeIndex { - static int32_t v() { return TypeToFieldStaticTypeIndex::value; } -}; - -template -struct TypeToRuntimeTypeIndex>> { - static int32_t v() { return T::ContainerType::RuntimeTypeIndex(); } -}; - -// None -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone; - - TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFINone; - result->zero_padding = 0; - // invariant: the pointer field also equals nullptr - // this will simplify same_as comparisons and hash - result->v_int64 = 0; - } - - TVM_FFI_INLINE static void MoveToAny(std::nullptr_t, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFINone; - result->zero_padding = 0; - // invariant: the pointer field also equals nullptr - // this will simplify same_as comparisons and hash - result->v_int64 = 0; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFINone; - } - - TVM_FFI_INLINE static std::nullptr_t CopyFromAnyViewAfterCheck(const TVMFFIAny*) { - return nullptr; - } - - TVM_FFI_INLINE static std::nullptr_t MoveFromAnyAfterCheck(TVMFFIAny*) { return nullptr; } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return nullptr; - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFINone; } -}; - -/** - * \brief A type that forbids implicit conversion from int to bool - * - * This type is used to prevent implicit conversion from int to bool. - */ -class StrictBool { - public: - StrictBool(bool value) : value_(value) {} // NOLINT(*) - operator bool() const { return value_; } - - private: - bool value_; -}; - -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; - - TVM_FFI_INLINE static void CopyToAnyView(const StrictBool& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIBool; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(StrictBool src, TVMFFIAny* result) { - CopyToAnyView(src, result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIBool; - } - - TVM_FFI_INLINE static StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static StrictBool MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIBool) { - return StrictBool(static_cast(src->v_int64)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } -}; - -// Bool type, allow implicit casting from int -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; - - TVM_FFI_INLINE static void CopyToAnyView(const bool& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIBool; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(bool src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIBool; - } - - TVM_FFI_INLINE static bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static bool MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return static_cast(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } -}; - -// Integer POD values -template -struct TypeTraits>> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; - - TVM_FFI_INLINE static void CopyToAnyView(const Int& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIInt; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(Int src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIInt; - } - - TVM_FFI_INLINE static Int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static Int MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return Int(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } -}; - -// Enum Integer POD values -template -struct TypeTraits && - std::is_integral_v>>> - : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; - - TVM_FFI_INLINE static void CopyToAnyView(const IntEnum& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIInt; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(IntEnum src, TVMFFIAny* result) { - CopyToAnyView(src, result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIInt; - } - - TVM_FFI_INLINE static IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static IntEnum MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return static_cast(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } -}; - -// Float POD values -template -struct TypeTraits>> - : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFloat; - - TVM_FFI_INLINE static void CopyToAnyView(const Float& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIFloat; - result->zero_padding = 0; - result->v_float64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(Float src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIFloat; - } - - TVM_FFI_INLINE static Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_float64); - } - - TVM_FFI_INLINE static Float MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIFloat) { - return Float(src->v_float64); - } else if (src->type_index == TypeIndex::kTVMFFIInt || - src->type_index == TypeIndex::kTVMFFIBool) { - return Float(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIFloat; } -}; - -// void* -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIOpaquePtr; - - TVM_FFI_INLINE static void CopyToAnyView(void* src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIOpaquePtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static void MoveToAny(void* src, TVMFFIAny* result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIOpaquePtr; - } - - TVM_FFI_INLINE static void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return src->v_ptr; } - - TVM_FFI_INLINE static void* MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { - return static_cast(src->v_ptr); - } - if (src->type_index == TypeIndex::kTVMFFINone) { - return static_cast(nullptr); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIOpaquePtr; } -}; - -// Device -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDevice; - - TVM_FFI_INLINE static void CopyToAnyView(const DLDevice& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIDevice; - result->zero_padding = 0; - result->v_device = src; - } - - TVM_FFI_INLINE static void MoveToAny(DLDevice src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIDevice; - result->zero_padding = 0; - result->v_device = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIDevice; - } - - TVM_FFI_INLINE static DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return src->v_device; - } - - TVM_FFI_INLINE static DLDevice MoveFromAnyAfterCheck(TVMFFIAny* src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIDevice) { - return src->v_device; - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIDevice; } -}; - -// DLTensor*, requirement: not nullable, do not retain ownership -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; - - TVM_FFI_INLINE static void CopyToAnyView(DLTensor* src, TVMFFIAny* result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIDLTensorPtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; - } - - TVM_FFI_INLINE static DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return static_cast(src->v_ptr); - } - - TVM_FFI_INLINE static void MoveToAny(DLTensor*, TVMFFIAny*) { - TVM_FFI_THROW(RuntimeError) - << "DLTensor* cannot be held in Any as it does not retain ownership, use NDArray instead"; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { - return static_cast(src->v_ptr); - } else if (src->type_index == TypeIndex::kTVMFFINDArray) { - // Conversion from NDArray pointer to DLTensor - // based on the assumption that NDArray always follows the TVMFFIObject header - static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 8 bytes"); - return reinterpret_cast(reinterpret_cast(src->v_obj) + - sizeof(TVMFFIObject)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "DLTensor*"; } -}; - -// Traits for ObjectRef, None to ObjectRef will always fail. -// use std::optional instead for nullable references. -template -struct ObjectRefTypeTraitsBase : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject; - using ContainerType = typename TObjRef::ContainerType; - - TVM_FFI_INLINE static void CopyToAnyView(const TObjRef& src, TVMFFIAny* result) { - if constexpr (TObjRef::_type_is_nullable) { - if (!src.defined()) { - TypeTraits::CopyToAnyView(nullptr, result); - return; - } - } - TVMFFIObject* obj_ptr = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static void MoveToAny(TObjRef src, TVMFFIAny* result) { - if constexpr (TObjRef::_type_is_nullable) { - if (!src.defined()) { - TypeTraits::CopyToAnyView(nullptr, result); - return; - } - } - TVMFFIObject* obj_ptr = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(src)); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) return true; - } - return (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && - details::IsObjectInstance(src->type_index)); - } - - TVM_FFI_INLINE static TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); - } - } - return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); - } - - TVM_FFI_INLINE static TObjRef MoveFromAnyAfterCheck(TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); - } - } - // move out the object pointer - ObjectPtr obj_ptr = details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); - // reset the src to nullptr - TypeTraits::MoveToAny(nullptr, src); - return TObjRef(std::move(obj_ptr)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return TObjRef(ObjectPtr(nullptr)); - } - } - if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - if (details::IsObjectInstance(src->type_index)) { - return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); - } - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return ContainerType::_type_key; } -}; - -template -struct TypeTraits && - use_default_type_traits_v>> - : public ObjectRefTypeTraitsBase {}; - -/*! - * \brief Helper class that convert to T only via the FallbackTypes - * - * The conversion will go through the FallbackTypes in the order - * specified in the template parameter. - * \tparam T The type of the target value. - * \tparam FallbackTypes The type of the fallback value. - * \note TypeTraits must be derived from this class and define - * ConvertFallbackValue(FallbackType)->T for each FallbackType - */ -template -struct FallbackOnlyTraitsBase : public TypeTraitsBase { - // disable container for FallbackOnlyTraitsBase - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - return TryFallbackTypes(src); - } - - template - TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny* src) { - static_assert(!std::is_same_v, - "Using bool as FallbackType can cause bug because int will be detected as bool, " - "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { - return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryFallbackTypes(src); - } - return std::nullopt; - } -}; - -/*! - * \brief Helper class to define ObjectRef that can be auto-converted from a - * fallback type, the Traits must be derived from it - * and define a static methods named ConvertFallbackValue for each - * FallbackType - * - * The conversion will go through the FallbackTypes in the order - * specified in the template parameter. - * \tparam ObjectRefType The type of the ObjectRef. - * \tparam FallbackTypes The type of the fallback value. - */ -template -struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase { - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if (auto opt_obj = ObjectRefTypeTraitsBase::TryCastFromAnyView(src)) { - return *opt_obj; - } - // apply fallback types in TryCastFromAnyView - return TryFallbackTypes(src); - } - - template - TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny* src) { - static_assert(!std::is_same_v, - "Using bool as FallbackType can cause bug because int will be detected as bool, " - "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { - return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryFallbackTypes(src); - } - return std::nullopt; - } -}; - -// Traits for weak pointer of object -// NOTE: we require the weak pointer cast from - -template -struct TypeTraits>> - : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(TObject* src, TVMFFIAny* result) { - TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static void MoveToAny(TObject* src, TVMFFIAny* result) { - TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - // needs to increase ref because original weak ptr do not own the code - details::ObjectUnsafe::IncRefObjectHandle(result->v_obj); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && - details::IsObjectInstance(src->type_index); - } - - TVM_FFI_INLINE static TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - if constexpr (!std::is_const_v) { - static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); - } - return details::ObjectUnsafe::RawObjectPtrFromUnowned(src->v_obj); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - if constexpr (!std::is_const_v) { - static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); - } - if (CheckAnyStrict(src)) return CopyFromAnyViewAfterCheck(src); - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return TObject::_type_key; } -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const Optional& src, TVMFFIAny* result) { - if (src.has_value()) { - TypeTraits::CopyToAnyView(*src, result); - } else { - TypeTraits::CopyToAnyView(nullptr, result); - } - } - - TVM_FFI_INLINE static void MoveToAny(Optional src, TVMFFIAny* result) { - if (src.has_value()) { - TypeTraits::MoveToAny(*std::move(src), result); - } else { - TypeTraits::CopyToAnyView(nullptr, result); - } - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) return true; - return TypeTraits::CheckAnyStrict(src); - } - - TVM_FFI_INLINE static Optional CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return Optional(std::nullopt); - } - return TypeTraits::CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static Optional MoveFromAnyAfterCheck(TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return Optional(std::nullopt); - } - return TypeTraits::MoveFromAnyAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFINone) return Optional(std::nullopt); - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { - return Optional(*std::move(opt)); - } else { - // important to be explicit here - // because nullopt can convert to std::optional(nullopt) which indicate success - // return std::optional>(std::nullopt) to indicate failure - return std::optional>(std::nullopt); - } - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - return TypeTraits::GetMismatchTypeInfo(src); - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "Optional<" + TypeTraits::TypeStr() + ">"; - } -}; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_TYPE_TRAITS_H_ diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py deleted file mode 100644 index 1453aa95a67c..000000000000 --- a/ffi/scripts/benchmark_dlpack.py +++ /dev/null @@ -1,411 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This script is used to benchmark the API overhead of different -python FFI API calling overhead, through DLPack API. - -Specifically, we would like to understand the overall overhead -python/C++ API calls. The general goal is to understand the overall -space and get a sense of what are the possible operations. - -We pick function f(x, y, z) where x, y, z are length 1 tensors. -The benchmark is running in eager mode so we can see what is possible. -It is orthogonal to other optimizations. For example cudagraph can -eliminate these overheads completely. So the goal is to get a sense -of what is possible under eager mode. - -Summary of some takeaways: -- numpy.add roughly takes 0.36 us per call, which gives roughly what can - be done in python env. -- torch.add on gpu takes about 3.7us per call, giving us an idea of what - roughly we need to get to in eager mode. -- - -""" -import os -import torch -import numpy as np -from tvm import ffi as tvm_ffi -import time - - -def print_speed(name, speed): - print(f"{name:<40} {speed} sec/call") - - -def print_error(name, error): - print(f"{name:<40} {error}") - - -def baseline_torch_add(repeat): - """Run torch.add with one element""" - - def run_bench(device): - x = torch.arange(1, device=device) - y = torch.arange(1, device=device) - z = torch.arange(1, device=device) - - torch.add(x, y, out=z) - if device == "cuda": - torch.cuda.synchronize() - start = time.time() - for i in range(repeat): - torch.add(x, y, out=z) - # note we deliberately do not use torch.cuda.synchronize() - # because we want to see the overhead of the FFI call. - end = time.time() - print_speed(f"torch.add[{device}]", (end - start) / repeat) - - # rough take away: add on cuda roughly takes 3e-6 sec/call - run_bench("cpu") - run_bench("cuda") - - -def baseline_numpy_add(repeat): - """Run numpy.add with one element""" - x = np.arange(1) - y = np.arange(1) - z = np.arange(1) - - np.add(x, y, out=z) - start = time.time() - for i in range(repeat): - np.add(x, y, out=z) - end = time.time() - speed = (end - start) / repeat - print_speed("numpy.add", speed) - - -def baseline_cupy_add(repeat): - """Run cupy.add with one element""" - try: - import cupy - except ImportError: - # skip if cupy is not installed - return - x = cupy.arange(1) - y = cupy.arange(1) - z = cupy.arange(1) - - cupy.add(x, y, out=z) - start = time.time() - for i in range(repeat): - cupy.add(x, y, out=z) - end = time.time() - speed = (end - start) / repeat - print_speed("cupy.add", speed) - - -def tvm_ffi_nop(repeat): - """Overhead of tvm FFI python call via calling a NOP. - - testing.nop is defined in c++ and do nothing. - """ - nop = tvm_ffi.get_global_func("testing.nop") - x = tvm_ffi.from_dlpack(torch.arange(1)) - y = tvm_ffi.from_dlpack(torch.arange(1)) - z = tvm_ffi.from_dlpack(torch.arange(1)) - nop(x, y, z) - start = time.time() - for i in range(repeat): - y = tvm_ffi.from_dlpack(x) - end = time.time() - print_speed("tvm.ffi.nop", (end - start) / repeat) - - -def bench_ffi_nop_from_dlpack(name, x, y, z, repeat): - """run dlpack conversion + tvm.ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - nop = tvm_ffi.get_global_func("testing.nop") - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - - start = time.time() - for i in range(repeat): - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - end = time.time() - print_speed(name, (end - start) / repeat) - - -def tvm_ffi_nop_from_torch_dlpack(repeat): - """run dlpack conversion + tvm.ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - x = torch.arange(1) - y = torch.arange(1) - z = torch.arange(1) - bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(torch)", x, y, z, repeat) - - -def tvm_ffi_nop_from_numpy_dlpack(repeat): - """run dlpack conversion + tvm.ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - x = np.arange(1) - y = np.arange(1) - z = np.arange(1) - bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(numpy)", x, y, z, repeat) - - -def tvm_ffi_self_dlpack_nop(repeat): - """run dlpack conversion + tvm.ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - x = tvm_ffi.from_dlpack(torch.arange(1)) - y = tvm_ffi.from_dlpack(torch.arange(1)) - z = tvm_ffi.from_dlpack(torch.arange(1)) - bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(tvm)", x, y, z, repeat) - - -def bench_ffi_nop_from_dlpack(name, x, y, z, repeat): - """run dlpack conversion + tvm.ffi.nop - - Measures overhead of running dlpack for each args then invoke - """ - nop = tvm_ffi.get_global_func("testing.nop") - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - - start = time.time() - for i in range(repeat): - tx = tvm_ffi.from_dlpack(x) - ty = tvm_ffi.from_dlpack(y) - tz = tvm_ffi.from_dlpack(z) - nop(tx, ty, tz) - end = time.time() - print_speed(name, (end - start) / repeat) - - -def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat): - """ - Measures overhead of running dlpack for each args then invoke - but uses the legacy torch.utils.dlpack.to_dlpack API - - This helps to measure possible implementation overhead of torch. - """ - nop = tvm_ffi.get_global_func("testing.nop") - x = torch.arange(1) - y = torch.arange(1) - z = torch.arange(1) - - tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x)) - ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y)) - tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z)) - nop(tx, ty, tz) - - start = time.time() - for i in range(repeat): - tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x)) - ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y)) - tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z)) - nop(tx, ty, tz) - end = time.time() - speed = (end - start) / repeat - print_speed("tvm.ffi.nop+from_dlpack(torch.utils)", speed) - - -def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat): - """ - Measures overhead of running dlpack via auto convert by directly - take torch.Tensor as inputs. - """ - nop = tvm_ffi.get_global_func("testing.nop") - nop(x, y, z) - start = time.time() - for i in range(repeat): - nop(x, y, z) - end = time.time() - speed = (end - start) / repeat - print_speed(name, speed) - - -def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False): - """ - Measures overhead of running dlpack via auto convert by directly - take torch.Tensor as inputs. - """ - # use larger to ensure alignment req is met - x = torch.arange(1, device=device) - y = torch.arange(1, device=device) - z = torch.arange(1, device=device) - if stream: - with torch.cuda.stream(torch.cuda.Stream()): - bench_tvm_ffi_nop_autodlpack( - f"tvm.ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, repeat - ) - else: - bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat) - - -def tvm_ffi_nop_autodlpack_from_numpy(repeat): - """ - Measures overhead of running dlpack via auto convert by directly - take numpy.ndarray as inputs. - """ - # use larger to ensure alignment req is met - x = np.arange(256) - y = np.arange(256) - z = np.arange(256) - bench_tvm_ffi_nop_autodlpack("tvm.ffi.nop.autodlpack(numpy)", x, y, z, repeat) - - -def bench_to_dlpack(x, name, repeat): - x.__dlpack__() - start = time.time() - for i in range(repeat): - x.__dlpack__() - end = time.time() - speed = (end - start) / repeat - print_speed(name, speed) - - -def bench_to_dlpack_versioned(x, name, repeat, max_version=(1, 1)): - """ - Measures overhead of running dlpack with latest 1.1. - """ - try: - x.__dlpack__(max_version=max_version) - start = time.time() - for i in range(repeat): - x.__dlpack__(max_version=max_version) - end = time.time() - speed = (end - start) / repeat - print_speed(name, speed) - except Exception as e: - print_error(name, e) - - -def bench_torch_utils_to_dlpack(repeat): - """ - Measures overhead of running torch.utils.dlpack.to_dlpack - """ - x = torch.arange(1) - torch.utils.dlpack.to_dlpack(x) - start = time.time() - for i in range(repeat): - torch.utils.dlpack.to_dlpack(x) - end = time.time() - speed = (end - start) / repeat - print_speed("torch.utils.dlpack.to_dlpack", speed) - - -def torch_get_cuda_stream_native(device_id): - return torch.cuda.current_stream(device_id).cuda_stream - - -def load_torch_get_current_cuda_stream(): - """Create a faster get_current_cuda_stream for torch through cpp extension.""" - from torch.utils import cpp_extension - - source = """ - #include - - int64_t get_current_cuda_stream(int device_id) { - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); - // fast invariant, default stream is always 0 - if (stream.id() == 0) return 0; - // convert to cudaStream_t - return reinterpret_cast(static_cast(stream)); - } - """ - result = cpp_extension.load_inline( - name="get_current_cuda_stream", - cpp_sources=[source], - cuda_sources=[], - extra_cflags=["-O3"], - extra_include_paths=cpp_extension.include_paths("cuda"), - functions=["get_current_cuda_stream"], - ) - return result.get_current_cuda_stream - - -def bench_torch_get_current_stream(repeat, name, func): - """ - Measures overhead of running torch.cuda.current_stream - """ - x = torch.arange(1, device="cuda") - func(0) - start = time.time() - for i in range(repeat): - func(0) - end = time.time() - speed = (end - start) / repeat - print_speed(f"torch.cuda.current_stream[{name}]", speed) - - -def main(): - repeat = 10000 - print("-----------------------------") - print("Benchmark f(x, y, z) overhead") - print("-----------------------------") - baseline_numpy_add(repeat) - baseline_torch_add(repeat) - baseline_cupy_add(repeat) - tvm_ffi_nop(repeat) - tvm_ffi_nop_from_torch_dlpack(repeat) - tvm_ffi_nop_from_numpy_dlpack(repeat) - tvm_ffi_self_dlpack_nop(repeat) - tvm_ffi_nop_from_torch_utils_to_dlpack(repeat) - tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu") - tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda") - tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True) - - tvm_ffi_nop_autodlpack_from_numpy(repeat) - print("-------------------------------") - print("Benchmark x.__dlpack__ overhead") - print("-------------------------------") - bench_torch_utils_to_dlpack(repeat) - bench_to_dlpack(torch.arange(1), "torch.__dlpack__", repeat) - bench_to_dlpack(np.arange(1), "numpy.__dlpack__", repeat) - bench_to_dlpack(tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__", repeat) - print("---------------------------------------------------") - print("Benchmark x.__dlpack__(max_version=(1,1)) overhead") - print("---------------------------------------------------") - bench_to_dlpack_versioned(torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat) - bench_to_dlpack_versioned(np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat) - bench_to_dlpack_versioned( - tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat - ) - print("---------------------------------------------------") - print("Benchmark torch.get_cuda_stream[default stream]") - print("---------------------------------------------------") - bench_torch_get_current_stream(repeat, "cpp-extension", load_torch_get_current_cuda_stream()) - bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) - print("---------------------------------------------------") - print("Benchmark torch.get_cuda_stream[non-default stream]") - print("---------------------------------------------------") - with torch.cuda.stream(torch.cuda.Stream()): - bench_torch_get_current_stream( - repeat, "cpp-extension", load_torch_get_current_cuda_stream() - ) - bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) - - -if __name__ == "__main__": - main() diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc deleted file mode 100644 index 858cbd47c771..000000000000 --- a/ffi/src/ffi/container.cc +++ /dev/null @@ -1,88 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/container.cc - */ -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -// Favor struct outside function scope as MSVC may have bug for in fn scope struct. -class MapForwardIterFunctor { - public: - MapForwardIterFunctor(ffi::MapObj::iterator iter, ffi::MapObj::iterator end) - : iter_(iter), end_(end) {} - // 0 get current key - // 1 get current value - // 2 move to next: return true if success, false if end - Any operator()(int command) const { - if (command == 0) { - return (*iter_).first; - } else if (command == 1) { - return (*iter_).second; - } else { - ++iter_; - if (iter_ == end_) { - return false; - } - return true; - } - } - - private: - mutable ffi::MapObj::iterator iter_; - ffi::MapObj::iterator end_; -}; - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def_packed("ffi.Array", - [](ffi::PackedArgs args, Any* ret) { - *ret = Array(args.data(), args.data() + args.size()); - }) - .def("ffi.ArrayGetItem", [](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }) - .def("ffi.ArraySize", - [](const ffi::ArrayObj* n) -> int64_t { return static_cast(n->size()); }) - .def_packed("ffi.Map", - [](ffi::PackedArgs args, Any* ret) { - TVM_FFI_ICHECK_EQ(args.size() % 2, 0); - Map data; - for (int i = 0; i < args.size(); i += 2) { - data.Set(args[i], args[i + 1]); - } - *ret = data; - }) - .def("ffi.MapSize", - [](const ffi::MapObj* n) -> int64_t { return static_cast(n->size()); }) - .def("ffi.MapGetItem", [](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }) - .def("ffi.MapCount", - [](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }) - .def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> ffi::Function { - return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end())); - }); -}); -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc deleted file mode 100644 index e119f7733044..000000000000 --- a/ffi/src/ffi/dtype.cc +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Get the custom type name for a given type code. - */ -inline String DLDataTypeCodeGetCustomTypeName(DLDataTypeCode type_code) { - static Function fget_custom_type_name = Function::GetGlobalRequired("dtype.get_custom_type_name"); - return fget_custom_type_name(static_cast(type_code)).cast(); -} - -/*! - * \brief Get the custom type name for a given type code. - * \param str The string to parse. - * \param scan The scan pointer. - * \return The custom type name. - */ -inline int ParseCustomDataTypeCode(const std::string_view& str, const char** scan) { - TVM_FFI_ICHECK(str.substr(0, 6) == "custom") << "Not a valid custom datatype string"; - auto tmp = str.data(); - TVM_FFI_ICHECK(str.data() == tmp); - *scan = str.data() + 6; - TVM_FFI_ICHECK(str.data() == tmp); - if (**scan != '[') - TVM_FFI_THROW(ValueError) << "expected opening brace after 'custom' type in" << str; - TVM_FFI_ICHECK(str.data() == tmp); - *scan += 1; - TVM_FFI_ICHECK(str.data() == tmp); - size_t custom_name_len = 0; - TVM_FFI_ICHECK(str.data() == tmp); - while (*scan + custom_name_len <= str.data() + str.length() && - *(*scan + custom_name_len) != ']') { - ++custom_name_len; - } - TVM_FFI_ICHECK(str.data() == tmp); - if (*(*scan + custom_name_len) != ']') { - TVM_FFI_THROW(ValueError) << "expected closing brace after 'custom' type in" << str; - } - TVM_FFI_ICHECK(str.data() == tmp); - *scan += custom_name_len + 1; - TVM_FFI_ICHECK(str.data() == tmp); - auto type_name = str.substr(7, custom_name_len); - TVM_FFI_ICHECK(str.data() == tmp); - static Function fget_custom_type_code = Function::GetGlobalRequired("dtype.get_custom_type_code"); - return fget_custom_type_code(std::string(type_name)).cast(); -} - -/* - * \brief Convert a DLDataTypeCode to a string. - * \param os The output stream. - * \param type_code The DLDataTypeCode to convert. - */ -inline void PrintDLDataTypeCodeAsStr(std::ostream& os, DLDataTypeCode type_code) { // NOLINT(*) - switch (static_cast(type_code)) { - case kDLInt: { - os << "int"; - break; - } - case kDLUInt: { - os << "uint"; - break; - } - case kDLFloat: { - os << "float"; - break; - } - case kDLOpaqueHandle: { - os << "handle"; - break; - } - case kDLBfloat: { - os << "bfloat"; - break; - } - case kDLFloat8_e3m4: { - os << "float8_e3m4"; - break; - } - case kDLFloat8_e4m3: { - os << "float8_e4m3"; - break; - } - case kDLFloat8_e4m3b11fnuz: { - os << "float8_e4m3b11fnuz"; - break; - } - case kDLFloat8_e4m3fn: { - os << "float8_e4m3fn"; - break; - } - case kDLFloat8_e4m3fnuz: { - os << "float8_e4m3fnuz"; - break; - } - case kDLFloat8_e5m2: { - os << "float8_e5m2"; - break; - } - case kDLFloat8_e5m2fnuz: { - os << "float8_e5m2fnuz"; - break; - } - case kDLFloat8_e8m0fnu: { - os << "float8_e8m0fnu"; - break; - } - case kDLFloat6_e2m3fn: { - os << "float6_e2m3fn"; - break; - } - case kDLFloat6_e3m2fn: { - os << "float6_e3m2fn"; - break; - } - case kDLFloat4_e2m1fn: { - os << "float4_e2m1fn"; - break; - } - default: { - if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { - os << "custom[" << details::DLDataTypeCodeGetCustomTypeName(type_code) << "]"; - } else { - TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" - << static_cast(type_code); - } - TVM_FFI_UNREACHABLE(); - } - } -} -} // namespace details - -/*! - * \brief Printer function for DLDataType. - * \param os The output stream. - * \param dtype The DLDataType to print. - * \return The output stream. - */ -inline std::string DLDataTypeToString_(DLDataType dtype) { // NOLINT(*) - if (dtype.bits == 1 && dtype.lanes == 1 && dtype.code == kDLUInt) { - return "bool"; - } - // specially handle void - if (dtype.code == kDLOpaqueHandle && dtype.lanes == 0 && dtype.bits == 0) { - return ""; - } - - std::ostringstream os; - if (dtype.code >= kDLExtCustomBegin) { - os << "custom[" - << details::DLDataTypeCodeGetCustomTypeName(static_cast(dtype.code)) << "]"; - } else { - os << details::DLDataTypeCodeAsCStr(static_cast(dtype.code)); - } - if (dtype.code == kDLOpaqueHandle) return os.str(); - int16_t lanes = static_cast(dtype.lanes); - if (dtype.code < kDLFloat8_e3m4) { - os << static_cast(dtype.bits); - } - if (lanes > 1) { - os << 'x' << lanes; - } else if (lanes < -1) { - os << "xvscalex" << -lanes; - } - return os.str(); -} - -/*! - * \brief Parse a string to a DLDataType. - * \param str The string to convert. - * \return The corresponding DLDataType. - */ -inline DLDataType StringViewToDLDataType_(std::string_view str) { - DLDataType dtype; - // handle void type - if (str.length() == 0 || str == "void") { - dtype.code = kDLOpaqueHandle; - dtype.bits = 0; - dtype.lanes = 0; - return dtype; - } - // set the default values; - dtype.bits = 32; - dtype.lanes = 1; - const char* scan; - - auto parse_float = [&](const std::string_view& str, int offset, int code, int bits) { - dtype.code = static_cast(code); - dtype.bits = static_cast(bits); - scan = str.data() + offset; - char* endpt = nullptr; - if (*scan == 'x') { - dtype.lanes = static_cast(strtoul(scan + 1, &endpt, 10)); - scan = endpt; - } - if (scan != str.data() + str.length()) { - TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; - } - return dtype; - }; - - if (str.compare(0, 3, "int") == 0) { - dtype.code = kDLInt; - scan = str.data() + 3; - } else if (str.compare(0, 4, "uint") == 0) { - dtype.code = kDLUInt; - scan = str.data() + 4; - } else if (str.compare(0, 5, "float") == 0) { - if (str.compare(5, 2, "8_") == 0) { - if (str.compare(7, 4, "e3m4") == 0) { - return parse_float(str, 11, kDLFloat8_e3m4, 8); - } else if (str.compare(7, 4, "e4m3") == 0) { - if (str.compare(11, 7, "b11fnuz") == 0) { - return parse_float(str, 18, kDLFloat8_e4m3b11fnuz, 8); - } else if (str.compare(11, 2, "fn") == 0) { - if (str.compare(13, 2, "uz") == 0) { - return parse_float(str, 15, kDLFloat8_e4m3fnuz, 8); - } else { - return parse_float(str, 13, kDLFloat8_e4m3fn, 8); - } - } else { - return parse_float(str, 11, kDLFloat8_e4m3, 8); - } - } else if (str.compare(7, 8, "e5m2fnuz") == 0) { - return parse_float(str, 15, kDLFloat8_e5m2fnuz, 8); - } else if (str.compare(7, 4, "e5m2") == 0) { - return parse_float(str, 11, kDLFloat8_e5m2, 8); - } else if (str.compare(7, 7, "e8m0fnu") == 0) { - return parse_float(str, 14, kDLFloat8_e8m0fnu, 8); - } else { - TVM_FFI_THROW(ValueError) << "unknown float8 type `" << str << '`'; - TVM_FFI_UNREACHABLE(); - } - } else if (str.compare(5, 2, "6_") == 0) { - if (str.compare(7, 6, "e2m3fn") == 0) { - return parse_float(str, 13, kDLFloat6_e2m3fn, 6); - } else if (str.compare(7, 6, "e3m2fn") == 0) { - return parse_float(str, 13, kDLFloat6_e3m2fn, 6); - } else { - TVM_FFI_THROW(ValueError) << "unknown float6 type `" << str << '`'; - TVM_FFI_UNREACHABLE(); - } - } else if (str.compare(5, 2, "4_") == 0) { - // kFloat4_e2m1fn - if (str.compare(7, 6, "e2m1fn") == 0) { - return parse_float(str, 13, kDLFloat4_e2m1fn, 4); - } else { - TVM_FFI_THROW(ValueError) << "unknown float4 type `" << str << '`'; - TVM_FFI_UNREACHABLE(); - } - } else { - dtype.code = kDLFloat; - scan = str.data() + 5; - } - } else if (str.compare(0, 6, "handle") == 0) { - dtype.code = kDLOpaqueHandle; - dtype.bits = 64; // handle uses 64 bit by default. - scan = str.data() + 6; - } else if (str == "bool") { - dtype.code = kDLUInt; - dtype.bits = 1; - dtype.lanes = 1; - return dtype; - } else if (str.compare(0, 6, "bfloat") == 0) { - dtype.code = kDLBfloat; - dtype.bits = 16; - scan = str.data() + 6; - } else if (str.compare(0, 6, "custom") == 0) { - dtype.code = static_cast(details::ParseCustomDataTypeCode(str, &scan)); - } else { - scan = str.data(); - TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; - } - char* xdelim; // emulate sscanf("%ux%u", bits, lanes) - uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); - if (bits != 0) dtype.bits = bits; - int scalable_multiplier = 1; - if (strncmp(xdelim, "xvscale", 7) == 0) { - scalable_multiplier = -1; - xdelim += 7; - } - char* endpt = xdelim; - if (*xdelim == 'x') { - dtype.lanes = static_cast(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10)); - } - if (endpt != str.data() + str.length()) { - TVM_FFI_THROW(ValueError) << "unknown dtype `" << str << '`'; - } - return dtype; -} - -} // namespace ffi -} // namespace tvm - -int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::StringViewToDLDataType_(std::string_view(str->data, str->size)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype)); - tvm::ffi::TypeTraits::MoveToAny(std::move(out_str), out); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/error.cc b/ffi/src/ffi/error.cc deleted file mode 100644 index 9fd81c47890a..000000000000 --- a/ffi/src/ffi/error.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/error.cc - * \brief Error handling implementation - */ -#include -#include - -namespace tvm { -namespace ffi { - -class SafeCallContext { - public: - void SetRaised(TVMFFIObjectHandle error) { - last_error_ = - details::ObjectUnsafe::ObjectPtrFromUnowned(static_cast(error)); - } - - void SetRaisedByCstr(const char* kind, const char* message, const TVMFFIByteArray* traceback) { - Error error(kind, message, traceback); - last_error_ = details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(error)); - } - - void MoveFromRaised(TVMFFIObjectHandle* result) { - result[0] = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(last_error_)); - } - - static SafeCallContext* ThreadLocal() { - static thread_local SafeCallContext ctx; - return &ctx; - } - - private: - ObjectPtr last_error_; -}; - -} // namespace ffi -} // namespace tvm - -void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message) { - // NOTE: run traceback here to simplify the depth of tracekback - tvm::ffi::SafeCallContext::ThreadLocal()->SetRaisedByCstr(kind, message, TVM_FFI_TRACEBACK_HERE); -} - -void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) { - tvm::ffi::SafeCallContext::ThreadLocal()->SetRaised(error); -} - -void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) { - tvm::ffi::SafeCallContext::ThreadLocal()->MoveFromRaised(result); -} - -TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message, - const TVMFFIByteArray* traceback) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - tvm::ffi::Error error(std::string(kind->data, kind->size), - std::string(message->data, message->size), - std::string(traceback->data, traceback->size)); - TVMFFIObjectHandle out = - tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(error)); - return out; - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIErrorCreate); -} diff --git a/ffi/src/ffi/extra/buffer_stream.h b/ffi/src/ffi/extra/buffer_stream.h deleted file mode 100644 index f6f162676607..000000000000 --- a/ffi/src/ffi/extra/buffer_stream.h +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file buffer_stream.h - * \brief Internal minimal stream helper to read from a buffer. - */ -#ifndef TVM_FFI_EXTRA_BUFFER_STREAM_H_ -#define TVM_FFI_EXTRA_BUFFER_STREAM_H_ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Lightweight stream helper to read from a buffer. - */ -class BufferInStream { - public: - /*! - * \brief constructor - * \param p_buffer the head pointer of the memory region. - * \param buffer_size the size of the memorybuffer - */ - BufferInStream(const void* data, size_t size) - : data_(reinterpret_cast(data)), size_(size) {} - /*! - * \brief Reads raw from stream. - * \param ptr pointer to the data to be read - * \param size the size of the data to be read - * \return the number of bytes read - */ - size_t Read(void* ptr, size_t size) { - size_t nread = std::min(size_ - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, data_ + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - /*! - * \brief Reads arithmetic data from stream in endian-aware manner. - * \param data data to be read - * \tparam T the data type to be read - * \return whether the read was successful - */ - template >> - bool Read(T* data) { - bool ret = Read(static_cast(data), sizeof(T)) == sizeof(T); // NOLINT(*) - if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { - ByteSwap(&data, sizeof(T), 1); - } - return ret; - } - /*! - * \brief Reads an array of data from stream in endian-aware manner. - * \param data data to be read - * \param size the size of the data to be read - * \return whether the read was successful - */ - template >> - bool ReadArray(T* data, size_t size) { - bool ret = - this->Read(static_cast(data), sizeof(T) * size) == sizeof(T) * size; // NOLINT(*) - if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { - ByteSwap(data, sizeof(T), size); - } - return ret; - } - /*! - * \brief Reads a string from stream. - * \param data data to be read - * \return whether the read was successful - */ - bool Read(std::string* data) { - // use uint64_t to ensure platform independent size - uint64_t size = 0; - if (!this->Read(&size)) return false; - data->resize(size); - if (!this->Read(data->data(), size)) return false; - return true; - } - /*! - * \brief Reads a vector of data from stream in endian-aware manner. - * \param data data to be read - * \return whether the read was successful - */ - template >> - bool Read(std::vector* data) { - uint64_t size = 0; - if (!this->Read(&size)) return false; - data->resize(size); - return this->ReadArray(data->data(), size); - } - - private: - /*! \brief in memory buffer */ - const char* data_; - /*! \brief size of the buffer */ - size_t size_; - /*! \brief current pointer */ - size_t curr_ptr_{0}; -}; // class BytesInStream - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_BUFFER_STREAM_H_ diff --git a/ffi/src/ffi/extra/env_c_api.cc b/ffi/src/ffi/extra/env_c_api.cc deleted file mode 100644 index 121cc9a3ccde..000000000000 --- a/ffi/src/ffi/extra/env_c_api.cc +++ /dev/null @@ -1,148 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/env_c_api.cc - * \brief Environment C API implementation. - */ -#include -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Execution environment specific API registry. - * - * This registry stores C API function pointers about - * execution environment(e.g. python) specific API function that - * we need for specific low-level handling(e.g. signal checking). - * - * We only stores the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). Always consider use the Function FFI when possible - * in other cases. - */ -class EnvCAPIRegistry { - public: - /*! - * \brief Callback to check if signals have been sent to the process and - * if so invoke the registered signal handler in the frontend environment. - * - * When running FFI in another language (Python), the signal handler - * may not be immediately executed, but instead the signal is marked - * in the interpreter state (to ensure non-blocking of the signal handler). - * - * \return 0 if no error happens, -1 if error happens. - */ - typedef int (*F_PyErr_CheckSignals)(); - - /*! \brief Callback to increment/decrement the python ref count */ - typedef void (*F_Py_IncDefRef)(void*); - - /*! - * \brief PyErr_CheckSignal function - */ - F_PyErr_CheckSignals pyerr_check_signals = nullptr; - - /*! - \brief PyGILState_Ensure function - */ - void* (*py_gil_state_ensure)() = nullptr; - - /*! - \brief PyGILState_Release function - */ - void (*py_gil_state_release)(void*) = nullptr; - - static EnvCAPIRegistry* Global() { - static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); - return inst; - } - - // register environment(e.g. python) specific api functions - void Register(const String& symbol_name, void* fptr) { - if (symbol_name == "PyErr_CheckSignals") { - Update(symbol_name, &pyerr_check_signals, fptr); - } else if (symbol_name == "PyGILState_Ensure") { - Update(symbol_name, &py_gil_state_ensure, fptr); - } else if (symbol_name == "PyGILState_Release") { - Update(symbol_name, &py_gil_state_release, fptr); - } else { - TVM_FFI_THROW(ValueError) << "Unknown env API " + symbol_name; - } - } - - int EnvCheckSignals() { - // check python signal to see if there are exception raised - if (pyerr_check_signals != nullptr) { - // The C++ env comes without gil, so we need to grab gil here - WithGIL context(this); - if ((*pyerr_check_signals)() != 0) { - // The error will let FFI know that the frontend environment - // already set an error. - return -1; - } - } - return 0; - } - - private: - // update the internal API table - template - void Update(const String& symbol_name, FType* target, void* ptr) { - FType ptr_casted = reinterpret_cast(ptr); - target[0] = ptr_casted; - } - - struct WithGIL { - explicit WithGIL(EnvCAPIRegistry* self) : self(self) { - TVM_FFI_ICHECK(self->py_gil_state_ensure); - TVM_FFI_ICHECK(self->py_gil_state_release); - gil_state = self->py_gil_state_ensure(); - } - ~WithGIL() { - if (self && gil_state) { - self->py_gil_state_release(gil_state); - } - } - WithGIL(const WithGIL&) = delete; - WithGIL(WithGIL&&) = delete; - WithGIL& operator=(const WithGIL&) = delete; - WithGIL& operator=(WithGIL&&) = delete; - - EnvCAPIRegistry* self = nullptr; - void* gil_state = nullptr; - }; -}; -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvCheckSignals() { return tvm::ffi::EnvCAPIRegistry::Global()->EnvCheckSignals(); } - -/*! - * \brief Register a symbol into the from the surrounding env. - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String s_name(name); - tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/extra/json_parser.cc b/ffi/src/ffi/extra/json_parser.cc deleted file mode 100644 index 8bd372699dad..000000000000 --- a/ffi/src/ffi/extra/json_parser.cc +++ /dev/null @@ -1,731 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/json/parser.cc - * - * \brief A minimalistic JSON parser based on ffi values. - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -namespace json { - -/*! - * \brief Helper class to parse a JSON string. - * - * Keep leaf level string/number parse also in context. - */ -class JSONParserContext { - public: - JSONParserContext(const char* begin, const char* end) : begin_(begin), cur_(begin), end_(end) { - last_line_begin_ = cur_; - } - - /*! - * \brief Peek the current character. - * \return The current character, or -1 if the end of the string is reached. - */ - int Peek() const { - return (cur_ != end_ ? static_cast(*reinterpret_cast(cur_)) : -1); - } - - /*! - * \brief Skip the next char that we know is not a space - * - * \note Caller must explicitly call SkipSpaces first or use - * Peek already that confirms char is not any space char. - */ - void SkipNextAssumeNoSpace() { ++cur_; } - - /*! - * \brief Get the current position. - * \return The current position. - */ - const char* GetCurrentPos() const { return cur_; } - - /*! - * \brief Set the current position for better error message - * \param pos The new position. - * \note implementation can do it as no-op if needed - */ - void SetCurrentPosForBetterErrorMsg(const char* pos) { cur_ = pos; } - - /*! - * \brief Skip the space characters. - * \note This function does not check if the end of the string is reached. - */ - void SkipSpaces() { - while (cur_ != end_) { - if (!(*cur_ == ' ' || *cur_ == '\t' || *cur_ == '\n' || *cur_ == '\r')) { - break; - } - if (*cur_ == '\n') { - ++line_counter_; - last_line_begin_ = cur_ + 1; - } - ++cur_; - } - } - - /*! - * \brief Check if the next characters match the given string. - * \param str The string to match. - * \param len The length of the string. - * \return True if the next characters match the given string, false otherwise. - */ - bool MatchLiteral(const char* pattern, int len) { - const char* pend = pattern + len; - const char* ptr = pattern; - for (; ptr != pend && cur_ != end_; ++ptr, ++cur_) { - if (*ptr != *cur_) { - return false; - } - } - // we get to the end of the pattern and match is successful - return ptr == pend; - } - - /* - * \brief Parse the next strin starting with a double quote. - * \param out The output string. - * \return Whether the next string parsing is successful. - */ - bool NextString(json::Value* out) { - // NOTE: we keep string parsing logic here to allow some special - // optimizations for simple string that do not e - const char* start_pos = cur_; - TVM_FFI_ICHECK(*cur_ == '\"'); - // skip first double quote - ++cur_; - // the loop focuses on simple string without escape characters - for (; cur_ != end_; ++cur_) { - if (*cur_ == '\"') { - *out = String(start_pos + 1, cur_ - start_pos - 1); - ++cur_; - return true; - } - if (*cur_ < ' ' || *cur_ == '\\') { - // fallback to full string handling - return this->NextStringWithFullHandling(out, start_pos); - } - } - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorUnterminatedString(); - return false; - } - - /*! - * \brief Parse the next number. - * \param out The output number. - * \return Whether the next number parsing is successful. - */ - bool NextNumber(json::Value* out) { - const char* start_pos = cur_; - if (cur_ == end_) { - this->SetErrorExpectingValue(); - return false; - } - // JSON number grammar: - // - // number = [ minus ] int [ frac ] [ exp ] - // decimal-point = %x2E ; . - // digit1-9 = %x31-39 ; 1-9 - // e = %x65 / %x45 ; e E - // exp = e [ minus / plus ] 1*DIGIT - // frac = decimal-point 1*DIGIT - std::string temp_buffer; - bool maybe_int = true; - // parse [minus], cross check for Infinity/NaN/-Infinity - if (*cur_ == '-') { - temp_buffer.push_back('-'); - ++cur_; - if (cur_ != end_ && *cur_ == 'I') { - if (this->MatchLiteral("Infinity", 8)) { - *out = FastMathSafeNegInf(); - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } - } else if (*cur_ == 'I') { - if (this->MatchLiteral("Infinity", 8)) { - *out = FastMathSafePosInf(); - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } else if (*cur_ == 'N') { - if (this->MatchLiteral("NaN", 3)) { - *out = FastMathSafeNaN(); - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } - // read in all parts that are possibly part of a number - while (cur_ != end_) { - char next_char = *cur_; - if ((next_char >= '0' && next_char <= '9') || next_char == 'e' || next_char == 'E' || - next_char == '+' || next_char == '-' || next_char == '.') { - temp_buffer.push_back(next_char); - if (next_char == '.' || next_char == 'e' || next_char == 'E') { - maybe_int = false; - } - ++cur_; - } else { - break; - } - } - if (temp_buffer.empty()) { - this->SetErrorExpectingValue(); - return false; - } - // parse from temp_buffer_ - if (maybe_int) { - // now try to parse the number as int64 - char* end_ptr; - errno = 0; - intmax_t int_val = strtoimax(temp_buffer.data(), &end_ptr, 10); - if (errno == 0 && int_val >= std::numeric_limits::min() && - int_val <= std::numeric_limits::max() && - end_ptr == temp_buffer.data() + temp_buffer.size()) { - *out = static_cast(int_val); - return true; - } - } - { - // now try to parse number as double - char* end_ptr; - errno = 0; - double double_val = strtod(temp_buffer.data(), &end_ptr); - if (errno == 0 && end_ptr == temp_buffer.data() + temp_buffer.size()) { - *out = double_val; - return true; - } else { - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorExpectingValue(); - return false; - } - } - } - - /*! - * \brief Get the current line context. - * \return The current line context. - */ - String GetSyntaxErrorContext(std::string err_prefix) const { - int64_t column = static_cast(cur_ - last_line_begin_) + 1; - int64_t char_pos = static_cast(cur_ - begin_); - if (err_prefix.empty()) { - err_prefix = "Syntax error"; - } - err_prefix += ": line " + std::to_string(line_counter_) + " column " + std::to_string(column) + - " (char " + std::to_string(char_pos) + ")"; - return String(err_prefix); - } - - std::string FinalizeErrorMsg() { - if (error_msg_.empty()) { - SetErrorDefault(); - } - return std::string(error_msg_); - } - - void SetErrorDefault() { error_msg_ = GetSyntaxErrorContext("Syntax error near"); } - - void SetErrorExpectingValue() { error_msg_ = GetSyntaxErrorContext("Expecting value"); } - - void SetErrorInvalidControlCharacter() { - error_msg_ = GetSyntaxErrorContext("Invalid control character at"); - } - - void SetErrorUnterminatedString() { - error_msg_ = GetSyntaxErrorContext("Unterminated string starting at"); - } - - void SetErrorInvalidUnicodeEscape() { - error_msg_ = GetSyntaxErrorContext("Invalid \\uXXXX escape"); - } - - void SetErrorInvalidSurrogatePair() { - error_msg_ = GetSyntaxErrorContext("Invalid surrogate pair of \\uXXXX escapes"); - } - - void SetErrorInvalidEscape() { error_msg_ = GetSyntaxErrorContext("Invalid \\escape"); } - - void SetErrorExtraData() { error_msg_ = GetSyntaxErrorContext("Extra data"); } - - void SetErrorExpectingPropertyName() { - error_msg_ = GetSyntaxErrorContext("Expecting property name enclosed in double quotes"); - } - - void SetErrorExpectingColon() { error_msg_ = GetSyntaxErrorContext("Expecting \':\' delimiter"); } - - void SetErrorExpectingComma() { error_msg_ = GetSyntaxErrorContext("Expecting \',\' delimiter"); } - - private: - static double FastMathSafePosInf() { -#ifdef __FAST_MATH__ - union { - uint64_t from; - double to; - } u; - u.from = 0x7FF0000000000000ULL; // write "from", read "to" - return u.to; -#else - return std::numeric_limits::infinity(); -#endif - } - - static double FastMathSafeNegInf() { -#ifdef __FAST_MATH__ - union { - uint64_t from; - double to; - } u; - u.from = 0xFFF0000000000000ULL; // write "from", read "to" - return u.to; -#else - return -std::numeric_limits::infinity(); -#endif - } - - static double FastMathSafeNaN() { -#ifdef __FAST_MATH__ - union { - uint64_t from; - double to; - } u; - u.from = 0x7FF8000000000000ULL; // write "from", read "to" - return u.to; -#else - return std::numeric_limits::quiet_NaN(); -#endif - } - - // Full string parsing with escape and unicode handling - bool NextStringWithFullHandling(Any* out, const char* start_pos) { - // copy over the prefix that was already parsed - std::string out_str(start_pos + 1, cur_ - start_pos - 1); - while (cur_ != end_) { - if (*cur_ < ' ') { - this->SetErrorInvalidControlCharacter(); - return false; - } - if (*cur_ == '\"') { - *out = String(std::move(out_str)); - ++cur_; - return true; - } - if (*cur_ == '\\') { - ++cur_; - switch (*cur_) { - // handle escape characters per JSON spec(RFC 8259) -#define HANDLE_ESCAPE_CHAR(pattern, val) \ - case pattern: \ - ++cur_; \ - out_str.push_back(val); \ - break - HANDLE_ESCAPE_CHAR('\"', '\"'); - HANDLE_ESCAPE_CHAR('\\', '\\'); - HANDLE_ESCAPE_CHAR('/', '/'); - HANDLE_ESCAPE_CHAR('b', '\b'); - HANDLE_ESCAPE_CHAR('f', '\f'); - HANDLE_ESCAPE_CHAR('n', '\n'); - HANDLE_ESCAPE_CHAR('r', '\r'); - HANDLE_ESCAPE_CHAR('t', '\t'); -#undef HANDLE_ESCAPE_CHAR - case 'u': { - const char* escape_pos = cur_; - // handle unicode code point - ++cur_; - int32_t first_i16, code_point = 0; - if (!Parse4Hex(&first_i16)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidUnicodeEscape(); - return false; - } - // Check if the first i16 is a UTF-16 surrogate pair - // - // Surrogate pair encoding rule: - // U' = yyyyyyyyyyxxxxxxxxxx // U - 0x10000 - // W1 = 110110yyyyyyyyyy // 0xD800 + yyyyyyyyyy - // W2 = 110111xxxxxxxxxx // 0xDC00 + xxxxxxxxxx - // - // Range of W1 and W2: - // 0xD800–0xDBFF for W1 - // 0xDC00–0xDFFF for W2 - // both W1 and W2 fit into 0xD800–0xDFFF - // Detect if the first i16 fit into range of W1/W2 - if (first_i16 >= 0xD800 && first_i16 <= 0xDFFF) { - // we are in the surrogate pair range - if (first_i16 >= 0xDC00) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidSurrogatePair(); - // we need to return false instead because this range is for W2 - return false; - } - if (!this->MatchLiteral("\\u", 2)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidSurrogatePair(); - return false; - } - escape_pos = cur_; - // get the value of the W2 (second i16) - int32_t second_i16; - if (!Parse4Hex(&second_i16)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidUnicodeEscape(); - return false; - } - if (!(second_i16 >= 0xDC00 && second_i16 <= 0xDFFF)) { - this->SetCurrentPosForBetterErrorMsg(escape_pos); - this->SetErrorInvalidSurrogatePair(); - return false; - } - // recover the code point - code_point = ((first_i16 - 0xD800) << 10) + (second_i16 - 0xDC00) + 0x10000; - } else { - // not a surrogate case, just assign as code point - code_point = first_i16; - } - // now need to push back the string based on UTF-8 encoding - // UTF-8 encoding rule: four cases - // ------------------------------------------------------------ - // Pattern | code point range - // ------------------------------------------------------------ - // 0xxxxxxx | 0x0 - 0x7F - // 110xxxxx 10xxxxxx | 0x80 - 0x7FF - // 1110xxxx 10xxxxxx 10xxxxxx | 0x800 - 0xFFFF - // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx | 0x10000 - end - // ------------------------------------------------------------ - if (code_point < 0x80) { - out_str.push_back(code_point); - } else if (code_point < 0x800) { - // first byte: 110xxxxx (5 effective bits) - // second byte: 10xxxxxx (6 effecive bits) - // shift by 6 bits to get the first bytes - out_str.push_back(0xC0 | (code_point >> 6)); - // mask by 6 effective bits - out_str.push_back(0x80 | (code_point & 0x3F)); - } else if (code_point < 0x10000) { - // first byte: 1110xxxx (4 effective bits) - // second byte: 10xxxxxx (6 effecive bits) - // third byte: 10xxxxxx (6 effecive bits) - // shift by 12 bits to get the first bytes - out_str.push_back(0xE0 | (code_point >> 12)); - // shift by 6 bits to get the second bytes, mask by 6 effective bits - out_str.push_back(0x80 | ((code_point >> 6) & 0x3F)); - // mask by 6 effective bits - out_str.push_back(0x80 | (code_point & 0x3F)); - } else { - // first byte: 11110xxx (3 effective bits) - // second byte: 10xxxxxx (6 effecive bits) - // third byte: 10xxxxxx (6 effecive bits) - // fourth byte: 10xxxxxx (6 effecive bits) - // shift by 18 bits to get the first bytes - out_str.push_back(0xF0 | (code_point >> 18)); - // shift by 12 bits to get the second bytes, mask by 6 effective bits - out_str.push_back(0x80 | ((code_point >> 12) & 0x3F)); - // shift by 6 bits to get the third bytes, mask by 6 effective bits - out_str.push_back(0x80 | ((code_point >> 6) & 0x3F)); - // mask by 6 effective bits - out_str.push_back(0x80 | (code_point & 0x3F)); - } - break; - } - default: { - this->SetErrorInvalidEscape(); - return false; - } - } - } else { - out_str.push_back(*cur_); - ++cur_; - } - } - this->SetCurrentPosForBetterErrorMsg(start_pos); - this->SetErrorUnterminatedString(); - return false; - } - /*! - * \brief Parse the four hex digits of a unicode code point per json spec. - * \param out_i16 The output i16 number - * \return True if four hex digits are parsed successfully, false otherwise. - */ - bool Parse4Hex(int32_t* out_i16) { - int32_t result = 0; - for (int i = 0; i < 4; ++i, ++cur_) { - int hex_val = *reinterpret_cast(cur_); - if (hex_val >= '0' && hex_val <= '9') { - hex_val -= '0'; - } else if (hex_val >= 'a' && hex_val <= 'f') { - hex_val -= 'a' - 0xa; - } else if (hex_val >= 'A' && hex_val <= 'F') { - hex_val -= 'A' - 0xa; - } else { - return false; - } - result = result * 16 + hex_val; - } - *out_i16 = result; - return true; - } - - /*! \brief The beginning of the string */ - const char* begin_; - /*! \brief The current pointer */ - const char* cur_; - /*! \brief End of the string */ - const char* end_; - /*! \brief The beginning of the last line */ - const char* last_line_begin_; - /*! \brief The error message */ - std::string error_msg_; - /*! \brief The line counter */ - int64_t line_counter_{1}; -}; - -class JSONParser { - public: - static json::Value Parse(const String& json_str, String* error_msg) { - JSONParser parser(json_str); - json::Value result; - if (parser.ParseValue(&result) && parser.ParseTail()) { - if (error_msg != nullptr) { - *error_msg = String(""); - } - return result; - } - if (error_msg != nullptr) { - *error_msg = parser.ctx_.FinalizeErrorMsg(); - TVM_FFI_ICHECK(!error_msg->empty()); - } else { - TVM_FFI_THROW(ValueError) << parser.ctx_.FinalizeErrorMsg(); - } - // note that when we don't throw, error msg is set to indicate - // an error happens - return nullptr; - } - - private: - explicit JSONParser(String json_str) : ctx_(json_str.data(), json_str.data() + json_str.size()) {} - - bool ParseTail() { - ctx_.SkipSpaces(); - // there are extra data in the tail - if (ctx_.Peek() != -1) { - ctx_.SetErrorExtraData(); - return false; - } - return true; - } - - bool ParseValue(json::Value* out) { - ctx_.SkipSpaces(); - // record start pos for cases where we might need to reset - // current position for better error message - auto start_pos = ctx_.GetCurrentPos(); - // check if the end of the string is reached - switch (ctx_.Peek()) { - case -1: { - ctx_.SetErrorExpectingValue(); - return false; - } - case '{': { - return ParseObject(out); - } - case '[': { - return ParseArray(out); - } - case '\"': { - return ctx_.NextString(out); - } - case 't': { - ctx_.SkipNextAssumeNoSpace(); - if (ctx_.MatchLiteral("rue", 3)) { - *out = true; - return true; - } else { - ctx_.SetCurrentPosForBetterErrorMsg(start_pos); - ctx_.SetErrorExpectingValue(); - return false; - } - } - case 'f': { - ctx_.SkipNextAssumeNoSpace(); - if (ctx_.MatchLiteral("alse", 4)) { - *out = false; - return true; - } else { - ctx_.SetCurrentPosForBetterErrorMsg(start_pos); - ctx_.SetErrorExpectingValue(); - return false; - } - } - case 'n': { - ctx_.SkipNextAssumeNoSpace(); - if (ctx_.MatchLiteral("ull", 3)) { - *out = nullptr; - return true; - } else { - ctx_.SetCurrentPosForBetterErrorMsg(start_pos); - ctx_.SetErrorExpectingValue(); - return false; - } - } - default: { - return ctx_.NextNumber(out); - } - } - return false; - } - - bool ParseObject(json::Value* out) { - size_t stack_top = object_temp_stack_.size(); - json::Object result; - ctx_.SkipNextAssumeNoSpace(); - ctx_.SkipSpaces(); - int next_char = ctx_.Peek(); - if (next_char == -1) { - ctx_.SetErrorExpectingPropertyName(); - return false; - } - // empty object - if (next_char == '}') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Object(); - return true; - } - // non-empty object - while ((next_char = ctx_.Peek()) != -1) { - if (next_char != '\"') { - ctx_.SetErrorExpectingPropertyName(); - return false; - } - json::Value key; - if (!ctx_.NextString(&key)) return false; - ctx_.SkipSpaces(); - if (ctx_.Peek() != ':') { - ctx_.SetErrorExpectingColon(); - return false; - } - ctx_.SkipNextAssumeNoSpace(); - json::Value value; - if (!ParseValue(&value)) return false; - object_temp_stack_.emplace_back(key, value); - // result.Set(key, value); - ctx_.SkipSpaces(); - if (ctx_.Peek() == '}') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Object(object_temp_stack_.begin() + stack_top, object_temp_stack_.end()); - // recover the stack to original state - object_temp_stack_.resize(stack_top); - return true; - } else if (ctx_.Peek() == ',') { - ctx_.SkipNextAssumeNoSpace(); - // must skip space so next iteration do not have to do so - ctx_.SkipSpaces(); - } else { - ctx_.SetErrorExpectingComma(); - return false; - } - } - return false; - } - - bool ParseArray(json::Value* out) { - size_t stack_top = array_temp_stack_.size(); - ctx_.SkipNextAssumeNoSpace(); - ctx_.SkipSpaces(); - int next_char = ctx_.Peek(); - if (next_char == -1) { - ctx_.SetErrorExpectingValue(); - return false; - } - // empty array - if (next_char == ']') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Array(); - return true; - } - // non-empty array - while ((next_char = ctx_.Peek()) != -1) { - json::Value value; - // no need to skip space here because we already skipped space - // at the beginning or in previous iteration - if (!ParseValue(&value)) return false; - array_temp_stack_.emplace_back(std::move(value)); - ctx_.SkipSpaces(); - next_char = ctx_.Peek(); - if (next_char == ',') { - ctx_.SkipNextAssumeNoSpace(); - // must skip space so next iteration do not have to do so - ctx_.SkipSpaces(); - } else if (next_char == ']') { - ctx_.SkipNextAssumeNoSpace(); - *out = json::Array(array_temp_stack_.begin() + stack_top, array_temp_stack_.end()); - // recover the stack - array_temp_stack_.resize(stack_top); - return true; - } else { - ctx_.SetErrorExpectingComma(); - return false; - } - } - return false; - } - - JSONParserContext ctx_; - // Temp stack for intermediate values - // we first create a persistent stack to store the parsed values - // then create the final array/object object with the precise size - std::vector array_temp_stack_; - std::vector> object_temp_stack_; -}; - -json::Value Parse(const String& json_str, String* error_msg) { - return JSONParser::Parse(json_str, error_msg); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.json.Parse", - [](const String& json_str) { return json::Parse(json_str); }); -}); - -} // namespace json -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/json_writer.cc b/ffi/src/ffi/extra/json_writer.cc deleted file mode 100644 index c2cd3f2f36d3..000000000000 --- a/ffi/src/ffi/extra/json_writer.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/json/writer.cc - * - * \brief A minimalistic JSON writer based on ffi values. - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -#define TVM_FFI_SNPRINTF _snprintf_s -#pragma warning(push) -#pragma warning(disable : 4244) -#pragma warning(disable : 4127) -#pragma warning(disable : 4702) -#else -#define TVM_FFI_SNPRINTF snprintf -#endif - -namespace tvm { -namespace ffi { -namespace json { - -class JSONWriter { - public: - static String Stringify(const json::Value& value, Optional indent) { - JSONWriter writer(indent.value_or(0)); - writer.WriteValue(value); - return String(std::move(writer.result_)); - } - - private: - explicit JSONWriter(int indent) : indent_(indent), out_iter_(result_) {} - - static bool FastMathSafeIsNaN(double x) { -#ifdef __FAST_MATH__ - // Bit-level NaN detection (IEEE 754 double) - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // NaN is encoded as all 1s in the exponent and non-zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - union { - double from; - uint64_t to; - } u; - u.from = x; // write "from", read "to" - uint64_t bits = u.to; - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - return (exponent == 0x7FF) && (mantissa != 0); -#else - // Safe to use std::isnan when fast-math is off - return std::isnan(x); -#endif - } - - static bool FastMathSafeIsInf(double x) { -#ifdef __FAST_MATH__ - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // Inf is encoded as all 1s in the exponent and zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - union { - double from; - uint64_t to; - } u; - u.from = x; // write "from", read "to" - uint64_t bits = u.to; - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - // inf is encoded as all 1s in the exponent and zero in the mantissa - return (exponent == 0x7FF) && (mantissa == 0); -#else - return std::isinf(x); -#endif - } - - void WriteValue(const json::Value& value) { - switch (value.type_index()) { - case TypeIndex::kTVMFFINone: { - WriteLiteral("null", 4); - break; - } - case TypeIndex::kTVMFFIBool: { - bool bool_value = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - if (bool_value) { - WriteLiteral("true", 4); - } else { - WriteLiteral("false", 5); - } - break; - } - case TypeIndex::kTVMFFIInt: { - WriteInt(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIFloat: { - WriteFloat(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFISmallStr: - case TypeIndex::kTVMFFIStr: { - WriteString(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIArray: { - WriteArray(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIMap: { - WriteObject(details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - default: { - TVM_FFI_THROW(ValueError) << "Unsupported type: `" << value.GetTypeKey() << "`"; - TVM_FFI_UNREACHABLE(); - } - } - } - - void WriteLiteral(const char* literal, int size) { - for (int i = 0; i < size; ++i) { - *out_iter_++ = literal[i]; - } - } - - void WriteInt(int64_t value) { - // the biggest possible string representation of -INT64_MIN - char buffer[sizeof("-9223372036854775808") + 1]; - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "%" PRId64, value); - WriteLiteral(buffer, size); - } - - void WriteFloat(double value) { - // largest possible string representation of a double is around 24 chars plus - // one null terminator keep 32 to be safe - char buffer[32]; - if (FastMathSafeIsNaN(value)) { - WriteLiteral("NaN", 3); - } else if (FastMathSafeIsInf(value)) { - if (value < 0) { - WriteLiteral("-Infinity", 9); - } else { - WriteLiteral("Infinity", 8); - } - } else { - double int_part; - // if the value can be represented as integer - if (std::fabs(value) < (1ULL << 53) && std::modf(value, &int_part) == 0) { - // always print an extra .0 for integer so integer numbers are printed as floats - // this helps us to distinguish between integer and float, which is not necessary - // but helps to ensure roundtrip property of the parser/printer in terms of int/float types - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "%.1f", int_part); - WriteLiteral(buffer, size); - } else { - // Save 17 decimal digits to avoid loss during loading JSON - // this is the maximum precision that can be represented in a double - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "%.17g", value); - WriteLiteral(buffer, size); - } - } - } - - void WriteString(const String& value) { - *out_iter_++ = '"'; - const char* data = value.data(); - const size_t size = value.size(); - for (size_t i = 0; i < size; ++i) { - switch (data[i]) { -// handle escape characters per JSON spec(RFC 8259) -#define HANDLE_ESCAPE_CHAR(pattern, val) \ - case pattern: \ - WriteLiteral(val, std::char_traits::length(val)); \ - break - HANDLE_ESCAPE_CHAR('\"', "\\\""); - HANDLE_ESCAPE_CHAR('\\', "\\\\"); - HANDLE_ESCAPE_CHAR('/', "\\/"); - HANDLE_ESCAPE_CHAR('\b', "\\b"); - HANDLE_ESCAPE_CHAR('\f', "\\f"); - HANDLE_ESCAPE_CHAR('\n', "\\n"); - HANDLE_ESCAPE_CHAR('\r', "\\r"); - HANDLE_ESCAPE_CHAR('\t', "\\t"); -#undef HANDLE_ESCAPE_CHAR - default: { - uint8_t u8_val = static_cast(data[i]); - // this is a control character, print as \uXXXX - if (u8_val < 0x20 || u8_val == 0x7f) { - char buffer[8]; - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "\\u%04x", - static_cast(data[i]) & 0xff); - WriteLiteral(buffer, size); - } else { - *out_iter_++ = data[i]; - } - break; - } - } - } - *out_iter_++ = '"'; - } - - void WriteArray(const json::Array& value) { - *out_iter_++ = '['; - if (indent_ != 0) { - total_indent_ += indent_; - } - for (size_t i = 0; i < value.size(); ++i) { - if (i != 0) { - *out_iter_++ = ','; - } - if (indent_ != 0) { - WriteIndent(); - } - WriteValue(value[i]); - } - if (indent_ != 0) { - total_indent_ -= indent_; - WriteIndent(); - } - *out_iter_++ = ']'; - } - - void WriteObject(const json::Object& value) { - *out_iter_++ = '{'; - if (indent_ != 0) { - total_indent_ += indent_; - } - int counter = 0; - for (const auto& [key, value] : value) { - if (counter++ != 0) { - *out_iter_++ = ','; - } - if (indent_ != 0) { - WriteIndent(); - } - auto opt_key = key.as(); - if (!opt_key.has_value()) { - TVM_FFI_THROW(ValueError) << "Expect key to be string, got `" << key.GetTypeKey() << "`"; - } - WriteString(*opt_key); - *out_iter_++ = ':'; - if (indent_ != 0) { - *out_iter_++ = ' '; - } - WriteValue(value); - } - if (indent_ != 0) { - total_indent_ -= indent_; - WriteIndent(); - } - *out_iter_++ = '}'; - } - - // Write a newline and indent the current level - void WriteIndent() { - *out_iter_++ = '\n'; - for (int i = 0; i < total_indent_; ++i) { - *out_iter_++ = ' '; - } - } - - int indent_ = 0; - int total_indent_ = 0; - std::string result_; - std::back_insert_iterator out_iter_; -}; - -String Stringify(const json::Value& value, Optional indent) { - return JSONWriter::Stringify(value, indent); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.json.Stringify", Stringify); -}); - -} // namespace json -} // namespace ffi -} // namespace tvm - -#undef TVM_FFI_SNPRINTF diff --git a/ffi/src/ffi/extra/library_module.cc b/ffi/src/ffi/extra/library_module.cc deleted file mode 100644 index 71c6da6f7cc4..000000000000 --- a/ffi/src/ffi/extra/library_module.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/library_module.cc - * - * \brief Library module implementation. - */ -#include -#include -#include - -#include "buffer_stream.h" -#include "module_internal.h" - -namespace tvm { -namespace ffi { - -class LibraryModuleObj final : public ModuleObj { - public: - explicit LibraryModuleObj(ObjectPtr lib) : lib_(lib) {} - - const char* kind() const final { return "library"; } - - /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { return Module::kBinarySerializable | Module::kRunnable; }; - - Optional GetFunction(const String& name) final { - TVMFFISafeCallType faddr; - faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); - // ensure the function keeps the Library Module alive - Module self_strong_ref = GetRef(this); - if (faddr != nullptr) { - return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, - ffi::Any* rv) { - TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); - TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), - args.size(), reinterpret_cast(rv))); - }); - } - return std::nullopt; - } - - private: - ObjectPtr lib_; -}; - -Module LoadModuleFromBytes(const std::string& kind, const Bytes& bytes) { - std::string loader_key = "ffi.Module.load_from_bytes." + kind; - const auto floader = tvm::ffi::Function::GetGlobal(loader_key); - if (!floader.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Library binary was created using {" << kind - << "} but a loader of that name is not registered. " - << "Make sure to have runtime that registers " << loader_key; - } - return (*floader)(bytes).cast(); -} - -/*! - * \brief Process libary binary to recover binary-serialized modules - * \param library_bin The binary embedded in the library. - * \param opt_lib The library, can be nullptr in which case we expect to deserialize - * all binary-serialized modules - * \param library_ctx_addr the pointer to library module as ctx addr - * \return the root module - * - */ -Module ProcessLibraryBin(const char* library_bin, ObjectPtr opt_lib, - void** library_ctx_addr = nullptr) { - // Layout of the library binary: - // ... - // key can be: "_lib", or a module kind - // - "_lib" indicate this location places the library module - // - other keys are module kinds - // Import tree structure (CSR structure of child indices): - // = > > - TVM_FFI_ICHECK(library_bin != nullptr); - uint64_t nbytes = 0; - for (size_t i = 0; i < sizeof(nbytes); ++i) { - uint64_t c = library_bin[i]; - nbytes |= (c & 0xffUL) << (i * 8); - } - - BufferInStream stream(library_bin + sizeof(nbytes), static_cast(nbytes)); - std::vector import_tree_indptr; - std::vector import_tree_child_indices; - TVM_FFI_ICHECK(stream.Read(&import_tree_indptr)); - TVM_FFI_ICHECK(stream.Read(&import_tree_child_indices)); - size_t num_modules = import_tree_indptr.size() - 1; - std::vector modules; - modules.reserve(num_modules); - - for (uint64_t i = 0; i < num_modules; ++i) { - std::string kind; - TVM_FFI_ICHECK(stream.Read(&kind)); - // "_lib" serves as a placeholder in the module import tree to indicate where - // to place the DSOModule - if (kind == "_lib") { - TVM_FFI_ICHECK(opt_lib != nullptr) << "_lib is not allowed during module serialization"; - auto lib_mod_ptr = make_object(opt_lib); - if (library_ctx_addr) { - *library_ctx_addr = lib_mod_ptr.get(); - } - modules.emplace_back(Module(lib_mod_ptr)); - } else { - std::string module_bytes; - TVM_FFI_ICHECK(stream.Read(&module_bytes)); - Module m = LoadModuleFromBytes(kind, Bytes(module_bytes)); - modules.emplace_back(m); - } - } - for (size_t i = 0; i < modules.size(); ++i) { - for (size_t j = import_tree_indptr[i]; j < import_tree_indptr[i + 1]; ++j) { - Array* module_imports = ModuleObj::InternalUnsafe::GetImports(modules[i].operator->()); - auto child_index = import_tree_child_indices[j]; - TVM_FFI_ICHECK(child_index < modules.size()); - module_imports->emplace_back(modules[child_index]); - } - } - return modules[0]; -} - -// registry to store context symbols -class ContextSymbolRegistry { - public: - void InitContextSymbols(ObjectPtr lib) { - for (const auto& [name, symbol] : context_symbols_) { - if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name.c_str()))) { - *symbol_addr = symbol; - } - } - } - - void VisitContextSymbols(const ffi::TypedFunction& callback) { - for (const auto& [name, symbol] : context_symbols_) { - callback(name, symbol); - } - } - - void Register(String name, void* symbol) { context_symbols_.emplace_back(name, symbol); } - - static ContextSymbolRegistry* Global() { - static ContextSymbolRegistry* inst = new ContextSymbolRegistry(); - return inst; - } - - private: - std::vector> context_symbols_; -}; - -void Module::VisitContextSymbols(const ffi::TypedFunction& callback) { - ContextSymbolRegistry::Global()->VisitContextSymbols(callback); -} - -Module CreateLibraryModule(ObjectPtr lib) { - const char* library_bin = - reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_bin)); - void** library_ctx_addr = - reinterpret_cast(lib->GetSymbol(ffi::symbol::tvm_ffi_library_ctx)); - - ContextSymbolRegistry::Global()->InitContextSymbols(lib); - if (library_bin != nullptr) { - // we have embedded binaries that needs to be deserialized - return ProcessLibraryBin(library_bin, lib, library_ctx_addr); - } else { - // Only have one single DSO Module - auto lib_mod_ptr = make_object(lib); - Module root_mod = Module(lib_mod_ptr); - if (library_ctx_addr) { - *library_ctx_addr = root_mod.operator->(); - } - return root_mod; - } -} - -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String s_name(name); - tvm::ffi::ContextSymbolRegistry::Global()->Register(s_name, symbol); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/extra/library_module_dynamic_lib.cc b/ffi/src/ffi/extra/library_module_dynamic_lib.cc deleted file mode 100644 index 25463a7e5f92..000000000000 --- a/ffi/src/ffi/extra/library_module_dynamic_lib.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file library_module_dynamic_lib.cc - * \brief Create library module to load from dynamic shared library. - */ -#include -#include -#include - -#include "module_internal.h" - -#if defined(_WIN32) -#include -#else -#include -#endif - -#if defined(__hexagon__) -extern "C" { -#include -} -#endif - -namespace tvm { -namespace ffi { - -class DSOLibrary final : public Library { - public: - explicit DSOLibrary(const String& name) { Load(name); } - ~DSOLibrary() { - if (lib_handle_) Unload(); - } - - void* GetSymbol(const char* name) final { return GetSymbol_(name); } - - private: - // private system dependent implementation - void* GetSymbol_(const char* name); - void Load(const String& name); - void Unload(); - -#if defined(_WIN32) - //! \brief Windows library handle - HMODULE lib_handle_{nullptr}; -#else - // \brief Linux library handle - void* lib_handle_{nullptr}; -#endif -}; - -#if defined(_WIN32) - -void* DSOLibrary::GetSymbol_(const char* name) { - return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) -} - -void DSOLibrary::Load(const String& name) { - // use wstring version that is needed by LLVM. - std::wstring wname(name.data(), name.data() + name.size()); - lib_handle_ = LoadLibraryW(wname.c_str()); - TVM_FFI_ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; -} - -void DSOLibrary::Unload() { - FreeLibrary(lib_handle_); - lib_handle_ = nullptr; -} - -#else - -void DSOLibrary::Load(const String& name) { - lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - TVM_FFI_ICHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name << " " << dlerror(); -#if defined(__hexagon__) - int p; - int rc = dlinfo(lib_handle_, RTLD_DI_LOAD_ADDR, &p); - if (rc) - FARF(ERROR, "error getting model .so start address : %u", rc); - else - FARF(ALWAYS, "Model .so Start Address : %x", p); -#endif -} - -void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } - -void DSOLibrary::Unload() { - dlclose(lib_handle_); - lib_handle_ = nullptr; -} -#endif - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.Module.load_from_file.so", [](String library_path, String) { - return CreateLibraryModule(make_object(library_path)); - }); -}); -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/library_module_system_lib.cc b/ffi/src/ffi/extra/library_module_system_lib.cc deleted file mode 100644 index cdc932cba292..000000000000 --- a/ffi/src/ffi/extra/library_module_system_lib.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file system_library.cc - * \brief Create library module that directly get symbol from the system lib. - */ -#include -#include -#include -#include -#include - -#include - -#include "module_internal.h" - -namespace tvm { -namespace ffi { - -class SystemLibSymbolRegistry { - public: - void RegisterSymbol(const std::string& name, void* ptr) { - auto it = symbol_table_.find(name); - if (it != symbol_table_.end() && ptr != (*it).second) { - std::cerr << "Warning:SystemLib symbol " << name << " get overriden to a different address " - << ptr << "->" << (*it).second << std::endl; - } - symbol_table_.Set(name, ptr); - } - - void* GetSymbol(const char* name) { - auto it = symbol_table_.find(name); - if (it != symbol_table_.end()) { - return (*it).second; - } else { - return nullptr; - } - } - - static SystemLibSymbolRegistry* Global() { - static SystemLibSymbolRegistry* inst = new SystemLibSymbolRegistry(); - return inst; - } - - private: - // Internal symbol table - Map symbol_table_; -}; - -class SystemLibrary final : public Library { - public: - explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {} - - void* GetSymbol(const char* name) { - if (symbol_prefix_.length() != 0) { - String name_with_prefix = symbol_prefix_ + name; - void* symbol = reg_->GetSymbol(name_with_prefix.c_str()); - if (symbol != nullptr) return symbol; - } - return reg_->GetSymbol(name); - } - - private: - SystemLibSymbolRegistry* reg_ = SystemLibSymbolRegistry::Global(); - String symbol_prefix_; -}; - -class SystemLibModuleRegistry { - public: - Module GetOrCreateModule(String symbol_prefix) { - std::lock_guard lock(mutex_); - auto it = lib_map_.find(symbol_prefix); - if (it != lib_map_.end()) { - return (*it).second; - } else { - Module mod = CreateLibraryModule(make_object(symbol_prefix)); - lib_map_.Set(symbol_prefix, mod); - return mod; - } - } - - static SystemLibModuleRegistry* Global() { - static SystemLibModuleRegistry* inst = new SystemLibModuleRegistry(); - return inst; - } - - private: - // Internal mutex - std::mutex mutex_; - // maps prefix to the library module - // we need to make sure each lib map have an unique - // copy through out the entire lifetime of the process - Map lib_map_; -}; - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("ffi.SystemLib", [](ffi::PackedArgs args, ffi::Any* rv) { - String symbol_prefix = ""; - if (args.size() != 0) { - symbol_prefix = args[0].cast(); - } - *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); - }); -}); -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* ptr) { - tvm::ffi::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr); - return 0; -} diff --git a/ffi/src/ffi/extra/module.cc b/ffi/src/ffi/extra/module.cc deleted file mode 100644 index d8ec77f98c97..000000000000 --- a/ffi/src/ffi/extra/module.cc +++ /dev/null @@ -1,139 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include -#include - -#include "module_internal.h" - -namespace tvm { -namespace ffi { - -Optional ModuleObj::GetFunction(const String& name, bool query_imports) { - if (auto opt_func = this->GetFunction(name)) { - return opt_func; - } - if (query_imports) { - for (const Any& import : imports_) { - if (auto opt_func = import.cast()->GetFunction(name, query_imports)) { - return *opt_func; - } - } - } - return std::nullopt; -} - -void ModuleObj::ImportModule(const Module& other) { - std::unordered_set visited{other.operator->()}; - std::vector stack{other.operator->()}; - while (!stack.empty()) { - const ModuleObj* n = stack.back(); - stack.pop_back(); - for (const Any& m : n->imports_) { - const ModuleObj* next = m.cast(); - if (visited.count(next)) continue; - visited.insert(next); - stack.push_back(next); - } - } - if (visited.count(this)) { - TVM_FFI_THROW(RuntimeError) << "Cyclic dependency detected during import"; - } - imports_.push_back(other); -} - -void ModuleObj::ClearImports() { imports_.clear(); } - -bool ModuleObj::ImplementsFunction(const String& name, bool query_imports) { - if (this->ImplementsFunction(name)) { - return true; - } - if (query_imports) { - for (const Any& import : imports_) { - if (import.cast()->ImplementsFunction(name, query_imports)) { - return true; - } - } - } - return false; -} - -Module Module::LoadFromFile(const String& file_name) { - String format = [&file_name]() -> String { - const char* data = file_name.data(); - for (size_t i = file_name.size(); i > 0; i--) { - if (data[i - 1] == '.') { - return String(data + i, file_name.size() - i); - } - } - TVM_FFI_THROW(RuntimeError) << "Failed to get file format from " << file_name; - TVM_FFI_UNREACHABLE(); - }(); - - if (format == "dll" || format == "dylib" || format == "dso") { - format = "so"; - } - String loader_name = "ffi.Module.load_from_file." + format; - const auto floader = tvm::ffi::Function::GetGlobal(loader_name); - if (!floader.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Loader for `." << format << "` files is not registered," - << " resolved to (" << loader_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - } - return (*floader)(file_name, format).cast(); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - ModuleObj::InternalUnsafe::RegisterReflection(); - - refl::GlobalDef() - .def("ffi.ModuleLoadFromFile", &Module::LoadFromFile) - .def_method("ffi.ModuleImplementsFunction", - [](Module mod, String name, bool query_imports) { - return mod->ImplementsFunction(name, query_imports); - }) - .def_method("ffi.ModuleGetFunction", - [](Module mod, String name, bool query_imports) { - return mod->GetFunction(name, query_imports); - }) - .def_method("ffi.ModuleGetPropertyMask", &ModuleObj::GetPropertyMask) - .def_method("ffi.ModuleInspectSource", &ModuleObj::InspectSource) - .def_method("ffi.ModuleGetKind", [](const Module& mod) -> String { return mod->kind(); }) - .def_method("ffi.ModuleGetWriteFormats", &ModuleObj::GetWriteFormats) - .def_method("ffi.ModuleWriteToFile", &ModuleObj::WriteToFile) - .def_method("ffi.ModuleImportModule", &ModuleObj::ImportModule) - .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports); -}); -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, - TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::ModuleObj::InternalUnsafe::GetFunctionFromImports( - reinterpret_cast(library_ctx), func_name); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/extra/module_internal.h b/ffi/src/ffi/extra/module_internal.h deleted file mode 100644 index 472d531f4b51..000000000000 --- a/ffi/src/ffi/extra/module_internal.h +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file library_module.h - * \brief Module that builds from a libary of symbols. - */ -#ifndef TVM_FFI_EXTRA_MODULE_INTERNAL_H_ -#define TVM_FFI_EXTRA_MODULE_INTERNAL_H_ - -#include -#include - -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Library is the common interface - * for storing data in the form of shared libaries. - * - * \sa src/ffi/extra/dso_library.cc - * \sa src/ffi/extra/system_library.cc - */ -class Library : public Object { - public: - // destructor. - virtual ~Library() {} - /*! - * \brief Get the symbol address for a given name. - * \param name The name of the symbol. - * \return The symbol. - */ - virtual void* GetSymbol(const char* name) = 0; - // NOTE: we do not explicitly create an type index and type_key here for libary. - // This is because we do not need dynamic type downcasting and only need to use the refcounting -}; - -struct ModuleObj::InternalUnsafe { - static Array* GetImports(ModuleObj* module) { return &(module->imports_); } - - static void* GetFunctionFromImports(ModuleObj* module, const char* name) { - // backend implementation for TVMFFIEnvModLookupFromImports - static std::mutex mutex_; - std::lock_guard lock(mutex_); - String s_name(name); - auto it = module->import_lookup_cache_.find(s_name); - if (it != module->import_lookup_cache_.end()) { - return const_cast((*it).second.operator->()); - } - - auto opt_func = [&]() -> std::optional { - for (const Any& import : module->imports_) { - if (auto opt_func = import.cast()->GetFunction(s_name, true)) { - return *opt_func; - } - } - // try global at last - return tvm::ffi::Function::GetGlobal(s_name); - }(); - if (!opt_func.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Cannot find function " << name - << " in the imported modules or global registry."; - } - module->import_lookup_cache_.Set(s_name, *opt_func); - return const_cast((*opt_func).operator->()); - } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("imports_", &ModuleObj::imports_); - } -}; - -/*! - * \brief Create a library module from a given library. - * - * \param lib The library. - * - * \return The corresponding loaded module. - */ -Module CreateLibraryModule(ObjectPtr lib); - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_MODULE_INTERNAL_H_ diff --git a/ffi/src/ffi/extra/reflection_extra.cc b/ffi/src/ffi/extra/reflection_extra.cc deleted file mode 100644 index 698be6337698..000000000000 --- a/ffi/src/ffi/extra/reflection_extra.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/reflection_extra.cc - * - * \brief Extra reflection registrations. * - */ -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { - int32_t type_index; - if (auto opt_type_index = args[0].try_cast()) { - type_index = *opt_type_index; - } else { - String type_key = args[0].cast(); - TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - } - - TVM_FFI_ICHECK(args.size() % 2 == 1); - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support reflection creation"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - - std::vector keys; - std::vector keys_found; - - for (int i = 1; i < args.size(); i += 2) { - keys.push_back(args[i].cast()); - } - keys_found.resize(keys.size(), false); - - auto search_field = [&](const TVMFFIByteArray& field_name) { - for (size_t i = 0; i < keys.size(); ++i) { - if (keys_found[i]) continue; - if (keys[i].compare(field_name) == 0) { - return i; - } - } - return keys.size(); - }; - - auto update_fields = [&](const TVMFFITypeInfo* tinfo) { - for (int i = 0; i < tinfo->num_fields; ++i) { - const TVMFFIFieldInfo* field_info = tinfo->fields + i; - size_t arg_index = search_field(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (arg_index < keys.size()) { - AnyView field_value = args[arg_index * 2 + 2]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); - keys_found[arg_index] = true; - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; - } - } - }; - - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - update_fields(type_info->type_acenstors[i]); - } - update_fields(type_info); - - for (size_t i = 0; i < keys.size(); ++i) { - if (!keys_found[i]) { - TVM_FFI_THROW(TypeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have field `" << keys[i] << "`"; - } - } - *ret = ObjectRef(ptr); -} - -inline void AccessStepRegisterReflection() { - // register access step reflection here since it is only needed for bindings - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("kind", &AccessStepObj::kind) - .def_ro("key", &AccessStepObj::key); -} - -inline void AccessPathRegisterReflection() { - // register access path reflection here since it is only needed for bindings - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("parent", &AccessPathObj::parent) - .def_ro("step", &AccessPathObj::step) - .def_ro("depth", &AccessPathObj::depth) - .def_static("_root", &AccessPath::Root) - .def("_extend", &AccessPathObj::Extend) - .def("_attr", &AccessPathObj::Attr) - .def("_array_item", &AccessPathObj::ArrayItem) - .def("_map_item", &AccessPathObj::MapItem) - .def("_attr_missing", &AccessPathObj::AttrMissing) - .def("_array_item_missing", &AccessPathObj::ArrayItemMissing) - .def("_map_item_missing", &AccessPathObj::MapItemMissing) - .def("_is_prefix_of", &AccessPathObj::IsPrefixOf) - .def("_to_steps", &AccessPathObj::ToSteps) - .def("_path_equal", - [](const AccessPath& self, const AccessPath& other) { return self->PathEqual(other); }); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - AccessStepRegisterReflection(); - AccessPathRegisterReflection(); - refl::GlobalDef().def_packed("ffi.MakeObjectFromPackedArgs", MakeObjectFromPackedArgs); -}); - -} // namespace reflection -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/serialization.cc b/ffi/src/ffi/extra/serialization.cc deleted file mode 100644 index ea9a96b696ec..000000000000 --- a/ffi/src/ffi/extra/serialization.cc +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/serialization.cc - * - * \brief Reflection-based serialization utilities. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -class ObjectGraphSerializer { - public: - static json::Value Serialize(const Any& value, Any metadata) { - ObjectGraphSerializer serializer; - json::Object result; - result.Set("root_index", serializer.GetOrCreateNodeIndex(value)); - result.Set("nodes", std::move(serializer.nodes_)); - if (metadata != nullptr) { - result.Set("metadata", metadata); - } - return result; - } - - private: - ObjectGraphSerializer() = default; - - int64_t GetOrCreateNodeIndex(const Any& value) { - // already mapped value, return the index - auto it = node_index_map_.find(value); - if (it != node_index_map_.end()) { - return (*it).second; - } - json::Object node; - switch (value.type_index()) { - case TypeIndex::kTVMFFINone: { - node.Set("type", ffi::StaticTypeKey::kTVMFFINone); - break; - } - case TypeIndex::kTVMFFIBool: { - node.Set("type", ffi::StaticTypeKey::kTVMFFIBool); - node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIInt: { - node.Set("type", ffi::StaticTypeKey::kTVMFFIInt); - node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIFloat: { - node.Set("type", ffi::StaticTypeKey::kTVMFFIFloat); - node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); - break; - } - case TypeIndex::kTVMFFIDataType: { - DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIDataType); - node.Set("data", DLDataTypeToString(dtype)); - break; - } - case TypeIndex::kTVMFFIDevice: { - DLDevice device = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIDevice); - node.Set("data", json::Array{ - static_cast(device.device_type), - static_cast(device.device_id), - }); - break; - } - case TypeIndex::kTVMFFISmallStr: - case TypeIndex::kTVMFFIStr: { - String str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIStr); - node.Set("data", str); - break; - } - case TypeIndex::kTVMFFISmallBytes: - case TypeIndex::kTVMFFIBytes: { - Bytes bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIBytes); - node.Set("data", Base64Encode(bytes)); - break; - } - case TypeIndex::kTVMFFIArray: { - Array array = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIArray); - node.Set("data", CreateArrayData(array)); - break; - } - case TypeIndex::kTVMFFIMap: { - Map map = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIMap); - node.Set("data", CreateMapData(map)); - break; - } - case TypeIndex::kTVMFFIShape: { - ffi::Shape shape = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); - node.Set("type", ffi::StaticTypeKey::kTVMFFIShape); - node.Set("data", Array(shape->data, shape->data + shape->size)); - break; - } - default: { - if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) { - // serialize type key since type index is runtime dependent - node.Set("type", value.GetTypeKey()); - node.Set("data", CreateObjectData(value)); - } else { - TVM_FFI_THROW(RuntimeError) << "Cannot serialize type `" << value.GetTypeKey() << "`"; - TVM_FFI_UNREACHABLE(); - } - } - } - int64_t node_index = nodes_.size(); - nodes_.push_back(node); - node_index_map_.Set(value, node_index); - return node_index; - } - - json::Array CreateArrayData(const Array& value) { - json::Array data; - data.reserve(value.size()); - for (const Any& item : value) { - data.push_back(GetOrCreateNodeIndex(item)); - } - return data; - } - - json::Array CreateMapData(const Map& value) { - json::Array data; - data.reserve(value.size() * 2); - for (const auto& [key, value] : value) { - data.push_back(GetOrCreateNodeIndex(key)); - data.push_back(GetOrCreateNodeIndex(value)); - } - return data; - } - - // create the data for the object, if the type has a custom data to json function, - // use it. otherwise, we go over the fields and create the data. - json::Value CreateObjectData(const Any& value) { - static reflection::TypeAttrColumn data_to_json = reflection::TypeAttrColumn("__data_to_json__"); - if (data_to_json[value.type_index()] != nullptr) { - return data_to_json[value.type_index()].cast()(value); - } - // NOTE: invariant: lhs and rhs are already the same type - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index()); - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" - << String(type_info->type_key) - << "`, so ToJSONGraph is not supported for this type"; - } - const Object* obj = value.cast(); - json::Object data; - // go over the content and hash the fields - reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - // get the field value from both side - reflection::FieldGetter getter(field_info); - Any field_value = getter(obj); - int field_static_type_index = field_info->field_static_type_index; - String field_name(field_info->name); - // for static field index that are known, we can directly set the field value. - switch (field_static_type_index) { - case TypeIndex::kTVMFFINone: { - data.Set(field_name, nullptr); - break; - } - case TypeIndex::kTVMFFIBool: { - data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); - break; - } - case TypeIndex::kTVMFFIInt: { - data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); - break; - } - case TypeIndex::kTVMFFIFloat: { - data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); - break; - } - case TypeIndex::kTVMFFIDataType: { - DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value); - data.Set(field_name, DLDataTypeToString(dtype)); - break; - } - default: { - // for dynamic field index, we need need to put them onto nodes - int64_t node_index = GetOrCreateNodeIndex(field_value); - data.Set(field_name, node_index); - break; - } - } - }); - return data; - } - - // maps the original value to the index of the node in the nodes_ array - Map node_index_map_; - // records nodes that are serialized - json::Array nodes_; -}; - -json::Value ToJSONGraph(const Any& value, const Any& metadata) { - return ObjectGraphSerializer::Serialize(value, metadata); -} - -class ObjectGraphDeserializer { - public: - static Any Deserialize(const json::Value& value) { - ObjectGraphDeserializer deserializer(value); - return deserializer.GetOrDecodeNode(deserializer.root_index_); - } - - Any GetOrDecodeNode(int64_t node_index) { - // already decoded null index - if (node_index == decoded_null_index_) { - return Any(nullptr); - } - // already decoded - if (decoded_nodes_[node_index] != nullptr) { - return decoded_nodes_[node_index]; - } - // now decode the node - Any value = DecodeNode(nodes_[node_index].cast()); - decoded_nodes_[node_index] = value; - if (value == nullptr) { - decoded_null_index_ = node_index; - } - return value; - } - - private: - Any DecodeNode(const json::Object& node) { - String type_key = node["type"].cast(); - TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()}; - int32_t type_index; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); - - switch (type_index) { - case TypeIndex::kTVMFFINone: { - return nullptr; - } - case TypeIndex::kTVMFFIBool: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIInt: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIFloat: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIDataType: { - return StringToDLDataType(node["data"].cast()); - } - case TypeIndex::kTVMFFIDevice: { - Array data = node["data"].cast>(); - return DLDevice{static_cast(data[0]), data[1]}; - } - case TypeIndex::kTVMFFIStr: { - return node["data"].cast(); - } - case TypeIndex::kTVMFFIBytes: { - return Base64Decode(node["data"].cast()); - } - case TypeIndex::kTVMFFIMap: { - return DecodeMapData(node["data"].cast()); - } - case TypeIndex::kTVMFFIArray: { - return DecodeArrayData(node["data"].cast()); - } - case TypeIndex::kTVMFFIShape: { - Array data = node["data"].cast>(); - return ffi::Shape(data); - } - default: { - return DecodeObjectData(type_index, node["data"]); - } - } - } - - Array DecodeArrayData(const json::Array& data) { - Array array; - array.reserve(data.size()); - for (size_t i = 0; i < data.size(); i++) { - array.push_back(GetOrDecodeNode(data[i].cast())); - } - return array; - } - - Map DecodeMapData(const json::Array& data) { - Map map; - for (size_t i = 0; i < data.size(); i += 2) { - int64_t key_index = data[i].cast(); - int64_t value_index = data[i + 1].cast(); - map.Set(GetOrDecodeNode(key_index), GetOrDecodeNode(value_index)); - } - return map; - } - - Any DecodeObjectData(int32_t type_index, const json::Value& data) { - static reflection::TypeAttrColumn data_from_json = - reflection::TypeAttrColumn("__data_from_json__"); - if (data_from_json[type_index] != nullptr) { - return data_from_json[type_index].cast()(data); - } - // otherwise, we go over the fields and create the data. - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor" - << ", so ToJSONGraph is not supported for this type"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - - auto decode_field_value = [&](const TVMFFIFieldInfo* field_info, json::Value data) -> Any { - switch (field_info->field_static_type_index) { - case TypeIndex::kTVMFFINone: { - return nullptr; - } - case TypeIndex::kTVMFFIBool: { - return data.cast(); - } - case TypeIndex::kTVMFFIInt: { - return data.cast(); - } - case TypeIndex::kTVMFFIFloat: { - return data.cast(); - } - case TypeIndex::kTVMFFIDataType: { - return StringToDLDataType(data.cast()); - } - default: { - return GetOrDecodeNode(data.cast()); - } - } - }; - - json::Object data_object = data.cast(); - reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - String field_name(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (data_object.count(field_name) != 0) { - Any field_value = decode_field_value(field_info, data_object[field_name]); - field_info->setter(field_addr, reinterpret_cast(&field_value)); - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; - } - }); - return ObjectRef(ptr); - } - - explicit ObjectGraphDeserializer(json::Value serialized) { - if (!serialized.as()) { - TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected an object"; - } - json::Object encoded_object = serialized.cast(); - if (encoded_object.count("root_index") == 0 || !encoded_object["root_index"].as()) { - TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `root_index` integer field"; - } - if (encoded_object.count("nodes") == 0 || !encoded_object["nodes"].as()) { - TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `nodes` array field"; - } - root_index_ = encoded_object["root_index"].cast(); - nodes_ = encoded_object["nodes"].cast(); - decoded_nodes_.resize(nodes_.size(), Any(nullptr)); - } - // nodes - json::Array nodes_; - // root index - int64_t root_index_; - // null index if already created - int64_t decoded_null_index_{-1}; - // decoded nodes - std::vector decoded_nodes_; -}; - -Any FromJSONGraph(const json::Value& value) { return ObjectGraphDeserializer::Deserialize(value); } - -// string version of the api -Any FromJSONGraphString(const String& value) { return FromJSONGraph(json::Parse(value)); } - -String ToJSONGraphString(const Any& value, const Any& metadata) { - return json::Stringify(ToJSONGraph(value, metadata)); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("ffi.ToJSONGraph", ToJSONGraph) - .def("ffi.ToJSONGraphString", ToJSONGraphString) - .def("ffi.FromJSONGraph", FromJSONGraph) - .def("ffi.FromJSONGraphString", FromJSONGraphString); - refl::EnsureTypeAttrColumn("__data_to_json__"); - refl::EnsureTypeAttrColumn("__data_from_json__"); -}); - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/stream_context.cc b/ffi/src/ffi/extra/stream_context.cc deleted file mode 100644 index d063efdef579..000000000000 --- a/ffi/src/ffi/extra/stream_context.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/stream_context.cc - * - * \brief A minimalistic stream context based on ffi values. - */ - -#include -#include - -#include - -namespace tvm { -namespace ffi { - -class StreamContext { - public: - void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - if (static_cast(device_type) >= stream_table_.size()) { - stream_table_.resize(device_type + 1); - } - if (static_cast(device_id) >= stream_table_[device_type].size()) { - stream_table_[device_type].resize(device_id + 1, nullptr); - } - if (out_original_stream != nullptr) { - *out_original_stream = stream_table_[device_type][device_id]; - } - stream_table_[device_type][device_id] = stream; - } - - TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) { - if (static_cast(device_type) < stream_table_.size() && - static_cast(device_id) < stream_table_[device_type].size()) { - return stream_table_[device_type][device_id]; - } - return nullptr; - } - - static StreamContext* ThreadLocal() { - static thread_local StreamContext inst; - return &inst; - } - - private: - std::vector> stream_table_; -}; - -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream, - out_original_stream); - TVM_FFI_SAFE_CALL_END(); -} - -TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::StreamContext::ThreadLocal()->GetStream(device_type, device_id); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetCurrentStream); -} diff --git a/ffi/src/ffi/extra/structural_equal.cc b/ffi/src/ffi/extra/structural_equal.cc deleted file mode 100644 index 171fa2f750a0..000000000000 --- a/ffi/src/ffi/extra/structural_equal.cc +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/reflection/structural_equal.cc - * - * \brief Structural equal implementation. - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -/** - * \brief Internal Handler class for structural equal comparison. - */ -class StructEqualHandler { - public: - StructEqualHandler() = default; - - bool CompareAny(ffi::Any lhs, ffi::Any rhs) { - using ffi::details::AnyUnsafe; - const TVMFFIAny* lhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(lhs); - const TVMFFIAny* rhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(rhs); - if (lhs_data->type_index != rhs_data->type_index) { - // type_index mismatch, if index is not string, return false - if (lhs_data->type_index != kTVMFFIStr && lhs_data->type_index != kTVMFFISmallStr && - lhs_data->type_index != kTVMFFISmallBytes && lhs_data->type_index != kTVMFFIBytes) { - return false; - } - // small string and normal string comparison - if (lhs_data->type_index == kTVMFFIStr && rhs_data->type_index == kTVMFFISmallStr) { - const details::BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_str->data, rhs_data->v_bytes, lhs_str->size, - rhs_data->small_str_len); - } - if (lhs_data->type_index == kTVMFFISmallStr && rhs_data->type_index == kTVMFFIStr) { - const details::BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_data->v_bytes, rhs_str->data, lhs_data->small_str_len, - rhs_str->size); - } - if (lhs_data->type_index == kTVMFFIBytes && rhs_data->type_index == kTVMFFISmallBytes) { - const details::BytesObjBase* lhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_bytes->data, rhs_data->v_bytes, lhs_bytes->size, - rhs_data->small_str_len); - } - if (lhs_data->type_index == kTVMFFISmallBytes && rhs_data->type_index == kTVMFFIBytes) { - const details::BytesObjBase* rhs_bytes = - details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_data->v_bytes, rhs_bytes->data, lhs_data->small_str_len, - rhs_bytes->size); - } - return false; - } - - if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - // specially handle nan for float, as there can be multiple representations of nan - if (lhs_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(lhs_data->v_float64)) { - return std::isnan(rhs_data->v_float64); - } - // this is POD data, we can just compare the value - return lhs_data->zero_padding == rhs_data->zero_padding && - lhs_data->v_int64 == rhs_data->v_int64; - } - switch (lhs_data->type_index) { - case TypeIndex::kTVMFFIStr: - case TypeIndex::kTVMFFIBytes: { - // compare bytes - const details::BytesObjBase* lhs_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - const details::BytesObjBase* rhs_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); - } - case TypeIndex::kTVMFFIArray: { - return CompareArray(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck>(std::move(rhs))); - } - case TypeIndex::kTVMFFIMap: { - return CompareMap(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck>(std::move(rhs))); - } - case TypeIndex::kTVMFFIShape: { - return CompareShape(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); - } - case TypeIndex::kTVMFFINDArray: { - return CompareNDArray(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); - } - default: { - return CompareObject(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), - AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); - } - } - } - - bool CompareObject(ObjectRef lhs, ObjectRef rhs) { - // NOTE: invariant: lhs and rhs are already the same type - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index()); - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { - TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - - auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind; - if (structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) { - // use pointer comparison - return lhs.same_as(rhs); - } - if (structural_eq_hash_kind == kTVMFFISEqHashKindConstTreeNode) { - // fast path: constant tree node, pointer equality indicate equality and avoid content - // comparison if false, we should still run content comparison - if (lhs.same_as(rhs)) return true; - } - // check recorded mapping for DAG and fre var - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode || - structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - // if there is pre-recorded mapping, need to cross check the pointer equality after mapping - auto it = equal_map_lhs_.find(lhs); - if (it != equal_map_lhs_.end()) { - return it->second.same_as(rhs); - } - // if rhs is mapped but lhs is not, it means lhs is a free var, return false - if (equal_map_rhs_.count(rhs)) { - return false; - } - } - - static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__"); - - bool success = true; - if (custom_s_equal[type_info->type_index] == nullptr) { - // We recursively compare the fields the object - reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { - // skip fields that are marked as structural eq hash ignore - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) return false; - // get the field value from both side - reflection::FieldGetter getter(field_info); - Any lhs_value = getter(lhs); - Any rhs_value = getter(rhs); - // field is in def region, enable free var mapping - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - success = CompareAny(lhs_value, rhs_value); - std::swap(allow_free_var, map_free_vars_); - } else { - success = CompareAny(lhs_value, rhs_value); - } - if (!success) { - // record the first mismatching field if we sub-rountine compare failed - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(String(field_info->name))); - mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(String(field_info->name))); - } - // return true to indicate early stop - return true; - } else { - // return false to continue checking other fields - return false; - } - }); - } else { - // run custom equal function defined via __s_equal__ type attribute - if (s_equal_callback_ == nullptr) { - s_equal_callback_ = ffi::Function::FromTyped( - [this](AnyView lhs, AnyView rhs, bool def_region, AnyView field_name) { - // NOTE: we explicitly make field_name as AnyView to avoid copy overhead initially - // and only cast to string if mismatch happens - bool success = true; - if (def_region) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - success = CompareAny(lhs, rhs); - std::swap(allow_free_var, map_free_vars_); - } else { - success = CompareAny(lhs, rhs); - } - if (!success) { - if (mismatch_lhs_reverse_path_ != nullptr) { - String field_name_str = field_name.cast(); - mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(field_name_str)); - mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::Attr(field_name_str)); - } - } - return success; - }); - } - success = custom_s_equal[type_info->type_index] - .cast()(lhs, rhs, s_equal_callback_) - .cast(); - } - - if (success) { - if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - // we are in a free var case that is not yet mapped. - // in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be - // set - if (lhs.same_as(rhs) || map_free_vars_) { - // record the equality - equal_map_lhs_[lhs] = rhs; - equal_map_rhs_[rhs] = lhs; - return true; - } else { - return false; - } - } - // if we have a success mapping and in graph/var mode, record the equality mapping - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { - // record the equality - equal_map_lhs_[lhs] = rhs; - equal_map_rhs_[rhs] = lhs; - } - return true; - } else { - return false; - } - } - - bool CompareMap(Map lhs, Map rhs) { - if (lhs.size() != rhs.size()) { - // size mismatch, and there is no path tracing - // return false since we don't need informative error message - if (mismatch_lhs_reverse_path_ == nullptr) return false; - } - // compare key and value pair by pair - for (auto kv : lhs) { - Any rhs_key = this->MapLhsToRhs(kv.first); - auto it = rhs.find(rhs_key); - if (it == rhs.end()) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(rhs_key)); - } - return false; - } - // now recursively compare value - if (!CompareAny(kv.second, (*it).second)) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(rhs_key)); - } - return false; - } - } - // fast path, all contents equals to each other - if (lhs.size() == rhs.size()) return true; - // slow path, cross check every key from rhs in lhs to find the missing - // key for better error reporting - for (auto kv : rhs) { - Any lhs_key = this->MapRhsToLhs(kv.first); - auto it = lhs.find(lhs_key); - if (it == lhs.end()) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(lhs_key)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first)); - } - return false; - } - } - return false; - } - - bool CompareArray(ffi::Array lhs, ffi::Array rhs) { - if (lhs.size() != rhs.size()) { - // fast path, size mismatch, and there is no path tracing - // return false since we don't need informative error message - if (mismatch_lhs_reverse_path_ == nullptr) return false; - } - for (size_t i = 0; i < std::min(lhs.size(), rhs.size()); ++i) { - if (!CompareAny(lhs[i], rhs[i])) { - if (mismatch_lhs_reverse_path_ != nullptr) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i)); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i)); - } - return false; - } - } - if (lhs.size() == rhs.size()) return true; - if (mismatch_lhs_reverse_path_ != nullptr) { - if (lhs.size() > rhs.size()) { - mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(rhs.size())); - mismatch_rhs_reverse_path_->emplace_back( - reflection::AccessStep::ArrayItemMissing(rhs.size())); - } else { - mismatch_lhs_reverse_path_->emplace_back( - reflection::AccessStep::ArrayItemMissing(lhs.size())); - mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(lhs.size())); - } - } - return false; - } - - bool CompareShape(Shape lhs, Shape rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - for (size_t i = 0; i < lhs.size(); ++i) { - if (lhs[i] != rhs[i]) { - return false; - } - } - return true; - } - - bool CompareNDArray(NDArray lhs, NDArray rhs) { - if (lhs.same_as(rhs)) return true; - if (lhs->ndim != rhs->ndim) return false; - for (int i = 0; i < lhs->ndim; ++i) { - if (lhs->shape[i] != rhs->shape[i]) return false; - } - if (lhs->dtype != rhs->dtype) return false; - if (!skip_ndarray_content_) { - TVM_FFI_ICHECK_EQ(lhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; - TVM_FFI_ICHECK_EQ(rhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; - TVM_FFI_ICHECK(lhs.IsContiguous()) << "Can only compare contiguous tensor"; - TVM_FFI_ICHECK(rhs.IsContiguous()) << "Can only compare contiguous tensor"; - size_t data_size = GetDataSize(*(lhs.operator->())); - return std::memcmp(lhs->data, rhs->data, data_size) == 0; - } else { - return true; - } - } - - Any MapLhsToRhs(Any lhs) const { - if (lhs.type_index() < TypeIndex::kTVMFFIStaticObjectBegin) { - return lhs; - } - ObjectRef lhs_obj = ffi::details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)); - auto it = equal_map_lhs_.find(lhs_obj); - if (it != equal_map_lhs_.end()) { - return it->second; - } - return lhs_obj; - } - - Any MapRhsToLhs(Any rhs) const { - if (rhs.type_index() < TypeIndex::kTVMFFIStaticObjectBegin) { - return rhs; - } - ObjectRef rhs_obj = ffi::details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs)); - auto it = equal_map_rhs_.find(rhs_obj); - if (it != equal_map_rhs_.end()) { - return it->second; - } - return rhs_obj; - } - // whether we map free variables that are not defined - bool map_free_vars_{false}; - // whether we compare ndarray data - bool skip_ndarray_content_{false}; - // the root lhs for result printing - std::vector* mismatch_lhs_reverse_path_ = nullptr; - std::vector* mismatch_rhs_reverse_path_ = nullptr; - // lazily initialize custom equal function - ffi::Function s_equal_callback_ = nullptr; - // map from lhs to rhs - std::unordered_map equal_map_lhs_; - // map from rhs to lhs - std::unordered_map equal_map_rhs_; -}; - -bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars, - bool skip_ndarray_content) { - StructEqualHandler handler; - handler.map_free_vars_ = map_free_vars; - handler.skip_ndarray_content_ = skip_ndarray_content; - return handler.CompareAny(lhs, rhs); -} - -Optional StructuralEqual::GetFirstMismatch(const Any& lhs, - const Any& rhs, - bool map_free_vars, - bool skip_ndarray_content) { - StructEqualHandler handler; - handler.map_free_vars_ = map_free_vars; - handler.skip_ndarray_content_ = skip_ndarray_content; - std::vector lhs_reverse_path; - std::vector rhs_reverse_path; - handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path; - handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path; - if (handler.CompareAny(lhs, rhs)) { - return std::nullopt; - } - using reflection::AccessPath; - reflection::AccessPath lhs_path = - AccessPath::FromSteps(lhs_reverse_path.rbegin(), lhs_reverse_path.rend()); - reflection::AccessPath rhs_path = - AccessPath::FromSteps(rhs_reverse_path.rbegin(), rhs_reverse_path.rend()); - return reflection::AccessPathPair(lhs_path, rhs_path); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.GetFirstStructuralMismatch", StructuralEqual::GetFirstMismatch); - // ensure the type attribute column is presented in the system even if it is empty. - refl::EnsureTypeAttrColumn("__s_equal__"); -}); - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/structural_hash.cc b/ffi/src/ffi/extra/structural_hash.cc deleted file mode 100644 index 9f245c1d174d..000000000000 --- a/ffi/src/ffi/extra/structural_hash.cc +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/reflection/structural_equal.cc - * - * \brief Structural equal implementation. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -/** - * \brief Internal Handler class for structural hash. - */ -class StructuralHashHandler { - public: - StructuralHashHandler() = default; - - uint64_t HashAny(ffi::Any src) { - using ffi::details::AnyUnsafe; - const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); - - if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - // specially handle nan for float, as there can be multiple representations of nan - // make sure they map to the same hash value - if (src_data->type_index == TypeIndex::kTVMFFIFloat && std::isnan(src_data->v_float64)) { - TVMFFIAny temp = *src_data; - temp.v_float64 = std::numeric_limits::quiet_NaN(); - return details::StableHashCombine(temp.type_index, temp.v_uint64); - } - if (src_data->type_index == TypeIndex::kTVMFFISmallStr) { - // for small string, we use the same type key hash as normal string - // so heap allocated string and on stack string will have the same hash - return details::StableHashCombine(TypeIndex::kTVMFFIStr, - details::StableHashSmallStrBytes(src_data)); - } - // this is POD data, we can just hash the value - return details::StableHashCombine(src_data->type_index, src_data->v_uint64); - } - - switch (src_data->type_index) { - case TypeIndex::kTVMFFIStr: - case TypeIndex::kTVMFFIBytes: { - // return same hash as AnyHash - const details::BytesObjBase* src_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(src); - return details::StableHashCombine(src_data->type_index, - details::StableHashBytes(src_str->data, src_str->size)); - } - case TypeIndex::kTVMFFIArray: { - return HashArray(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(src))); - } - case TypeIndex::kTVMFFIMap: { - return HashMap(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(src))); - } - case TypeIndex::kTVMFFIShape: { - return HashShape(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); - } - case TypeIndex::kTVMFFINDArray: { - return HashNDArray(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); - } - default: { - return HashObject(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); - } - } - } - - uint64_t HashObject(ObjectRef obj) { - // NOTE: invariant: lhs and rhs are already the same type - const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - if (type_info->metadata->structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { - TVM_FFI_THROW(TypeError) << "_type_s_eq_hash_kind is not set for type `" - << String(type_info->type_key) - << "`, so StructuralHash is not supported for this type"; - } - - auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind; - if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { - // Fallback to pointer hash - return std::hash()(obj.get()); - } - // return recored hash value if it is already computed - auto it = hash_memo_.find(obj); - if (it != hash_memo_.end()) { - return it->second; - } - - static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__"); - - // compute the hash value - uint64_t hash_value = obj->GetTypeKeyHash(); - if (custom_s_hash[type_info->type_index] == nullptr) { - // go over the content and hash the fields - reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { - // skip fields that are marked as structural eq hash ignore - if (!(field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore)) { - // get the field value from both side - reflection::FieldGetter getter(field_info); - Any field_value = getter(obj); - // field is in def region, enable free var mapping - if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); - std::swap(allow_free_var, map_free_vars_); - } else { - hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); - } - } - }); - } else { - if (s_hash_callback_ == nullptr) { - s_hash_callback_ = - ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, bool def_region) { - if (def_region) { - bool allow_free_var = true; - std::swap(allow_free_var, map_free_vars_); - uint64_t hash_value = HashAny(val); - std::swap(allow_free_var, map_free_vars_); - return details::StableHashCombine(init_hash, hash_value); - } else { - return details::StableHashCombine(init_hash, HashAny(val)); - } - }); - } - hash_value = custom_s_hash[type_info->type_index] - .cast()(obj, hash_value, s_hash_callback_) - .cast(); - } - - if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - if (map_free_vars_) { - // use lexical order of free var and its type - hash_value = details::StableHashCombine(hash_value, free_var_counter_++); - } else { - // Fallback to pointer hash, we are not mapping free var. - hash_value = std::hash()(obj.get()); - } - } - // if it is a DAG node, also record the lexical order of graph counter - // this helps to distinguish DAG from trees. - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { - hash_value = details::StableHashCombine(hash_value, graph_node_counter_++); - } - // record the hash value for this object - hash_memo_[obj] = hash_value; - return hash_value; - } - - uint64_t HashArray(Array arr) { - uint64_t hash_value = details::StableHashCombine(arr->GetTypeKeyHash(), arr.size()); - for (size_t i = 0; i < arr.size(); ++i) { - hash_value = details::StableHashCombine(hash_value, HashAny(arr[i])); - } - return hash_value; - } - - // Find an order independent hash value for a given Any. - // Order independent hash value means the hash value will remain stable independent - // of the order we hash the content at the current context. - // This property is needed to support stable hash for map. - std::optional FindOrderIndependentHash(Any src) { - using ffi::details::AnyUnsafe; - const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); - - if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - if (src_data->type_index == TypeIndex::kTVMFFISmallStr) { - // for small string, we use the same type key hash as normal string - // so heap allocated string and on stack string will have the same hash - return details::StableHashCombine( - TypeIndex::kTVMFFIStr, - details::StableHashBytes(src_data->v_bytes, src_data->small_str_len)); - } - // this is POD data, we can just hash the value - return details::StableHashCombine(src_data->type_index, src_data->v_uint64); - } else { - if (src_data->type_index == TypeIndex::kTVMFFIStr || - src_data->type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase* src_str = - AnyUnsafe::CopyFromAnyViewAfterCheck(src); - // return same hash as AnyHash - return details::StableHashCombine(src_data->type_index, - details::StableHashBytes(src_str->data, src_str->size)); - } else { - // if the hash of the object is already computed, return it - auto it = hash_memo_.find(src.cast()); - if (it != hash_memo_.end()) { - return it->second; - } - return std::nullopt; - } - } - } - - uint64_t HashMap(Map map) { - // Compute a deterministic hash value for the map. - uint64_t hash_value = details::StableHashCombine(map->GetTypeKeyHash(), map.size()); - std::vector> items; - for (auto [key, value] : map) { - // if we cannot find order independent hash, we skip the key - if (auto hash_key = FindOrderIndependentHash(key)) { - items.emplace_back(*hash_key, value); - } - } - // sort the items by the hash key, so the hash value is deterministic - // and independent of the order of insertion - std::sort(items.begin(), items.end(), - [](const auto& a, const auto& b) { return a.first < b.first; }); - - for (size_t i = 0; i < items.size();) { - size_t k = i + 1; - for (; k < items.size() && items[k].first == items[i].first; ++k) { - } - // detect ties, which are rare, but we need to skip value hash during ties - // to make sure that the hash value is deterministic. - if (k == i + 1) { - // no ties, we just hash the key and value - hash_value = details::StableHashCombine(hash_value, items[i].first); - hash_value = details::StableHashCombine(hash_value, HashAny(items[i].second)); - } else { - // ties occur, we skip the value hash to make sure that the hash value is deterministic. - hash_value = details::StableHashCombine(hash_value, items[i].first); - } - i = k; - } - return hash_value; - } - - uint64_t HashShape(Shape shape) { - uint64_t hash_value = details::StableHashCombine(shape->GetTypeKeyHash(), shape.size()); - for (size_t i = 0; i < shape.size(); ++i) { - hash_value = details::StableHashCombine(hash_value, shape[i]); - } - return hash_value; - } - - uint64_t HashNDArray(NDArray ndarray) { - uint64_t hash_value = details::StableHashCombine(ndarray->GetTypeKeyHash(), ndarray->ndim); - for (int i = 0; i < ndarray->ndim; ++i) { - hash_value = details::StableHashCombine(hash_value, ndarray->shape[i]); - } - TVMFFIAny temp; - temp.v_uint64 = 0; - temp.v_dtype = ndarray->dtype; - hash_value = details::StableHashCombine(hash_value, temp.v_int64); - - if (!skip_ndarray_content_) { - TVM_FFI_ICHECK_EQ(ndarray->device.device_type, kDLCPU) << "can only hash CPU tensor"; - TVM_FFI_ICHECK(ndarray.IsContiguous()) << "Can only hash contiguous tensor"; - size_t data_size = GetDataSize(*(ndarray.operator->())); - uint64_t data_hash = - details::StableHashBytes(static_cast(ndarray->data), data_size); - hash_value = details::StableHashCombine(hash_value, data_hash); - } - return hash_value; - } - - bool map_free_vars_{false}; - bool skip_ndarray_content_{false}; - // free var counter. - uint32_t free_var_counter_{0}; - // graph node counter. - uint32_t graph_node_counter_{0}; - // lazily initialize custom hash function - ffi::Function s_hash_callback_ = nullptr; - // map from lhs to rhs - std::unordered_map hash_memo_; -}; - -uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool skip_ndarray_content) { - StructuralHashHandler handler; - handler.map_free_vars_ = map_free_vars; - handler.skip_ndarray_content_ = skip_ndarray_content; - return handler.HashAny(value); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.StructuralHash", StructuralHash::Hash); - refl::EnsureTypeAttrColumn("__s_hash__"); -}); - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/extra/testing.cc b/ffi/src/ffi/extra/testing.cc deleted file mode 100644 index 3d27d5ccb6a4..000000000000 --- a/ffi/src/ffi/extra/testing.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -// This file is used for testing the FFI API. -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { - -class TestObjectBase : public Object { - public: - int64_t v_i64; - double v_f64; - String v_str; - - int64_t AddI64(int64_t other) const { return v_i64 + other; } - - // declare as one slot, with float as overflow - static constexpr bool _type_mutable = true; - static constexpr uint32_t _type_child_slots = 1; - static constexpr const char* _type_key = "testing.TestObjectBase"; - TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjectBase, Object); -}; - -class TestObjectDerived : public TestObjectBase { - public: - Map v_map; - Array v_array; - - // declare as one slot, with float as overflow - static constexpr const char* _type_key = "testing.TestObjectDerived"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjectDerived, TestObjectBase); -}; - -void TestRaiseError(String kind, String msg) { - throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE); -} - -void TestApply(Function f, PackedArgs args, Any* ret) { f.CallPacked(args, ret); } - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - - refl::ObjectDef() - .def_rw("v_i64", &TestObjectBase::v_i64, refl::DefaultValue(10), "i64 field") - .def_ro("v_f64", &TestObjectBase::v_f64, refl::DefaultValue(10.0)) - .def_rw("v_str", &TestObjectBase::v_str, refl::DefaultValue("hello")) - .def("add_i64", &TestObjectBase::AddI64, "add_i64 method"); - - refl::ObjectDef() - .def_ro("v_map", &TestObjectDerived::v_map) - .def_ro("v_array", &TestObjectDerived::v_array); - - refl::GlobalDef() - .def("testing.test_raise_error", TestRaiseError) - .def_packed("testing.nop", [](PackedArgs args, Any* ret) { *ret = args[0]; }) - .def_packed("testing.echo", [](PackedArgs args, Any* ret) { *ret = args[0]; }) - .def_packed("testing.apply", - [](PackedArgs args, Any* ret) { - auto f = args[0].cast(); - TestApply(f, args.Slice(1), ret); - }) - .def("testing.run_check_signal", - [](int nsec) { - for (int i = 0; i < nsec; ++i) { - if (TVMFFIEnvCheckSignals() != 0) { - throw ffi::EnvErrorAlreadySet(); - } - std::this_thread::sleep_for(std::chrono::seconds(1)); - } - std::cout << "Function finished without catching signal" << std::endl; - }) - .def("testing.object_use_count", [](const Object* obj) { return obj->use_count(); }); -}); - -} // namespace ffi -} // namespace tvm diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc deleted file mode 100644 index 8db03bf28eb0..000000000000 --- a/ffi/src/ffi/function.cc +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/function.cc - * \brief Function call registry and safecall context - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Global function table. - * - - * \note We do not use mutex to guard updating of GlobalFunctionTable - * - * The assumption is that updating of GlobalFunctionTable will be done - * in the main thread during initialization or loading, or - * explicitly locked from the caller. - * - * Then the followup code will leverage the information - */ -class GlobalFunctionTable { - public: - // Note: this class is hidden from the public API, so we just - // use it as a private class as ObjectRef - class Entry : public Object, public TVMFFIMethodInfo { - public: - String name_data; - String doc_data; - String type_schema_data; - ffi::Function func_data; - - explicit Entry(const TVMFFIMethodInfo* method_info) { - // make copy of the metadata - name_data = String(method_info->name.data, method_info->name.size); - doc_data = String(method_info->doc.data, method_info->doc.size); - type_schema_data = String(method_info->type_schema.data, method_info->type_schema.size); - func_data = AnyView::CopyFromTVMFFIAny(method_info->method).cast(); - this->SyncMethodInfo(method_info->flags); - // no need to update method pointer as it would remain the same as func and we retained - } - explicit Entry(String name, ffi::Function func) : name_data(name), func_data(func) { - this->SyncMethodInfo(kTVMFFIFieldFlagBitMaskIsStaticMethod); - } - - private: - void SyncMethodInfo(int64_t flags) { - this->flags = flags; - this->name = TVMFFIByteArray{name_data.data(), name_data.size()}; - this->doc = TVMFFIByteArray{doc_data.data(), doc_data.size()}; - this->type_schema = TVMFFIByteArray{type_schema_data.data(), type_schema_data.size()}; - } - }; - - void Update(const String& name, Function func, bool can_override) { - if (table_.count(name)) { - if (!can_override) { - TVM_FFI_THROW(RuntimeError) << "Global Function `" << name << "` is already registered"; - } - } - table_.Set(name, ObjectRef(make_object(name, func))); - } - - void Update(const TVMFFIMethodInfo* method_info, bool can_override) { - String name(method_info->name.data, method_info->name.size); - if (table_.count(name)) { - if (!can_override) { - TVM_FFI_LOG_AND_THROW(RuntimeError) - << "Global Function `" << name << "` is already registered, possible causes:\n" - << "- Two GlobalDef().def registrations for the same function \n" - << "Please remove the duplicate registration."; - } - } - table_.Set(name, ObjectRef(make_object(method_info))); - } - - bool Remove(const String& name) { - auto it = table_.find(name); - if (it == table_.end()) return false; - table_.erase(name); - return true; - } - - const Entry* Get(const String& name) { - auto it = table_.find(name); - if (it == table_.end()) return nullptr; - const Object* obj = (*it).second.cast(); - return static_cast(obj); - } - - Array ListNames() const { - Array names; - names.reserve(table_.size()); - for (const auto& kv : table_) { - names.push_back(kv.first); - } - return names; - } - - static GlobalFunctionTable* Global() { - // We deliberately create a new instance via raw new - // This is because GlobalFunctionTable can contain callbacks into - // the host language (Python) and the resource can become invalid - // indeterministic order of destruction and forking. - // The resources will only be recycled during program exit. - static GlobalFunctionTable* inst = new GlobalFunctionTable(); - return inst; - } - - private: - Map table_; -}; -} // namespace ffi -} // namespace tvm - -int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self), - TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::Function func = tvm::ffi::Function::FromExternC(self, safe_call, deleter); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(func)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::Any result(*reinterpret_cast(any_view)); - *out = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(result)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, int override) { - using namespace tvm::ffi; - TVM_FFI_SAFE_CALL_BEGIN(); - String name_str(name->data, name->size); - GlobalFunctionTable::Global()->Update(name_str, GetRef(static_cast(f)), - override != 0); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo* method_info, int override) { - using namespace tvm::ffi; - TVM_FFI_SAFE_CALL_BEGIN(); - GlobalFunctionTable::Global()->Update(method_info, override != 0); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out) { - using namespace tvm::ffi; - TVM_FFI_SAFE_CALL_BEGIN(); - String name_str(name->data, name->size); - const GlobalFunctionTable::Entry* fp = GlobalFunctionTable::Global()->Get(name_str); - if (fp != nullptr) { - tvm::ffi::Function func(fp->func_data); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(func)); - } else { - *out = nullptr; - } - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result) { - using namespace tvm::ffi; - // NOTE: this is a tail call - return reinterpret_cast(func)->safe_call(func, args, num_args, result); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("ffi.FunctionRemoveGlobal", - [](const tvm::ffi::String& name) -> bool { - return tvm::ffi::GlobalFunctionTable::Global()->Remove(name); - }) - .def("ffi.FunctionListGlobalNamesFunctor", - []() { - // NOTE: we return functor instead of array - // so list global function names do not need to depend on array - // this is because list global function names usually is a core api that happens - // before array ffi functions are available. - tvm::ffi::Array names = - tvm::ffi::GlobalFunctionTable::Global()->ListNames(); - auto return_functor = [names](int64_t i) -> tvm::ffi::Any { - if (i < 0) { - return names.size(); - } else { - return names[i]; - } - }; - return tvm::ffi::Function::FromTyped(return_functor); - }) - .def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return val; }) - .def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return val; }); -}); diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc deleted file mode 100644 index 41d4273b597c..000000000000 --- a/ffi/src/ffi/ndarray.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/ndarray.cc - * \brief NDArray C API implementation - */ -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("ffi.Shape", [](ffi::PackedArgs args, Any* ret) { - int64_t* mutable_data; - ObjectPtr shape = details::MakeEmptyShape(args.size(), &mutable_data); - for (int i = 0; i < args.size(); ++i) { - if (auto opt_int = args[i].try_cast()) { - mutable_data[i] = *opt_int; - } else { - TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; - } - } - *ret = Shape(shape); - }); -}); - -} // namespace ffi -} // namespace tvm - -int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t min_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::NDArray nd = - tvm::ffi::NDArray::FromDLPack(from, static_cast(min_alignment), require_contiguous); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, int32_t min_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::NDArray nd = tvm::ffi::NDArray::FromDLPackVersioned( - from, static_cast(min_alignment), require_contiguous); - *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(nd)); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned( - static_cast(from)) - ->ToDLPack(); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, DLManagedTensorVersioned** out) { - TVM_FFI_SAFE_CALL_BEGIN(); - *out = tvm::ffi::details::ObjectUnsafe::RawObjectPtrFromUnowned( - static_cast(from)) - ->ToDLPackVersioned(); - TVM_FFI_SAFE_CALL_END(); -} diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc deleted file mode 100644 index 61107cb63ff7..000000000000 --- a/ffi/src/ffi/object.cc +++ /dev/null @@ -1,449 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/object.cc - * \brief Registry to record dynamic types - */ -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Global registry that manages - * - * \note We do not use mutex to guard updating of TypeTable - * - * The assumption is that updating of TypeTable will be done - * in the main thread during initialization or loading, or - * explicitly locked from the caller. - * - * Then the followup code will leverage the information - */ -class TypeTable { - public: - /*! \brief Type information */ - struct Entry : public TypeInfo { - /*! \brief stored type key */ - String type_key_data; - /*! \brief acenstor information */ - std::vector type_acenstors_data; - /*! \brief type fields informaton */ - std::vector type_fields_data; - /*! \brief type methods informaton */ - std::vector type_methods_data; - /*! \brief extra information */ - TVMFFITypeMetadata metadata_data; - // NOTE: the indices in [index, index + num_reserved_slots) are - // reserved for the child-class of this type. - /*! \brief Total number of slots reserved for the type and its children. */ - int32_t num_slots; - /*! \brief number of allocated child slots. */ - int32_t allocated_slots; - /*! \brief Whether child can overflow. */ - bool child_slots_can_overflow{true}; - - Entry(int32_t type_index, int32_t type_depth, String type_key, int32_t num_slots, - bool child_slots_can_overflow, const Entry* parent) { - // setup fields in the class - this->type_key_data = std::move(type_key); - this->num_slots = num_slots; - this->allocated_slots = 1; - this->child_slots_can_overflow = child_slots_can_overflow; - // set up type acenstors information - if (type_depth != 0) { - TVM_FFI_ICHECK_NOTNULL(parent); - TVM_FFI_ICHECK_EQ(type_depth, parent->type_depth + 1); - type_acenstors_data.resize(type_depth); - // copy over parent's type information - for (int32_t i = 0; i < parent->type_depth; ++i) { - type_acenstors_data[i] = parent->type_acenstors[i]; - } - // set last type information to be parent - type_acenstors_data[parent->type_depth] = parent; - } - // initialize type info: no change to type_key and type_acenstors fields - // after this line - this->type_index = type_index; - this->type_depth = type_depth; - this->type_key = TVMFFIByteArray{this->type_key_data.data(), this->type_key_data.length()}; - this->type_key_hash = std::hash()(this->type_key_data); - this->type_acenstors = type_acenstors_data.data(); - // initialize the reflection information - this->num_fields = 0; - this->num_methods = 0; - this->fields = nullptr; - this->methods = nullptr; - this->metadata = nullptr; - } - }; - - struct TypeAttrColumnData : public TVMFFITypeAttrColumn { - std::vector data_; - }; - - int32_t GetOrAllocTypeIndex(String type_key, int32_t static_type_index, int32_t type_depth, - int32_t num_child_slots, bool child_slots_can_overflow, - int32_t parent_type_index) { - auto it = type_key2index_.find(type_key); - if (it != type_key2index_.end()) { - return type_table_[(*it).second]->type_index; - } - - // get parent's entry - Entry* parent = [&]() -> Entry* { - if (parent_type_index < 0) return nullptr; - // try to allocate from parent's type table. - TVM_FFI_ICHECK_LT(parent_type_index, type_table_.size()) - << " type_key=" << type_key << ", static_index=" << static_type_index; - return type_table_[parent_type_index].get(); - }(); - - // get allocated index - int32_t allocated_tindex = [&]() { - // Step 0: static allocation - if (static_type_index >= 0) { - TVM_FFI_ICHECK_LT(static_type_index, type_table_.size()); - TVM_FFI_ICHECK(type_table_[static_type_index] == nullptr) - << "Conflicting static index " << static_type_index << " between " - << ToStringView(type_table_[static_type_index]->type_key) << " and " << type_key; - return static_type_index; - } - TVM_FFI_ICHECK_NOTNULL(parent); - int num_slots = num_child_slots + 1; - if (parent->allocated_slots + num_slots <= parent->num_slots) { - // allocate the slot from parent's reserved pool - int32_t allocated_tindex = parent->type_index + parent->allocated_slots; - // update parent's state - parent->allocated_slots += num_slots; - return allocated_tindex; - } - // Step 2: allocate from overflow - TVM_FFI_ICHECK(parent->child_slots_can_overflow) - << "Reach maximum number of sub-classes for " << ToStringView(parent->type_key); - // allocate new entries. - int32_t allocated_tindex = type_counter_; - type_counter_ += num_slots; - TVM_FFI_ICHECK_LE(type_table_.size(), type_counter_); - type_table_.reserve(type_counter_); - // resize type table - while (static_cast(type_table_.size()) < type_counter_) { - type_table_.emplace_back(nullptr); - } - return allocated_tindex; - }(); - - // if parent cannot overflow, then this class cannot. - if (parent != nullptr && !(parent->child_slots_can_overflow)) { - child_slots_can_overflow = false; - } - // total number of slots include the type itself. - - if (parent != nullptr) { - TVM_FFI_ICHECK_GT(allocated_tindex, parent->type_index); - } - - type_table_[allocated_tindex] = - std::make_unique(allocated_tindex, type_depth, type_key, num_child_slots + 1, - child_slots_can_overflow, parent); - // update the key2index mapping. - type_key2index_.Set(type_key, allocated_tindex); - return allocated_tindex; - } - - int32_t TypeKeyToIndex(const TVMFFIByteArray* type_key) { - String type_key_str(type_key->data, type_key->size); - auto it = type_key2index_.find(type_key_str); - TVM_FFI_ICHECK(it != type_key2index_.end()) << "Cannot find type `" << type_key_str << "`"; - return static_cast((*it).second); - } - - Entry* GetTypeEntry(int32_t type_index) { - Entry* entry = nullptr; - if (type_index >= 0 && static_cast(type_index) < type_table_.size()) { - entry = type_table_[type_index].get(); - } - TVM_FFI_ICHECK(entry != nullptr) << "Cannot find type info for type_index=" << type_index; - return entry; - } - - void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { - Entry* entry = GetTypeEntry(type_index); - TVMFFIFieldInfo field_data = *info; - field_data.name = this->CopyString(info->name); - field_data.doc = this->CopyString(info->doc); - field_data.type_schema = this->CopyString(info->type_schema); - if (info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_data.default_value = - this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value)).CopyToTVMFFIAny(); - } else { - field_data.default_value = AnyView(nullptr).CopyToTVMFFIAny(); - } - entry->type_fields_data.push_back(field_data); - // refresh ptr as the data can change - entry->fields = entry->type_fields_data.data(); - entry->num_fields = static_cast(entry->type_fields_data.size()); - } - - void RegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info) { - Entry* entry = GetTypeEntry(type_index); - TVMFFIMethodInfo method_data = *info; - method_data.name = this->CopyString(info->name); - method_data.doc = this->CopyString(info->doc); - method_data.type_schema = this->CopyString(info->type_schema); - method_data.method = this->CopyAny(AnyView::CopyFromTVMFFIAny(info->method)).CopyToTVMFFIAny(); - entry->type_methods_data.push_back(method_data); - entry->methods = entry->type_methods_data.data(); - entry->num_methods = static_cast(entry->type_methods_data.size()); - } - - void RegisterTypeMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) { - Entry* entry = GetTypeEntry(type_index); - if (entry->metadata != nullptr) { - TVM_FFI_LOG_AND_THROW(RuntimeError) - << "Overriding " << ToStringView(entry->type_key) << ", possible causes:\n" - << "- two ObjectDef() calls for the same T \n" - << "- when we forget to assign _type_key to ObjectRef that inherits from T\n" - << "- another type with the same key is already registered\n" - << "Cross check the reflection registration."; - } - entry->metadata_data = *metadata; - entry->metadata_data.doc = this->CopyString(metadata->doc); - entry->metadata = &(entry->metadata_data); - } - - void RegisterTypeAttr(int32_t type_index, const TVMFFIByteArray* name, const TVMFFIAny* value) { - AnyView value_view = AnyView::CopyFromTVMFFIAny(*value); - String name_str(*name); - size_t column_index = 0; - auto it = type_attr_name_to_column_index_.find(name_str); - if (it == type_attr_name_to_column_index_.end()) { - column_index = type_attr_columns_.size(); - type_attr_columns_.emplace_back(std::make_unique()); - type_attr_name_to_column_index_.Set(name_str, column_index); - } else { - column_index = (*it).second; - } - TypeAttrColumnData* column = type_attr_columns_[column_index].get(); - if (column->data_.size() < static_cast(type_index + 1)) { - column->data_.resize(type_index + 1, Any(nullptr)); - column->data = reinterpret_cast(column->data_.data()); - column->size = column->data_.size(); - } - if (type_index == kTVMFFINone) return; - if (column->data_[type_index] != nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type attribute `" << name_str << "` is already set for type `" - << TypeIndexToTypeKey(type_index) << "`"; - } - column->data_[type_index] = value_view; - } - const TVMFFITypeAttrColumn* GetTypeAttrColumn(const TVMFFIByteArray* name) { - String name_str(*name); - auto it = type_attr_name_to_column_index_.find(name_str); - if (it == type_attr_name_to_column_index_.end()) return nullptr; - return type_attr_columns_[(*it).second].get(); - } - - void Dump(int min_children_count) { - std::vector num_children(type_table_.size(), 0); - // expected child slots compute the expected slots - // based on the current child slot setting - std::vector expected_child_slots(type_table_.size(), 0); - // reverse accumulation so we can get total counts in a bottom-up manner. - for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) { - const Entry* ptr = it->get(); - if (ptr != nullptr && ptr->type_depth != 0) { - int parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index; - num_children[parent_index] += num_children[ptr->type_index] + 1; - if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) { - expected_child_slots[ptr->type_index] = ptr->num_slots - 1; - } - expected_child_slots[parent_index] += expected_child_slots[ptr->type_index] + 1; - } - } - - for (const auto& ptr : type_table_) { - if (ptr != nullptr && num_children[ptr->type_index] >= min_children_count) { - std::cerr << '[' << ptr->type_index << "]\t" << ToStringView(ptr->type_key); - if (ptr->type_depth != 0) { - int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index; - std::cerr << "\tparent=" << ToStringView(type_table_[parent_index]->type_key); - } else { - std::cerr << "\tparent=root"; - } - std::cerr << "\tnum_child_slots=" << ptr->num_slots - 1 - << "\tnum_children=" << num_children[ptr->type_index] - << "\texpected_child_slots=" << expected_child_slots[ptr->type_index] - << std::endl; - } - } - } - - static TypeTable* Global() { - static TypeTable inst; - return &inst; - } - - private: - TypeTable() { - type_table_.reserve(TypeIndex::kTVMFFIDynObjectBegin); - for (int32_t i = 0; i < TypeIndex::kTVMFFIDynObjectBegin; ++i) { - type_table_.emplace_back(nullptr); - } - // initialize the entry for object - this->GetOrAllocTypeIndex(String(Object::_type_key), Object::_type_index, Object::_type_depth, - Object::_type_child_slots, Object::_type_child_slots_can_overflow, - -1); - TVMFFITypeMetadata info; - info.total_size = sizeof(Object); - info.creator = nullptr; - info.doc = TVMFFIByteArray{nullptr, 0}; - RegisterTypeMetadata(Object::_type_index, &info); - // reserve the static types - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFINone, TypeIndex::kTVMFFINone); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIInt, TypeIndex::kTVMFFIInt); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIFloat, TypeIndex::kTVMFFIFloat); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIBool, TypeIndex::kTVMFFIBool); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIRawStr, TypeIndex::kTVMFFIRawStr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIOpaquePtr, TypeIndex::kTVMFFIOpaquePtr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDataType, TypeIndex::kTVMFFIDataType); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIDevice, TypeIndex::kTVMFFIDevice); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIByteArrayPtr, TypeIndex::kTVMFFIByteArrayPtr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef, - TypeIndex::kTVMFFIObjectRValueRef); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallStr, TypeIndex::kTVMFFISmallStr); - ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFISmallBytes, TypeIndex::kTVMFFISmallBytes); - // no need to reserve for object types as they will be registered - } - - void ReserveBuiltinTypeIndex(const char* type_key, int32_t static_type_index) { - this->GetOrAllocTypeIndex(String(type_key), static_type_index, 0, 0, false, -1); - } - - static ObjectPtr MakeInplaceString(const char* data, size_t length) { - ObjectPtr p = - make_inplace_array_object(length + 1); - static_assert(alignof(details::StringObj) % alignof(char) == 0); - static_assert(sizeof(details::StringObj) % alignof(char) == 0); - char* dest_data = reinterpret_cast(p.get()) + sizeof(details::StringObj); - p->data = dest_data; - p->size = length; - std::memcpy(dest_data, data, length); - dest_data[length] = '\0'; - return p; - } - - TVMFFIByteArray CopyString(TVMFFIByteArray str) { - if (str.size == 0) { - return TVMFFIByteArray{nullptr, 0}; - } - // use explicit object creation to ensure the space pointer to not move - auto str_obj = MakeInplaceString(str.data, str.size); - TVMFFIByteArray c_val{str_obj->data, str_obj->size}; - any_pool_.emplace_back(ObjectRef(std::move(str_obj))); - return c_val; - } - - AnyView CopyAny(Any val) { - AnyView view = AnyView(val); - any_pool_.emplace_back(std::move(val)); - return view; - } - - int64_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin}; - std::vector> type_table_; - Map type_key2index_; - std::vector any_pool_; - // type attribute columns - std::vector> type_attr_columns_; - Map type_attr_name_to_column_index_; -}; - -} // namespace ffi -} // namespace tvm - -int TVMFFIObjectFree(TVMFFIObjectHandle handle) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) { - TVM_FFI_SAFE_CALL_BEGIN(); - out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeMethod(type_index, info); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeMetadata(type_index, metadata); - TVM_FFI_SAFE_CALL_END(); -} - -int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* name, - const TVMFFIAny* value) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::TypeTable::Global()->RegisterTypeAttr(type_index, name, value); - TVM_FFI_SAFE_CALL_END(); -} - -const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* name) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::TypeTable::Global()->GetTypeAttrColumn(name); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeAttrColumn); -} - -int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t static_type_index, - int32_t type_depth, int32_t num_child_slots, - int32_t child_slots_can_overflow, int32_t parent_type_index) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - tvm::ffi::String s_type_key(type_key->data, type_key->size); - return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex( - s_type_key, static_type_index, type_depth, num_child_slots, child_slots_can_overflow, - parent_type_index); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFITypeGetOrAllocIndex); -} - -const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::TypeTable::Global()->GetTypeEntry(type_index); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo); -} diff --git a/ffi/src/ffi/traceback.cc b/ffi/src/ffi/traceback.cc deleted file mode 100644 index 90d02121f0f5..000000000000 --- a/ffi/src/ffi/traceback.cc +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file traceback.cc - * \brief Traceback implementation on non-windows platforms - * \note We use the term "traceback" to be consistent with python naming convention. - */ -#ifndef _MSC_VER - -#include "./traceback.h" - -#include -#include - -#if TVM_FFI_USE_LIBBACKTRACE - -#include -#include - -#include -#include -#include -#include - -#if TVM_FFI_BACKTRACE_ON_SEGFAULT -#include -#endif - -namespace tvm { -namespace ffi { -namespace { - -void BacktraceCreateErrorCallback(void*, const char* msg, int) { - std::cerr << "Could not initialize backtrace state: " << msg << std::endl; -} - -backtrace_state* BacktraceCreate() { - return backtrace_create_state(nullptr, 1, BacktraceCreateErrorCallback, nullptr); -} - -static backtrace_state* _bt_state = BacktraceCreate(); - -std::string DemangleName(std::string name) { - int status = 0; - size_t length = name.size(); - char* demangled_name = abi::__cxa_demangle(name.c_str(), nullptr, &length, &status); - if (demangled_name && status == 0 && length > 0) { - name = demangled_name; - } - if (demangled_name) { - std::free(demangled_name); - } - return name; -} - -void BacktraceErrorCallback(void*, const char*, int) { - // do nothing -} - -void BacktraceSyminfoCallback(void* data, uintptr_t pc, const char* symname, uintptr_t, uintptr_t) { - auto str = reinterpret_cast(data); - - if (symname != nullptr) { - *str = DemangleName(symname); - } else { - std::ostringstream s; - s << "0x" << std::setfill('0') << std::setw(sizeof(uintptr_t) * 2) << std::hex << pc; - *str = s.str(); - } -} - -int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int lineno, - const char* symbol) { - auto stack_trace = reinterpret_cast(data); - std::string symbol_str = ""; - if (symbol) { - symbol_str = DemangleName(symbol); - } else { - // see if syminfo gives anything - backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, BacktraceErrorCallback, &symbol_str); - } - symbol = symbol_str.data(); - - if (stack_trace->ExceedTracebackLimit()) { - return 1; - } - if (ShouldStopTraceback(filename, symbol)) { - return 1; - } - if (ShouldExcludeFrame(filename, symbol)) { - return 0; - } - stack_trace->Append(filename, symbol, lineno); - return 0; -} - -std::string Traceback() { - TracebackStorage traceback; - - if (_bt_state == nullptr) { - return ""; - } - // libbacktrace eats memory if run on multiple threads at the same time, so we guard against it - { - static std::mutex m; - std::lock_guard lock(m); - backtrace_full(_bt_state, 0, BacktraceFullCallback, BacktraceErrorCallback, &traceback); - } - return traceback.GetTraceback(); -} - -#if TVM_FFI_BACKTRACE_ON_SEGFAULT -void backtrace_handler(int sig) { - // Technically we shouldn't do any allocation in a signal handler, but - // Backtrace may allocate. What's the worst it could do? We're already - // crashing. - std::cerr << "!!!!!!! TVM FFI encountered a Segfault !!!!!!!\n" << Traceback() << std::endl; - - // Re-raise signal with default handler - struct sigaction act; - std::memset(&act, 0, sizeof(struct sigaction)); - act.sa_flags = SA_RESETHAND; - act.sa_handler = SIG_DFL; - sigaction(sig, &act, nullptr); - raise(sig); -} - -__attribute__((constructor)) void install_signal_handler(void) { - // this may override already installed signal handlers - std::signal(SIGSEGV, backtrace_handler); -} -#endif // TVM_FFI_BACKTRACE_ON_SEGFAULT -} // namespace -} // namespace ffi -} // namespace tvm - -const TVMFFIByteArray* TVMFFITraceback(const char*, int, const char*) { - static thread_local std::string traceback_str; - static thread_local TVMFFIByteArray traceback_array; - traceback_str = ::tvm::ffi::Traceback(); - traceback_array.data = traceback_str.data(); - traceback_array.size = traceback_str.size(); - return &traceback_array; -} -#else -// fallback implementation simply print out the last trace -const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) { - static thread_local std::string traceback_str; - static thread_local TVMFFIByteArray traceback_array; - std::ostringstream traceback_stream; - // python style backtrace - traceback_stream << " File \"" << filename << "\", line " << lineno << ", in " << func << '\n'; - traceback_str = traceback_stream.str(); - traceback_array.data = traceback_str.data(); - traceback_array.size = traceback_str.size(); - return &traceback_array; -} -#endif // TVM_FFI_USE_LIBBACKTRACE -#endif // _MSC_VER diff --git a/ffi/src/ffi/traceback.h b/ffi/src/ffi/traceback.h deleted file mode 100644 index 47b91e16b0f7..000000000000 --- a/ffi/src/ffi/traceback.h +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file traceback.h - * \brief Common headers for traceback. - * \note We use the term "traceback" to be consistent with python naming convention. - */ -#ifndef TVM_FFI_TRACEBACK_H_ -#define TVM_FFI_TRACEBACK_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4996) // std::getenv is unsafe -#endif - -inline int32_t GetTracebackLimit() { - if (const char* env = std::getenv("TVM_TRACEBACK_LIMIT")) { - return std::stoi(env); - } - return 512; -} - -#ifdef _MSC_VER -#pragma warning(pop) -#endif - -/*! - * \brief List frame patterns that should be excluded as they contain less information - */ -inline bool ShouldExcludeFrame(const char* filename, const char* symbol) { - if (filename) { - // Stack frames for TVM FFI - if (strstr(filename, "include/tvm/ffi/error.h")) { - return true; - } - if (strstr(filename, "include/tvm/ffi/function_details.h")) { - return true; - } - if (strstr(filename, "include/tvm/ffi/function.h")) { - return true; - } - if (strstr(filename, "include/tvm/ffi/any.h")) { - return true; - } - if (strstr(filename, "include/tvm/runtime/logging.h")) { - return true; - } - if (strstr(filename, "src/ffi/traceback.cc")) { - return true; - } - // C++ stdlib frames - if (strstr(filename, "include/c++/")) { - return true; - } - } - - if (symbol) { - // C++ stdlib frames - if (strstr(symbol, "__libc_")) { - return true; - } - } - if (strncmp(symbol, "TVMFFIErrorSetRaisedFromCStr", 28) == 0) { - return true; - } - // libffi.so stack frames. These may also show up as numeric - // addresses with no symbol name. This could be improved in the - // future by using dladdr() to check whether an address is contained - // in libffi.so - if (strstr(symbol, "ffi_call_")) { - return true; - } - return false; -} - -/** - * \brief List frames that should stop the traceback. - * \param filename The filename of the frame. - * \param symbol The symbol name of the frame. - * \return true if the frame should stop the traceback. - * \note We stop traceback at the FFI boundary. - */ -inline bool ShouldStopTraceback(const char* filename, const char* symbol) { - if (symbol != nullptr) { - if (strncmp(symbol, "TVMFFIFunctionCall", 14) == 0) { - return true; - } - // Python interpreter stack frames - // we stop traceback at the Python interpreter stack frames - // since these frame will be handled from by the python side. - if (strncmp(symbol, "_Py", 3) == 0 || strncmp(symbol, "PyObject", 9) == 0) { - return true; - } - } - return false; -} - -/*! - * \brief storage to store traceback - */ -struct TracebackStorage { - std::vector lines; - /*! \brief Maximum size of the traceback. */ - size_t max_frame_size = GetTracebackLimit(); - - void Append(const char* filename, const char* func, int lineno) { - // skip frames with empty filename - if (filename == nullptr) { - if (func != nullptr) { - if (strncmp(func, "0x0", 3) == 0) { - return; - } - filename = ""; - } else { - return; - } - } - std::ostringstream trackeback_stream; - trackeback_stream << " File \"" << filename << "\""; - if (lineno != 0) { - trackeback_stream << ", line " << lineno; - } - trackeback_stream << ", in " << func << '\n'; - lines.push_back(trackeback_stream.str()); - } - - bool ExceedTracebackLimit() const { return lines.size() >= max_frame_size; } - - // get traceback in the order of most recent call last - std::string GetTraceback() const { - std::string traceback; - for (auto it = lines.rbegin(); it != lines.rend(); ++it) { - traceback.insert(traceback.end(), it->begin(), it->end()); - } - return traceback; - } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_TRACEBACK_H_ diff --git a/ffi/src/ffi/traceback_win.cc b/ffi/src/ffi/traceback_win.cc deleted file mode 100644 index 8278de1d77cf..000000000000 --- a/ffi/src/ffi/traceback_win.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file traceback_win.cc - * \brief Traceback implementation on windows platform - * \note We use the term "traceback" to be consistent with python naming convention. - */ -#ifdef _MSC_VER - -// clang-format off -#include -#include // NOLINT(*) -// clang-format on - -#include -#include - -#include -#include - -#include "./traceback.h" - -namespace tvm { -namespace ffi { -namespace { - -std::string Traceback() { - TracebackStorage traceback; - HANDLE process = GetCurrentProcess(); - HANDLE thread = GetCurrentThread(); - - SymSetOptions(SYMOPT_LOAD_LINES | SYMOPT_UNDNAME); - SymInitialize(process, NULL, TRUE); - CONTEXT context = {}; - RtlCaptureContext(&context); - - STACKFRAME64 stack = {}; - DWORD machine_type; - -#if defined(_M_X64) - machine_type = IMAGE_FILE_MACHINE_AMD64; - stack.AddrPC.Offset = context.Rip; - stack.AddrFrame.Offset = context.Rbp; - stack.AddrStack.Offset = context.Rsp; -#elif defined(_M_IX86) - machine_type = IMAGE_FILE_MACHINE_I386; - stack.AddrPC.Offset = context.Eip; - stack.AddrFrame.Offset = context.Ebp; - stack.AddrStack.Offset = context.Esp; -#else -#error "Platform not supported!" -#endif - - stack.AddrPC.Mode = AddrModeFlat; - stack.AddrFrame.Mode = AddrModeFlat; - stack.AddrStack.Mode = AddrModeFlat; - - while (!traceback.ExceedTracebackLimit()) { - if (!StackWalk64(machine_type, process, thread, &stack, &context, nullptr, - SymFunctionTableAccess64, SymGetModuleBase64, nullptr)) { - break; - } - - if (stack.AddrPC.Offset == 0) { - break; - } - const char* filename = nullptr; - const char* symbol = ""; - int lineno = 0; - // Get file and line number - IMAGEHLP_LINE64 line_info; - ZeroMemory(&line_info, sizeof(IMAGEHLP_LINE64)); - line_info.SizeOfStruct = sizeof(IMAGEHLP_LINE64); - DWORD displacement32 = 0; - - if (SymGetLineFromAddr64(process, stack.AddrPC.Offset, &displacement32, &line_info)) { - filename = line_info.FileName; - lineno = line_info.LineNumber; - } - // allocate symbol info that aligns to the SYMBOL_INFO - // we use u64 here to be safe - size_t total_symbol_bytes = sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR); - size_t total_u64_words = (total_symbol_bytes + 7) / 8; - static_assert(8 % alignof(SYMBOL_INFO) == 0); - std::vector symbol_buffer(total_u64_words, 0); - PSYMBOL_INFO symbol_info = reinterpret_cast(symbol_buffer.data()); - symbol_info->SizeOfStruct = sizeof(SYMBOL_INFO); - symbol_info->MaxNameLen = MAX_SYM_NAME; - DWORD64 displacement = 0; - if (SymFromAddr(process, stack.AddrPC.Offset, &displacement, symbol_info)) { - symbol = symbol_info->Name; - } - - if (ShouldStopTraceback(filename, symbol)) { - break; - } - if (ShouldExcludeFrame(filename, symbol)) { - continue; - } - traceback.Append(filename, symbol, lineno); - } - SymCleanup(process); - return traceback.GetTraceback(); -} -} // namespace -} // namespace ffi -} // namespace tvm - -const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) { - static thread_local std::string traceback_str; - static thread_local TVMFFIByteArray traceback_array; - traceback_str = ::tvm::ffi::Traceback(); - traceback_array.data = traceback_str.data(); - traceback_array.size = traceback_str.size(); - return &traceback_array; -} -#endif // _MSC_VER diff --git a/ffi/tests/cpp/CMakeLists.txt b/ffi/tests/cpp/CMakeLists.txt deleted file mode 100644 index 0c820fc80ea8..000000000000 --- a/ffi/tests/cpp/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc") -file(GLOB _test_extra_sources "${CMAKE_CURRENT_SOURCE_DIR}/extra/test*.cc") - -if (TVM_FFI_USE_EXTRA_CXX_API) - list(APPEND _test_sources ${_test_extra_sources}) -endif() - -add_executable( - tvm_ffi_tests - EXCLUDE_FROM_ALL - ${_test_sources} -) -set_target_properties( - tvm_ffi_tests PROPERTIES - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - CXX_EXTENSIONS OFF - MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>" - ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" -) -add_cxx_warning(tvm_ffi_tests) -add_sanitizer_address(tvm_ffi_tests) -add_dsymutil(tvm_ffi_tests) -add_msvc_flags(tvm_ffi_tests) -target_link_libraries(tvm_ffi_tests PRIVATE tvm_ffi_shared) -add_googletest(tvm_ffi_tests) - -if (MSVC) - target_link_options(tvm_ffi_tests PRIVATE /DEBUG) -endif() diff --git a/ffi/tests/cpp/extra/test_json_parser.cc b/ffi/tests/cpp/extra/test_json_parser.cc deleted file mode 100644 index a1cc2800094f..000000000000 --- a/ffi/tests/cpp/extra/test_json_parser.cc +++ /dev/null @@ -1,394 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include - -namespace { - -using namespace tvm::ffi; - -inline bool FastMathSafeIsNaN(double x) { -#ifdef __FAST_MATH__ - // Bit-level NaN detection (IEEE 754 double) - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // NaN is encoded as all 1s in the exponent and non-zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - uint64_t bits = *reinterpret_cast(&x); - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - return (exponent == 0x7FF) && (mantissa != 0); -#else - // Safe to use std::isnan when fast-math is off - return std::isnan(x); -#endif -} - -inline bool FastMathSafeIsInf(double x) { -#ifdef __FAST_MATH__ - // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 - // Inf is encoded as all 1s in the exponent and zero in the mantissa - static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); - uint64_t bits = *reinterpret_cast(&x); - uint64_t exponent = (bits >> 52) & 0x7FF; - uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; - // inf is encoded as all 1s in the exponent and zero in the mantissa - return (exponent == 0x7FF) && (mantissa == 0); -#else - return std::isinf(x); -#endif -} - -TEST(JSONParser, BoolNull) { - // boolean value - EXPECT_EQ(json::Parse("true").cast(), true); - EXPECT_EQ(json::Parse("false").cast(), false); - EXPECT_EQ(json::Parse("null"), nullptr); -} - -TEST(JSONParser, WrongBoolNull) { - String error_msg; - EXPECT_EQ(json::Parse("nul", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("fals", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("\n\nfx", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 3 column 1 (char 2)"); - EXPECT_EQ(json::Parse("fx", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("n1", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("t1", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("f1", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); -} - -TEST(JSONParser, Number) { - // number - EXPECT_EQ(json::Parse("123").cast(), 123); - EXPECT_EQ(json::Parse("-124").cast(), -124); - EXPECT_EQ(json::Parse("123.456").cast(), 123.456); - // parsing scientific notation - EXPECT_EQ(json::Parse("1.456e12").cast(), 1.456e12); - // NaN - EXPECT_EQ(FastMathSafeIsNaN(json::Parse("NaN").cast()), true); - // Infinity - EXPECT_EQ(FastMathSafeIsInf(json::Parse("Infinity").cast()), true); - // -Infinity - EXPECT_EQ(FastMathSafeIsInf(-json::Parse("-Infinity").cast()), true); - - // Test zero variants - EXPECT_EQ(json::Parse("0").cast(), 0); - EXPECT_EQ(json::Parse("-0").cast(), -0.0); - EXPECT_EQ(json::Parse("0.0").cast(), 0.0); - - // Test very large numbers - EXPECT_EQ(json::Parse("9223372036854775807").cast(), - std::numeric_limits::max()); - EXPECT_EQ(json::Parse("-9223372036854775808").cast(), - std::numeric_limits::min()); - - // Test very small decimals - EXPECT_EQ(json::Parse("1e-10").cast(), 1e-10); - EXPECT_EQ(json::Parse("-1e-10").cast(), -1e-10); - - // Test scientific notation edge cases - EXPECT_EQ(json::Parse("1E+10").cast(), 1E+10); - EXPECT_EQ(json::Parse("1e+10").cast(), 1e+10); - EXPECT_EQ(json::Parse("1E-10").cast(), 1E-10); - EXPECT_EQ(json::Parse("123.456E+10").cast(), 123.456E+10); -} - -TEST(JSONParser, WrongNumber) { - String error_msg; - EXPECT_EQ(json::Parse("123.456.789", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - - // Test invalid number formats - EXPECT_EQ(json::Parse("123e", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("123e+", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - EXPECT_EQ(json::Parse("123E-", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); -} - -TEST(JSONParser, String) { - EXPECT_EQ(json::Parse("\"hello\"").cast(), "hello"); - EXPECT_EQ(json::Parse("\n\t \"hello\"\n\r").cast(), "hello"); - EXPECT_EQ(json::Parse("\"hello\\nworld\"").cast(), "hello\nworld"); - EXPECT_EQ(json::Parse("\"\"").cast(), ""); - // test escape characters - EXPECT_EQ(json::Parse("\"\\ta\\n\\/\\f\\\"\\\\\"").cast(), "\ta\n/\f\"\\"); - // test unicode code point - EXPECT_EQ(json::Parse("\"\\u0041\"").cast(), "A"); - // test unicode surrogate pair - EXPECT_EQ(json::Parse("\"\\uD83D\\uDE04hello\"").cast(), u8"\U0001F604hello"); -} - -TEST(JSONParser, WrongString) { - String error_msg; - EXPECT_EQ(json::Parse("\"hello", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Unterminated string starting at: line 1 column 1 (char 0)"); - - EXPECT_EQ(json::Parse("\"hello\x01\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid control character at: line 1 column 7 (char 6)"); - - EXPECT_EQ(json::Parse("\"hello\\uxx\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid \\uXXXX escape: line 1 column 8 (char 7)"); - - EXPECT_EQ(json::Parse("\"hello\\uDC00\\uDE04\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid surrogate pair of \\uXXXX escapes: line 1 column 8 (char 7)"); - - EXPECT_EQ(json::Parse("\"hello\\uD800\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid surrogate pair of \\uXXXX escapes: line 1 column 8 (char 7)"); - - EXPECT_EQ(json::Parse("\"hello\\uD800\\uxx\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid \\uXXXX escape: line 1 column 15 (char 14)"); - - EXPECT_EQ(json::Parse("\"hello\\a\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Invalid \\escape: line 1 column 8 (char 7)"); -} - -TEST(JSONParser, Array) { - EXPECT_TRUE(StructuralEqual()(json::Parse("[]"), json::Array{})); - - EXPECT_TRUE(StructuralEqual()(json::Parse("[1, 2,\n\t\"a\"]"), json::Array{1, 2, "a"})); -} - -TEST(JSONParser, WrongArray) { - String error_msg; - - EXPECT_EQ(json::Parse("]", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - - EXPECT_EQ(json::Parse("[1,]", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 4 (char 3)"); - - EXPECT_EQ(json::Parse("[", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 2 (char 1)"); - - EXPECT_EQ(json::Parse("[1a", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ',' delimiter: line 1 column 3 (char 2)"); - - EXPECT_EQ(json::Parse("[1,2,3", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ',' delimiter: line 1 column 7 (char 6)"); - - EXPECT_EQ(json::Parse("[1] a", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 6 (char 5)"); -} - -TEST(JSONParser, Object) { - EXPECT_TRUE(StructuralEqual()(json::Parse("{}"), json::Object{})); - - EXPECT_TRUE(StructuralEqual()(json::Parse("{\"a\": 1, \n\"b\": \t\"c\"} "), - json::Object{{"a", 1}, {"b", "c"}})); -} - -TEST(JSONParser, ObjectOrderPreserving) { - auto obj = json::Parse("{\"c\": 1, \"a\": 2, \"b\": 3} "); - json::Array keys; - for (auto& [key, value] : obj.cast()) { - keys.push_back(key); - } - EXPECT_TRUE(StructuralEqual()(keys, json::Array{"c", "a", "b"})); -} - -TEST(JSONParser, WrongObject) { - String error_msg; - EXPECT_EQ(json::Parse("{\"a\":", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 6 (char 5)"); - - EXPECT_EQ(json::Parse("{", &error_msg), nullptr); - EXPECT_EQ(error_msg, - "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)"); - - // Test incomplete structures - EXPECT_EQ(json::Parse("{\"incomplete\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ':' delimiter: line 1 column 14 (char 13)"); -} - -TEST(JSONParser, NestedObject) { - EXPECT_TRUE( - StructuralEqual()(json::Parse("{\"a\": \t{\"b\": 1}, \n\"c\": [1, 2, 3]}"), - json::Object{{"a", json::Object{{"b", 1}}}, {"c", json::Array{1, 2, 3}}})); - - EXPECT_TRUE(StructuralEqual()( - json::Parse("{\"a\": \t{\"b\": 1}, \n\"c\": [1, null, Infinity]}"), - json::Object{{"a", json::Object{{"b", 1}}}, - {"c", json::Array{1, nullptr, std::numeric_limits::infinity()}}})); - - EXPECT_TRUE(StructuralEqual()( - json::Parse("[{}, {\"a\": [1.1, 1000000]}]"), - json::Array{json::Object{}, json::Object{{"a", json::Array{1.1, 1000000}}}})); -} - -TEST(JSONParser, WrongNestedObject) { - String error_msg; - EXPECT_EQ(json::Parse("{\"a\":\n\n[1]", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting ',' delimiter: line 3 column 4 (char 10)"); - - EXPECT_EQ(json::Parse("{\"a\":\n\n[abc]}", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 3 column 2 (char 8)"); -} - -// edge cases -TEST(JSONParser, WhitespaceHandling) { - // Test various whitespace characters - EXPECT_EQ(json::Parse(" \t\n\r true \t\n\r ").cast(), true); - EXPECT_EQ(json::Parse("\n\n\n123\n\n\n").cast(), 123); - EXPECT_EQ(json::Parse(" \"hello world\" ").cast(), "hello world"); - - // Test whitespace in arrays and objects - EXPECT_TRUE(StructuralEqual()(json::Parse(" [ 1 , 2 , 3 ] "), json::Array{1, 2, 3})); - - EXPECT_TRUE(StructuralEqual()(json::Parse(" { \"a\" : 1 , \"b\" : 2 } "), - json::Object{{"a", 1}, {"b", 2}})); -} - -TEST(JSONParser, WrongEmptyAndMinimalInputs) { - String error_msg; - // Test empty string - EXPECT_EQ(json::Parse("", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 1 column 1 (char 0)"); - - // Test only whitespace - EXPECT_EQ(json::Parse(" \t\n ", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Expecting value: line 2 column 5 (char 9)"); -} - -TEST(JSONParser, UnicodeEdgeCases) { - // Test various unicode characters - EXPECT_EQ(json::Parse("\"\\u0000\"").cast(), std::string("\0", 1)); - // replace using \U to avoid encoding issues - EXPECT_EQ(json::Parse("\"\\u00FF\"").cast(), u8"\U000000FF"); - EXPECT_EQ(json::Parse("\"\\u4E2D\\u6587\"").cast(), u8"\U00004E2D\U00006587"); - - // Test multiple surrogate pairs - EXPECT_EQ(json::Parse("\"\\uD83D\\uDE00\\uD83D\\uDE01\"").cast(), - u8"\U0001F600\U0001F601"); -} - -TEST(JSONParser, LargeInputs) { - // Test large array - std::string large_array = "["; - for (int i = 0; i < 1000; ++i) { - if (i > 0) large_array += ","; - large_array += std::to_string(i); - } - large_array += "]"; - - auto result = json::Parse(large_array); - EXPECT_TRUE(result != nullptr); - EXPECT_EQ(result.cast().size(), 1000); - - // Test large object - std::string large_object = "{"; - for (int i = 0; i < 500; ++i) { - if (i > 0) large_object += ","; - large_object += "\"key" + std::to_string(i) + "\":" + std::to_string(i); - } - large_object += "}"; - - result = json::Parse(large_object); - EXPECT_TRUE(result != nullptr); - EXPECT_EQ(result.cast().size(), 500); -} - -TEST(JSONParser, MixedDataTypes) { - // Test complex nested structure with all data types - std::string complex_json = R"({ - "null_value": null, - "boolean_true": true, - "boolean_false": false, - "integer": 42, - "negative_integer": -42, - "float": 3.14159, - "scientific": 1.23e-4, - "string": "hello world", - "unicode_string": "Hello \u4e16\u754c \ud83c\udf0d", - "empty_string": "", - "empty_array": [], - "empty_object": {}, - "number_array": [1, 2, 3, 4, 5], - "mixed_array": [1, "two", true, null, 3.14], - "nested_object": { - "level1": { - "level2": { - "data": [1, 2, {"nested_array": [true, false]}] - } - } - } - })"; - - auto result = json::Parse(complex_json); - - // Create expected structure for comparison - json::Object expected{ - {"null_value", nullptr}, - {"boolean_true", true}, - {"boolean_false", false}, - {"integer", 42}, - {"negative_integer", -42}, - {"float", 3.14159}, - {"scientific", 1.23e-4}, - {"string", "hello world"}, - {"unicode_string", u8"Hello \U00004E16\U0000754C \U0001F30D"}, - {"empty_string", ""}, - {"empty_array", json::Array{}}, - {"empty_object", json::Object{}}, - {"number_array", json::Array{1, 2, 3, 4, 5}}, - {"mixed_array", json::Array{1, "two", true, nullptr, 3.14}}, - {"nested_object", - json::Object{ - {"level1", - json::Object{ - {"level2", - json::Object{ - {"data", - json::Array{1, 2, - json::Object{{"nested_array", json::Array{true, false}}}}}}}}}}}}; - - EXPECT_TRUE(StructuralEqual()(result, expected)); -} - -TEST(JSONParser, WrongExtraData) { - String error_msg; - - EXPECT_EQ(json::Parse("truee", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 5 (char 4)"); - - EXPECT_EQ(json::Parse("true false", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 6 (char 5)"); - - EXPECT_EQ(json::Parse("123 456", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 5 (char 4)"); - - EXPECT_EQ(json::Parse("\"hello\" \"world\"", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 9 (char 8)"); - - EXPECT_EQ(json::Parse("{} []", &error_msg), nullptr); - EXPECT_EQ(error_msg, "Extra data: line 1 column 4 (char 3)"); -} -} // namespace diff --git a/ffi/tests/cpp/extra/test_json_writer.cc b/ffi/tests/cpp/extra/test_json_writer.cc deleted file mode 100644 index ae6172c2e53b..000000000000 --- a/ffi/tests/cpp/extra/test_json_writer.cc +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include - -namespace { - -using namespace tvm::ffi; - -TEST(JSONWriter, BoolNull) { - // boolean value - EXPECT_EQ(json::Stringify(json::Value(true)), "true"); - EXPECT_EQ(json::Stringify(json::Value(false)), "false"); - EXPECT_EQ(json::Stringify(json::Value(nullptr)), "null"); -} - -TEST(JSONWriter, Integer) { - // positive integer - EXPECT_EQ(json::Stringify(json::Value(42)), "42"); - // negative integer - EXPECT_EQ(json::Stringify(json::Value(-123)), "-123"); - // zero - EXPECT_EQ(json::Stringify(json::Value(0)), "0"); - // large positive integer - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::max())), - "9223372036854775807"); - // large negative integer - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::min())), - "-9223372036854775808"); -} - -TEST(JSONWriter, Float) { - // regular float - EXPECT_EQ(json::Stringify(json::Value(2.5)), "2.5"); - // integer-like float (should have .0 suffix) - EXPECT_EQ(json::Stringify(json::Value(5.0)), "5.0"); - EXPECT_EQ(json::Stringify(json::Value(-10.0)), "-10.0"); - // zero float - EXPECT_EQ(json::Stringify(json::Value(0.0)), "0.0"); - // scientific notation for very small numbers - EXPECT_EQ(json::Stringify(json::Value(-7.89e-15)), "-7.89e-15"); - // short scientific notation (shorter than fixed-point) - EXPECT_EQ(json::Stringify(json::Value(2e-8)), "2e-08"); - // NaN - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::quiet_NaN())), "NaN"); - // positive infinity - EXPECT_EQ(json::Stringify(json::Value(std::numeric_limits::infinity())), "Infinity"); - // negative infinity - EXPECT_EQ(json::Stringify(json::Value(-std::numeric_limits::infinity())), "-Infinity"); -} - -TEST(JSONWriter, String) { - // simple string - EXPECT_EQ(json::Stringify(json::Value(String("hello"))), "\"hello\""); - // empty string - EXPECT_EQ(json::Stringify(json::Value(String(""))), "\"\""); - // string with escaped characters - EXPECT_EQ(json::Stringify(json::Value(String("\"quoted\""))), "\"\\\"quoted\\\"\""); - EXPECT_EQ(json::Stringify(json::Value(String("backslash\\"))), "\"backslash\\\\\""); - EXPECT_EQ(json::Stringify(json::Value(String("forward/slash"))), "\"forward\\/slash\""); - EXPECT_EQ(json::Stringify(json::Value(String("line\nbreak"))), "\"line\\nbreak\""); - EXPECT_EQ(json::Stringify(json::Value(String("tab\there"))), "\"tab\\there\""); - EXPECT_EQ(json::Stringify(json::Value(String("carriage\rreturn"))), "\"carriage\\rreturn\""); - // string with control character - EXPECT_EQ(json::Stringify(json::Value(String(std::string("\x01", 1) + "control"))), - "\"\\u0001control\""); -} - -TEST(JSONWriter, Array) { - // empty array - json::Array empty_array; - EXPECT_EQ(json::Stringify(empty_array), "[]"); - - // single element array - json::Array single_array{42}; - EXPECT_EQ(json::Stringify(single_array), "[42]"); - - // multiple elements array - json::Array multi_array{1, "hello", true}; - EXPECT_EQ(json::Stringify(multi_array), "[1,\"hello\",true]"); - - // nested array - json::Array nested_array{json::Array{1, 2}, 3}; - EXPECT_EQ(json::Stringify(nested_array), "[[1,2],3]"); -} - -TEST(JSONWriter, Object) { - // empty object - json::Object empty_object; - EXPECT_EQ(json::Stringify(empty_object), "{}"); - - // single key-value pair - json::Object single_object{{String("key"), String("value")}}; - EXPECT_EQ(json::Stringify(single_object), "{\"key\":\"value\"}"); - - // multiple key-value pairs - insertion order preservation - json::Object multi_object{{"name", "Alice"}, {"age", 30}, {"active", true}, {"score", 95.5}}; - EXPECT_EQ(json::Stringify(multi_object), - "{\"name\":\"Alice\",\"age\":30,\"active\":true,\"score\":95.5}"); -} - -TEST(JSONWriter, InsertionOrderPreservation) { - // test that objects preserve insertion order - json::Object ordered_object{ - {"zebra", "last"}, {"alpha", "first"}, {"beta", "middle"}, {"gamma", 123}, {"delta", true}}; - EXPECT_EQ( - json::Stringify(ordered_object), - "{\"zebra\":\"last\",\"alpha\":\"first\",\"beta\":\"middle\",\"gamma\":123,\"delta\":true}"); - - // test with indentation to verify order is preserved - std::string ordered_indented = json::Stringify(ordered_object, 2); - EXPECT_EQ(ordered_indented, String(R"({ - "zebra": "last", - "alpha": "first", - "beta": "middle", - "gamma": 123, - "delta": true -})")); - - // test nested objects also preserve order - json::Object nested_ordered{ - {"outer1", - json::Object{{"inner_z", "z_value"}, {"inner_a", "a_value"}, {"inner_m", "m_value"}}}, - {"outer2", json::Object{{"third", 3}, {"first", 1}, {"second", 2}}}}; - std::string nested_ordered_indented = json::Stringify(nested_ordered, 2); - EXPECT_EQ(nested_ordered_indented, String(R"({ - "outer1": { - "inner_z": "z_value", - "inner_a": "a_value", - "inner_m": "m_value" - }, - "outer2": { - "third": 3, - "first": 1, - "second": 2 - } -})")); -} - -TEST(JSONWriter, NestedStructures) { - // object containing array - json::Object obj_with_array{{String("numbers"), json::Array{1, 2, 3}}}; - EXPECT_EQ(json::Stringify(obj_with_array), "{\"numbers\":[1,2,3]}"); - - // array containing object - json::Array arr_with_obj{json::Object{{String("key"), String("value")}}}; - EXPECT_EQ(json::Stringify(arr_with_obj), "[{\"key\":\"value\"}]"); - - // deeply nested structure - json::Object nested_obj{ - {String("nested"), json::Array{json::Object{{String("deep"), String("value")}}}}}; - EXPECT_EQ(json::Stringify(nested_obj), "{\"nested\":[{\"deep\":\"value\"}]}"); -} - -TEST(JSONWriter, Indentation) { - // test with indentation - json::Array arr{1, 2}; - std::string indented = json::Stringify(arr, 2); - EXPECT_EQ(indented, String(R"([ - 1, - 2 -])")); - - // object with indentation - json::Object obj{{"key", "value"}}; - std::string indented_obj = json::Stringify(obj, 2); - EXPECT_EQ(indented_obj, String(R"({ - "key": "value" -})")); - - // complex nested structure with multiple data types - // keep double as .5 so output is deterministic as they exactly rounds to power of 2 - json::Object complex_nested{ - {"name", "test"}, - {"count", 42}, - {"price", 3.5}, - {"active", true}, - {"metadata", nullptr}, - {"numbers", json::Array{1, 2, 3}}, - {"config", json::Object{{"enabled", false}, - {"timeout", 30.5}, - {"tags", json::Array{"production", "critical", nullptr}}}}, - {"matrix", json::Array{json::Array{1, 2}, json::Array{3.5, 4.5}, json::Array{"a", "b"}}}}; - std::string complex_indented = json::Stringify(complex_nested, 2); - EXPECT_EQ(complex_indented, String(R"({ - "name": "test", - "count": 42, - "price": 3.5, - "active": true, - "metadata": null, - "numbers": [ - 1, - 2, - 3 - ], - "config": { - "enabled": false, - "timeout": 30.5, - "tags": [ - "production", - "critical", - null - ] - }, - "matrix": [ - [ - 1, - 2 - ], - [ - 3.5, - 4.5 - ], - [ - "a", - "b" - ] - ] -})")); -} -} // namespace diff --git a/ffi/tests/cpp/extra/test_serialization.cc b/ffi/tests/cpp/extra/test_serialization.cc deleted file mode 100644 index 9d18e6a03e2d..000000000000 --- a/ffi/tests/cpp/extra/test_serialization.cc +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Serialization, BoolNull) { - json::Object expected_null = - json::Object{{"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "None"}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(nullptr), expected_null)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_null), nullptr)); - - json::Object expected_true = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(true), expected_true)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_true), true)); - - json::Object expected_false = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", false}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(false), expected_false)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_false), false)); -} - -TEST(Serialization, IntegerTypes) { - // Test positive integer - json::Object expected_int = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "int"}, {"data", 42}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(static_cast(42)), expected_int)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int), static_cast(42))); -} - -TEST(Serialization, FloatTypes) { - // Test positive float - json::Object expected_float = - json::Object{{"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "float"}, {"data", 3.14159}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(3.14159), expected_float)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float), 3.14159)); -} - -TEST(Serialization, StringTypes) { - // Test short string - json::Object expected_short = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("hello")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello")), expected_short)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_short), String("hello"))); - - // Test long string - std::string long_str(1000, 'x'); - json::Object expected_long = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String(long_str)}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String(long_str)), expected_long)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_long), String(long_str))); - - // Test string with special characters - json::Object expected_special = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, - {"data", String("hello\nworld\t\"quotes\"")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello\nworld\t\"quotes\"")), expected_special)); - EXPECT_TRUE( - StructuralEqual()(FromJSONGraph(expected_special), String("hello\nworld\t\"quotes\""))); -} - -TEST(Serialization, Bytes) { - // Test empty bytes - Bytes empty_bytes; - json::Object expected_empty = json::Object{ - {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", ""}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_bytes), expected_empty)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_bytes)); - - // Test bytes with that encoded as base64 - Bytes bytes_content = Bytes("abcd"); - json::Object expected_encoded = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "YWJjZA=="}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_content), expected_encoded)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded), bytes_content)); - - // Test bytes with that encoded as base64, that contains control characters via utf-8 - char bytes_v2_content[] = {0x01, 0x02, 0x03, 0x04, 0x01, 0x0b}; - Bytes bytes_v2 = Bytes(bytes_v2_content, sizeof(bytes_v2_content)); - json::Object expected_encoded_v2 = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "AQIDBAEL"}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_v2), expected_encoded_v2)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded_v2), bytes_v2)); -} - -TEST(Serialization, DataTypes) { - // Test int32 dtype - DLDataType int32_dtype; - int32_dtype.code = kDLInt; - int32_dtype.bits = 32; - int32_dtype.lanes = 1; - - json::Object expected_int32 = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("int32")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(int32_dtype), expected_int32)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int32), int32_dtype)); - - // Test float64 dtype - DLDataType float64_dtype; - float64_dtype.code = kDLFloat; - float64_dtype.bits = 64; - float64_dtype.lanes = 1; - - json::Object expected_float64 = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float64")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(float64_dtype), expected_float64)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float64), float64_dtype)); - - // Test vector dtype - DLDataType vector_dtype; - vector_dtype.code = kDLFloat; - vector_dtype.bits = 32; - vector_dtype.lanes = 4; - - json::Object expected_vector = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float32x4")}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(vector_dtype), expected_vector)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_vector), vector_dtype)); -} - -TEST(Serialization, DeviceTypes) { - // Test CPU device - DLDevice cpu_device; - cpu_device.device_type = kDLCPU; - cpu_device.device_id = 0; - - json::Object expected_cpu = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "Device"}, - {"data", json::Array{static_cast(kDLCPU), - static_cast(0)}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(cpu_device), expected_cpu)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_cpu), cpu_device)); - - // Test GPU device - DLDevice gpu_device; - gpu_device.device_type = kDLCUDA; - gpu_device.device_id = 1; - - json::Object expected_gpu = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{ - {"type", "Device"}, {"data", json::Array{static_cast(kDLCUDA), 1}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(gpu_device), expected_gpu)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_gpu), gpu_device)); -} - -TEST(Serialization, Arrays) { - // Test empty array - Array empty_array; - json::Object expected_empty = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_array), expected_empty)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_array)); - - // Test single element array - Array single_array; - single_array.push_back(Any(42)); - json::Object expected_single = - json::Object{{"root_index", 1}, - {"nodes", json::Array{ - json::Object{{"type", "int"}, {"data", static_cast(42)}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{0}}}, - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_array), expected_single)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_array)); - - // Test duplicated element array - Array duplicated_array; - duplicated_array.push_back(42); - duplicated_array.push_back(42); - json::Object expected_duplicated = - json::Object{{"root_index", 1}, - {"nodes", json::Array{ - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 0}}}, - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_array), expected_duplicated)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_array)); - // Test mixed element array, note that 42 and "hello" are duplicated and will - // be indexed as 0 and 1 - Array mixed_array; - mixed_array.push_back(42); - mixed_array.push_back(String("hello")); - mixed_array.push_back(true); - mixed_array.push_back(nullptr); - mixed_array.push_back(42); - mixed_array.push_back(String("hello")); - json::Object expected_mixed = json::Object{ - {"root_index", 4}, - {"nodes", json::Array{ - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.String"}, {"data", String("hello")}}, - json::Object{{"type", "bool"}, {"data", true}}, - json::Object{{"type", "None"}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 1, 2, 3, 0, 1}}}, - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(mixed_array), expected_mixed)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_mixed), mixed_array)); -} - -TEST(Serialization, Maps) { - // Test empty map - Map empty_map; - json::Object expected_empty = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Map"}, {"data", json::Array{}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_map), expected_empty)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_map)); - - // Test single element map - Map single_map{{"key", 42}}; - json::Object expected_single = json::Object{ - {"root_index", 2}, - {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("key")}}, - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_map), expected_single)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_map)); - - // Test duplicated element map - Map duplicated_map{{"b", 42}, {"a", 42}}; - json::Object expected_duplicated = json::Object{ - {"root_index", 3}, - {"nodes", json::Array{ - json::Object{{"type", "ffi.String"}, {"data", "b"}}, - json::Object{{"type", "int"}, {"data", 42}}, - json::Object{{"type", "ffi.String"}, {"data", "a"}}, - json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1, 2, 1}}}, - - }}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_map), expected_duplicated)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_map)); -} - -TEST(Serialization, Shapes) { - Shape empty_shape; - - json::Object expected_empty_shape = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_shape), expected_empty_shape)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty_shape), empty_shape)); - - Shape shape({1, 2, 3}); - json::Object expected_shape = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "ffi.Shape"}, {"data", json::Array{1, 2, 3}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(shape), expected_shape)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shape), shape)); -} - -TEST(Serialization, TestObjectVar) { - TVar x = TVar("x"); - json::Object expected_x = json::Object{ - {"root_index", 1}, - {"nodes", - json::Array{json::Object{{"type", "ffi.String"}, {"data", "x"}}, - json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(x), expected_x)); - EXPECT_TRUE(StructuralEqual::Equal(FromJSONGraph(expected_x), x, /*map_free_vars=*/true)); -} - -TEST(Serialization, TestObjectIntCustomToJSON) { - TInt value = TInt(42); - json::Object expected_i = json::Object{ - {"root_index", 0}, - {"nodes", - json::Array{json::Object{{"type", "test.Int"}, {"data", json::Object{{"value", 42}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value), expected_i)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_i), value)); -} - -TEST(Serialization, TestObjectFunc) { - TVar x = TVar("x"); - // comment fields are ignored - TFunc fa = TFunc({x}, {x, x}, String("comment a")); - - json::Object expected_fa = json::Object{ - {"root_index", 5}, - {"nodes", - json::Array{ - json::Object{{"type", "ffi.String"}, {"data", "x"}}, // string "x" - json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}, // var x - json::Object{{"type", "ffi.Array"}, {"data", json::Array{1}}}, // array [x] - json::Object{{"type", "ffi.Array"}, {"data", json::Array{1, 1}}}, // array [x, x] - json::Object{{"type", "ffi.String"}, {"data", "comment a"}}, // "comment a" - json::Object{{"type", "test.Func"}, - {"data", json::Object{{"params", 2}, {"body", 3}, {"comment", 4}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fa), expected_fa)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fa), fa)); - - TFunc fb = TFunc({}, {}, std::nullopt); - json::Object expected_fb = json::Object{ - {"root_index", 3}, - {"nodes", - json::Array{ - json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, - json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, - json::Object{{"type", "None"}}, - json::Object{{"type", "test.Func"}, - {"data", json::Object{{"params", 0}, {"body", 1}, {"comment", 2}}}}}}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fb), expected_fb)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fb), fb)); -} - -TEST(Serialization, AttachMetadata) { - bool value = true; - json::Object metadata{{"version", "1.0"}}; - json::Object expected = - json::Object{{"root_index", 0}, - {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}, - {"metadata", metadata}}; - EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value, metadata), expected)); - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected), value)); -} - -TEST(Serialization, ShuffleNodeOrder) { - // the FromJSONGraph is agnostic to the node order - // so we can shuffle the node order as it reads nodes lazily - Map duplicated_map{{"b", 42}, {"a", 42}}; - json::Object expected_shuffled = json::Object{ - {"root_index", 0}, - {"nodes", json::Array{ - json::Object{{"type", "ffi.Map"}, {"data", json::Array{2, 3, 1, 3}}}, - json::Object{{"type", "ffi.String"}, {"data", "a"}}, - json::Object{{"type", "ffi.String"}, {"data", "b"}}, - json::Object{{"type", "int"}, {"data", 42}}, - }}}; - EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shuffled), duplicated_map)); -} - -} // namespace diff --git a/ffi/tests/cpp/extra/test_structural_equal_hash.cc b/ffi/tests/cpp/extra/test_structural_equal_hash.cc deleted file mode 100644 index a05c50cc2617..000000000000 --- a/ffi/tests/cpp/extra/test_structural_equal_hash.cc +++ /dev/null @@ -1,178 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; -namespace refl = tvm::ffi::reflection; - -TEST(StructuralEqualHash, Array) { - Array a = {1, 2, 3}; - Array b = {1, 2, 3}; - EXPECT_TRUE(StructuralEqual()(a, b)); - EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); - - Array c = {1, 3}; - EXPECT_FALSE(StructuralEqual()(a, c)); - EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - - // first directly interepret diff, - EXPECT_TRUE(diff_a_c.has_value()); - auto lhs_steps = (*diff_a_c).get<0>()->ToSteps(); - auto rhs_steps = (*diff_a_c).get<1>()->ToSteps(); - EXPECT_EQ(lhs_steps[0]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ(rhs_steps[0]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ(lhs_steps[0]->key.cast(), 1); - EXPECT_EQ(rhs_steps[0]->key.cast(), 1); - EXPECT_EQ(lhs_steps.size(), 1); - EXPECT_EQ(rhs_steps.size(), 1); - - // use structural equal for checking in future parts - // given we have done some basic checks above by directly interepret diff, - Array d = {1, 2}; - auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); - auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath::FromSteps({ - refl::AccessStep::ArrayItem(2), - }), - refl::AccessPath::FromSteps({ - refl::AccessStep::ArrayItemMissing(2), - })); - // then use structural equal to check it - EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d)); -} - -TEST(StructuralEqualHash, Map) { - // same map but different insertion order - Map a = {{"a", 1}, {"b", 2}, {"c", 3}}; - Map b = {{"b", 2}, {"c", 3}, {"a", 1}}; - EXPECT_TRUE(StructuralEqual()(a, b)); - EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); - - Map c = {{"a", 1}, {"b", 2}, {"c", 4}}; - EXPECT_FALSE(StructuralEqual()(a, c)); - EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - - auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath::Root()->MapItem("c"), - refl::AccessPath::Root()->MapItem("c")); - EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); -} - -TEST(StructuralEqualHash, NestedMapArray) { - Map> a = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; - Map> b = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; - EXPECT_TRUE(StructuralEqual()(a, b)); - EXPECT_EQ(StructuralHash()(a), StructuralHash()(b)); - - Map> c = {{"a", {1, 2, 3}}, {"b", {4, "world", 6}}}; - EXPECT_FALSE(StructuralEqual()(a, c)); - EXPECT_NE(StructuralHash()(a), StructuralHash()(c)); - - auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c); - auto expected_diff_a_c = - refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b")->ArrayItem(1), - refl::AccessPath::Root()->MapItem("b")->ArrayItem(1)); - EXPECT_TRUE(diff_a_c.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c)); - - Map> d = {{"a", {1, 2, 3}}}; - auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d); - auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath::Root()->MapItem("b"), - refl::AccessPath::Root()->MapItemMissing("b")); - EXPECT_TRUE(diff_a_d.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d)); - - auto diff_d_a = StructuralEqual::GetFirstMismatch(d, a); - auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath::Root()->MapItemMissing("b"), - refl::AccessPath::Root()->MapItem("b")); -} - -TEST(StructuralEqualHash, FreeVar) { - TVar a = TVar("a"); - TVar b = TVar("b"); - EXPECT_TRUE(StructuralEqual::Equal(a, b, /*map_free_vars=*/true)); - EXPECT_FALSE(StructuralEqual::Equal(a, b)); - - EXPECT_NE(StructuralHash()(a), StructuralHash()(b)); - EXPECT_EQ(StructuralHash::Hash(a, /*map_free_vars=*/true), - StructuralHash::Hash(b, /*map_free_vars=*/true)); -} - -TEST(StructuralEqualHash, FuncDefAndIgnoreField) { - TVar x = TVar("x"); - TVar y = TVar("y"); - // comment fields are ignored - TFunc fa = TFunc({x}, {TInt(1), x}, String("comment a")); - TFunc fb = TFunc({y}, {TInt(1), y}, String("comment b")); - - TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, String("comment c")); - - EXPECT_TRUE(StructuralEqual()(fa, fb)); - EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); - - EXPECT_FALSE(StructuralEqual()(fa, fc)); - auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); - auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath::FromSteps({ - refl::AccessStep::Attr("body"), - refl::AccessStep::ArrayItem(1), - }), - refl::AccessPath::FromSteps({ - refl::AccessStep::Attr("body"), - refl::AccessStep::ArrayItem(1), - })); - EXPECT_TRUE(diff_fa_fc.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); -} - -TEST(StructuralEqualHash, CustomTreeNode) { - TVar x = TVar("x"); - TVar y = TVar("y"); - // comment fields are ignored - TCustomFunc fa = TCustomFunc({x}, {TInt(1), x}, "comment a"); - TCustomFunc fb = TCustomFunc({y}, {TInt(1), y}, "comment b"); - - TCustomFunc fc = TCustomFunc({x}, {TInt(1), TInt(2)}, "comment c"); - - EXPECT_TRUE(StructuralEqual()(fa, fb)); - EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); - - EXPECT_FALSE(StructuralEqual()(fa, fc)); - auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc); - auto expected_diff_fa_fc = - refl::AccessPathPair(refl::AccessPath::Root()->Attr("body")->ArrayItem(1), - refl::AccessPath::Root()->Attr("body")->ArrayItem(1)); - EXPECT_TRUE(diff_fa_fc.has_value()); - EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); -} - -} // namespace diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc deleted file mode 100644 index d1f56e1a93d9..000000000000 --- a/ffi/tests/cpp/test_any.cc +++ /dev/null @@ -1,415 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Any, Int) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `int`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - AnyView view1 = 1; - EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); - - auto int_v1 = view1.cast(); - EXPECT_EQ(int_v1, 1); - - int64_t v1 = 2; - view0 = v1; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2); -} - -TEST(Any, Enum) { - enum class ENum : int { - A = 1, - B = 2, - }; - - AnyView view0; - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - AnyView view1 = ENum::A; - EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); - EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); - - ENum v1 = view1.cast(); - EXPECT_EQ(v1, ENum::A); -} - -TEST(Any, bool) { - AnyView view0; - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `bool`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - AnyView view1 = true; - EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); - EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); - - auto int_v1 = view1.cast(); - EXPECT_EQ(int_v1, 1); - - bool v1 = false; - view0 = v1; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIBool); - EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 0); -} - -TEST(Any, nullptrcmp) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - EXPECT_TRUE(view0 == nullptr); - EXPECT_FALSE(view0 != nullptr); - - view0 = 1; - EXPECT_TRUE(view0 != nullptr); - EXPECT_FALSE(view0 == nullptr); - - Any any0 = view0; - EXPECT_TRUE(any0 != nullptr); - EXPECT_FALSE(any0 == nullptr); - - any0 = nullptr; - EXPECT_TRUE(any0 == nullptr); - EXPECT_FALSE(any0 != nullptr); -} - -TEST(Any, Float) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `float`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - AnyView view1_int = 1; - auto float_v1 = view1_int.cast(); - EXPECT_EQ(float_v1, 1); - - AnyView view2 = 2.2; - EXPECT_EQ(view2.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); - EXPECT_EQ(view2.CopyToTVMFFIAny().v_float64, 2.2); - - float v1 = 2; - view0 = v1; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIFloat); - EXPECT_EQ(view0.CopyToTVMFFIAny().v_float64, 2); -} - -TEST(Any, Device) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `Device`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - DLDevice device{kDLCUDA, 1}; - - AnyView view1_device = device; - auto dtype_v1 = view1_device.cast(); - EXPECT_EQ(dtype_v1.device_type, kDLCUDA); - EXPECT_EQ(dtype_v1.device_id, 1); - - Any any2 = DLDevice{kDLCPU, 0}; - TVMFFIAny ffi_v2 = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(any2)); - EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDevice); - EXPECT_EQ(ffi_v2.v_device.device_type, kDLCPU); - EXPECT_EQ(ffi_v2.v_device.device_id, 0); -} - -TEST(Any, DLTensor) { - AnyView view0; - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `DLTensor*`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - DLTensor dltensor; - - AnyView view1_dl = &dltensor; - auto dl_v1 = view1_dl.cast(); - EXPECT_EQ(dl_v1, &dltensor); -} - -TEST(Any, Object) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - // int object is not nullable - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - TInt v1(11); - EXPECT_EQ(v1.use_count(), 1); - // view won't increase refcount - AnyView view1 = v1; - EXPECT_EQ(v1.use_count(), 1); - // any will trigger ref count increase - Any any1 = v1; - EXPECT_EQ(v1.use_count(), 2); - // copy to another view - AnyView view2 = any1; - EXPECT_EQ(v1.use_count(), 2); - - // convert to weak raw object ptr - const TIntObj* v1_ptr = view2.cast(); - EXPECT_EQ(v1.use_count(), 2); - EXPECT_EQ(v1_ptr->value, 11); - Any any2 = v1_ptr; - EXPECT_EQ(v1.use_count(), 3); - EXPECT_TRUE(any2.as().has_value()); - - any2 = const_cast(v1_ptr); - EXPECT_TRUE(any2.as().has_value()); - - // convert to raw opaque ptr - void* raw_v1_ptr = const_cast(v1_ptr); - any2 = raw_v1_ptr; - EXPECT_TRUE(any2.as().value() == v1_ptr); - - // convert to ObjectRef - { - auto v1_obj_ref = view2.cast(); - EXPECT_EQ(v1.use_count(), 3); - any2 = v1_obj_ref; - EXPECT_EQ(v1.use_count(), 4); - EXPECT_TRUE(any2.as().has_value()); - any2.reset(); - } - - // convert that triggers error - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view1.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - std::cout << what; - EXPECT_NE(what.find("Cannot convert from type `test.Int` to `test.Float`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - // Try to convert to number - auto number0 = any1.cast(); - EXPECT_EQ(v1.use_count(), 3); - EXPECT_TRUE(number0.as()); - EXPECT_EQ(number0.as()->value, 11); - EXPECT_TRUE(!any1.as().has_value()); - - auto int1 = view2.cast(); - EXPECT_EQ(v1.use_count(), 4); - any1.reset(); - EXPECT_EQ(v1.use_count(), 3); -} - -TEST(Any, ObjectRefWithFallbackTraits) { - // Test case for TPrimExpr fallback from Any - Any any1 = TPrimExpr("float32", 3.14); - auto v0 = any1.cast(); - EXPECT_EQ(v0->value, 3.14); - EXPECT_EQ(v0->dtype, "float32"); - - any1 = true; - auto v1 = any1.cast(); - EXPECT_EQ(v1->value, 1); - EXPECT_EQ(v1->dtype, "bool"); - - any1 = int64_t(42); - auto v2 = any1.cast(); - EXPECT_EQ(v2->value, 42); - EXPECT_EQ(v2->dtype, "int64"); - - any1 = 2.718; - auto v3 = any1.cast(); - EXPECT_EQ(v3->value, 2.718); - EXPECT_EQ(v3->dtype, "float32"); - - // Test case for TPrimExpr fallback from AnyView - TPrimExpr texpr1("float32", 3.14); - AnyView view1 = texpr1; - auto v4 = view1.cast(); - EXPECT_EQ(v4->value, 3.14); - EXPECT_EQ(v4->dtype, "float32"); - - view1 = true; - auto v5 = view1.cast(); - EXPECT_EQ(v5->value, 1); - EXPECT_EQ(v5->dtype, "bool"); - - view1 = int64_t(42); - auto v6 = view1.cast(); - EXPECT_EQ(v6->value, 42); - EXPECT_EQ(v6->dtype, "int64"); - - view1 = 2.718; - auto v7 = view1.cast(); - EXPECT_EQ(v7->value, 2.718); - EXPECT_EQ(v7->dtype, "float32"); - - // Test case for TPrimExpr fallback from Any with String - any1 = std::string("test_string"); - auto v8 = any1.cast(); - EXPECT_EQ(v8->dtype, "test_string"); - EXPECT_EQ(v8->value, 0); - - // Test case for TPrimExpr fallback from AnyView with String - view1 = "test_string"; - auto v9 = view1.cast(); - EXPECT_EQ(v9->dtype, "test_string"); - EXPECT_EQ(v9->value, 0); -} - -TEST(Any, CastVsAs) { - AnyView view0 = 1; - // as only runs strict check - auto opt_v0 = view0.as(); - EXPECT_TRUE(opt_v0.has_value()); - EXPECT_EQ(opt_v0.value(), 1); - - auto opt_v1 = view0.as(); - EXPECT_TRUE(!opt_v1.has_value()); - auto opt_v2 = view0.as(); - EXPECT_TRUE(!opt_v2.has_value()); - - // try_cast will try run the conversion. - auto opt_v3 = view0.try_cast(); - EXPECT_TRUE(opt_v3.has_value()); - EXPECT_EQ(opt_v3.value(), 1); - auto opt_v4 = view0.try_cast(); - EXPECT_TRUE(opt_v4.has_value()); - EXPECT_EQ(opt_v4.value(), 1); - - Any any1 = true; - auto opt_v5 = any1.as(); - EXPECT_TRUE(opt_v5.has_value()); - EXPECT_EQ(opt_v5.value(), 1); - - auto opt_v6 = any1.try_cast(); - EXPECT_TRUE(opt_v6.has_value()); - EXPECT_EQ(opt_v6.value(), 1); - - auto opt_v7 = any1.try_cast(); - EXPECT_TRUE(opt_v7.has_value()); -} - -TEST(Any, ObjectMove) { - Any any1 = TPrimExpr("float32", 3.14); - auto v0 = std::move(any1).cast(); - EXPECT_EQ(v0->value, 3.14); - EXPECT_EQ(v0.use_count(), 1); - EXPECT_TRUE(any1 == nullptr); -} - -TEST(Any, AnyEqualHash) { - // small string - Any a = "a1"; - // on heap allocated string - Any b = String(std::string("a1")); - EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr); - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_TRUE(AnyEqual()(a, b)); - EXPECT_EQ(AnyHash()(a), AnyHash()(b)); - - Any c = Bytes("a1", 2); - Any d = Bytes(std::string("a1")); - EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFISmallBytes); - EXPECT_EQ(d.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_TRUE(AnyEqual()(c, d)); - EXPECT_EQ(AnyHash()(c), AnyHash()(d)); -} - -} // namespace diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc deleted file mode 100644 index 321af7ae16ac..000000000000 --- a/ffi/tests/cpp/test_array.cc +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Array, Basic) { - Array arr = {TInt(11), TInt(12)}; - TInt v1 = arr[0]; - EXPECT_EQ(v1->value, 11); - EXPECT_EQ(v1.use_count(), 2); - EXPECT_EQ(arr[1]->value, 12); -} - -TEST(Array, COWSet) { - Array arr = {TInt(11), TInt(12)}; - Array arr2 = arr; - EXPECT_EQ(arr.use_count(), 2); - arr.Set(1, TInt(13)); - EXPECT_EQ(arr.use_count(), 1); - EXPECT_EQ(arr[1]->value, 13); - EXPECT_EQ(arr2[1]->value, 12); -} - -TEST(Array, MutateInPlaceForUniqueReference) { - TInt x(1); - Array arr{x, x}; - EXPECT_TRUE(arr.unique()); - auto* before = arr.get(); - - arr.MutateByApply([](TInt) { return TInt(2); }); - auto* after = arr.get(); - EXPECT_EQ(before, after); -} - -TEST(Array, CopyWhenMutatingNonUniqueReference) { - TInt x(1); - Array arr{x, x}; - Array arr2 = arr; - - EXPECT_TRUE(!arr.unique()); - auto* before = arr.get(); - - arr.MutateByApply([](TInt) { return TInt(2); }); - auto* after = arr.get(); - EXPECT_NE(before, after); -} - -TEST(Array, Map) { - // Basic functionality - TInt x(1), y(1); - Array var_arr{x, y}; - Array expr_arr = - var_arr.Map([](TInt var) -> TNumber { return TFloat(static_cast(var->value + 1)); }); - - EXPECT_NE(var_arr.get(), expr_arr.get()); - EXPECT_TRUE(expr_arr[0]->IsInstance()); - EXPECT_TRUE(expr_arr[1]->IsInstance()); -} - -TEST(Array, Iterator) { - Array array{1, 2, 3}; - std::vector vector(array.begin(), array.end()); - EXPECT_EQ(vector[1], 2); -} - -TEST(Array, PushPop) { - Array a; - std::vector b; - for (int i = 0; i < 10; ++i) { - a.push_back(i); - b.push_back(i); - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), b.size()); - int n = static_cast(a.size()); - for (int j = 0; j < n; ++j) { - ASSERT_EQ(a[j], b[j]); - } - } - for (int i = 9; i >= 0; --i) { - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), b.size()); - a.pop_back(); - b.pop_back(); - int n = static_cast(a.size()); - for (int j = 0; j < n; ++j) { - ASSERT_EQ(a[j], b[j]); - } - } - ASSERT_EQ(a.empty(), true); -} - -TEST(Array, ResizeReserveClear) { - for (size_t n = 0; n < 10; ++n) { - Array a; - Array b; - a.resize(n); - b.reserve(n); - ASSERT_EQ(a.size(), n); - ASSERT_GE(a.capacity(), n); - a.clear(); - b.clear(); - ASSERT_EQ(a.size(), 0); - ASSERT_EQ(b.size(), 0); - } -} - -TEST(Array, InsertErase) { - Array a; - std::vector b; - for (int n = 1; n <= 10; ++n) { - a.insert(a.end(), n); - b.insert(b.end(), n); - for (int pos = 0; pos <= n; ++pos) { - a.insert(a.begin() + pos, pos); - b.insert(b.begin() + pos, pos); - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n + 1); - ASSERT_EQ(b.size(), n + 1); - for (int k = 0; k <= n; ++k) { - ASSERT_EQ(a[k], b[k]); - } - a.erase(a.begin() + pos); - b.erase(b.begin() + pos); - } - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n); - } -} - -TEST(Array, InsertEraseRange) { - Array range_a{-1, -2, -3, -4}; - std::vector range_b{-1, -2, -3, -4}; - Array a; - std::vector b; - - static_assert(std::is_same_v); - for (size_t n = 1; n <= 10; ++n) { - a.insert(a.end(), static_cast(n)); - b.insert(b.end(), static_cast(n)); - for (size_t pos = 0; pos <= n; ++pos) { - a.insert(a.begin() + pos, range_a.begin(), range_a.end()); - b.insert(b.begin() + pos, range_b.begin(), range_b.end()); - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n + range_a.size()); - ASSERT_EQ(b.size(), n + range_b.size()); - size_t m = n + range_a.size(); - for (size_t k = 0; k < m; ++k) { - ASSERT_EQ(a[k], b[k]); - } - a.erase(a.begin() + pos, a.begin() + pos + range_a.size()); - b.erase(b.begin() + pos, b.begin() + pos + range_b.size()); - } - ASSERT_EQ(a.front(), b.front()); - ASSERT_EQ(a.back(), b.back()); - ASSERT_EQ(a.size(), n); - } -} - -TEST(Array, FuncArrayAnyArg) { - Function fadd_one = Function::FromTyped([](Array a) -> Any { return a[0].cast() + 1; }); - EXPECT_EQ(fadd_one(Array{1}).cast(), 2); -} - -TEST(Array, MapUniquePropogation) { - // Basic functionality - Array var_arr{TInt(1), TInt(2)}; - var_arr.MutateByApply([](TInt x) -> TInt { - EXPECT_TRUE(x.unique()); - return x; - }); -} - -TEST(Array, AnyImplicitConversion) { - Array arr0_mixed = {11.1, 1}; - EXPECT_EQ(arr0_mixed[1].cast(), 1); - - AnyView view0 = arr0_mixed; - auto arr0_float = view0.cast>(); - // they are not the same because arr_mixed - // stores arr_mixed[1] as int but we need to convert to float - EXPECT_TRUE(!arr0_float.same_as(arr0_mixed)); - EXPECT_EQ(arr0_float[1], 1.0); - - Any any1 = arr0_float; - // if storage check passes, the same array get returned - auto arr1_float = any1.cast>(); - EXPECT_TRUE(arr1_float.same_as(arr0_float)); - // total count equals 3 include any1 - EXPECT_EQ(arr1_float.use_count(), 3); - - // convert to Array do not need any conversion - auto arr1_mixed = any1.cast>(); - EXPECT_TRUE(arr1_mixed.same_as(arr1_float)); - EXPECT_EQ(arr1_float.use_count(), 4); -} - -TEST(Array, AnyConvertCheck) { - Array arr = {11.1, 1}; - EXPECT_EQ(arr[1].cast(), 1); - - AnyView view0 = arr; - auto arr1 = view0.cast>(); - EXPECT_EQ(arr1[0], 11.1); - EXPECT_EQ(arr1[1], 1.0); - - Any any1 = arr; - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast>(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `Array[index 0: float]` to `Array`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - Array> arr_nested = {{}, {TInt(1), TFloat(2)}}; - any1 = arr_nested; - auto arr1_nested = any1.cast>>(); - EXPECT_EQ(arr1_nested.use_count(), 3); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast>>(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("`Array[index 1: Array[index 0: test.Int]]` to `Array>`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Array, Upcast) { - Array a0 = {1, 2, 3}; - Array a1 = a0; - EXPECT_EQ(a1[0].cast(), 1); - EXPECT_EQ(a1[1].cast(), 2); - EXPECT_EQ(a1[2].cast(), 3); - - Array> a2 = {a0}; - Array> a3 = a2; - Array> a4 = a2; - - static_assert(details::type_contains_v, Array>); - static_assert(details::type_contains_v>); -} - -} // namespace diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc deleted file mode 100644 index 79fc9d7c2da1..000000000000 --- a/ffi/tests/cpp/test_dtype.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -TEST(DType, StringConversion) { - DLDataType dtype = DLDataType{kDLFloat, 32, 1}; - EXPECT_EQ(DLDataTypeToString(dtype), "float32"); - EXPECT_EQ(StringToDLDataType("float32"), dtype); - - dtype = DLDataType{kDLInt, 16, 2}; - EXPECT_EQ(DLDataTypeToString(dtype), "int16x2"); - EXPECT_EQ(StringToDLDataType("int16x2"), dtype); - - dtype = DLDataType{kDLOpaqueHandle, 0, 0}; - EXPECT_EQ(DLDataTypeToString(dtype), ""); - EXPECT_EQ(StringToDLDataType("void"), dtype); - - // test bfloat with lanes - dtype = DLDataType{kDLBfloat, 16, 2}; - EXPECT_EQ(DLDataTypeToString(dtype), "bfloat16x2"); - EXPECT_EQ(StringToDLDataType("bfloat16x2"), dtype); - - // test float8 - dtype = DLDataType{kDLFloat8_e4m3fn, 8, 2}; - EXPECT_EQ(DLDataTypeToString(dtype), "float8_e4m3fnx2"); - EXPECT_EQ(StringToDLDataType("float8_e4m3fnx2"), dtype); -} - -TEST(DType, StringConversionAllDLPackTypes) { - std::vector> test_cases = { - {DLDataType{kDLFloat, 32, 1}, "float32"}, - {DLDataType{kDLInt, 16, 1}, "int16"}, - {DLDataType{kDLUInt, 16, 1}, "uint16"}, - {DLDataType{kDLBfloat, 16, 1}, "bfloat16"}, - {DLDataType{kDLFloat8_e3m4, 8, 1}, "float8_e3m4"}, - {DLDataType{kDLFloat8_e4m3, 8, 1}, "float8_e4m3"}, - {DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}, "float8_e4m3b11fnuz"}, - {DLDataType{kDLFloat8_e4m3fn, 8, 1}, "float8_e4m3fn"}, - {DLDataType{kDLFloat8_e4m3fnuz, 8, 1}, "float8_e4m3fnuz"}, - {DLDataType{kDLFloat8_e5m2, 8, 1}, "float8_e5m2"}, - {DLDataType{kDLFloat8_e5m2fnuz, 8, 1}, "float8_e5m2fnuz"}, - {DLDataType{kDLFloat8_e8m0fnu, 8, 1}, "float8_e8m0fnu"}, - {DLDataType{kDLFloat6_e2m3fn, 6, 1}, "float6_e2m3fn"}, - {DLDataType{kDLFloat6_e3m2fn, 6, 1}, "float6_e3m2fn"}, - {DLDataType{kDLFloat4_e2m1fn, 4, 1}, "float4_e2m1fn"}, - }; - - for (const auto& [dtype, str] : test_cases) { - EXPECT_EQ(DLDataTypeToString(dtype), str); - EXPECT_EQ(StringToDLDataType(str), dtype); - } -} - -TEST(DataType, AnyConversion) { - AnyView view0; - EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone); - - Optional opt_v0 = view0.as(); - EXPECT_TRUE(!opt_v0.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto v0 = view0.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `None` to `DataType`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - DLDataType dtype{kDLFloat, 32, 1}; - - AnyView view1_dtype = dtype; - auto dtype_v1 = view1_dtype.cast(); - EXPECT_EQ(dtype_v1.code, kDLFloat); - EXPECT_EQ(dtype_v1.bits, 32); - EXPECT_EQ(dtype_v1.lanes, 1); - - Any any2 = DLDataType{kDLInt, 16, 2}; - TVMFFIAny ffi_v2 = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(any2)); - EXPECT_EQ(ffi_v2.type_index, TypeIndex::kTVMFFIDataType); - EXPECT_EQ(ffi_v2.v_dtype.code, kDLInt); - EXPECT_EQ(ffi_v2.v_dtype.bits, 16); - EXPECT_EQ(ffi_v2.v_dtype.lanes, 2); -} - -// String can be automatically converted to DLDataType -TEST(DataType, AnyConversionWithString) { - AnyView view0 = "float32"; - - Optional opt_v0 = view0.try_cast(); - DLDataType dtype_v0 = opt_v0.value(); - EXPECT_EQ(dtype_v0.code, kDLFloat); - EXPECT_EQ(dtype_v0.bits, 32); - EXPECT_EQ(dtype_v0.lanes, 1); - - Any any = String("bfloat16x2"); - Optional opt_v1 = any.try_cast(); - EXPECT_EQ(opt_v1.value().code, kDLBfloat); - EXPECT_EQ(opt_v1.value().bits, 16); - EXPECT_EQ(opt_v1.value().lanes, 2); -} -} // namespace diff --git a/ffi/tests/cpp/test_error.cc b/ffi/tests/cpp/test_error.cc deleted file mode 100644 index 9938603a47ba..000000000000 --- a/ffi/tests/cpp/test_error.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -void ThrowRuntimeError() { TVM_FFI_THROW(RuntimeError) << "test0"; } - -TEST(Error, Traceback) { - EXPECT_THROW( - { - try { - ThrowRuntimeError(); - } catch (const Error& error) { - EXPECT_EQ(error.message(), "test0"); - EXPECT_EQ(error.kind(), "RuntimeError"); - std::string what = error.what(); - EXPECT_NE(what.find("line"), std::string::npos); - EXPECT_NE(what.find("ThrowRuntimeError"), std::string::npos); - EXPECT_NE(what.find("RuntimeError: test0"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(CheckError, Traceback) { - EXPECT_THROW( - { - try { - TVM_FFI_ICHECK_GT(2, 3); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "InternalError"); - std::string what = error.what(); - EXPECT_NE(what.find("line"), std::string::npos); - EXPECT_NE(what.find("2 > 3"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Error, AnyConvert) { - Any any = Error("TypeError", "here", "test0"); - Optional opt_err = any.as(); - EXPECT_EQ(opt_err.value().kind(), "TypeError"); - EXPECT_EQ(opt_err.value().message(), "here"); -} -} // namespace diff --git a/ffi/tests/cpp/test_function.cc b/ffi/tests/cpp/test_function.cc deleted file mode 100644 index c3c484f33317..000000000000 --- a/ffi/tests/cpp/test_function.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Func, FromPacked) { - Function fadd1 = Function::FromPacked([](const AnyView* args, int32_t num_args, Any* rv) { - EXPECT_EQ(num_args, 1); - int32_t a = args[0].cast(); - *rv = a + 1; - }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - Function fadd2 = Function::FromPacked([](const AnyView* args, int32_t num_args, Any* rv) { - EXPECT_EQ(num_args, 1); - auto a = args[0].cast(); - EXPECT_EQ(a.use_count(), 2); - *rv = a->value + 1; - }); - EXPECT_EQ(fadd2(TInt(12)).cast(), 13); -} - -TEST(Func, PackedArgs) { - Function fadd1 = Function::FromPacked([](PackedArgs args, Any* rv) { - EXPECT_EQ(args.size(), 1); - int32_t a = args[0].cast(); - *rv = a + 1; - }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - Function fadd2 = Function::FromPacked([](PackedArgs args, Any* rv) { - EXPECT_EQ(args.size(), 1); - TInt a = args[0].cast(); - EXPECT_EQ(a.use_count(), 2); - *rv = a->value + 1; - }); - EXPECT_EQ(fadd2(TInt(12)).cast(), 13); - - TInt v(12); - AnyView data[3]; - PackedArgs::Fill(data, 3, 1, v); - EXPECT_EQ(data[0].cast(), 3); - EXPECT_EQ(data[1].cast(), 1); - EXPECT_EQ(data[2].cast()->value, 12); -} - -TEST(Func, FromTyped) { - // try decution - Function fadd1 = Function::FromTyped([](const int32_t& a) -> int { return a + 1; }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(1.1); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: int) -> int`. " - "Expected `int` but got `float`"); - throw; - } - }, - ::tvm::ffi::Error); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched number of arguments when calling: `(0: int) -> int`. " - "Expected 1 but got 0 arguments"); - throw; - } - }, - ::tvm::ffi::Error); - - // try decution - Function fpass_and_return = Function::FromTyped( - [](TInt x, int value, AnyView z) -> Function { - EXPECT_EQ(x.use_count(), 2); - EXPECT_EQ(x->value, value); - if (auto opt = z.as()) { - EXPECT_EQ(value, *opt); - } - return Function::FromTyped([value](int x) -> int { return x + value; }); - }, - "fpass_and_return"); - TInt a(11); - auto fret = fpass_and_return(std::move(a), 11, 11).cast(); - EXPECT_EQ(fret(12).cast(), 23); - - EXPECT_THROW( - { - try { - fpass_and_return(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched number of arguments when calling: " - "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> ffi.Function`. " - "Expected 3 but got 0 arguments"); - throw; - } - }, - ::tvm::ffi::Error); - - Function fconcact = - Function::FromTyped([](const String& a, const String& b) -> String { return a + b; }); - EXPECT_EQ(fconcact("abc", "def").cast(), "abcdef"); -} - -TEST(Func, PassReturnAny) { - Function fadd_one = Function::FromTyped([](Any a) -> Any { return a.cast() + 1; }); - EXPECT_EQ(fadd_one(1).cast(), 2); -} - -TEST(Func, Global) { - Function::SetGlobal("testing.add1", - Function::FromTyped([](const int32_t& a) -> int { return a + 1; })); - auto fadd1 = Function::GetGlobalRequired("testing.add1"); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - auto fnot_exist = Function::GetGlobal("testing.not_existing_func"); - EXPECT_TRUE(!fnot_exist); - - auto fname_functor = - Function::GetGlobal("ffi.FunctionListGlobalNamesFunctor").value()().cast(); - Array names; - int len = fname_functor(-1).cast(); - for (int i = 0; i < len; ++i) { - names.push_back(fname_functor(i).cast()); - } - EXPECT_TRUE(std::find(names.begin(), names.end(), "testing.add1") != names.end()); -} - -TEST(Func, TypedFunction) { - TypedFunction fadd1 = [](int a) -> int { return a + 1; }; - EXPECT_EQ(fadd1(1), 2); - - TypedFunction fadd2([](int a) -> int { return a + 2; }); - EXPECT_EQ(fadd2(1), 3); - EXPECT_EQ(fadd2.packed()(1).cast(), 3); - - TypedFunction fcheck_int; - EXPECT_TRUE(fcheck_int == nullptr); - fcheck_int = [](int a) -> void { EXPECT_EQ(a, 1); }; - fcheck_int(1); -} - -TEST(Func, TypedFunctionAsAny) { - TypedFunction fadd1 = [](int a) -> int { return a + 1; }; - Any fany(std::move(fadd1)); - EXPECT_TRUE(fadd1 == nullptr); - auto fadd1_dup = fany.cast>(); - EXPECT_EQ(fadd1_dup(1), 2); -} - -TEST(Func, TypedFunctionAsAnyView) { - TypedFunction fadd2 = [](int a) -> int { return a + 2; }; - AnyView fview(fadd2); - auto fadd2_dup = fview.cast>(); - EXPECT_EQ(fadd2_dup(1), 3); -} - -TEST(Func, ObjectRefWithFallbackTraits) { - // test cases to test automatic type conversion via ObjectRefWithFallbackTraits - // through TPrimExpr - Function freturn_primexpr = Function::FromTyped([](TPrimExpr a) -> TPrimExpr { return a; }); - - auto result_int = freturn_primexpr(1).cast(); - EXPECT_EQ(result_int->dtype, "int64"); - EXPECT_EQ(result_int->value, 1); - - // Test case for float - auto result_float = freturn_primexpr(2.5).cast(); - EXPECT_EQ(result_float->dtype, "float32"); - EXPECT_EQ(result_float->value, 2.5); - - // Test case for bool - auto result_bool = freturn_primexpr(true).cast(); - EXPECT_EQ(result_bool->dtype, "bool"); - EXPECT_EQ(result_bool->value, 1); - - // Test case for string - auto result_string = freturn_primexpr("test_string").cast(); - EXPECT_EQ(result_string->dtype, "test_string"); - EXPECT_EQ(result_string->value, 0); - - EXPECT_THROW( - { - try { - freturn_primexpr(TInt(1)); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ( - error.message(), - "Mismatched type on argument #0 when calling: `(0: test.PrimExpr) -> test.PrimExpr`. " - "Expected `test.PrimExpr` but got `test.Int`"); - throw; - } - }, - ::tvm::ffi::Error); -} - -} // namespace diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc deleted file mode 100644 index 98d8427c23a1..000000000000 --- a/ffi/tests/cpp/test_map.cc +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Map, Basic) { - Map map0; - TInt k0(0); - map0.Set(k0, 1); - - EXPECT_EQ(map0.size(), 1); - - map0.Set(k0, 2); - EXPECT_EQ(map0.size(), 1); - - auto it = map0.find(k0); - EXPECT_TRUE(it != map0.end()); - EXPECT_EQ((*it).second, 2); -} - -TEST(Map, PODKey) { - Map map0; - - // int as key - map0.Set(1, 2); - // float key is different - map0.Set(1.1, 3); - EXPECT_EQ(map0.size(), 2); - - auto it = map0.find(1.1); - EXPECT_TRUE(it != map0.end()); - EXPECT_EQ((*it).second.cast(), 3); -} - -TEST(Map, Object) { - TInt x(1); - TInt z(100); - TInt zz(1000); - Map dict{{x, z}, {z, zz}}; - EXPECT_EQ(dict.size(), 2); - EXPECT_TRUE(dict[x].same_as(z)); - EXPECT_TRUE(dict.count(z)); - EXPECT_TRUE(!dict.count(zz)); -} - -TEST(Map, Str) { - TInt x(1); - TInt z(100); - Map dict{{"x", z}, {"z", z}}; - EXPECT_EQ(dict.size(), 2); - EXPECT_TRUE(dict["x"].same_as(z)); -} - -TEST(Map, Mutate) { - TInt x(1); - TInt z(100); - TInt zz(1000); - Map dict{{x, z}, {z, zz}}; - - EXPECT_TRUE(dict[x].same_as(z)); - dict.Set(x, zz); - auto dict2 = dict; - EXPECT_EQ(dict2.count(z), 1); - dict.Set(zz, x); - EXPECT_EQ(dict2.count(zz), 0); - EXPECT_EQ(dict.count(zz), 1); - - auto it = dict.find(zz); - EXPECT_TRUE(it != dict.end() && (*it).second.same_as(x)); - - it = dict2.find(zz); - EXPECT_TRUE(it == dict2.end()); -} - -TEST(Map, Clear) { - TInt x(1); - TInt z(100); - Map dict{{x, z}, {z, z}}; - EXPECT_EQ(dict.size(), 2); - dict.clear(); - EXPECT_EQ(dict.size(), 0); -} - -TEST(Map, Insert) { - auto check = [](const Map& result, - std::unordered_map expected) { - EXPECT_EQ(result.size(), expected.size()); - for (const auto& kv : result) { - EXPECT_TRUE(expected.count(kv.first)); - EXPECT_EQ(expected[kv.first], kv.second); - expected.erase(kv.first); - } - }; - Map result; - std::unordered_map expected; - char key = 'a'; - int64_t val = 1; - for (int i = 0; i < 26; ++i, ++key, ++val) { - std::string s(1, key); - result.Set(s, val); - expected[s] = val; - check(result, expected); - } -} - -TEST(Map, Erase) { - auto check = [](const Map& result, - std::unordered_map expected) { - EXPECT_EQ(result.size(), expected.size()); - for (const auto& kv : result) { - EXPECT_TRUE(expected.count(kv.first)); - EXPECT_EQ(expected[kv.first], kv.second); - expected.erase(kv.first); - } - }; - Map map{{"a", 1}, {"b", 2}, {"c", 3}, {"d", 4}, {"e", 5}}; - std::unordered_map stl; - std::transform(map.begin(), map.end(), std::inserter(stl, stl.begin()), - [](auto&& p) { return std::make_pair(p.first, p.second); }); - for (char c = 'a'; c <= 'e'; ++c) { - Map result = map; - std::unordered_map expected(stl); - std::string key(1, c); - result.erase(key); - expected.erase(key); - check(result, expected); - } -} - -TEST(Map, AnyImplicitConversion) { - Map map0; - map0.Set(1, 2); - map0.Set(2, 3.1); - EXPECT_EQ(map0.size(), 2); - - // check will trigger copy - AnyView view0 = map0; - auto map1 = view0.cast>(); - EXPECT_TRUE(!map1.same_as(map0)); - EXPECT_EQ(map1[1], 2); - EXPECT_EQ(map1[2], 3.1); - EXPECT_EQ(map1.use_count(), 1); - - auto map2 = view0.cast>(); - EXPECT_TRUE(map2.same_as(map0)); - EXPECT_EQ(map2.use_count(), 2); - - auto map3 = view0.cast>(); - EXPECT_TRUE(!map3.same_as(map0)); - EXPECT_EQ(map3.use_count(), 1); - - Map map4{{"yes", 1.1}, {"no", 2.2}}; - Any any1 = map4; - - auto map5 = any1.cast>(); - EXPECT_TRUE(map5.same_as(map4)); - EXPECT_EQ(map5.use_count(), 3); - - auto map6 = any1.cast>(); - EXPECT_TRUE(map6.same_as(map4)); - EXPECT_EQ(map6.use_count(), 4); - - EXPECT_EQ(map6["yes"].cast(), 1.1); - EXPECT_EQ(map6["no"].cast(), 2.2); - - auto map7 = any1.cast>(); - EXPECT_TRUE(map7.same_as(map4)); - EXPECT_EQ(map7.use_count(), 5); - - auto map8 = any1.cast>(); - EXPECT_TRUE(!map8.same_as(map4)); - EXPECT_EQ(map8.use_count(), 1); - EXPECT_EQ(map8["yes"]->value, 1.1); - EXPECT_EQ(map8["no"]->value, 2.2); -} - -TEST(Map, AnyConvertCheck) { - Map map = {{11, 1.1}}; - EXPECT_EQ(map[11].cast(), 1.1); - - AnyView view0 = map; - auto arr1 = view0.cast>(); - EXPECT_EQ(arr1[11], 1.1); - - Any any1 = map; - using WrongMap = Map; - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE( - what.find( - "Cannot convert from type `Map[K, some value is float]` to `Map`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); - - using WrongMap2 = Map; - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = any1.cast(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - EXPECT_NE(what.find("Cannot convert from type `Map[some key is int, V]` to " - "`Map`"), - std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Map, FunctionGetItem) { - Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, - "map_get_item"); - Map map{{"x", 1}, {"y", 2}}; - Any k("x"); - Any v = f(map, k); - EXPECT_EQ(v.cast(), 1); -} - -TEST(Map, Upcast) { - Map m0 = {{1, 2}, {3, 4}}; - Map m1 = m0; - EXPECT_EQ(m1[1].cast(), 2); - EXPECT_EQ(m1[3].cast(), 4); - static_assert(details::type_contains_v, Map>); - - Map> m2 = {{"x", {1}}, {"y", {2}}}; - Map> m3 = m2; -} - -template -void PrintMap(const Map& m0) { - std::cout << "{"; - for (auto it = m0.begin(); it != m0.end(); ++it) { - if (it != m0.begin()) { - std::cout << ", "; - } - std::cout << (*it).first << ": " << (*it).second; - } - std::cout << "}" << std::endl; -} - -TEST(Map, MapInsertOrder) { - // test that map preserves the insertion order - auto get_reverse_order = [](size_t size) { - std::vector reverse_order; - for (int i = static_cast(size); i != 0; --i) { - reverse_order.push_back(i - 1); - } - return reverse_order; - }; - - auto check_map = [&](Map m0, size_t size, const std::vector& order) { - auto lhs = m0.begin(); - auto rhs = order.begin(); - while (lhs != m0.end()) { - TVM_FFI_ICHECK_EQ((*lhs).first, "hello" + std::to_string(*rhs)); - TVM_FFI_ICHECK_EQ((*lhs).second, *rhs); - ++lhs; - ++rhs; - } - lhs = m0.end(); - rhs = order.begin() + size; - do { - --lhs; - --rhs; - TVM_FFI_ICHECK_EQ((*lhs).first, "hello" + std::to_string(*rhs)); - TVM_FFI_ICHECK_EQ((*lhs).second, *rhs); - } while (lhs != m0.begin()); - }; - - auto check_order = [&](std::vector order) { - Map m0; - for (size_t i = 0; i < order.size(); ++i) { - m0.Set("hello" + std::to_string(order[i]), order[i]); - check_map(m0, i + 1, order); - } - check_map(m0, order.size(), order); - // erase a few items - m0.erase("hello" + std::to_string(order[0])); - auto item0 = order[0]; - order.erase(order.begin()); - check_map(m0, order.size(), order); - // erase the middle part - if (order.size() > 1) { - m0.erase("hello" + std::to_string(order[1])); - order.erase(order.begin() + 1); - check_map(m0, order.size(), order); - } - // erase the end - m0.erase("hello" + std::to_string(order.back())); - auto item2 = order.back(); - order.erase(order.end() - 1); - check_map(m0, order.size(), order); - EXPECT_NE(m0.size(), 0); - // put back some items - order.push_back(item2); - m0.Set("hello" + std::to_string(item2), item2); - check_map(m0, order.size(), order); - order.push_back(item0); - m0.Set("hello" + std::to_string(item0), item0); - check_map(m0, order.size(), order); - }; - // test with 17 items: DenseMapObj - check_order(get_reverse_order(17)); - // test with 4 items: SmallMapObj - check_order(get_reverse_order(4)); -} - -TEST(Map, EmptyIter) { - Map m0; - EXPECT_EQ(m0.begin(), m0.end()); - // create a big map and then erase to keep a dense map empty - for (int i = 0; i < 10; ++i) { - m0.Set("hello" + std::to_string(i), i); - } - for (int i = 0; i < 10; ++i) { - m0.erase("hello" + std::to_string(i)); - } - EXPECT_EQ(m0.size(), 0); - // now m0 is dense map with all empty slots - EXPECT_EQ(m0.begin(), m0.end()); -} - -TEST(Map, DuplicatedKeysInit) { - std::vector> data = {{"a", 1}, {"a", 2}, {"a", 3}}; - Map map(data.begin(), data.end()); - EXPECT_EQ(map.size(), 1); - EXPECT_EQ(map["a"], 3); -} -} // namespace diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc deleted file mode 100644 index 3d7b00cd33c3..000000000000 --- a/ffi/tests/cpp/test_ndarray.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -namespace { - -using namespace tvm::ffi; - -struct CPUNDAlloc { - void AllocData(DLTensor* tensor) { tensor->data = malloc(GetDataSize(*tensor)); } - void FreeData(DLTensor* tensor) { free(tensor->data); } -}; - -inline NDArray Empty(Shape shape, DLDataType dtype, DLDevice device) { - return NDArray::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); -} - -TEST(NDArray, Basic) { - NDArray nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); - Shape shape = nd.shape(); - EXPECT_EQ(shape.size(), 3); - EXPECT_EQ(shape[0], 1); - EXPECT_EQ(shape[1], 2); - EXPECT_EQ(shape[2], 3); - EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1})); - for (int64_t i = 0; i < shape.Product(); ++i) { - reinterpret_cast(nd->data)[i] = static_cast(i); - } - - Any any0 = nd; - NDArray nd2 = any0.as().value(); - EXPECT_EQ(nd2.shape(), shape); - EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1})); - for (int64_t i = 0; i < shape.Product(); ++i) { - EXPECT_EQ(reinterpret_cast(nd2->data)[i], i); - } - - EXPECT_EQ(nd.IsContiguous(), true); - EXPECT_EQ(nd2.use_count(), 3); -} - -TEST(NDArray, DLPack) { - NDArray nd = Empty({1, 2, 3}, DLDataType({kDLInt, 16, 1}), DLDevice({kDLCPU, 0})); - DLManagedTensor* dlpack = nd.ToDLPack(); - EXPECT_EQ(dlpack->dl_tensor.ndim, 3); - EXPECT_EQ(dlpack->dl_tensor.shape[0], 1); - EXPECT_EQ(dlpack->dl_tensor.shape[1], 2); - EXPECT_EQ(dlpack->dl_tensor.shape[2], 3); - EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLInt); - EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 16); - EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); - EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); - EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); - EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); - EXPECT_EQ(nd.use_count(), 2); - { - NDArray nd2 = NDArray::FromDLPack(dlpack); - EXPECT_EQ(nd2.use_count(), 1); - EXPECT_EQ(nd2->data, nd->data); - EXPECT_EQ(nd.use_count(), 2); - EXPECT_EQ(nd2.use_count(), 1); - } - EXPECT_EQ(nd.use_count(), 1); -} - -TEST(NDArray, DLPackVersioned) { - DLDataType dtype = DLDataType({kDLFloat4_e2m1fn, 4, 1}); - EXPECT_EQ(GetDataSize(2, dtype), 2 * 4 / 8); - NDArray nd = Empty({2}, dtype, DLDevice({kDLCPU, 0})); - DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned(); - EXPECT_EQ(dlpack->version.major, DLPACK_MAJOR_VERSION); - EXPECT_EQ(dlpack->version.minor, DLPACK_MINOR_VERSION); - EXPECT_EQ(dlpack->dl_tensor.ndim, 1); - EXPECT_EQ(dlpack->dl_tensor.shape[0], 2); - EXPECT_EQ(dlpack->dl_tensor.dtype.code, kDLFloat4_e2m1fn); - EXPECT_EQ(dlpack->dl_tensor.dtype.bits, 4); - EXPECT_EQ(dlpack->dl_tensor.dtype.lanes, 1); - EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); - EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); - EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); - - EXPECT_EQ(nd.use_count(), 2); - { - NDArray nd2 = NDArray::FromDLPackVersioned(dlpack); - EXPECT_EQ(nd2.use_count(), 1); - EXPECT_EQ(nd2->data, nd->data); - EXPECT_EQ(nd.use_count(), 2); - EXPECT_EQ(nd2.use_count(), 1); - } - EXPECT_EQ(nd.use_count(), 1); -} -} // namespace diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc deleted file mode 100644 index 4b53a70b42a2..000000000000 --- a/ffi/tests/cpp/test_object.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Object, RefCounter) { - ObjectPtr a = make_object(11); - ObjectPtr b = a; - - EXPECT_EQ(a->value, 11); - - EXPECT_EQ(a.use_count(), 2); - ObjectPtr aa = make_object(*a); - EXPECT_EQ(aa.use_count(), 1); - EXPECT_EQ(aa->value, 11); - - b.reset(); - EXPECT_EQ(a.use_count(), 1); - EXPECT_TRUE(b == nullptr); - EXPECT_EQ(b.use_count(), 0); - - ObjectPtr c = std::move(a); - EXPECT_EQ(c.use_count(), 1); - EXPECT_TRUE(a == nullptr); - - EXPECT_EQ(c->value, 11); -} - -TEST(Object, TypeInfo) { - const TypeInfo* info = TVMFFIGetTypeInfo(TIntObj::RuntimeTypeIndex()); - EXPECT_TRUE(info != nullptr); - EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex()); - EXPECT_EQ(info->type_depth, 2); - EXPECT_EQ(info->type_acenstors[0]->type_index, Object::_type_index); - EXPECT_EQ(info->type_acenstors[1]->type_index, TNumberObj::_type_index); - EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin); -} - -TEST(Object, InstanceCheck) { - ObjectPtr a = make_object(11); - ObjectPtr b = make_object(11); - - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(!a->IsInstance()); - - EXPECT_TRUE(a->IsInstance()); - EXPECT_TRUE(b->IsInstance()); - EXPECT_TRUE(!b->IsInstance()); - EXPECT_TRUE(b->IsInstance()); -} - -TEST(ObjectRef, as) { - ObjectRef a = TInt(10); - ObjectRef b = TFloat(20); - // nullable object - ObjectRef c(nullptr); - - EXPECT_TRUE(a.as() != nullptr); - EXPECT_TRUE(a.as() == nullptr); - EXPECT_TRUE(a.as() != nullptr); - - EXPECT_TRUE(b.as() == nullptr); - EXPECT_TRUE(b.as() != nullptr); - EXPECT_TRUE(b.as() != nullptr); - - EXPECT_TRUE(c.as() == nullptr); - EXPECT_TRUE(c.as() == nullptr); - EXPECT_TRUE(c.as() == nullptr); - - EXPECT_EQ(a.as()->value, 10); - EXPECT_EQ(b.as()->value, 20); -} - -TEST(Object, CAPIAccessor) { - ObjectRef a = TInt(10); - TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(a); - int32_t type_index = TVMFFIObjectGetTypeIndex(obj); - EXPECT_EQ(type_index, TIntObj::RuntimeTypeIndex()); -} -} // namespace diff --git a/ffi/tests/cpp/test_optional.cc b/ffi/tests/cpp/test_optional.cc deleted file mode 100644 index eb114df8a3fa..000000000000 --- a/ffi/tests/cpp/test_optional.cc +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Optional, TInt) { - Optional x; - Optional y = TInt(11); - static_assert(sizeof(Optional) == sizeof(ObjectRef)); - - EXPECT_TRUE(!x.has_value()); - EXPECT_EQ(x.value_or(TInt(12))->value, 12); - - EXPECT_TRUE(y.has_value()); - EXPECT_EQ(y.value_or(TInt(12))->value, 11); - - Any z_any = std::move(y); - EXPECT_TRUE(z_any != nullptr); - EXPECT_EQ((z_any.cast())->value, 11); - EXPECT_TRUE(!y.has_value()); - - // move from any to optional - auto y2 = std::move(z_any).cast>(); - EXPECT_EQ(y2.use_count(), 1); - EXPECT_TRUE(y2.has_value()); - EXPECT_EQ(y2.value_or(TInt(12))->value, 11); -} - -TEST(Optional, double) { - Optional x; - Optional y = 11.0; - static_assert(sizeof(Optional) > sizeof(ObjectRef)); - - EXPECT_TRUE(!x.has_value()); - EXPECT_EQ(x.value_or(12), 12); - EXPECT_TRUE(x != 12); - - EXPECT_TRUE(y.has_value()); - EXPECT_EQ(y.value_or(12), 11); - EXPECT_TRUE(y == 11); - EXPECT_TRUE(y != 12); -} - -TEST(Optional, AnyConvert_int) { - Optional opt_v0 = 1; - EXPECT_EQ(opt_v0.value(), 1); - EXPECT_TRUE(opt_v0.has_value()); - - AnyView view0 = opt_v0; - EXPECT_EQ(view0.cast(), 1); - - Any any1; - auto opt_v1 = std::move(any1).cast>(); - EXPECT_TRUE(!opt_v1.has_value()); - Optional opt_v2 = 11; - Any any2 = std::move(opt_v2); - EXPECT_EQ(any2.cast(), 11); -} - -TEST(Optional, AnyConvert_Array) { - AnyView view0; - Array> arr_nested = {{}, {TInt(1), TFloat(2)}}; - view0 = arr_nested; - - auto opt_arr = view0.cast>>>(); - EXPECT_EQ(arr_nested.use_count(), 2); - - auto arr1 = view0.cast>>>(); - EXPECT_EQ(arr_nested.use_count(), 3); - EXPECT_EQ(arr1.value()[1][1].as()->value, 2); - - Any any1; - auto arr2 = any1.cast>>>(); - EXPECT_TRUE(!arr2.has_value()); - - EXPECT_THROW( - { - try { - [[maybe_unused]] auto arr2 = view0.cast>>>(); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - std::string what = error.what(); - std::cout << what << std::endl; - EXPECT_NE(what.find("to `Optional>>`"), std::string::npos); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Optional, OptionalOfOptional) { - // testcase of optional - Optional> opt_opt_int; - EXPECT_TRUE(!opt_opt_int.has_value()); - - Optional> opt_opt_int2 = Optional(std::nullopt); - EXPECT_TRUE(opt_opt_int2.has_value()); - EXPECT_TRUE(!opt_opt_int2.value().has_value()); - - // Optional> - Optional> opt_opt_tint; - EXPECT_TRUE(!opt_opt_tint.has_value()); - - Optional> opt_opt_tint2 = Optional(std::nullopt); - EXPECT_TRUE(opt_opt_tint2.has_value()); - EXPECT_TRUE(!opt_opt_tint2.value().has_value()); - opt_opt_tint2 = std::nullopt; - EXPECT_TRUE(!opt_opt_tint2.has_value()); - - Optional> opt_opt_tint3 = Optional(TInt(42)); - EXPECT_TRUE(opt_opt_tint3.has_value()); - EXPECT_TRUE(opt_opt_tint3.value().has_value()); - EXPECT_EQ(opt_opt_tint3.value().value()->value, 42); -} - -TEST(Optional, ValueMove) { - Optional y = TInt(11); - TInt x = std::move(y).value(); - EXPECT_TRUE(!y.has_value()); - EXPECT_EQ(x->value, 11); - - Optional opt_tint = TInt(21); - EXPECT_TRUE(opt_tint.has_value()); - EXPECT_EQ((*opt_tint)->value, 21); - - TInt moved_tint = *std::move(opt_tint); - EXPECT_EQ(moved_tint->value, 21); - EXPECT_TRUE(!opt_tint.has_value()); -} - -TEST(Optional, OptionalInArray) { - // This pattern plus iteration may cause memory leak - // this is because arr[0] returns a temporary object - // and further call arr[0].value() may return a reference to - // the temporary object - Array>> arr = {Array({TInt(0), TInt(1)})}; - int counter = 0; - - for (const auto& x : arr[0].value()) { - EXPECT_EQ(x->value, counter++); - } - - Any any = arr; - auto opt_arr = any.cast>>>(); - EXPECT_EQ(opt_arr[0].value()[0]->value, 0); -} - -TEST(Optional, String) { - Optional opt_str; - EXPECT_TRUE(!opt_str.has_value()); - EXPECT_EQ(opt_str.value_or("default"), "default"); - EXPECT_TRUE(opt_str != "default"); - EXPECT_TRUE(opt_str != String("default")); - EXPECT_TRUE(opt_str == std::nullopt); - - opt_str = "hello"; - EXPECT_TRUE(opt_str.has_value()); - EXPECT_EQ(opt_str.value(), "hello"); - EXPECT_TRUE(opt_str == "hello"); - EXPECT_TRUE(opt_str == String("hello")); - EXPECT_TRUE(opt_str != std::nullopt); - static_assert(sizeof(Optional) == sizeof(String)); -} - -TEST(Optional, Bytes) { - Optional opt_bytes; - EXPECT_TRUE(!opt_bytes.has_value()); - EXPECT_EQ(opt_bytes.value_or(std::string("default")), "default"); - - opt_bytes = std::string("hello"); - EXPECT_TRUE(opt_bytes.has_value()); - EXPECT_EQ(opt_bytes.value().operator std::string(), "hello"); - EXPECT_TRUE(opt_bytes != std::nullopt); - static_assert(sizeof(Optional) == sizeof(Bytes)); -} -} // namespace diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc deleted file mode 100644 index 85da00c1321d..000000000000 --- a/ffi/tests/cpp/test_reflection.cc +++ /dev/null @@ -1,272 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -struct TestObjA : public Object { - int64_t x; - int64_t y; - - static constexpr const char* _type_key = "test.TestObjA"; - static constexpr bool _type_mutable = true; - TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjA, Object); -}; - -struct TestObjADerived : public TestObjA { - int64_t z; - - static constexpr const char* _type_key = "test.TestObjADerived"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjADerived, TestObjA); -}; - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - - TIntObj::RegisterReflection(); - TFloatObj::RegisterReflection(); - TPrimExprObj::RegisterReflection(); - TVarObj::RegisterReflection(); - TFuncObj::RegisterReflection(); - TCustomFuncObj::RegisterReflection(); - - refl::ObjectDef().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y); - refl::ObjectDef().def_ro("z", &TestObjADerived::z); -}); - -TEST(Reflection, GetFieldByteOffset) { - EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::x), sizeof(TVMFFIObject)); - EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::y), 8 + sizeof(TVMFFIObject)); - EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject)); -} - -TEST(Reflection, FieldGetter) { - ObjectRef a = TInt(10); - reflection::FieldGetter getter("test.Int", "value"); - EXPECT_EQ(getter(a).cast(), 10); - - ObjectRef b = TFloat(10.0); - reflection::FieldGetter getter_float("test.Float", "value"); - EXPECT_EQ(getter_float(b).cast(), 10.0); -} - -TEST(Reflection, FieldSetter) { - ObjectRef a = TFloat(10.0); - reflection::FieldSetter setter("test.Float", "value"); - setter(a, 20.0); - EXPECT_EQ(a.as()->value, 20.0); -} - -TEST(Reflection, FieldInfo) { - const TVMFFIFieldInfo* info_int = reflection::GetFieldInfo("test.Int", "value"); - EXPECT_FALSE(info_int->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_FALSE(info_int->flags & kTVMFFIFieldFlagBitMaskWritable); - EXPECT_EQ(Bytes(info_int->doc).operator std::string(), ""); - - const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", "value"); - EXPECT_EQ(info_float->default_value.v_float64, 10.0); - EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_FALSE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable); - EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value field"); - - const TVMFFIFieldInfo* info_prim_expr_dtype = reflection::GetFieldInfo("test.PrimExpr", "dtype"); - AnyView default_value = AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value); - EXPECT_EQ(default_value.cast(), "float"); - EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable); - EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype field"); -} - -TEST(Reflection, MethodInfo) { - const TVMFFIMethodInfo* info_int_static_add = reflection::GetMethodInfo("test.Int", "static_add"); - EXPECT_TRUE(info_int_static_add->flags & kTVMFFIFieldFlagBitMaskIsStaticMethod); - EXPECT_EQ(Bytes(info_int_static_add->doc).operator std::string(), "static add method"); - - const TVMFFIMethodInfo* info_float_add = reflection::GetMethodInfo("test.Float", "add"); - EXPECT_FALSE(info_float_add->flags & kTVMFFIFieldFlagBitMaskIsStaticMethod); - EXPECT_EQ(Bytes(info_float_add->doc).operator std::string(), "add method"); - - const TVMFFIMethodInfo* info_float_sub = reflection::GetMethodInfo("test.Float", "sub"); - EXPECT_FALSE(info_float_sub->flags & kTVMFFIFieldFlagBitMaskIsStaticMethod); - EXPECT_EQ(Bytes(info_float_sub->doc).operator std::string(), ""); -} - -TEST(Reflection, CallMethod) { - Function static_int_add = reflection::GetMethod("test.Int", "static_add"); - EXPECT_EQ(static_int_add(TInt(1), TInt(2)).cast()->value, 3); - - Function float_add = reflection::GetMethod("test.Float", "add"); - EXPECT_EQ(float_add(TFloat(1), 2.0).cast(), 3.0); - - Function float_sub = reflection::GetMethod("test.Float", "sub"); - EXPECT_EQ(float_sub(TFloat(1), 2.0).cast(), -1.0); - - Function prim_expr_sub = reflection::GetMethod("test.PrimExpr", "sub"); - EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast(), -1.0); -} - -TEST(Reflection, ForEachFieldInfo) { - const TypeInfo* info = TVMFFIGetTypeInfo(TestObjADerived::RuntimeTypeIndex()); - Map field_name_to_offset; - reflection::ForEachFieldInfo(info, [&](const TVMFFIFieldInfo* field_info) { - field_name_to_offset.Set(String(field_info->name), field_info->offset); - }); - EXPECT_EQ(field_name_to_offset["x"], sizeof(TVMFFIObject)); - EXPECT_EQ(field_name_to_offset["y"], 8 + sizeof(TVMFFIObject)); - EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject)); -} - -TEST(Reflection, TypeAttrColumn) { - reflection::TypeAttrColumn size_attr("test.size"); - EXPECT_EQ(size_attr[TIntObj::_type_index].cast(), sizeof(TIntObj)); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("testing.Int_GetValue", &TIntObj::GetValue); -}); - -TEST(Reflection, FuncRegister) { - Function fget_value = Function::GetGlobalRequired("testing.Int_GetValue"); - TInt a(12); - EXPECT_EQ(fget_value(a).cast(), 12); -} - -TEST(Reflection, ObjectCreator) { - namespace refl = tvm::ffi::reflection; - refl::ObjectCreator creator("test.Int"); - EXPECT_EQ(creator(Map({{"value", 1}})).cast()->value, 1); -} - -TEST(Reflection, AccessPath) { - namespace refl = tvm::ffi::reflection; - - // Test basic path construction and ToSteps() - refl::AccessPath path = refl::AccessPath::Root()->Attr("body")->ArrayItem(1); - auto steps = path->ToSteps(); - EXPECT_EQ(steps.size(), 2); - EXPECT_EQ(steps[0]->kind, refl::AccessKind::kAttr); - EXPECT_EQ(steps[1]->kind, refl::AccessKind::kArrayItem); - EXPECT_EQ(steps[0]->key.cast(), "body"); - EXPECT_EQ(steps[1]->key.cast(), 1); - - // Test PathEqual with identical paths - refl::AccessPath path2 = refl::AccessPath::Root()->Attr("body")->ArrayItem(1); - EXPECT_TRUE(path->PathEqual(path2)); - EXPECT_TRUE(path->IsPrefixOf(path2)); - - // Test PathEqual with different paths - refl::AccessPath path3 = refl::AccessPath::Root()->Attr("body")->ArrayItem(2); - EXPECT_FALSE(path->PathEqual(path3)); - EXPECT_FALSE(path->IsPrefixOf(path3)); - - // Test prefix relationship - path4 extends path, so path should be prefix of path4 - refl::AccessPath path4 = refl::AccessPath::Root()->Attr("body")->ArrayItem(1)->Attr("body"); - EXPECT_FALSE(path->PathEqual(path4)); // Not equal (different lengths) - EXPECT_TRUE(path->IsPrefixOf(path4)); // But path is a prefix of path4 - - // Test completely different paths - refl::AccessPath path5 = refl::AccessPath::Root()->ArrayItem(0)->ArrayItem(1)->Attr("body"); - EXPECT_FALSE(path->PathEqual(path5)); - EXPECT_FALSE(path->IsPrefixOf(path5)); - - // Test Root path - refl::AccessPath root = refl::AccessPath::Root(); - auto root_steps = root->ToSteps(); - EXPECT_EQ(root_steps.size(), 0); - EXPECT_EQ(root->depth, 0); - EXPECT_TRUE(root->IsPrefixOf(path)); - EXPECT_TRUE(root->IsPrefixOf(root)); - EXPECT_TRUE(root->PathEqual(refl::AccessPath::Root())); - - // Test depth calculations - EXPECT_EQ(path->depth, 2); - EXPECT_EQ(path4->depth, 3); - EXPECT_EQ(root->depth, 0); - - // Test MapItem access - refl::AccessPath map_path = refl::AccessPath::Root()->Attr("data")->MapItem("key1"); - auto map_steps = map_path->ToSteps(); - EXPECT_EQ(map_steps.size(), 2); - EXPECT_EQ(map_steps[0]->kind, refl::AccessKind::kAttr); - EXPECT_EQ(map_steps[1]->kind, refl::AccessKind::kMapItem); - EXPECT_EQ(map_steps[0]->key.cast(), "data"); - EXPECT_EQ(map_steps[1]->key.cast(), "key1"); - - // Test MapItemMissing access - refl::AccessPath map_missing_path = refl::AccessPath::Root()->MapItemMissing(42); - auto map_missing_steps = map_missing_path->ToSteps(); - EXPECT_EQ(map_missing_steps.size(), 1); - EXPECT_EQ(map_missing_steps[0]->kind, refl::AccessKind::kMapItemMissing); - EXPECT_EQ(map_missing_steps[0]->key.cast(), 42); - - // Test ArrayItemMissing access - refl::AccessPath array_missing_path = refl::AccessPath::Root()->ArrayItemMissing(5); - auto array_missing_steps = array_missing_path->ToSteps(); - EXPECT_EQ(array_missing_steps.size(), 1); - EXPECT_EQ(array_missing_steps[0]->kind, refl::AccessKind::kArrayItemMissing); - EXPECT_EQ(array_missing_steps[0]->key.cast(), 5); - - // Test FromSteps static method - round trip conversion - auto original_steps = path->ToSteps(); - refl::AccessPath reconstructed = refl::AccessPath::FromSteps(original_steps); - EXPECT_TRUE(path->PathEqual(reconstructed)); - EXPECT_EQ(path->depth, reconstructed->depth); - - // Test complex prefix relationships - refl::AccessPath short_path = refl::AccessPath::Root()->Attr("x"); - refl::AccessPath medium_path = refl::AccessPath::Root()->Attr("x")->ArrayItem(0); - refl::AccessPath long_path = refl::AccessPath::Root()->Attr("x")->ArrayItem(0)->MapItem("z"); - - EXPECT_TRUE(short_path->IsPrefixOf(medium_path)); - EXPECT_TRUE(short_path->IsPrefixOf(long_path)); - EXPECT_TRUE(medium_path->IsPrefixOf(long_path)); - EXPECT_FALSE(medium_path->IsPrefixOf(short_path)); - EXPECT_FALSE(long_path->IsPrefixOf(medium_path)); - EXPECT_FALSE(long_path->IsPrefixOf(short_path)); - - // Test non-prefix relationships - refl::AccessPath branch1 = refl::AccessPath::Root()->Attr("x")->ArrayItem(0); - refl::AccessPath branch2 = refl::AccessPath::Root()->Attr("x")->ArrayItem(1); - EXPECT_FALSE(branch1->IsPrefixOf(branch2)); - EXPECT_FALSE(branch2->IsPrefixOf(branch1)); - EXPECT_FALSE(branch1->PathEqual(branch2)); - - // Test GetParent functionality - auto parent = path4->GetParent(); - EXPECT_TRUE(parent.has_value()); - EXPECT_TRUE(parent.value()->PathEqual(path)); - - auto root_parent = root->GetParent(); - EXPECT_FALSE(root_parent.has_value()); -} -} // namespace diff --git a/ffi/tests/cpp/test_rvalue_ref.cc b/ffi/tests/cpp/test_rvalue_ref.cc deleted file mode 100644 index dd211a34dc60..000000000000 --- a/ffi/tests/cpp/test_rvalue_ref.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(RValueRef, Basic) { - auto append = - Function::FromTyped([](RValueRef> ref, int val, bool is_unique) -> Array { - Array arr = *std::move(ref); - EXPECT_EQ(arr.unique(), is_unique); - arr.push_back(val); - return arr; - }); - auto a = append(RValueRef(Array({1, 2})), 3, true).cast>(); - EXPECT_EQ(a.size(), 3); - a = append(RValueRef(std::move(a)), 4, true).cast>(); - EXPECT_EQ(a.size(), 4); - // pass in lvalue instead, the append still will succeed but array will not be unique - a = append(a, 5, false).cast>(); - EXPECT_EQ(a.size(), 5); -} - -TEST(RValueRef, ParamChecking) { - // try decution - Function fadd1 = Function::FromTyped([](TInt a) -> int64_t { return a->value + 1; }); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(RValueRef(TInt(1))); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: test.Int) -> int`. " - "Expected `test.Int` but got `ObjectRValueRef`"); - throw; - } - }, - ::tvm::ffi::Error); - - Function fadd2 = Function::FromTyped([](RValueRef> a) -> int { - Array arr = *std::move(a); - return arr[0] + 1; - }); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd2(RValueRef(Array({1, 2.2}))); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ( - error.message(), - "Mismatched type on argument #0 when calling: `(0: RValueRef>) -> int`. " - "Expected `RValueRef>` but got `RValueRef`"); - throw; - } - }, - ::tvm::ffi::Error); - // triggered a rvalue based conversion - Function func3 = Function::FromTyped([](RValueRef a) -> String { - TPrimExpr expr = *std::move(a); - return expr->dtype; - }); - // EXPECT_EQ(func3(RValueRef(String("int32"))).cast(), "int32"); - // triggered a lvalue based conversion - // EXPECT_EQ(func3(String("int32")).cast(), "int32"); -} -} // namespace diff --git a/ffi/tests/cpp/test_shape.cc b/ffi/tests/cpp/test_shape.cc deleted file mode 100644 index 0ccba7820ad7..000000000000 --- a/ffi/tests/cpp/test_shape.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -TEST(Shape, Basic) { - Shape shape = Shape({1, 2, 3}); - EXPECT_EQ(shape.size(), 3); - EXPECT_EQ(shape[0], 1); - EXPECT_EQ(shape[1], 2); - EXPECT_EQ(shape[2], 3); - - Shape shape2 = Shape(Array({4, 5, 6, 7})); - EXPECT_EQ(shape2.size(), 4); - EXPECT_EQ(shape2[0], 4); - EXPECT_EQ(shape2[1], 5); - EXPECT_EQ(shape2[2], 6); - EXPECT_EQ(shape2[3], 7); - - std::vector vec = {8, 9, 10}; - Shape shape3 = Shape(std::move(vec)); - EXPECT_EQ(shape3.size(), 3); - EXPECT_EQ(shape3[0], 8); - EXPECT_EQ(shape3[1], 9); - EXPECT_EQ(shape3[2], 10); - EXPECT_EQ(shape3.Product(), 8 * 9 * 10); - - Shape shape4 = Shape(); - EXPECT_EQ(shape4.size(), 0); - EXPECT_EQ(shape4.Product(), 1); -} - -TEST(Shape, AnyConvert) { - Shape shape0 = Shape({1, 2, 3}); - Any any0 = shape0; - - auto shape1 = any0.cast(); - EXPECT_EQ(shape1.size(), 3); - EXPECT_EQ(shape1[0], 1); - EXPECT_EQ(shape1[1], 2); - EXPECT_EQ(shape1[2], 3); - - Array arr({1, 2}); - AnyView any_view0 = arr; - auto shape2 = any_view0.cast(); - EXPECT_EQ(shape2.size(), 2); - EXPECT_EQ(shape2[0], 1); - EXPECT_EQ(shape2[1], 2); -} - -} // namespace diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc deleted file mode 100644 index 8522aa93a3b9..000000000000 --- a/ffi/tests/cpp/test_string.cc +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -namespace { - -using namespace tvm::ffi; - -TEST(String, MoveFromStd) { - using namespace std; - string source = "this is a string"; - string expect = source; - String s(std::move(source)); - string copy = (string)s; - EXPECT_EQ(copy, expect); - EXPECT_EQ(source.size(), 0); -} - -TEST(String, CopyFromStd) { - using namespace std; - string source = "this is a string"; - string expect = source; - String s{source}; - string copy = (string)s; - EXPECT_EQ(copy, expect); - EXPECT_EQ(source.size(), expect.size()); -} - -TEST(String, Assignment) { - using namespace std; - String s{string{"hello"}}; - s = string{"world"}; - EXPECT_EQ(s == "world", true); - string s2{"world2"}; - s = std::move(s2); - EXPECT_EQ(s == "world2", true); - - Any r; - r = String("hello"); - EXPECT_EQ(r != nullptr, true); -} - -TEST(String, empty) { - using namespace std; - String s{"hello"}; - EXPECT_EQ(s.empty(), false); - s = std::string(""); - EXPECT_EQ(s.empty(), true); -} - -TEST(String, Comparisons) { - using namespace std; - string source = "a string"; - string mismatch = "a string but longer"; - String s{"a string"}; - String m{mismatch}; - - EXPECT_EQ("a str" >= s, false); - EXPECT_EQ(s == source, true); - EXPECT_EQ(s == mismatch, false); - EXPECT_EQ(s == source.data(), true); - EXPECT_EQ(s == mismatch.data(), false); - - EXPECT_EQ(s < m, source < mismatch); - EXPECT_EQ(s > m, source > mismatch); - EXPECT_EQ(s <= m, source <= mismatch); - EXPECT_EQ(s >= m, source >= mismatch); - EXPECT_EQ(s == m, source == mismatch); - EXPECT_EQ(s != m, source != mismatch); - - EXPECT_EQ(m < s, mismatch < source); - EXPECT_EQ(m > s, mismatch > source); - EXPECT_EQ(m <= s, mismatch <= source); - EXPECT_EQ(m >= s, mismatch >= source); - EXPECT_EQ(m == s, mismatch == source); - EXPECT_EQ(m != s, mismatch != source); -} - -TEST(String, Compare) { - // string compare const char* - String s{"hello"}; - EXPECT_EQ(s.compare("hello"), 0); - EXPECT_EQ(s.compare(String("hello")), 0); - - EXPECT_EQ(s.compare("hallo"), 1); - EXPECT_EQ(s.compare(String("hallo")), 1); - EXPECT_EQ(s.compare("hfllo"), -1); - EXPECT_EQ(s.compare(String("hfllo")), -1); - // s is longer - EXPECT_EQ(s.compare("hell"), 1); - EXPECT_EQ(s.compare(String("hell")), 1); - // s is shorter - EXPECT_EQ(s.compare("hello world"), -1); - EXPECT_EQ(s.compare(String("helloworld")), -1); -} - -// Check '\0' handling -TEST(String, null_byte_handling) { - using namespace std; - // Ensure string still compares equal if it contains '\0'. - string v1 = "hello world"; - size_t v1_size = v1.size(); - v1[5] = '\0'; - EXPECT_EQ(v1[5], '\0'); - EXPECT_EQ(v1.size(), v1_size); - String str_v1{v1}; - EXPECT_EQ(str_v1.compare(v1), 0); - EXPECT_EQ(str_v1.size(), v1_size); - - // Ensure bytes after '\0' are taken into account for mismatches. - string v2 = "aaa one"; - string v3 = "aaa two"; - v2[3] = '\0'; - v3[3] = '\0'; - String str_v2{v2}; - String str_v3{v3}; - EXPECT_EQ(str_v2.compare(str_v3), -1); - EXPECT_EQ(str_v2.size(), 7); - // strcmp won't be able to detect the mismatch - EXPECT_EQ(strcmp(v2.data(), v3.data()), 0); - // string::compare can handle \0 since it knows size - EXPECT_LT(v2.compare(v3), 0); - - // If there is mismatch before '\0', should still handle it. - string v4 = "acc one"; - string v5 = "abb two"; - v4[3] = '\0'; - v5[3] = '\0'; - String str_v4{v4}; - String str_v5{v5}; - EXPECT_GT(str_v4.compare(str_v5), 0); - EXPECT_EQ(str_v4.size(), 7); - // strcmp is able to detect the mismatch - EXPECT_GT(strcmp(v4.data(), v5.data()), 0); - // string::compare can handle \0 since it knows size - EXPECT_GT(v4.compare(v5), 0); -} - -TEST(String, compare_same_memory_region_different_size) { - using namespace std; - string source = "a string"; - String str_source{source}; - char* memory = const_cast(str_source.data()); - EXPECT_EQ(str_source.compare(memory), 0); - // This changes the string size - memory[2] = '\0'; - // memory is logically shorter now - EXPECT_GT(str_source.compare(memory), 0); -} - -TEST(String, compare) { - using namespace std; - constexpr auto mismatch1_cstr = "a string but longer"; - string source = "a string"; - string mismatch1 = mismatch1_cstr; - string mismatch2 = "a strin"; - string mismatch3 = "a b"; - string mismatch4 = "a t"; - String str_source{source}; - String str_mismatch1{mismatch1_cstr}; - String str_mismatch2{mismatch2}; - String str_mismatch3{mismatch3}; - String str_mismatch4{mismatch4}; - - // compare with string - EXPECT_EQ(str_source.compare(source), 0); - EXPECT_TRUE(str_source == source); - EXPECT_TRUE(source == str_source); - EXPECT_TRUE(str_source <= source); - EXPECT_TRUE(source <= str_source); - EXPECT_TRUE(str_source >= source); - EXPECT_TRUE(source >= str_source); - EXPECT_LT(str_source.compare(mismatch1), 0); - EXPECT_TRUE(str_source < mismatch1); - EXPECT_TRUE(mismatch1 != str_source); - EXPECT_GT(str_source.compare(mismatch2), 0); - EXPECT_TRUE(str_source > mismatch2); - EXPECT_TRUE(mismatch2 < str_source); - EXPECT_GT(str_source.compare(mismatch3), 0); - EXPECT_TRUE(str_source > mismatch3); - EXPECT_LT(str_source.compare(mismatch4), 0); - EXPECT_TRUE(str_source < mismatch4); - EXPECT_TRUE(mismatch4 > str_source); - - // compare with char* - EXPECT_EQ(str_source.compare(source.data()), 0); - EXPECT_TRUE(str_source == source.data()); - EXPECT_TRUE(source.data() == str_source); - EXPECT_TRUE(str_source <= source.data()); - EXPECT_TRUE(source <= str_source.data()); - EXPECT_TRUE(str_source >= source.data()); - EXPECT_TRUE(source >= str_source.data()); - EXPECT_LT(str_source.compare(mismatch1.data()), 0); - EXPECT_TRUE(str_source < mismatch1.data()); - EXPECT_TRUE(str_source != mismatch1.data()); - EXPECT_TRUE(mismatch1.data() != str_source); - EXPECT_GT(str_source.compare(mismatch2.data()), 0); - EXPECT_TRUE(str_source > mismatch2.data()); - EXPECT_TRUE(mismatch2.data() < str_source); - EXPECT_GT(str_source.compare(mismatch3.data()), 0); - EXPECT_TRUE(str_source > mismatch3.data()); - EXPECT_LT(str_source.compare(mismatch4.data()), 0); - EXPECT_TRUE(str_source < mismatch4.data()); - EXPECT_TRUE(mismatch4.data() > str_source); - - // compare with String - EXPECT_LT(str_source.compare(str_mismatch1), 0); - EXPECT_TRUE(str_source < str_mismatch1); - EXPECT_GT(str_source.compare(str_mismatch2), 0); - EXPECT_TRUE(str_source > str_mismatch2); - EXPECT_GT(str_source.compare(str_mismatch3), 0); - EXPECT_TRUE(str_source > str_mismatch3); - EXPECT_LT(str_source.compare(str_mismatch4), 0); - EXPECT_TRUE(str_source < str_mismatch4); -} - -TEST(String, c_str) { - using namespace std; - string source = "this is a string"; - string mismatch = "mismatch"; - String s{source}; - - EXPECT_EQ(std::strcmp(s.c_str(), source.data()), 0); - EXPECT_NE(std::strcmp(s.c_str(), mismatch.data()), 0); -} - -TEST(String, hash) { - using namespace std; - string source = "this is a string"; - String s{source}; - std::hash()(s); - - std::unordered_map map; - String k1{string{"k1"}}; - string v1{"v1"}; - String k2{string{"k2"}}; - string v2{"v2"}; - map[k1] = v1; - map[k2] = v2; - - EXPECT_EQ(map[k1], v1); - EXPECT_EQ(map[k2], v2); -} - -TEST(String, Cast) { - using namespace std; - string source = "this is a string"; - String s{source}; - Any r = s; - String s2 = r.cast(); -} - -TEST(String, Concat) { - String s1("hello"); - String s2("world"); - std::string s3("world"); - String res1 = s1 + s2; - String res2 = s1 + s3; - String res3 = s3 + s1; - String res4 = s1 + "world"; - String res5 = "world" + s1; - - EXPECT_EQ(res1.compare("helloworld"), 0); - EXPECT_EQ(res2.compare("helloworld"), 0); - EXPECT_EQ(res3.compare("worldhello"), 0); - EXPECT_EQ(res4.compare("helloworld"), 0); - EXPECT_EQ(res5.compare("worldhello"), 0); - - String storage_scope; - String res = "The input storage scope \"" + storage_scope + "\" is invalid."; - EXPECT_EQ(res.compare("The input storage scope \"\" is invalid."), 0); -} - -TEST(String, Any) { - // test anyview promotion to any - AnyView view = "hello"; - EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIRawStr); - - Any b = view; - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallStr); - EXPECT_EQ(b.as().value(), "hello"); - EXPECT_TRUE(b.as().has_value()); - EXPECT_EQ(b.try_cast().value(), "hello"); - - std::string s_world = "world"; - view = s_world; - EXPECT_EQ(view.try_cast().value(), "world"); - - String s{"hello"}; - Any a = s; - EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFISmallStr); - EXPECT_EQ(a.as().value(), "hello"); - EXPECT_EQ(a.try_cast().value(), "hello"); - - Any c = "long string very long"; - EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(c.as().value(), "long string very long"); - EXPECT_EQ(c.try_cast().value(), "long string very long"); -} - -TEST(String, Bytes) { - Bytes b0; - EXPECT_EQ(b0.size(), 0); - EXPECT_EQ(b0.operator std::string(), ""); - - // explicitly test zero element - std::string s = {'\0', 'a', 'b', 'c'}; - Bytes b = s; - EXPECT_EQ(b.size(), 4); - EXPECT_EQ(b.operator std::string(), s); - - TVMFFIByteArray arr{s.data(), static_cast(s.size())}; - Bytes b2 = arr; - EXPECT_EQ(b2.size(), 4); - EXPECT_EQ(b2.operator std::string(), s); -} - -TEST(String, BytesAny) { - std::string s = {'\0', 'a', 'b', 'c'}; - TVMFFIByteArray arr{s.data(), static_cast(s.size())}; - - AnyView view = &arr; - EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIByteArrayPtr); - EXPECT_EQ(view.try_cast().value().operator std::string(), s); - - Any b = view; - EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFISmallBytes); - - EXPECT_EQ(b.try_cast().value().operator std::string(), s); - EXPECT_EQ(b.cast(), s); - - std::string s2 = "hello long long long string"; - s2[0] = '\0'; - Any b2 = Bytes(s2); - EXPECT_EQ(b2.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(b2.try_cast().value(), s2); - EXPECT_EQ(b2.cast(), s2); -} - -TEST(String, StdString) { - std::string s1 = "test_string"; - AnyView view1 = s1; - EXPECT_EQ(view1.type_index(), TypeIndex::kTVMFFIRawStr); - EXPECT_EQ(view1.try_cast().value(), s1); - - TVMFFIByteArray arr1{s1.data(), static_cast(s1.size())}; - AnyView view2 = &arr1; - EXPECT_EQ(view2.type_index(), TypeIndex::kTVMFFIByteArrayPtr); - EXPECT_EQ(view2.try_cast().value(), s1); - - Bytes bytes1 = s1; - AnyView view3 = bytes1; - EXPECT_EQ(view3.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(view3.try_cast().value(), s1); - - String string1 = s1; - AnyView view4 = string1; - EXPECT_EQ(view4.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(view4.try_cast().value(), s1); - - // Test with Any - Any any1 = s1; - EXPECT_EQ(any1.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(any1.try_cast().value(), s1); - - Any any2 = &arr1; - EXPECT_EQ(any2.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(any2.try_cast().value(), s1); - - Any any3 = bytes1; - EXPECT_EQ(any3.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(any3.try_cast().value(), s1); - - Any any4 = string1; - EXPECT_EQ(any4.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(any4.try_cast().value(), s1); -} - -TEST(String, CAPIAccessor) { - using namespace std; - String s{"hello"}; - TVMFFIByteArray arr{s.data(), s.size()}; - EXPECT_EQ(arr.size, 5); - EXPECT_EQ(std::string(arr.data, arr.size), "hello"); -} - -TEST(String, BytesHash) { - std::vector data1(10); - std::vector data2(11); - for (size_t i = 0; i < data1.size(); ++i) { - data1[i] = i; - } - char* data1_ptr = reinterpret_cast(data1.data()); - char* data2_ptr = reinterpret_cast(data2.data()) + 1; - std::memcpy(data2_ptr, data1.data(), data1.size() * sizeof(int64_t)); - // has of aligned and unaligned data should be the same - uint64_t hash1 = details::StableHashBytes(data1_ptr, data1.size() * sizeof(int64_t)); - uint64_t hash2 = details::StableHashBytes(data2_ptr, data1.size() * sizeof(int64_t)); - EXPECT_EQ(hash1, hash2); -} - -TEST(String, StdHash) { - String s1 = "a"; - String s2(std::string("a")); - EXPECT_EQ(std::hash()(s1), std::hash()(s2)); - - Bytes s3("a", 1); - Bytes s4(std::string("a")); - EXPECT_EQ(std::hash()(s3), std::hash()(s4)); -} - -} // namespace diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc deleted file mode 100644 index 5735e86eca4d..000000000000 --- a/ffi/tests/cpp/test_tuple.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Tuple, Basic) { - Tuple tuple0(1, 2.0f); - EXPECT_EQ(tuple0.get<0>(), 1); - EXPECT_EQ(tuple0.get<1>(), 2.0f); - - Tuple tuple1 = tuple0; - EXPECT_EQ(tuple0.use_count(), 2); - - // test copy on write - tuple1.Set<0>(3); - EXPECT_EQ(tuple0.get<0>(), 1); - EXPECT_EQ(tuple1.get<0>(), 3); - - EXPECT_EQ(tuple0.use_count(), 1); - EXPECT_EQ(tuple1.use_count(), 1); - - // copy on write not triggered because - // tuple1 is unique. - tuple1.Set<1>(4); - EXPECT_EQ(tuple1.get<1>(), 4.0f); - EXPECT_EQ(tuple1.use_count(), 1); - - // default state - Tuple tuple2; - EXPECT_EQ(tuple2.use_count(), 1); - tuple2.Set<0>(1); - tuple2.Set<1>(2.0f); - EXPECT_EQ(tuple2.get<0>(), 1); - EXPECT_EQ(tuple2.get<1>(), 2.0f); - - // tuple of object and primitive - Tuple tuple3(1, 2); - EXPECT_EQ(tuple3.get<0>()->value, 1); - EXPECT_EQ(tuple3.get<1>(), 2); - tuple3.Set<0>(4); - EXPECT_EQ(tuple3.get<0>()->value, 4); -} - -TEST(Tuple, AnyConvert) { - Tuple tuple0(1, 2); - AnyView view0 = tuple0; - Array arr0 = view0.as>().value(); - EXPECT_EQ(arr0.size(), 2); - EXPECT_EQ(arr0[0].as().value(), 1); - EXPECT_EQ(arr0[1].as().value()->value, 2); - - // directly reuse the underlying storage. - auto tuple1 = view0.cast>(); - EXPECT_TRUE(tuple0.same_as(tuple1)); - - Any any0 = view0; - // trigger a copy due to implict conversion - auto tuple2 = any0.cast>(); - EXPECT_TRUE(!tuple0.same_as(tuple2)); - EXPECT_EQ(tuple2.get<0>()->value, 1); - EXPECT_EQ(tuple2.get<1>()->value, 2); -} - -TEST(Tuple, FromTyped) { - // try decution - Function fadd1 = Function::FromTyped([](const Tuple& a) -> int { - return a.get<0>() + static_cast(a.get<1>()->value); - }); - int b = fadd1(Tuple(1, 2)).cast(); - EXPECT_EQ(b, 3); - - int c = fadd1(Array({1, 2})).cast(); - EXPECT_EQ(c, 3); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(Array({1.1, 2})); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: Tuple) -> int`. " - "Expected `Tuple` but got `Array[index 0: float]`"); - throw; - } - }, - ::tvm::ffi::Error); - - EXPECT_THROW( - { - try { - fadd1(Array({1.1})); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: Tuple) -> int`. " - "Expected `Tuple` but got `Array[size=1]`"); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Tuple, Upcast) { - Tuple t0(1, 2.0f); - Tuple t1 = t0; - EXPECT_EQ(t1.get<0>().cast(), 1); - EXPECT_EQ(t1.get<1>().cast(), 2.0f); - static_assert(details::type_contains_v, Tuple>); - static_assert(details::type_contains_v, Tuple>); - static_assert(details::type_contains_v, Tuple>); -} - -TEST(Tuple, ArrayIterForwarding) { - Tuple t0(1, 2); - Tuple t1(3, 4); - Array> arr0 = {t0, t1}; - std::vector> vec0 = {t0}; - vec0.insert(vec0.end(), arr0.begin(), arr0.end()); - EXPECT_EQ(vec0.size(), 3); - EXPECT_EQ(vec0[0].get<0>()->value, 1); - EXPECT_EQ(vec0[0].get<1>()->value, 2); - EXPECT_EQ(vec0[1].get<0>()->value, 1); - EXPECT_EQ(vec0[1].get<1>()->value, 2); - EXPECT_EQ(vec0[2].get<0>()->value, 3); - EXPECT_EQ(vec0[2].get<1>()->value, 4); -} - -TEST(Tuple, ArrayIterForwardSingleElem) { - Tuple t0(1); - Tuple t1(2); - Array> arr0 = {t0, t1}; - std::vector> vec0 = {t0}; - vec0.insert(vec0.end(), arr0.begin(), arr0.end()); - EXPECT_EQ(vec0.size(), 3); - EXPECT_EQ(vec0[0].get<0>()->value, 1); - EXPECT_EQ(vec0[1].get<0>()->value, 1); - EXPECT_EQ(vec0[2].get<0>()->value, 2); -} - -} // namespace diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc deleted file mode 100644 index 639e6ee671dd..000000000000 --- a/ffi/tests/cpp/test_variant.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include - -#include "./testing_object.h" - -namespace { - -using namespace tvm::ffi; -using namespace tvm::ffi::testing; - -TEST(Variant, Basic) { - Variant v1 = 1; - EXPECT_EQ(v1.get(), 1); - - Variant v2 = 2.0f; - EXPECT_EQ(v2.get(), 2.0f); - v2 = v1; - EXPECT_EQ(v2.get(), 1); -} - -TEST(Variant, AnyConvert) { - Variant v = 1; - AnyView view0 = v; - EXPECT_EQ(view0.as().value(), 1); - - // implicit convert to variant - Any any0 = 1; - auto v1 = any0.cast>>(); - EXPECT_EQ(v1.get()->value, 1); - - // move from any to variant - Variant v2 = TInt(1); - Any any1 = std::move(v2); - auto v3 = std::move(any1).cast>(); - auto v4 = std::move(v3).get(); - EXPECT_EQ(v4->value, 1); - EXPECT_EQ(v4.use_count(), 1); -} - -TEST(Variant, ObjectPtrHashEqual) { - TInt x = TInt(1); - TFloat y = TFloat(1.0f); - - Variant v0 = x; - Variant v1 = y; - Variant v2 = v1; - - EXPECT_EQ(ObjectPtrHash()(v0), ObjectPtrHash()(x)); - EXPECT_TRUE(!ObjectPtrEqual()(v0, v1)); - EXPECT_TRUE(!ObjectPtrEqual()(v0, v2)); -} - -TEST(Variant, FromTyped) { - // try decution - Function fadd1 = Function::FromTyped([](const Variant& a) -> int64_t { - if (auto opt_int = a.as()) { - return opt_int.value() + 1; - } else { - return a.get()->value + 1; - } - }); - int b = fadd1(1).cast(); - EXPECT_EQ(b, 2); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd1(1.1); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ( - error.message(), - "Mismatched type on argument #0 when calling: `(0: Variant) -> int`. " - "Expected `Variant` but got `float`"); - throw; - } - }, - ::tvm::ffi::Error); - - Function fadd2 = Function::FromTyped([](const Array>& a) -> int64_t { - if (auto opt_int = a[0].as()) { - return opt_int.value() + 1; - } else { - return a[0].get()->value + 1; - } - }); - int c = fadd2(Array({1, 2})).cast(); - EXPECT_EQ(c, 2); - - // convert that triggers error - EXPECT_THROW( - { - try { - fadd2(Array({1, 1.1})); - } catch (const Error& error) { - EXPECT_EQ(error.kind(), "TypeError"); - EXPECT_EQ(error.message(), - "Mismatched type on argument #0 when calling: `(0: Array>) -> int`. " - "Expected `Array>` but got `Array[index 1: float]`"); - throw; - } - }, - ::tvm::ffi::Error); -} - -TEST(Variant, Upcast) { - Array a0 = {1, 2, 3}; - static_assert(details::type_contains_v>, Array>); - Array> a1 = a0; - EXPECT_EQ(a1[0].get(), 1); -} - -TEST(Variant, AllObjectRef) { - Variant> v0 = TInt(1); - EXPECT_EQ(v0.get()->value, 1); - static_assert(std::is_base_of_v); - Any any0 = v0; - EXPECT_EQ(any0.cast()->value, 1); - auto v2 = any0.cast>>(); - EXPECT_TRUE(v0.same_as(v2)); - // assignment operator - v0 = Array({TInt(2), TInt(3)}); - EXPECT_EQ(v0.get>().size(), 2); - EXPECT_EQ(v0.get>()[0]->value, 2); - EXPECT_EQ(v0.get>()[1]->value, 3); - EXPECT_EQ(sizeof(v0), sizeof(ObjectRef)); -} - -TEST(Variant, PODSameAs) { - Variant v0 = 1; - Variant v1 = 1; - EXPECT_TRUE(v0.same_as(v1)); - String s = String("hello long str"); - v0 = s; - v1 = s; - EXPECT_TRUE(v0.same_as(v1)); - v1 = String("hello long str"); - EXPECT_TRUE(!v0.same_as(v1)); -} -} // namespace diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h deleted file mode 100644 index fe3ba1b013c0..000000000000 --- a/ffi/tests/cpp/testing_object.h +++ /dev/null @@ -1,304 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_FFI_TESTING_OBJECT_H_ -#define TVM_FFI_TESTING_OBJECT_H_ - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace testing { - -// We deliberately pad extra -// in the header to test cases -// where the object subclass address -// do not align with the base object address -// not handling properly will cause buffer overflow -class BasePad { - public: - int64_t extra[4]; -}; - -class TNumberObj : public BasePad, public Object { - public: - // declare as one slot, with float as overflow - static constexpr uint32_t _type_child_slots = 1; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "test.Number"; - TVM_FFI_DECLARE_BASE_OBJECT_INFO(TNumberObj, Object); -}; - -class TNumber : public ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS(TNumber, ObjectRef, TNumberObj); -}; - -class TIntObj : public TNumberObj { - public: - int64_t value; - - TIntObj() = default; - TIntObj(int64_t value) : value(value) {} - - int64_t GetValue() const { return value; } - - static constexpr const char* _type_key = "test.Int"; - - inline static void RegisterReflection(); - - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj); -}; - -class TInt : public TNumber { - public: - explicit TInt(int64_t value) { data_ = make_object(value); } - - static TInt StaticAdd(TInt lhs, TInt rhs) { return TInt(lhs->value + rhs->value); } - - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj); -}; - -inline void TIntObj::RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("value", &TIntObj::value) - .def_static("static_add", &TInt::StaticAdd, "static add method"); - // define extra type attributes - refl::TypeAttrDef() - .def("test.GetValue", &TIntObj::GetValue) - .attr("test.size", sizeof(TIntObj)); - // custom json serialization - refl::TypeAttrDef() - .def("__data_to_json__", - [](const TIntObj* self) -> Map { - return Map{{"value", self->value}}; - }) - .def("__data_from_json__", [](Map json_obj) -> TInt { - return TInt(json_obj["value"].cast()); - }); -} - -class TFloatObj : public TNumberObj { - public: - double value; - - TFloatObj(double value) : value(value) {} - - double Add(double other) const { return value + other; } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) - .def("sub", - [](const TFloatObj* self, double other) -> double { return self->value - other; }) - .def("add", &TFloatObj::Add, "add method"); - } - - static constexpr const char* _type_key = "test.Float"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj); -}; - -class TFloat : public TNumber { - public: - explicit TFloat(double value) { data_ = make_object(value); } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFloat, TNumber, TFloatObj); -}; - -class TPrimExprObj : public Object { - public: - std::string dtype; - double value; - - TPrimExprObj(std::string dtype, double value) : dtype(dtype), value(value) {} - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_rw("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) - .def_ro("value", &TPrimExprObj::value, "value field", refl::DefaultValue(0)) - .def("sub", [](TPrimExprObj* self, double other) -> double { - // this is ok because TPrimExprObj is declared asmutable - return self->value - other; - }); - } - - static constexpr const char* _type_key = "test.PrimExpr"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr bool _type_mutable = true; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TPrimExprObj, Object); -}; - -class TPrimExpr : public ObjectRef { - public: - explicit TPrimExpr(std::string dtype, double value) { - data_ = make_object(dtype, value); - } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS(TPrimExpr, ObjectRef, TPrimExprObj); -}; - -class TVarObj : public Object { - public: - std::string name; - - // need default constructor for json serialization - TVarObj() = default; - TVarObj(std::string name) : name(name) {} - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("name", &TVarObj::name, - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr const char* _type_key = "test.Var"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TVarObj, Object); -}; - -class TVar : public ObjectRef { - public: - explicit TVar(std::string name) { data_ = make_object(name); } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS(TVar, ObjectRef, TVarObj); -}; - -class TFuncObj : public Object { - public: - Array params; - Array body; - Optional comment; - - // need default constructor for json serialization - TFuncObj() = default; - TFuncObj(Array params, Array body, Optional comment) - : params(params), body(body), comment(comment) {} - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("params", &TFuncObj::params, refl::AttachFieldFlag::SEqHashDef()) - .def_ro("body", &TFuncObj::body) - .def_ro("comment", &TFuncObj::comment, refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr const char* _type_key = "test.Func"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFuncObj, Object); -}; - -class TFunc : public ObjectRef { - public: - explicit TFunc(Array params, Array body, Optional comment) { - data_ = make_object(params, body, comment); - } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFunc, ObjectRef, TFuncObj); -}; - -class TCustomFuncObj : public Object { - public: - Array params; - Array body; - String comment; - - TCustomFuncObj(Array params, Array body, String comment) - : params(params), body(body), comment(comment) {} - - bool SEqual(const TCustomFuncObj* other, - ffi::TypedFunction cmp) const { - if (!cmp(params, other->params, true, "params")) { - return false; - } - if (!cmp(body, other->body, false, "body")) { - return false; - } - return true; - } - - uint64_t SHash(uint64_t init_hash, - ffi::TypedFunction hash) const { - uint64_t hash_value = init_hash; - hash_value = hash(params, hash_value, true); - hash_value = hash(body, hash_value, false); - return hash_value; - } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("params", &TCustomFuncObj::params) - .def_ro("body", &TCustomFuncObj::body) - .def_ro("comment", &TCustomFuncObj::comment); - refl::TypeAttrDef() - .def("__s_equal__", &TCustomFuncObj::SEqual) - .def("__s_hash__", &TCustomFuncObj::SHash); - } - - static constexpr const char* _type_key = "test.CustomFunc"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TCustomFuncObj, Object); -}; - -class TCustomFunc : public ObjectRef { - public: - explicit TCustomFunc(Array params, Array body, String comment) { - data_ = make_object(params, body, comment); - } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS(TCustomFunc, ObjectRef, TCustomFuncObj); -}; - -} // namespace testing - -template <> -inline constexpr bool use_default_type_traits_v = true; - -template <> -struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(StrictBool value) { - return testing::TPrimExpr("bool", static_cast(value)); - } - - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(int64_t value) { - return testing::TPrimExpr("int64", static_cast(value)); - } - - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(double value) { - return testing::TPrimExpr("float32", static_cast(value)); - } - // hack into the dtype to store string - TVM_FFI_INLINE static testing::TPrimExpr ConvertFallbackValue(String value) { - return testing::TPrimExpr(value, 0); - } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_TESTING_OBJECT_H_ diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 52e9e7209e89..e303b3becd54 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -33,6 +33,7 @@ #include #include #include +#include "tvm/ffi/object.h" namespace tvm { /*! \brief namespace of arithmetic analysis. */ @@ -103,8 +104,7 @@ class ConstIntBoundNode : public Object { static const constexpr int64_t kNegInf = -kPosInf; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "arith.ConstIntBound"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ConstIntBound", ConstIntBoundNode, Object); }; /*! @@ -122,7 +122,7 @@ class ConstIntBound : public ObjectRef { static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; - TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ConstIntBound, ObjectRef, ConstIntBoundNode); }; /*! @@ -176,6 +176,8 @@ class ConstIntBoundAnalyzer { friend class ConstraintContext; explicit ConstIntBoundAnalyzer(Analyzer* parent); TVM_DLL ~ConstIntBoundAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const ConstIntBoundAnalyzer& other); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -216,8 +218,7 @@ class ModularSetNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "arith.ModularSet"; - TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.ModularSet", ModularSetNode, Object); }; /*! @@ -228,7 +229,7 @@ class ModularSet : public ObjectRef { public: TVM_DLL ModularSet(int64_t coeff, int64_t base); - TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ModularSet, ObjectRef, ModularSetNode); }; /*! @@ -256,6 +257,8 @@ class ModularSetAnalyzer { friend class ConstraintContext; explicit ModularSetAnalyzer(Analyzer* parent); TVM_DLL ~ModularSetAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const ModularSetAnalyzer& other); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -296,7 +299,7 @@ class RewriteSimplifier { * * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - TVM_DLL std::function EnterConstraint(const PrimExpr& constraint); + TVM_DLL std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false); /*! \brief Flags to enable more computationally-intensive simplifications * @@ -409,6 +412,8 @@ class RewriteSimplifier { friend class CanonicalSimplifier; explicit RewriteSimplifier(Analyzer* parent); TVM_DLL ~RewriteSimplifier(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const RewriteSimplifier& other); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -440,6 +445,8 @@ class CanonicalSimplifier { friend class ConstraintContext; explicit CanonicalSimplifier(Analyzer* parent); TVM_DLL ~CanonicalSimplifier(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const CanonicalSimplifier& other); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -525,6 +532,8 @@ class TransitiveComparisonAnalyzer { friend class ConstraintContext; TransitiveComparisonAnalyzer(); TVM_DLL ~TransitiveComparisonAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const TransitiveComparisonAnalyzer& other); class Impl; /*! \brief Internal impl */ std::unique_ptr impl_; @@ -555,8 +564,8 @@ class ConstraintContext { * \param analyzer The analyzer. * \param constraint The constraint to be applied. */ - ConstraintContext(Analyzer* analyzer, PrimExpr constraint) - : analyzer_(analyzer), constraint_(constraint) {} + ConstraintContext(Analyzer* analyzer, PrimExpr constraint, bool is_assume=false) + : analyzer_(analyzer), constraint_(constraint), is_assume_(is_assume) {} // enter the scope. void EnterWithScope(); // exit the scope. @@ -567,6 +576,7 @@ class ConstraintContext { PrimExpr constraint_; /*! \brief function to be called in recovery */ std::vector> recovery_functions_; + bool is_assume_; }; /*! @@ -582,7 +592,7 @@ class IntSetAnalyzer { * \param dom_map The domain map to indicate which variable to relax. * \return the result of the analysis. */ - TVM_DLL IntSet operator()(const PrimExpr& expr, const Map& dom_map); + TVM_DLL IntSet operator()(const PrimExpr& expr, const ffi::Map& dom_map); /*! * \brief Find a symbolic integer set that contains all possible @@ -618,11 +628,116 @@ class IntSetAnalyzer { friend class Analyzer; explicit IntSetAnalyzer(Analyzer* parent); TVM_DLL ~IntSetAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const IntSetAnalyzer& other); class Impl; /*! \brief Internal impl */ Impl* impl_; }; +class Z3Prover { + public: + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_range The range of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! + * \brief Whether can we prove expr is always true. + * + * \param expr The expression. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + TVM_DLL bool CanProve(const PrimExpr & expr); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false); + + /*! + * \brief Get the SMTLIB2 representation of the current context + * \param expr The optional expression to check + * \return The SMTLIB2 string + */ + ffi::String GetSMTLIB2(const ffi::Optional expr); + + /*! + * \brief Get statistics about Z3 prover + * \return The statistics string + */ + ffi::String GetStats(); + + /*! + * \brief Set timeout in milliseconds for Z3 prover + * \param timeout_ms The timeout in milliseconds + */ + void SetTimeoutMs(unsigned timeout_ms); + + /*! + * \brief Set resource limitation for Z3 prover + * \param rlimit the resource limitation (like maxinum step or sth.) + */ + void SetRLimit(unsigned rlimit); + + /*! + * \brief Get the Z3 model for the given expression if satisfiable + * \param expr The expression to get the model for + * \return The model as a string + */ + ffi::String GetModel(const PrimExpr & expr); + + /*! + * \brief Count the number of integer values that satisfy the current constraints. + * + * This method uses Z3's model enumeration (AllSAT) to count how many distinct + * values of the given variable satisfy all current constraints. This is useful + * for determining the exact number of threads that will reach a synchronization + * point when the condition involves non-range constraints like modulo operations. + * + * For example, if the constraint is `threadIdx.x % 4 == 0` with `threadIdx.x in [0, 128)`, + * this method will return 32 (the values 0, 4, 8, ..., 124). + * + * \param var The variable to count satisfying values for. + * \param max_count Maximum number of solutions to enumerate (for safety). + * If more solutions exist, returns max_count. + * \param min_consecutive Minimum consecutive count requirement (default 1). + * Values must form groups of at least this many + * consecutive integers. E.g., with min_consecutive=4: + * {0,1,2,3,16,17,18,19} is valid, {0,1,4,5} is invalid. + * \return The number of distinct values that satisfy the constraints, + * -1 if the problem is unsatisfiable or an error occurred, + * -2 if the min_consecutive constraint is not satisfied. + */ + TVM_DLL int64_t CountSatisfyingValues(const Var& var, int64_t max_count = 2048, int64_t min_consecutive = 1); + + private: + friend class Analyzer; + explicit Z3Prover(Analyzer* parent); + TVM_DLL ~Z3Prover(); + void CopyFrom(const Z3Prover & other); + class Impl; + Impl* impl_; +}; + /*! * \brief Analyzer that contains bunch of sub-analyzers. * @@ -652,8 +767,15 @@ class TVM_DLL Analyzer { IntSetAnalyzer int_set; /*! \brief sub-analyzer transitive comparisons */ TransitiveComparisonAnalyzer transitive_comparisons; + /*! \brief analyzer using z3 */ + Z3Prover z3_prover; /*! \brief constructor */ Analyzer(); + /*! + * \brief Create a deep copy of this Analyzer, including all sub-analyzer states. + * \return A new Analyzer with copied internal state. + */ + std::unique_ptr Clone() const; /*! * \brief Mark the value as non-negative value globally in analyzer. * @@ -704,7 +826,7 @@ class TVM_DLL Analyzer { * expression. This option should not be used if there is any dependency * between variables. */ - void Bind(const Map& variables, bool allow_override = false); + void Bind(const ffi::Map& variables, bool allow_override = false); /*! * \brief Whether can we prove expr >= val. @@ -786,6 +908,8 @@ class TVM_DLL Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ PrimExpr Simplify(const PrimExpr& expr, int steps = 2); + + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false); }; } // namespace arith diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index cf84b9a3a641..6cde90b0b8e5 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -53,8 +53,8 @@ using tir::VarNode; * The deduce bound must implies e for all value in relax_map * \return An integer set that always satisfies the condition. */ -IntSet DeduceBound(PrimExpr v, PrimExpr cond, const Map& hint_map, - const Map& relax_map); +IntSet DeduceBound(PrimExpr v, PrimExpr cond, const ffi::Map& hint_map, + const ffi::Map& relax_map); /*! * \brief Same as DeduceBound with unordered_map signature. * diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 702edba1a462..d1e8f9475750 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -56,9 +56,7 @@ enum SignType { kPositive, kNegative, kZero, kUnknown }; */ class IntSetNode : public Object { public: - static constexpr const char* _type_key = "ir.IntSet"; - - TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.IntSet", IntSetNode, Object); }; /*! @@ -163,19 +161,19 @@ class IntSet : public ObjectRef { */ static IntSet Interval(PrimExpr min, PrimExpr max); - TVM_DEFINE_OBJECT_REF_METHODS(IntSet, ObjectRef, IntSetNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntSet, ObjectRef, IntSetNode); }; //----------------------------------------------- // Integer set legacy API. //------------------------------------------------ /*! - * \brief Convert std::unordered_map to Map + * \brief Convert std::unordered_map to ffi::Map * * \param dom_map The domain map to convert. * \return The converted map. */ -Map ConvertDomMap(const std::unordered_map& dom_map); +ffi::Map ConvertDomMap(const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -184,7 +182,7 @@ Map ConvertDomMap(const std::unordered_map& * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, const Map& dom_map); +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each variables. @@ -193,7 +191,7 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map); * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, const Map& dom_map); +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -210,7 +208,7 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(Range r, const Map& dom_map); +IntSet EvalSet(Range r, const ffi::Map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over @@ -230,13 +228,13 @@ IntSet EvalSet(IntSet s, const std::unordered_map& dom_m */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! - * \brief Same as EvalSet, but takes Array + * \brief Same as EvalSet, but takes ffi::Array * * \param region The range to be evaluated. * \param dom_map The domain of each variable. * \return An array of integer sets that can cover all the possible values. */ -Array EvalSet(const Array& region, const Map& dom_map); +ffi::Array EvalSet(const ffi::Array& region, const ffi::Map& dom_map); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; /*! @@ -255,42 +253,42 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, * \param sets The sets to be combined * \return the set after union */ -IntSet Union(const Array& sets); +IntSet Union(const ffi::Array& sets); /*! * \brief The union of N-dimensional integer sets * \param nd_int_sets A list of N-dimensional integer sets * \return An N-dimensional integer set as the result of union */ -Array UnionRegion(const Array>& nd_int_sets); +ffi::Array UnionRegion(const ffi::Array>& nd_int_sets); /*! * \brief Create a lower-bound of union set, where some of the segments may be dropped * \param sets The sets to be combined * \return the set after union */ -IntSet UnionLowerBound(const Array& sets); +IntSet UnionLowerBound(const ffi::Array& sets); /*! * \brief The union of N-dimensional integer sets * \param nd_int_sets A list of N-dimensional integer sets * \return An N-dimensional integer set as the result of union */ -Array UnionRegionLowerBound(const Array>& nd_int_sets); +ffi::Array UnionRegionLowerBound(const ffi::Array>& nd_int_sets); /*! * \brief Create an intersected set of all sets * \param sets The sets to be intersected * \return the set after intersected */ -IntSet Intersect(const Array& sets); +IntSet Intersect(const ffi::Array& sets); /*! * \brief Converts the Ranges to IntSets * \param var_dom The ranges of variables * \return The integer sets of the variables */ -Map AsIntSet(const Map& var_dom); +ffi::Map AsIntSet(const ffi::Map& var_dom); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate. @@ -302,10 +300,9 @@ Map AsIntSet(const Map& var_dom); * \return std::nullopt if the detection fails, or an array of arith::IntSet as the result of * analysis */ -TVM_DLL Optional> EstimateRegionStrictBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer); +TVM_DLL ffi::Optional> EstimateRegionStrictBound( + const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, + arith::Analyzer* analyzer); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate. @@ -317,10 +314,9 @@ TVM_DLL Optional> EstimateRegionStrictBound(const Array& re * \return std::nullopt if the detection fails, or an array of arith::IntSet as the result of * analysis */ -TVM_DLL Optional> EstimateRegionLowerBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer); +TVM_DLL ffi::Optional> EstimateRegionLowerBound( + const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, + arith::Analyzer* analyzer); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate @@ -332,10 +328,10 @@ TVM_DLL Optional> EstimateRegionLowerBound(const Array& reg * \param analyzer The analyzer used * \return an array of arith::IntSet as the result of analysis */ -TVM_DLL Array EstimateRegionUpperBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer); +TVM_DLL ffi::Array EstimateRegionUpperBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer); } // namespace arith } // namespace tvm diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 6dfc2f0ecb88..b8f0ac6d4327 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -58,9 +58,9 @@ constexpr int kSimplifyRewriteCanonicalRewrite = 3; class IntGroupBoundsNode : public Object { public: PrimExpr coef; - Array lower; - Array equal; - Array upper; + ffi::Array lower; + ffi::Array equal; + ffi::Array upper; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -72,9 +72,7 @@ class IntGroupBoundsNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const char* _type_key = "arith.IntGroupBounds"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntGroupBounds", IntGroupBoundsNode, Object); }; /*! @@ -93,8 +91,8 @@ class IntGroupBounds : public ObjectRef { * \param equal equalities * \param upper the upper bounds (include) */ - TVM_DLL IntGroupBounds(PrimExpr coef, Array lower, Array equal, - Array upper); + TVM_DLL IntGroupBounds(PrimExpr coef, ffi::Array lower, ffi::Array equal, + ffi::Array upper); /*! * \brief Construct bounds from a range. @@ -106,7 +104,7 @@ class IntGroupBounds : public ObjectRef { /*! * \brief Perform substitution on all components of the struct. */ - IntGroupBounds Substitute(const Map& subst) const; + IntGroupBounds Substitute(const ffi::Map& subst) const; /*! * \brief Find the best range from the grouped bounds. @@ -114,7 +112,7 @@ class IntGroupBounds : public ObjectRef { * \return The best range (has the least difference between the lower bound and upper bound). * undefined if (-inf, +inf). */ - Range FindBestRange(const Map& vranges_addl = {}) const; + Range FindBestRange(const ffi::Map& vranges_addl = {}) const; /*! * \brief Combine the bounds with another range. @@ -123,7 +121,7 @@ class IntGroupBounds : public ObjectRef { */ IntGroupBounds operator+(const Range& r); - TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntGroupBounds, ObjectRef, IntGroupBoundsNode); }; /*! @@ -134,14 +132,14 @@ class IntGroupBounds : public ObjectRef { class IntConstraintsNode : public Object { public: // e.g., \alpha, \beta, must be integers - Array variables; + ffi::Array variables; // e.g., 1 <= \alpha <= N, etc. // it is absolutely ok to include ranges for parameters // (variables that are not in this->variables) in this map - Map ranges; + ffi::Map ranges; // linear equalities or inequalities // e.g., A \alpha = \beta or A \alpha <= \beta - Array relations; + ffi::Array relations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -152,9 +150,7 @@ class IntConstraintsNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const char* _type_key = "arith.IntConstraints"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntConstraints", IntConstraintsNode, Object); }; /*! @@ -170,9 +166,10 @@ class IntConstraints : public ObjectRef { * \param relations The linear relations between the variables * (either equations or inequalities) */ - TVM_DLL IntConstraints(Array variables, Map ranges, Array relations); + TVM_DLL IntConstraints(ffi::Array variables, ffi::Map ranges, + ffi::Array relations); - TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntConstraints, ObjectRef, IntConstraintsNode); }; /*! @@ -193,8 +190,8 @@ class IntConstraintsTransformNode : public Object { public: IntConstraints src; IntConstraints dst; - Map src_to_dst; - Map dst_to_src; + ffi::Map src_to_dst; + ffi::Map dst_to_src; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -206,9 +203,8 @@ class IntConstraintsTransformNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const char* _type_key = "arith.IntConstraintsTransform"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntConstraintsTransform", IntConstraintsTransformNode, + Object); }; /*! @@ -228,7 +224,8 @@ class IntConstraintsTransform : public ObjectRef { * e.g., {m -> a, n -> -b} */ TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst, - Map src_to_dst, Map dst_to_src); + ffi::Map src_to_dst, + ffi::Map dst_to_src); /*! * \brief Chain-compose two IntConstraintsTransform together. @@ -239,10 +236,11 @@ class IntConstraintsTransform : public ObjectRef { */ IntConstraintsTransform operator+(const IntConstraintsTransform& other) const; - TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntConstraintsTransform, ObjectRef, + IntConstraintsTransformNode); }; -typedef std::pair, Array> PartialSolvedInequalities; +typedef std::pair, ffi::Array> PartialSolvedInequalities; /*! * \brief Obtain Smith Normal Form of linear equation A x = y. @@ -301,8 +299,9 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t * \param bounds grouped boundary of the variables. * \param relations other relations. */ -Array AsConditions(const Array& variables, const Map& bounds, - const Array& relations); +ffi::Array AsConditions(const ffi::Array& variables, + const ffi::Map& bounds, + const ffi::Array& relations); /*! * \brief Solve linear inequalities and infer the range of each variable. diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 25f8e14a7f7b..223fb3509571 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -66,9 +66,8 @@ namespace arith { */ class IterMapExprNode : public PrimExprNode { public: - static constexpr const char* _type_key = "arith.IterMapExpr"; static constexpr const uint32_t _type_child_slots = 2; - TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("arith.IterMapExpr", IterMapExprNode, PrimExprNode); }; /*! @@ -77,7 +76,7 @@ class IterMapExprNode : public PrimExprNode { */ class IterMapExpr : public PrimExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterMapExpr, PrimExpr, IterMapExprNode); }; /*! @@ -106,9 +105,7 @@ class IterMarkNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; - - static constexpr const char* _type_key = "arith.IterMark"; - TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IterMark", IterMarkNode, Object); }; /*! @@ -124,7 +121,7 @@ class IterMark : public ObjectRef { */ TVM_DLL IterMark(PrimExpr source, PrimExpr extent); - TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterMark, ObjectRef, IterMarkNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode); }; @@ -154,8 +151,7 @@ class IterSplitExprNode : public IterMapExprNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "arith.IterSplitExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IterSplitExpr", IterSplitExprNode, IterMapExprNode); }; /*! @@ -185,7 +181,7 @@ class IterSplitExpr : public IterMapExpr { TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale); - TVM_DEFINE_OBJECT_REF_METHODS(IterSplitExpr, IterMapExpr, IterSplitExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterSplitExpr, IterMapExpr, IterSplitExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSplitExprNode); }; @@ -197,7 +193,7 @@ class IterSplitExpr : public IterMapExpr { class IterSumExprNode : public IterMapExprNode { public: /*! \brief The args to the sum. */ - Array args; + ffi::Array args; /*! \brief The base offset. */ PrimExpr base; @@ -209,8 +205,7 @@ class IterSumExprNode : public IterMapExprNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "arith.IterSumExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IterSumExpr", IterSumExprNode, IterMapExprNode); }; /*! @@ -224,9 +219,9 @@ class IterSumExpr : public IterMapExpr { * \param args The args to the sum. * \param base The base offset. */ - TVM_DLL IterSumExpr(Array args, PrimExpr base); + TVM_DLL IterSumExpr(ffi::Array args, PrimExpr base); - TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterSumExpr, IterMapExpr, IterSumExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); }; @@ -246,11 +241,11 @@ enum IterMapLevel { class IterMapResultNode : public Object { public: // The detected pattern if a match exists. - Array indices; + ffi::Array indices; // Any errors that occurred while converting the input indices. If // the array is empty, the conversion was successful. - Array errors; + ffi::Array errors; /*! \brief Boolean expression indicating if a specific value w * @@ -269,9 +264,7 @@ class IterMapResultNode : public Object { .def_ro("errors", &IterMapResultNode::errors) .def_ro("padding_predicate", &IterMapResultNode::padding_predicate); } - - static constexpr const char* _type_key = "arith.IterMapResult"; - TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IterMapResult", IterMapResultNode, Object); }; /*! @@ -281,7 +274,7 @@ class IterMapResultNode : public Object { class IterMapResult : public ObjectRef { public: // constructor - IterMapResult() { data_ = make_object(); } + IterMapResult() { data_ = ffi::make_object(); } /*! \return mutable pointers to the node. */ IterMapResultNode* operator->() const { return static_cast(get_mutable()); } @@ -310,9 +303,10 @@ class IterMapResult : public ObjectRef { * \return The detected iteration result. * The return object's .indices is empty on failure. */ -IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, IterMapLevel check_level, - arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); +IterMapResult DetectIterMap(const ffi::Array& indices, + const ffi::Map& input_iters, const PrimExpr& predicate, + IterMapLevel check_level, arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices @@ -325,9 +319,11 @@ IterMapResult DetectIterMap(const Array& indices, const Map IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, IterMapLevel check_level, - arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); +ffi::Array IterMapSimplify(const ffi::Array& indices, + const ffi::Map& input_iters, + const PrimExpr& input_pred, IterMapLevel check_level, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); /*! * \brief Apply the inverse of the affine transformation to the outputs. @@ -349,8 +345,8 @@ Array IterMapSimplify(const Array& indices, const Map InverseAffineIterMap(const Array& iter_map, - const Array outputs); +ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, + const ffi::Array outputs); /*! * \brief Detect if bindings can be written as @@ -379,11 +375,12 @@ Map InverseAffineIterMap(const Array& iter_map, len(bindings): the predicate of outer space and inner space Empty array if no match can be found. */ -Array> SubspaceDivide(const Array& bindings, - const Map& input_iters, - const Array& sub_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer, - bool simplify_trivial_iterators = true); +ffi::Array> SubspaceDivide(const ffi::Array& bindings, + const ffi::Map& input_iters, + const ffi::Array& sub_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); /*! * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. @@ -408,7 +405,7 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr); * \param analyzer The input analyzer. * \note This function is useful to detect iterator stride patterns. */ -IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iters, +IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input_iters, arith::Analyzer* analyzer); } // namespace arith diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index 5e1165d509c4..254c1d0933ec 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -37,7 +37,7 @@ namespace arith { * \param vars List of variables to be used in detection. * \return [coeff[i]] if it is possible, empty array if it is not. */ -Array DetectLinearEquation(const PrimExpr& e, const Array& vars); +ffi::Array DetectLinearEquation(const PrimExpr& e, const ffi::Array& vars); /*! * \brief Detect if expression corresponds to clip bound of the vars @@ -47,7 +47,7 @@ Array DetectLinearEquation(const PrimExpr& e, const Array& v * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value * return empty if the e does not match the pattern. */ -Array DetectClipBound(const PrimExpr& e, const Array& vars); +ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars); } // namespace arith } // namespace tvm diff --git a/include/tvm/ir/analysis.h b/include/tvm/ir/analysis.h index ad95f2f0ebb5..5879f34633a2 100644 --- a/include/tvm/ir/analysis.h +++ b/include/tvm/ir/analysis.h @@ -55,7 +55,7 @@ class CalleeCollector { virtual void Mark(GlobalVar gvar) = 0; }; -Map> CollectCallMap(const IRModule& mod); +ffi::Map> CollectCallMap(const IRModule& mod); } // namespace ir } // namespace tvm diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 2553116634a2..e68261602a47 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -54,7 +54,7 @@ namespace tvm { template inline TObjectRef NullValue() { static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); - return TObjectRef(ObjectPtr(nullptr)); + return TObjectRef(ObjectPtr(nullptr)); } template <> @@ -68,11 +68,11 @@ inline DataType NullValue() { class AttrFieldInfoNode : public Object { public: /*! \brief name of the field */ - String name; + ffi::String name; /*! \brief type docstring information in str. */ - String type_info; + ffi::String type_info; /*! \brief detailed description of the type */ - String description; + ffi::String description; static void RegisterReflection() { namespace rfl = ffi::reflection; @@ -82,16 +82,15 @@ class AttrFieldInfoNode : public Object { .def_ro("description", &AttrFieldInfoNode::description); } - static constexpr const char* _type_key = "ir.AttrFieldInfo"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.AttrFieldInfo", AttrFieldInfoNode, Object); }; /*! \brief AttrFieldInfo */ class AttrFieldInfo : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); }; /*! @@ -122,9 +121,7 @@ class BaseAttrsNode : public Object { bool allow_unknown = false) = 0; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const char* _type_key = "ir.Attrs"; - TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, Object); }; /*! @@ -133,7 +130,7 @@ class BaseAttrsNode : public Object { */ class Attrs : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ObjectRef, BaseAttrsNode); }; /*! @@ -145,7 +142,7 @@ class Attrs : public ObjectRef { class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ - Map dict; + ffi::Map dict; static void RegisterReflection() { namespace rfl = ffi::reflection; @@ -155,8 +152,7 @@ class DictAttrsNode : public BaseAttrsNode { void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final; // type info - static constexpr const char* _type_key = "ir.DictAttrs"; - TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, BaseAttrsNode); }; /*! @@ -165,11 +161,15 @@ class DictAttrsNode : public BaseAttrsNode { */ class DictAttrs : public Attrs { public: + /*! + * \brief constructor with UnsafeInit + */ + explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {} /*! * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. */ - TVM_DLL explicit DictAttrs(Map dict = {}); + TVM_DLL explicit DictAttrs(ffi::Map dict = {}); // Utils for accessing attributes // This needs to be on DictAttrs, not DictAttrsNode because we return the default @@ -194,9 +194,9 @@ class DictAttrs : public Attrs { * \endcode */ template - Optional GetAttr( + ffi::Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { if (!defined()) return default_value; const DictAttrsNode* node = this->as(); auto it = node->dict.find(attr_key); @@ -208,8 +208,8 @@ class DictAttrs : public Attrs { } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! * \brief Check whether the function has an non-zero integer attr. @@ -234,7 +234,8 @@ class DictAttrs : public Attrs { return GetAttr(attr_key, 0).value_or(0).IntValue() != 0; } - TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, + DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -248,7 +249,7 @@ class DictAttrs : public Attrs { * * \returns The new DictAttrs with updated attributes. */ -DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); +DictAttrs WithAttrs(DictAttrs attrs, ffi::Map new_attrs); /*! * \brief Copy the DictAttrs, but overrides a single attribute. @@ -261,10 +262,10 @@ DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); * * \returns The new DictAttrs with updated attributes. */ -DictAttrs WithAttr(DictAttrs attrs, String key, Any value); +DictAttrs WithAttr(DictAttrs attrs, ffi::String key, Any value); inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, Any value) { - return WithAttr(std::move(attrs), String(key), std::move(value)); + return WithAttr(std::move(attrs), ffi::String(key), std::move(value)); } /*! @@ -325,7 +326,7 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value) * \returns The new function or module with updated attributes. */ template -inline TFunc WithAttrs(TFunc input, Map attrs) { +inline TFunc WithAttrs(TFunc input, ffi::Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); @@ -410,7 +411,7 @@ inline TAttrs AttrsWithDefaultValues() { finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); return rv.cast(); } else { - auto n = make_object(); + auto n = ffi::make_object(); n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false); return TAttrs(n); } diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 9f4f5770aa60..24553de6c408 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -64,7 +64,7 @@ class DiagnosticNode : public Object { */ ObjectRef loc; /*! \brief The diagnostic message. */ - String message; + ffi::String message; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -75,8 +75,7 @@ class DiagnosticNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "Diagnostic"; - TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("Diagnostic", DiagnosticNode, Object); }; class Diagnostic : public ObjectRef { @@ -101,7 +100,7 @@ class Diagnostic : public ObjectRef { static DiagnosticBuilder Note(const Object* loc); static DiagnosticBuilder Help(const Object* loc); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Diagnostic, ObjectRef, DiagnosticNode); }; /*! @@ -167,9 +166,7 @@ class DiagnosticRendererNode : public Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("renderer", &DiagnosticRendererNode::renderer); } - - static constexpr const char* _type_key = "DiagnosticRenderer"; - TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("DiagnosticRenderer", DiagnosticRendererNode, Object); }; class DiagnosticRenderer : public ObjectRef { @@ -185,7 +182,8 @@ class DiagnosticRenderer : public ObjectRef { return static_cast(get_mutable()); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DiagnosticRenderer, ObjectRef, + DiagnosticRendererNode); }; class DiagnosticContextNode : public Object { @@ -194,7 +192,7 @@ class DiagnosticContextNode : public Object { IRModule module; /*! \brief The set of diagnostics to report. */ - Array diagnostics; + ffi::Array diagnostics; /*! \brief The renderer set for the context. */ DiagnosticRenderer renderer; @@ -207,8 +205,7 @@ class DiagnosticContextNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "DiagnosticContext"; - TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("DiagnosticContext", DiagnosticContextNode, Object); }; class DiagnosticContext : public ObjectRef { @@ -238,7 +235,8 @@ class DiagnosticContext : public ObjectRef { return static_cast(get_mutable()); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticContext, ObjectRef, DiagnosticContextNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DiagnosticContext, ObjectRef, + DiagnosticContextNode); }; DiagnosticRenderer TerminalRenderer(std::ostream& ostream); diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 5afe464109cc..c0735b7cd69f 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -43,7 +43,7 @@ namespace tvm { class EnvFuncNode : public Object { public: /*! \brief Unique name of the global function */ - String name; + ffi::String name; /*! \brief The internal packed function */ ffi::Function func; /*! \brief constructor */ @@ -58,9 +58,7 @@ class EnvFuncNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "ir.EnvFunc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.EnvFunc", EnvFuncNode, Object); }; /*! @@ -71,6 +69,10 @@ class EnvFunc : public ObjectRef { public: EnvFunc() {} explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit EnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { return static_cast(get()); } /*! @@ -90,7 +92,7 @@ class EnvFunc : public ObjectRef { * \return The created global function. * \note The function can be unique */ - TVM_DLL static EnvFunc Get(const String& name); + TVM_DLL static EnvFunc Get(const ffi::String& name); /*! \brief specify container node */ using ContainerType = EnvFuncNode; }; @@ -117,6 +119,10 @@ class TypedEnvFunc : public ObjectRef { using TSelf = TypedEnvFunc; TypedEnvFunc() {} explicit TypedEnvFunc(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit TypedEnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 9b7645b56a46..09c0363986cf 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -39,8 +39,6 @@ namespace tvm { -using tvm::String; - // Forward-declare VirtualDevice to avoid circular imports. class VirtualDevice; @@ -63,12 +61,10 @@ class BaseExprNode : public Object { refl::AttachFieldFlag::SEqHashIgnore()); } - static constexpr const char* _type_key = "ir.BaseExpr"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const uint32_t _type_child_slots = 64; - TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.BaseExpr", BaseExprNode, Object); }; /*! @@ -77,7 +73,7 @@ class BaseExprNode : public Object { */ class BaseExpr : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseExpr, ObjectRef, BaseExprNode); }; /*! @@ -117,9 +113,8 @@ class PrimExprNode : public BaseExprNode { TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr const char* _type_key = "ir.PrimExpr"; static constexpr const uint32_t _type_child_slots = 40; - TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExpr", PrimExprNode, BaseExprNode); }; /*! @@ -142,13 +137,13 @@ class PrimExpr : public BaseExpr { /*! \return the data type of this expression. */ DataType dtype() const { return static_cast(get())->dtype; } - TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExpr, BaseExpr, PrimExprNode); /*! * \brief construct from string to form a StringImm. * \param value The value to be constructed. */ - TVM_DLL static PrimExpr ConvertFallbackValue(String value); // NOLINT(*) + TVM_DLL static PrimExpr ConvertFallbackValue(ffi::String value); // NOLINT(*) }; /*! @@ -160,9 +155,7 @@ class PrimExprConvertibleNode : public Object { public: virtual ~PrimExprConvertibleNode() {} virtual PrimExpr ToPrimExpr() const = 0; - - static constexpr const char* _type_key = "ir.PrimExprConvertible"; - TVM_DECLARE_BASE_OBJECT_INFO(PrimExprConvertibleNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExprConvertible", PrimExprConvertibleNode, Object); }; /*! @@ -171,23 +164,24 @@ class PrimExprConvertibleNode : public Object { */ class PrimExprConvertible : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(PrimExprConvertible, ObjectRef, PrimExprConvertibleNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExprConvertible, ObjectRef, + PrimExprConvertibleNode); }; namespace ffi { -// define automatic conversion from bool, int64_t, double, String to PrimExpr +// define automatic conversion from bool, int64_t, double, ffi::String to PrimExpr // These functions are declared early to avoid circular dependency template <> inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(StrictBool value); TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(int64_t value); TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(double value); - TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(String value) { + TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(ffi::String value) { return PrimExpr::ConvertFallbackValue(value); } TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(PrimExprConvertible value) { @@ -426,7 +420,7 @@ class RelaxExprNode : public BaseExprNode { * expression that encapsulate both static shape and * runtime information such as shape. */ - mutable Optional struct_info_ = Optional(); + mutable ffi::Optional struct_info_ = ffi::Optional(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -434,9 +428,8 @@ class RelaxExprNode : public BaseExprNode { refl::AttachFieldFlag::SEqHashIgnore()); } - static constexpr const char* _type_key = "ir.RelaxExpr"; static constexpr const uint32_t _type_child_slots = 22; - TVM_DECLARE_BASE_OBJECT_INFO(RelaxExprNode, BaseExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("ir.RelaxExpr", RelaxExprNode, BaseExprNode); }; /*! @@ -445,7 +438,7 @@ class RelaxExprNode : public BaseExprNode { */ class RelaxExpr : public BaseExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(RelaxExpr, BaseExpr, RelaxExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RelaxExpr, BaseExpr, RelaxExprNode); }; class GlobalVar; @@ -460,7 +453,7 @@ class GlobalVar; class GlobalVarNode : public RelaxExprNode { public: /*! \brief The name of the variable, this only acts as a hint. */ - String name_hint; + ffi::String name_hint; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -478,8 +471,7 @@ class GlobalVarNode : public RelaxExprNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; - static constexpr const char* _type_key = "ir.GlobalVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.GlobalVar", GlobalVarNode, RelaxExprNode); }; /*! @@ -488,9 +480,9 @@ class GlobalVarNode : public RelaxExprNode { */ class GlobalVar : public RelaxExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Span span = {}); + TVM_DLL explicit GlobalVar(ffi::String name_hint, Span span = {}); - TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelaxExpr, GlobalVarNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GlobalVar, RelaxExpr, GlobalVarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; @@ -507,9 +499,7 @@ class IntImmNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &IntImmNode::value); } - - static constexpr const char* _type_key = "ir.IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IntImm", IntImmNode, PrimExprNode); }; /*! @@ -527,7 +517,7 @@ class IntImm : public PrimExpr { */ TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; @@ -544,9 +534,7 @@ class FloatImmNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &FloatImmNode::value); } - - static constexpr const char* _type_key = "ir.FloatImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FloatImm", FloatImmNode, PrimExprNode); }; /*! @@ -564,7 +552,7 @@ class FloatImm : public PrimExpr { */ TVM_DLL FloatImm(DataType dtype, double value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloatImm, PrimExpr, FloatImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; @@ -580,7 +568,7 @@ class Bool : public IntImm { Bool operator!() const { return Bool((*this)->value == 0); } operator bool() const { return (*this)->value != 0; } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Bool, IntImm, IntImmNode); }; // Overload operators to make sure we have the most fine grained types. @@ -615,7 +603,11 @@ class Integer : public IntImm { /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : IntImm(node) {} + explicit Integer(ObjectPtr node) : IntImm(node) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit Integer(ffi::UnsafeInit tag) : IntImm(tag) {} /*! * \brief Construct integer from int value. */ @@ -688,10 +680,9 @@ class RangeNode : public Object { .def_ro("span", &RangeNode::span, refl::AttachFieldFlag::SEqHashIgnore()); } - static constexpr const char* _type_key = "ir.Range"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.Range", RangeNode, Object); }; /*! \brief Range container */ @@ -716,7 +707,7 @@ class Range : public ObjectRef { */ static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span = Span()); // declare range. - TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Range, ObjectRef, RangeNode); }; namespace ffi { @@ -742,7 +733,9 @@ inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static Integer ConvertFallbackValue(int64_t value) { return Integer(value); } + TVM_FFI_INLINE static Integer ConvertFallbackValue(int64_t value) { + return Integer(TypeTraits::ConvertFallbackValue(value)); + } }; template <> diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 53f19ed3f17c..c440e6fc9e17 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -161,14 +161,14 @@ class BaseFuncNode : public RelaxExprNode { * \endcode */ template - Optional GetAttr(const std::string& attr_key, - Optional default_value = std::nullopt) const { + ffi::Optional GetAttr(const std::string& attr_key, + ffi::Optional default_value = std::nullopt) const { return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! @@ -211,7 +211,7 @@ class BaseFuncNode : public RelaxExprNode { */ LinkageType GetLinkageType() const { - if (GetAttr(attr::kGlobalSymbol)) + if (GetAttr(attr::kGlobalSymbol)) return LinkageType::kExternal; else return LinkageType::kInternal; @@ -222,9 +222,8 @@ class BaseFuncNode : public RelaxExprNode { refl::ObjectDef().def_ro("attrs", &BaseFuncNode::attrs); } - static constexpr const char* _type_key = "ir.BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; - TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelaxExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("ir.BaseFunc", BaseFuncNode, RelaxExprNode); }; /*! @@ -233,7 +232,7 @@ class BaseFuncNode : public RelaxExprNode { */ class BaseFunc : public RelaxExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelaxExpr, BaseFuncNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseFunc, RelaxExpr, BaseFuncNode); }; } // namespace tvm diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index e6ff10ad1bc4..892bba4da694 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -34,7 +34,7 @@ namespace tvm { /*! * \brief Abstract label for an area of memory. */ -using MemoryScope = String; +using MemoryScope = ffi::String; /*! * \brief GlobalInfo are globally static object that are referred by the IR itself. @@ -42,11 +42,9 @@ using MemoryScope = String; */ class GlobalInfoNode : public Object { public: - static constexpr const char* _type_key = "ir.GlobalInfo"; - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.GlobalInfo", GlobalInfoNode, Object); }; /*! @@ -55,7 +53,7 @@ class GlobalInfoNode : public Object { */ class GlobalInfo : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(GlobalInfo, ObjectRef, GlobalInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GlobalInfo, ObjectRef, GlobalInfoNode); }; /*! @@ -79,8 +77,7 @@ class VDeviceNode : public GlobalInfoNode { .def_ro("memory_scope", &VDeviceNode::memory_scope); } - static constexpr const char* _type_key = "ir.VDevice"; - TVM_DECLARE_FINAL_OBJECT_INFO(VDeviceNode, GlobalInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.VDevice", VDeviceNode, GlobalInfoNode); }; /*! @@ -90,7 +87,7 @@ class VDeviceNode : public GlobalInfoNode { class VDevice : public GlobalInfo { public: TVM_DLL explicit VDevice(Target tgt, int dev_id, MemoryScope mem_scope); - TVM_DEFINE_OBJECT_REF_METHODS(VDevice, GlobalInfo, VDeviceNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VDevice, GlobalInfo, VDeviceNode); }; /*! @@ -103,8 +100,7 @@ class DummyGlobalInfoNode : public GlobalInfoNode { refl::ObjectDef(); } - static constexpr const char* _type_key = "ir.DummyGlobalInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(DummyGlobalInfoNode, GlobalInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DummyGlobalInfo", DummyGlobalInfoNode, GlobalInfoNode); }; /*! @@ -113,7 +109,7 @@ class DummyGlobalInfoNode : public GlobalInfoNode { */ class DummyGlobalInfo : public GlobalInfo { public: - TVM_DEFINE_OBJECT_REF_METHODS(DummyGlobalInfo, GlobalInfo, DummyGlobalInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DummyGlobalInfo, GlobalInfo, DummyGlobalInfoNode); }; } // namespace tvm diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h index 8ed8e5ed4c13..076b8d927ece 100644 --- a/include/tvm/ir/global_var_supply.h +++ b/include/tvm/ir/global_var_supply.h @@ -58,7 +58,7 @@ class GlobalVarSupplyNode : public Object { * \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended * to the name. \return A unique GlobalVar. */ - GlobalVar FreshGlobal(String name, bool add_prefix = true); + GlobalVar FreshGlobal(ffi::String name, bool add_prefix = true); /*! * \brief Looks up for a GlobalVar with the given name in this supply. @@ -67,7 +67,7 @@ class GlobalVarSupplyNode : public Object { * \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to * the name before performing the search. \return A cached GlobalVar. */ - GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true); + GlobalVar UniqueGlobalFor(const ffi::String& name, bool add_prefix = true); /*! * \brief Reserves an existing GlobalVar with this supply. @@ -84,9 +84,8 @@ class GlobalVarSupplyNode : public Object { /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ NameSupply name_supply_; - static constexpr const char* _type_key = "ir.GlobalVarSupply"; - - TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.GlobalVarSupply", GlobalVarSupplyNode, Object); private: std::unordered_map name_to_var_map_; @@ -111,7 +110,7 @@ class GlobalVarSupply : public ObjectRef { * guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array * of IRModules. */ - TVM_DLL explicit GlobalVarSupply(const Array& modules); + TVM_DLL explicit GlobalVarSupply(const ffi::Array& modules); /*! * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are @@ -120,8 +119,7 @@ class GlobalVarSupply : public ObjectRef { */ TVM_DLL explicit GlobalVarSupply(const IRModule module); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, - GlobalVarSupplyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode); }; } // namespace tvm diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index 1a91371cd38f..c14549f41283 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -103,7 +103,7 @@ namespace instrument { class PassInstrumentNode : public Object { public: /*! \brief Name of this pass instrument object. */ - String name; + ffi::String name; virtual ~PassInstrumentNode() {} @@ -141,9 +141,7 @@ class PassInstrumentNode : public Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("name", &PassInstrumentNode::name); } - - static constexpr const char* _type_key = "instrument.PassInstrument"; - TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("instrument.PassInstrument", PassInstrumentNode, Object); }; /*! @@ -152,7 +150,7 @@ class PassInstrumentNode : public Object { */ class PassInstrument : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PassInstrument, ObjectRef, PassInstrumentNode); }; } // namespace instrument diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 6f7d6d2d130d..3f70b2e25540 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -57,18 +57,18 @@ class IRModule; class IRModuleNode : public Object { public: /*! \brief A map from ids to all global functions. */ - Map functions; + ffi::Map functions; /*! \brief The source map for the module. */ SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; /*! \brief Globally static object that are referred by the IR itself */ - Map> global_infos; + ffi::Map> global_infos; /*! * \brief A map from string names to global variables that * ensures global uniqueness. */ - Map global_var_map_; + ffi::Map global_var_map_; /*! * \brief Get a module attribute. @@ -90,15 +90,15 @@ class IRModuleNode : public Object { * \endcode */ template - Optional GetAttr( + ffi::Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! @@ -179,7 +179,7 @@ class IRModuleNode : public Object { * \param name The name of the global info. * \param info The new array of global infos. */ - TVM_DLL void UpdateGlobalInfo(const String& name, const Array& info); + TVM_DLL void UpdateGlobalInfo(const ffi::String& name, const ffi::Array& info); /*! * \brief Remove a function from the global environment. @@ -192,21 +192,21 @@ class IRModuleNode : public Object { * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalVar(const String& name) const; + TVM_DLL bool ContainGlobalVar(const ffi::String& name) const; /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalVar GetGlobalVar(const String& str) const; + TVM_DLL GlobalVar GetGlobalVar(const ffi::String& str) const; /*! * \brief Collect all global vars defined in this module, ordered by * the global variable name. * \returns An array of global vars */ - TVM_DLL Array GetGlobalVars() const; + TVM_DLL ffi::Array GetGlobalVars() const; /*! * \brief Look up a global function by its variable. @@ -220,7 +220,7 @@ class IRModuleNode : public Object { * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL BaseFunc Lookup(const String& name) const; + TVM_DLL BaseFunc Lookup(const ffi::String& name) const; /*! * \brief Update the functions inside this environment by @@ -237,14 +237,13 @@ class IRModuleNode : public Object { /*! * \brief The set of imported files. */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::unordered_set Imports() const; TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr const char* _type_key = "ir.IRModule"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IRModule", IRModuleNode, Object); private: friend class IRModule; @@ -263,17 +262,21 @@ class IRModule : public ObjectRef { * \param attrs The module meta-data attributes. * \param global_infos Global infos in the module. */ - TVM_DLL explicit IRModule(Map functions, SourceMap map = {}, + TVM_DLL explicit IRModule(ffi::Map functions, SourceMap map = {}, DictAttrs attrs = DictAttrs(), - Map> global_infos = {}); + ffi::Map> global_infos = {}); /*! \brief default constructor */ - IRModule() : IRModule(Map({})) {} + IRModule() : IRModule(ffi::Map({})) {} /*! * \brief constructor * \param n The object pointer. */ - explicit IRModule(ObjectPtr n) : ObjectRef(n) {} + explicit IRModule(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit IRModule(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! \return mutable pointers to the node. */ IRModuleNode* operator->() const { auto* ptr = get_mutable(); @@ -286,7 +289,7 @@ class IRModule : public ObjectRef { * imports. */ TVM_DLL static IRModule FromExpr(const RelaxExpr& expr, - const Map& global_funcs = {}); + const ffi::Map& global_funcs = {}); /*! * \brief Create a shallow copy of an IRModule. @@ -314,11 +317,11 @@ namespace attr { constexpr const char* kModuleName = "mod_name"; /* - * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The + * \brief All the runtime::Tensors extracted from PrimFunc tir::AllocateConst nodes. The * node will record the index into this array. See also kConstNameToConstant below, which is * the analog for Realy Functions. * - * Type: Array + * Type: ffi::Array */ constexpr const char* kConstants = "constants"; @@ -326,7 +329,7 @@ constexpr const char* kConstants = "constants"; * \brief All the runtime::Modules accumulated during compilation by external codegen. These * modules must be either directly linked or captured in the final compilation artifact. * - * Type: Array + * Type: ffi::Array */ constexpr const char* kExternalMods = "external_mods"; @@ -360,12 +363,12 @@ constexpr const char* kExternalMods = "external_mods"; constexpr const char* kSystemLibPrefix = "system_lib_prefix"; /*! - * \brief All the named runtime::NDArrays accumulated during compilation by external codegen. + * \brief All the named runtime::Tensors accumulated during compilation by external codegen. * Generally the associated runtime::Module will indicate it requires bindings for these names, * and during module initialization these bindings will be recovered from a ConstLoaderModule. * See also kConstantsArray above, which is the analog for PrimFuncs. * - * Type: Map + * Type: ffi::Map */ constexpr const char* kConstNameToConstant = "const_name_to_constant"; diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 6eefaefea793..d3139ea2c821 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -50,7 +50,7 @@ class NameSupplyNode : public Object { * \param prefix The prefix to be used with this NameSupply. * \param name_map The map used to guarantee uniqueness. */ - NameSupplyNode(const String& prefix, std::unordered_map name_map) + NameSupplyNode(const ffi::String& prefix, std::unordered_map name_map) : prefix_(prefix), name_map(std::move(name_map)) {} /*! @@ -61,7 +61,8 @@ class NameSupplyNode : public Object { * \param add_underscore If set to true, add '_' between prefix and a digit. * \return A unique name. */ - String FreshName(const String& name, bool add_prefix = true, bool add_underscore = true); + ffi::String FreshName(const ffi::String& name, bool add_prefix = true, + bool add_underscore = true); /*! * \brief Reserves an existing name with this NameSupply. @@ -70,7 +71,7 @@ class NameSupplyNode : public Object { * name before reserving it. \return The name that was reserved with the NameSupply. It can be * different if a prefix is added. */ - String ReserveName(const String& name, bool add_prefix = true); + ffi::String ReserveName(const ffi::String& name, bool add_prefix = true); /*! * \brief Checks if this NameSupply already generated a name. @@ -79,17 +80,23 @@ class NameSupplyNode : public Object { * name before checking for it. \return True if the name has already been generated. False * otherwise. */ - bool ContainsName(const String& name, bool add_prefix = true); + bool ContainsName(const ffi::String& name, bool add_prefix = true); // Prefix for all GlobalVar names. It can be empty. std::string prefix_; - static constexpr const char* _type_key = "ir.NameSupply"; - TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object); + static constexpr const bool _type_mutable = true; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.NameSupply", NameSupplyNode, Object); private: /*! \brief Helper function to add the NameSupply prefix to the name. */ - String add_prefix_to_name(const String& name); + ffi::String add_prefix_to_name(const ffi::String& name); /*! * \brief Function that will generate a unique name. @@ -114,7 +121,7 @@ class NameSupply : public ObjectRef { * \param prefix The prefix to be used with this NameSupply. * \param name_map An optional map. */ - TVM_DLL explicit NameSupply(const String& prefix = "", + TVM_DLL explicit NameSupply(const ffi::String& prefix = "", std::unordered_map name_map = {}); /*! @@ -127,7 +134,7 @@ class NameSupply : public ObjectRef { TVM_DLL explicit NameSupply(Iter begin, Iter end, Lambda f) : NameSupply("", GetNameMap(begin, end, f)) {} - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(NameSupply, ObjectRef, NameSupplyNode); private: template diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 5f40ff4d3a7b..211fc3eecc1f 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -59,21 +59,21 @@ class OpAttrMap; class OpNode : public RelaxExprNode { public: /*! \brief name of the operator */ - String name; + ffi::String name; /*! \brief the type of the operator */ mutable FuncType op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. */ - String description; + ffi::String description; /* \brief Information of input arguments to the operator */ - Array arguments; + ffi::Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. */ - String attrs_type_key; + ffi::String attrs_type_key; /*! * \brief attribute type index, * this field varies in each run and is not exposed to frontend. @@ -104,8 +104,7 @@ class OpNode : public RelaxExprNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance; - static constexpr const char* _type_key = "ir.Op"; - TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelaxExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.Op", OpNode, RelaxExprNode); private: /*! \return the internal attr registry index. */ @@ -139,22 +138,22 @@ class Op : public RelaxExpr { * \tparam ValueType The type of the attribute. */ template - inline static OpAttrMap GetAttrMap(const String& attr_name); + inline static OpAttrMap GetAttrMap(const ffi::String& attr_name); /*! * \brief Checks if an attr map is present in the registry. * \param attr_name The name of the attribute. * \return bool True if the attr is present. */ - TVM_DLL static bool HasAttrMap(const String& attr_name); + TVM_DLL static bool HasAttrMap(const ffi::String& attr_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. * \param op_name Name of the operator. * \return Pointer to a Op, valid throughout program lifetime. */ - TVM_DLL static const Op& Get(const String& op_name); + TVM_DLL static const Op& Get(const ffi::String& op_name); - TVM_DEFINE_OBJECT_REF_METHODS(Op, RelaxExpr, OpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Op, RelaxExpr, OpNode); private: /*! @@ -162,7 +161,7 @@ class Op : public RelaxExpr { * \param key The attribute key * \return The attr map. */ - TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer(const String& key); + TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer(const ffi::String& key); }; /*! @@ -201,7 +200,7 @@ class OpRegEntry { * \param key The attribute type key to be set. * \return reference to self. */ - inline OpRegEntry& set_attrs_type_key(const String& key); + inline OpRegEntry& set_attrs_type_key(const ffi::String& key); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -249,7 +248,7 @@ class OpRegEntry { * \param name The name of the operator. * \return the corresponding entry. */ - TVM_DLL static OpRegEntry& RegisterOrGet(const String& name); + TVM_DLL static OpRegEntry& RegisterOrGet(const ffi::String& name); private: template @@ -263,11 +262,11 @@ class OpRegEntry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpAttrMap - TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel); + TVM_DLL void UpdateAttr(const ffi::String& key, ffi::Any value, int plevel); }; /*! - * \brief Map used to store meta-information about Op. + * \brief ffi::Map used to store meta-information about Op. * \tparam ValueType The type of the value stored in map. */ template @@ -318,7 +317,7 @@ class OpAttrMap : public AttrRegistryMap { // implementations template -inline OpAttrMap Op::GetAttrMap(const String& key) { +inline OpAttrMap Op::GetAttrMap(const ffi::String& key) { return OpAttrMap(Op::GetAttrMapContainer(key)); } @@ -331,7 +330,7 @@ inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(* inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, const std::string& description) { - auto n = make_object(); + auto n = ffi::make_object(); n->name = name; n->type_info = type; n->description = description; @@ -351,7 +350,7 @@ inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) return *this; } -inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_attrs_type_key(const ffi::String& key) { // NOLINT(*) get()->attrs_type_key = key; get()->attrs_type_index = tvm::ffi::TypeKeyToIndex(key.c_str()); return *this; @@ -376,7 +375,7 @@ template inline ValueType OpAttrMap::get(const RelaxExpr& expr, ValueType def_value) const { ICHECK(expr.defined()); if (const OpNode* op = expr.as()) { - return this->map_.get(GetRef(op), def_value); + return this->map_.get(ffi::GetRef(op), def_value); } else { return def_value; } diff --git a/include/tvm/ir/replace_global_vars.h b/include/tvm/ir/replace_global_vars.h index ea91d46d7c0a..0ed25c9a0a7a 100644 --- a/include/tvm/ir/replace_global_vars.h +++ b/include/tvm/ir/replace_global_vars.h @@ -41,10 +41,10 @@ namespace transform { * * \return The updated IRModule */ -TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, Map replacements); +TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, ffi::Map replacements); struct GlobalVarReplacer { - using FType = NodeFunctor)>; + using FType = NodeFunctor)>; TVM_DLL static FType& vtable() { static FType inst; return inst; diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index c7fce1c5024c..c94fb6b0a120 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -46,7 +46,7 @@ class SourceName; class SourceNameNode : public Object { public: /*! \brief The source name. */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -54,8 +54,7 @@ class SourceNameNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "ir.SourceName"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.SourceName", SourceNameNode, Object); }; /*! @@ -70,9 +69,9 @@ class SourceName : public ObjectRef { * \param name Name of the operator. * \return SourceName valid throughout program lifetime. */ - TVM_DLL static SourceName Get(const String& name); + TVM_DLL static SourceName Get(const ffi::String& name); - TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SourceName, ObjectRef, SourceNameNode); }; /*! @@ -106,8 +105,7 @@ class SpanNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "ir.Span"; - TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.Span", SpanNode, Object); }; class Span : public ObjectRef { @@ -117,7 +115,7 @@ class Span : public ObjectRef { /*! \brief Merge two spans into one which captures the combined regions. */ TVM_DLL Span Merge(const Span& other) const; - TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Span, ObjectRef, SpanNode); }; /*! @@ -126,15 +124,13 @@ class Span : public ObjectRef { class SequentialSpanNode : public SpanNode { public: /*! \brief The original source list of spans to construct a sequential span. */ - Array spans; + ffi::Array spans; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("spans", &SequentialSpanNode::spans); } - - static constexpr const char* _type_key = "ir.SequentialSpan"; - TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.SequentialSpan", SequentialSpanNode, SpanNode); }; /*! @@ -143,11 +139,11 @@ class SequentialSpanNode : public SpanNode { */ class SequentialSpan : public Span { public: - TVM_DLL SequentialSpan(Array spans); + TVM_DLL SequentialSpan(ffi::Array spans); TVM_DLL SequentialSpan(std::initializer_list init); - TVM_DEFINE_OBJECT_REF_METHODS(SequentialSpan, Span, SequentialSpanNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SequentialSpan, Span, SequentialSpanNode); }; /*! \brief A program source in any language. @@ -163,7 +159,7 @@ class SourceNode : public Object { SourceName source_name; /*! \brief The raw source. */ - String source; + ffi::String source; /*! \brief A mapping of line breaks into the raw source. */ std::vector> line_map; @@ -174,17 +170,15 @@ class SourceNode : public Object { .def_ro("source_name", &SourceNode::source_name) .def_ro("source", &SourceNode::source); } - - static constexpr const char* _type_key = "ir.Source"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.Source", SourceNode, Object); }; class Source : public ObjectRef { public: TVM_DLL Source(SourceName src_name, std::string source); - TVM_DLL tvm::String GetLine(int line); + TVM_DLL tvm::ffi::String GetLine(int line); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Source, ObjectRef, SourceNode); }; /*! @@ -197,7 +191,7 @@ class SourceMap; class SourceMapObj : public Object { public: /*! \brief The source mapping. */ - Map source_map; + ffi::Map source_map; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -205,18 +199,17 @@ class SourceMapObj : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "ir.SourceMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapObj, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.SourceMap", SourceMapObj, Object); }; class SourceMap : public ObjectRef { public: - explicit SourceMap(Map source_map); + explicit SourceMap(ffi::Map source_map); explicit SourceMap(std::initializer_list> source_map) - : SourceMap(Map(source_map)) {} + : SourceMap(ffi::Map(source_map)) {} - SourceMap() : SourceMap(Map()) {} + SourceMap() : SourceMap(ffi::Map()) {} void Add(const Source& source); @@ -225,7 +218,7 @@ class SourceMap : public ObjectRef { return static_cast(get_mutable()); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SourceMap, ObjectRef, SourceMapObj); }; } // namespace tvm diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 45f97ff61f2b..77d90a0e9558 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -82,16 +82,16 @@ class PassContextNode : public Object { int opt_level{2}; /*! \brief The list of required passes. */ - Array required_pass; + ffi::Array required_pass; /*! \brief The list of disabled passes. */ - Array disabled_pass; + ffi::Array disabled_pass; /*! \brief The diagnostic context. */ - mutable Optional diag_ctx; + mutable ffi::Optional diag_ctx; /*! \brief Pass specific configurations. */ - Map config; + ffi::Map config; /*! \brief A list of pass instrument implementations. */ - Array instruments; + ffi::Array instruments; PassContextNode() = default; @@ -107,21 +107,21 @@ class PassContextNode : public Object { * \throw Error if the key exists but the value does not match TObjectRef. */ template - Optional GetConfig( + ffi::Optional GetConfig( const std::string& key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { if (!config.defined()) return default_value; auto it = config.find(key); if (it != config.end()) { - return Downcast>((*it).second); + return Downcast>((*it).second); } else { return default_value; } } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetConfig(const std::string& key, TObjectRef default_value) const { - return GetConfig(key, Optional(default_value)); + ffi::Optional GetConfig(const std::string& key, TObjectRef default_value) const { + return GetConfig(key, ffi::Optional(default_value)); } static void RegisterReflection() { @@ -134,10 +134,7 @@ class PassContextNode : public Object { .def_ro("config", &PassContextNode::config) .def_ro("diag_ctx", &PassContextNode::diag_ctx); } - - static constexpr const char* _type_key = "transform.PassContext"; - - TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.PassContext", PassContextNode, Object); }; /*! @@ -156,7 +153,14 @@ class PassContextNode : public Object { class PassContext : public ObjectRef { public: PassContext() {} - explicit PassContext(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor with UnsafeInit + */ + explicit PassContext(ffi::UnsafeInit tag) : ObjectRef(tag) {} + /*! + * \brief constructor with ObjectPtr + */ + explicit PassContext(ObjectPtr n) : ObjectRef(n) {} /*! * \brief const accessor. * \return const access pointer. @@ -189,7 +193,7 @@ class PassContext : public ObjectRef { * \brief Get all supported configuration names and metadata, registered within the PassContext. * \return Map indexed by the config name, pointing to the metadata map as key-value */ - TVM_DLL static Map> ListConfigs(); + TVM_DLL static ffi::Map> ListConfigs(); /*! * \brief Call instrument implementations' callbacks when entering PassContext. @@ -247,7 +251,7 @@ class PassContext : public ObjectRef { int32_t tindex = ffi::TypeToRuntimeTypeIndex::v(); auto type_key = ffi::TypeIndexToTypeKey(tindex); auto legalization = [=](ffi::Any value) -> ffi::Any { - if (auto opt_map = value.try_cast>()) { + if (auto opt_map = value.try_cast>()) { return ffi::reflection::ObjectCreator(type_key)(opt_map.value()); } else { auto opt_val = value.try_cast(); @@ -288,7 +292,7 @@ class PassContext : public ObjectRef { // The exit of a pass context scope. TVM_DLL void ExitWithScope(); // Register configuration key value type. - TVM_DLL static void RegisterConfigOption(const char* key, String value_type_str, + TVM_DLL static void RegisterConfigOption(const char* key, ffi::String value_type_str, std::function legalization); // Classes to get the Python `with` like syntax. @@ -318,13 +322,13 @@ class PassInfoNode : public Object { int opt_level; /*! \brief The name of an optimization/analysis pass. */ - String name; + ffi::String name; /*! \brief Boolean that tells whether this pass will be traced or not. */ bool traceable; /*! \brief The passes that are required to perform the current pass. */ - Array required; + ffi::Array required; PassInfoNode() = default; @@ -336,10 +340,7 @@ class PassInfoNode : public Object { .def_ro("required", &PassInfoNode::required) .def_ro("traceable", &PassInfoNode::traceable); } - - static constexpr const char* _type_key = "transform.PassInfo"; - - TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.PassInfo", PassInfoNode, Object); }; /*! @@ -355,9 +356,10 @@ class PassInfo : public ObjectRef { * \param required The passes that are required to perform the current pass. * \param traceable Boolean that tells whether the pass is traceable. */ - TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); + TVM_DLL PassInfo(int opt_level, ffi::String name, ffi::Array required, + bool traceable); - TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PassInfo, ObjectRef, PassInfoNode); }; /*! @@ -392,9 +394,7 @@ class PassNode : public Object { * \return The transformed module. */ virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; - - static constexpr const char* _type_key = "transform.Pass"; - TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("transform.Pass", PassNode, Object); }; class Pass : public ObjectRef { @@ -426,7 +426,7 @@ class Pass : public ObjectRef { */ IRModule operator()(IRModule mod, const PassContext& pass_ctx) const; - TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Pass, ObjectRef, PassNode); private: IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node, @@ -447,7 +447,7 @@ class SequentialNode : public PassNode { PassInfo pass_info; /*! \brief A list of passes that used to compose a sequential pass. */ - tvm::Array passes; + tvm::ffi::Array passes; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -485,9 +485,7 @@ class SequentialNode : public PassNode { * \return Return the updated module. */ IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; - - static constexpr const char* _type_key = "transform.Sequential"; - TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.Sequential", SequentialNode, PassNode); }; class Sequential : public Pass { @@ -498,7 +496,7 @@ class Sequential : public Pass { * \param passes The passes to apply. * \param pass_info The pass metadata. */ - TVM_DLL Sequential(Array passes, PassInfo pass_info); + TVM_DLL Sequential(ffi::Array passes, PassInfo pass_info); /*! * \brief The constructor of `Sequential`. @@ -508,10 +506,10 @@ class Sequential : public Pass { * This allows users to only provide a list of passes and execute them * under a given context. */ - TVM_DLL Sequential(Array passes, String name = "sequential"); + TVM_DLL Sequential(ffi::Array passes, ffi::String name = "sequential"); Sequential() = default; - explicit Sequential(ObjectPtr n) : Pass(n) {} + explicit Sequential(ObjectPtr n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = SequentialNode; @@ -528,7 +526,7 @@ class Sequential : public Pass { * \return The created module pass. */ TVM_DLL Pass CreateModulePass(std::function pass_func, - int opt_level, String name, Array required, + int opt_level, ffi::String name, ffi::Array required, bool traceable = false); /* @@ -553,16 +551,15 @@ TVM_DLL Pass CreateModulePass(std::function pas * * \return The modified IRModule to IRModule pass. */ -TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex, +TVM_DLL Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex, bool error_if_no_function_matches_regex = false); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). * \param header The header to be attached to the output. - * \param show_meta_data Whether should we show meta data. * \return The pass. */ -TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false); +TVM_DLL Pass PrintIR(ffi::String header = ""); } // namespace transform } // namespace tvm diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 9d75e845f88f..117198214a0e 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -88,10 +88,9 @@ class TypeNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "ir.Type"; static constexpr const uint32_t _type_child_slots = 14; - TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, Object); }; /*! @@ -100,7 +99,7 @@ class TypeNode : public Object { */ class Type : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ObjectRef, TypeNode); }; /*! @@ -122,9 +121,7 @@ class PrimTypeNode : public TypeNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); } - - static constexpr const char* _type_key = "ir.PrimType"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PrimType", PrimTypeNode, TypeNode); }; /* @@ -140,7 +137,7 @@ class PrimType : public Type { */ TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimType, Type, PrimTypeNode); }; /*! @@ -162,7 +159,7 @@ class PointerTypeNode : public TypeNode { /*! * \brief The storage scope of the pointer */ - String storage_scope; + ffi::String storage_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -170,9 +167,7 @@ class PointerTypeNode : public TypeNode { .def_ro("element_type", &PointerTypeNode::element_type) .def_ro("storage_scope", &PointerTypeNode::storage_scope); } - - static constexpr const char* _type_key = "ir.PointerType"; - TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PointerType", PointerTypeNode, TypeNode); }; /* @@ -186,9 +181,9 @@ class PointerType : public Type { * \param element_type The type of the element which the pointer points to. * \param storage_scope The storage scope into which the pointer addresses */ - TVM_DLL explicit PointerType(Type element_type, String storage_scope = ""); + TVM_DLL explicit PointerType(Type element_type, ffi::String storage_scope = ""); - TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PointerType, Type, PointerTypeNode); }; /*! @@ -198,7 +193,7 @@ class PointerType : public Type { class TupleTypeNode : public TypeNode { public: /*! \brief The type of each field in the tuple. */ - Array fields; + ffi::Array fields; TupleTypeNode() {} @@ -208,9 +203,7 @@ class TupleTypeNode : public TypeNode { .def_ro("fields", &TupleTypeNode::fields) .def_ro("span", &TupleTypeNode::span); } - - static constexpr const char* _type_key = "ir.TupleType"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.TupleType", TupleTypeNode, TypeNode); }; /*! @@ -224,7 +217,7 @@ class TupleType : public Type { * \param fields Fields in the tuple. * \param span The span of the type. */ - TVM_DLL explicit TupleType(Array fields, Span span = Span()); + TVM_DLL explicit TupleType(ffi::Array fields, Span span = Span()); /*! * \brief Create an empty tuple type that constains nothing. @@ -232,7 +225,7 @@ class TupleType : public Type { */ TVM_DLL TupleType static Empty(); - TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleType, Type, TupleTypeNode); }; /*! @@ -260,7 +253,7 @@ inline bool IsVoidType(const Type& type) { class FuncTypeNode : public TypeNode { public: /*! \brief type type of arguments */ - Array arg_types; + ffi::Array arg_types; /*! \brief The type of return value. */ Type ret_type; @@ -271,9 +264,7 @@ class FuncTypeNode : public TypeNode { .def_ro("ret_type", &FuncTypeNode::ret_type) .def_ro("span", &FuncTypeNode::span); } - - static constexpr const char* _type_key = "ir.FuncType"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.FuncType", FuncTypeNode, TypeNode); }; /*! @@ -289,9 +280,9 @@ class FuncType : public Type { * \param span The span information. * \sa FuncTypeNode for more docs about these fields. */ - TVM_DLL FuncType(Array arg_types, Type ret_type, Span span = Span()); + TVM_DLL FuncType(ffi::Array arg_types, Type ret_type, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FuncType, Type, FuncTypeNode); }; /*! @@ -304,9 +295,7 @@ class TensorMapTypeNode : public TypeNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("span", &TensorMapTypeNode::span); } - - static constexpr const char* _type_key = "ir.TensorMapType"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.TensorMapType", TensorMapTypeNode, TypeNode); }; /*! @@ -317,8 +306,8 @@ class TensorMapType : public Type { public: TVM_DLL TensorMapType(Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, Type, TensorMapTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, Type, + TensorMapTypeNode); }; - } // namespace tvm #endif // TVM_IR_TYPE_H_ diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 858226354c66..b2878519c424 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -123,7 +123,7 @@ class TVM_DLL TypeMutator : public TypeFunctor { Type VisitType_(const PointerTypeNode* op) override; private: - Array MutateArray(Array arr); + ffi::Array MutateArray(ffi::Array arr); }; } // namespace tvm diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h index de005dcd125b..6c664b636925 100644 --- a/include/tvm/meta_schedule/arg_info.h +++ b/include/tvm/meta_schedule/arg_info.h @@ -33,8 +33,7 @@ namespace meta_schedule { /*! \brief The argument information. */ class ArgInfoNode : public runtime::Object { public: - static constexpr const char* _type_key = "meta_schedule.ArgInfo"; - TVM_DECLARE_BASE_OBJECT_INFO(ArgInfoNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.ArgInfo", ArgInfoNode, runtime::Object); public: /*! \brief Default destructor. */ @@ -60,16 +59,16 @@ class ArgInfo : public runtime::ObjectRef { * \param func The PrimFunc to get argument information from. * \return An array of the argument information derived. */ - TVM_DLL static Array FromPrimFunc(const tir::PrimFunc& func); + TVM_DLL static ffi::Array FromPrimFunc(const tir::PrimFunc& func); /*! * \brief Extract a list of the argument information from the entry func of an IRModule * \param mod The IRModule to extract argument information from. * \param remove_preproc Whether to remove the preprocessing blocks. * \return An array of the argument information derived. */ - TVM_DLL static Array FromEntryFunc(const IRModule& mod, bool remove_preproc); + TVM_DLL static ffi::Array FromEntryFunc(const IRModule& mod, bool remove_preproc); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ArgInfo, runtime::ObjectRef, ArgInfoNode); protected: ArgInfo() = default; @@ -89,9 +88,7 @@ class TensorInfoNode : public ArgInfoNode { .def_ro("dtype", &TensorInfoNode::dtype) .def_ro("shape", &TensorInfoNode::shape); } - - static constexpr const char* _type_key = "meta_schedule.TensorInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, ArgInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TensorInfo", TensorInfoNode, ArgInfoNode); public: ObjectRef AsJSON() const; @@ -115,7 +112,7 @@ class TensorInfo : public ArgInfo { * \return The argument information parsed. */ TVM_DLL static TensorInfo FromJSON(const ObjectRef& json_obj); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorInfo, ArgInfo, TensorInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TensorInfo, ArgInfo, TensorInfoNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 7e0be7de8265..e4b5f011eb46 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -26,8 +26,8 @@ #include #include #include -#include #include +#include #include namespace tvm { @@ -41,7 +41,7 @@ class BuilderInputNode : public runtime::Object { /*! \brief The target to be built for. */ Target target; /*! \brief Parameters for Relax build module. */ - Optional> params; + ffi::Optional> params; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -50,9 +50,8 @@ class BuilderInputNode : public runtime::Object { .def_ro("target", &BuilderInputNode::target) .def_ro("params", &BuilderInputNode::params); } - - static constexpr const char* _type_key = "meta_schedule.BuilderInput"; - TVM_DECLARE_FINAL_OBJECT_INFO(BuilderInputNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.BuilderInput", BuilderInputNode, + runtime::Object); }; /*! @@ -67,18 +66,19 @@ class BuilderInput : public runtime::ObjectRef { * \param target The target to be built for. * \param params Parameters for Relax build module. */ - TVM_DLL explicit BuilderInput(IRModule mod, Target target, - Optional> params = std::nullopt); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); + TVM_DLL explicit BuilderInput( + IRModule mod, Target target, + ffi::Optional> params = std::nullopt); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; /*! \brief The builder's output, containing the artifact path or error message if any. */ class BuilderResultNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ - Optional artifact_path; + ffi::Optional artifact_path; /*! \brief The error message if any. */ - Optional error_msg; + ffi::Optional error_msg; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -86,9 +86,8 @@ class BuilderResultNode : public runtime::Object { .def_ro("artifact_path", &BuilderResultNode::artifact_path) .def_ro("error_msg", &BuilderResultNode::error_msg); } - - static constexpr const char* _type_key = "meta_schedule.BuilderResult"; - TVM_DECLARE_FINAL_OBJECT_INFO(BuilderResultNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.BuilderResult", BuilderResultNode, + runtime::Object); }; /*! @@ -102,8 +101,10 @@ class BuilderResult : public runtime::ObjectRef { * \param artifact_path The path to the built artifact. * \param error_msg The error message if any. */ - TVM_DLL explicit BuilderResult(Optional artifact_path, Optional error_msg); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderResult, runtime::ObjectRef, BuilderResultNode); + TVM_DLL explicit BuilderResult(ffi::Optional artifact_path, + ffi::Optional error_msg); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BuilderResult, runtime::ObjectRef, + BuilderResultNode); }; /*! \brief The abstract builder interface. */ @@ -116,16 +117,16 @@ class BuilderNode : public runtime::Object { * \param build_inputs The inputs to be built. * \return The build results. */ - virtual Array Build(const Array& build_inputs) = 0; + virtual ffi::Array Build(const ffi::Array& build_inputs) = 0; /*! * \brief The function type of `Build` method. * \param build_inputs The inputs to be built. * \return The build results. */ - using FBuild = ffi::TypedFunction(const Array&)>; + using FBuild = ffi::TypedFunction(const ffi::Array&)>; - static constexpr const char* _type_key = "meta_schedule.Builder"; - TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Builder", BuilderNode, runtime::Object); }; /*! @@ -134,13 +135,20 @@ class BuilderNode : public runtime::Object { */ class Builder : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Builder(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a builder with customized build method on the python-side. * \param f_build The packed function to the `Build` function.. * \return The Builder created. */ static Builder PyBuilder(BuilderNode::FBuild f_build); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Builder, runtime::ObjectRef, BuilderNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Builder, runtime::ObjectRef, BuilderNode); }; /*! \brief An abstract builder with customized build method on the python-side. */ @@ -154,13 +162,11 @@ class PyBuilderNode : public BuilderNode { refl::ObjectDef().def_ro("f_build", &PyBuilderNode::f_build); } - Array Build(const Array& build_inputs) final { + ffi::Array Build(const ffi::Array& build_inputs) final { ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; return f_build(build_inputs); } - - static constexpr const char* _type_key = "meta_schedule.PyBuilder"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyBuilderNode, BuilderNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyBuilder", PyBuilderNode, BuilderNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index 9311fdef40c9..aaf4665c2729 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -47,13 +47,13 @@ class CostModelNode : public runtime::Object { * \brief Load the cost model from given file location. * \param path The file path. */ - virtual void Load(const String& path) = 0; + virtual void Load(const ffi::String& path) = 0; /*! * \brief Save the cost model to given file location. * \param path The file path. */ - virtual void Save(const String& path) = 0; + virtual void Save(const ffi::String& path) = 0; /*! * \brief Update the cost model given running results. @@ -61,8 +61,8 @@ class CostModelNode : public runtime::Object { * \param candidates The measure candidates. * \param results The running results of the measure candidates. */ - virtual void Update(const TuneContext& context, const Array& candidates, - const Array& results) = 0; + virtual void Update(const TuneContext& context, const ffi::Array& candidates, + const ffi::Array& results) = 0; /*! * \brief Predict the normalized score (the larger the better) of given measure candidates. @@ -71,10 +71,10 @@ class CostModelNode : public runtime::Object { * \return The predicted normalized score. */ virtual std::vector Predict(const TuneContext& context, - const Array& candidates) = 0; + const ffi::Array& candidates) = 0; - static constexpr const char* _type_key = "meta_schedule.CostModel"; - TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.CostModel", CostModelNode, Object); }; /*! \brief The cost model with customized methods on the python-side. */ @@ -84,12 +84,12 @@ class PyCostModelNode : public CostModelNode { * \brief Load the cost model from given file location. * \param path The file path. */ - using FLoad = ffi::TypedFunction; + using FLoad = ffi::TypedFunction; /*! * \brief Save the cost model to given file location. * \param path The file path. */ - using FSave = ffi::TypedFunction; + using FSave = ffi::TypedFunction; /*! * \brief Update the cost model given running results. * \param context The tuning context. @@ -97,21 +97,21 @@ class PyCostModelNode : public CostModelNode { * \param results The running results of the measure candidates. * \return Whether cost model was updated successfully. */ - using FUpdate = ffi::TypedFunction&, - const Array&)>; + using FUpdate = ffi::TypedFunction&, + const ffi::Array&)>; /*! * \brief Predict the running results of given measure candidates. * \param context The tuning context. * \param candidates The measure candidates. * \param p_addr The address to save the estimated running results. */ - using FPredict = - ffi::TypedFunction&, void* p_addr)>; + using FPredict = ffi::TypedFunction&, + void* p_addr)>; /*! * \brief Get the cost model as string with name. * \return The string representation of the cost model. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `Load` function. */ FLoad f_load; @@ -124,15 +124,13 @@ class PyCostModelNode : public CostModelNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - void Load(const String& path); - void Save(const String& path); - void Update(const TuneContext& context, const Array& candidates, - const Array& results); + void Load(const ffi::String& path); + void Save(const ffi::String& path); + void Update(const TuneContext& context, const ffi::Array& candidates, + const ffi::Array& results); std::vector Predict(const TuneContext& context, - const Array& candidates); - - static constexpr const char* _type_key = "meta_schedule.PyCostModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); + const ffi::Array& candidates); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyCostModel", PyCostModelNode, CostModelNode); }; /*! @@ -155,7 +153,7 @@ class CostModel : public runtime::ObjectRef { PyCostModelNode::FUpdate f_update, // PyCostModelNode::FPredict f_predict, // PyCostModelNode::FAsString f_as_string); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CostModel, ObjectRef, CostModelNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 29bc030c5b25..6f6b8bfca8d6 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -31,6 +31,7 @@ #include #include +#include #include namespace tvm { @@ -52,10 +53,7 @@ class WorkloadNode : public runtime::Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("mod", &WorkloadNode::mod); } - - static constexpr const char* _type_key = "meta_schedule.Workload"; - - TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.Workload", WorkloadNode, runtime::Object); /*! * \brief Export the workload to a JSON string. @@ -71,6 +69,7 @@ class WorkloadNode : public runtime::Object { class Workload : public runtime::ObjectRef { public: using THashCode = WorkloadNode::THashCode; + explicit Workload(ObjectPtr data) : ObjectRef(data) {} /*! * \brief Constructor of Workload. * \param mod The workload's IRModule. @@ -89,7 +88,7 @@ class Workload : public runtime::ObjectRef { */ TVM_DLL static Workload FromJSON(const ObjectRef& json_obj); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Workload, runtime::ObjectRef, WorkloadNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Workload, runtime::ObjectRef, WorkloadNode); }; /*! \brief The hash method for Workload */ @@ -117,13 +116,13 @@ class TuningRecordNode : public runtime::Object { /*! \brief The trace tuned. */ tir::Trace trace; /*! \brief The workload. */ - Workload workload{nullptr}; + Workload workload{ffi::UnsafeInit()}; /*! \brief The profiling result in seconds. */ - Optional> run_secs; + ffi::Optional> run_secs; /*! \brief The target for tuning. */ - Optional target; + ffi::Optional target; /*! \brief The argument information. */ - Optional> args_info; + ffi::Optional> args_info; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -134,10 +133,8 @@ class TuningRecordNode : public runtime::Object { .def_ro("target", &TuningRecordNode::target) .def_ro("args_info", &TuningRecordNode::args_info); } - - static constexpr const char* _type_key = "meta_schedule.TuningRecord"; - - TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TuningRecord", TuningRecordNode, + runtime::Object); /*! \brief Construct the measure candidate given the initial IR module and trace * stored in the tuning record. */ @@ -170,8 +167,9 @@ class TuningRecord : public runtime::ObjectRef { \param args_info The argument information of the tuning record. */ TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload, - Optional> run_secs, Optional target, - Optional> args_info); + ffi::Optional> run_secs, + ffi::Optional target, + ffi::Optional> args_info); /*! * \brief Create a tuning record from a json object. * \param json_obj The json object. @@ -179,7 +177,7 @@ class TuningRecord : public runtime::ObjectRef { * \return The tuning record created. */ TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TuningRecord, runtime::ObjectRef, TuningRecordNode); }; class Database; @@ -192,14 +190,14 @@ class DatabaseNode : public runtime::Object { * \param mod_eq_name A string to specify the module equality testing and hashing method. * It must be one of the followings: * - "structural": Use StructuralEqual/Hash - * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + * - "ignore-tensor": Same as "structural", but ignore tensor raw data during * equality testing and hashing. * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - * given module. The "ignore-ndarray" varint is used for the extracted blocks + * given module. The "ignore-tensor" varint is used for the extracted blocks * or in case no anchor block is found. * For the definition of the anchor block, see tvm/tir/analysis.h. */ - explicit DatabaseNode(String mod_eq_name = "structural"); + explicit DatabaseNode(ffi::String mod_eq_name = "structural"); /*! \brief Default destructor */ virtual ~DatabaseNode(); @@ -226,12 +224,12 @@ class DatabaseNode : public runtime::Object { * \param top_k The number of top records to be returned. * \return An array of top K tuning records for the given workload. */ - virtual Array GetTopK(const Workload& workload, int top_k) = 0; + virtual ffi::Array GetTopK(const Workload& workload, int top_k) = 0; /*! * \brief Get all tuning records from the database. * \return An Array of all the tuning records in the database. */ - virtual Array GetAllTuningRecords() = 0; + virtual ffi::Array GetAllTuningRecords() = 0; /*! * \brief Get the size of the database. * \return The size of the database. @@ -244,8 +242,8 @@ class DatabaseNode : public runtime::Object { * \param workload_name The name of the workload to be searched for. * \return The best record of the given workload; std::nullopt if not found. */ - virtual Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name); + virtual ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& workload_name); /*! * \brief Query the best schedule of the given workload from the database. * \param mod The IRModule to be searched for. @@ -253,8 +251,8 @@ class DatabaseNode : public runtime::Object { * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ - virtual Optional QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name); + virtual ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name); /*! * \brief Query the best IRModule of the given workload from the database. * \param mod The IRModule to be searched for. @@ -262,8 +260,8 @@ class DatabaseNode : public runtime::Object { * \param workload_name The name of the workload to be searched for. * \return The IRModule in the best IRModule of the given workload; std::nullopt if not found. */ - virtual Optional QueryIRModule(const IRModule& mod, const Target& target, - const String& workload_name); + virtual ffi::Optional QueryIRModule(const IRModule& mod, const Target& target, + const ffi::String& workload_name); /*! * \brief Prune the database and dump it a given database. * \param destination The destination database to be dumped to. @@ -275,8 +273,8 @@ class DatabaseNode : public runtime::Object { return *mod_eq_; } - static constexpr const char* _type_key = "meta_schedule.Database"; - TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Database", DatabaseNode, runtime::Object); private: /*! \brief The module equality testing and hashing method */ @@ -291,14 +289,14 @@ class PyDatabaseNode : public DatabaseNode { * \param mod_eq_name A string to specify the module equality testing and hashing method. * It must be one of the followings: * - "structural": Use StructuralEqual/Hash - * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + * - "ignore-tensor": Same as "structural", but ignore tensor raw data during * equality testing and hashing. * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - * given module. The "ignore-ndarray" varint is used for the extracted blocks + * given module. The "ignore-tensor" varint is used for the extracted blocks * or in case no anchor block is found. * For the definition of the anchor block, see tvm/tir/analysis.h. */ - explicit PyDatabaseNode(String mod_eq_name = "structural"); + explicit PyDatabaseNode(ffi::String mod_eq_name = "structural"); /*! * \brief The function type of `HasWorkload` method. @@ -323,12 +321,12 @@ class PyDatabaseNode : public DatabaseNode { * \param top_k The number of top records to be returned. * \return An array of top K tuning records for the given workload. */ - using FGetTopK = ffi::TypedFunction(const Workload&, int)>; + using FGetTopK = ffi::TypedFunction(const Workload&, int)>; /*! * \brief The function type of `GetAllTuningRecords` method. * \return An Array of all the tuning records in the database. */ - using FGetAllTuningRecords = ffi::TypedFunction()>; + using FGetAllTuningRecords = ffi::TypedFunction()>; /*! * \brief The function type of `QueryTuningRecord` method. * \param mod The IRModule to be searched for. @@ -336,8 +334,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The best record of the given workload; std::nullopt if not found. */ - using FQueryTuningRecord = - ffi::TypedFunction(const IRModule&, const Target&, const String&)>; + using FQueryTuningRecord = ffi::TypedFunction( + const IRModule&, const Target&, const ffi::String&)>; /*! * \brief The function type of `QuerySchedule` method. * \param mod The IRModule to be searched for. @@ -345,8 +343,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ - using FQuerySchedule = - ffi::TypedFunction(const IRModule&, const Target&, const String&)>; + using FQuerySchedule = ffi::TypedFunction( + const IRModule&, const Target&, const ffi::String&)>; /*! * \brief The function type of `QueryIRModule` method. * \param mod The IRModule to be searched for. @@ -354,8 +352,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The IRModule in the best IRModule of the given workload; std::nullopt if not found. */ - using FQueryIRModule = - ffi::TypedFunction(const IRModule&, const Target&, const String&)>; + using FQueryIRModule = ffi::TypedFunction(const IRModule&, const Target&, + const ffi::String&)>; /*! * \brief The function type of `Size` method. * \return The size of the database. @@ -394,6 +392,8 @@ class PyDatabaseNode : public DatabaseNode { // `f_query_schedule` is not registered // `f_query_ir_module` is not registered // `f_size` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } bool HasWorkload(const IRModule& mod) final { @@ -412,19 +412,19 @@ class PyDatabaseNode : public DatabaseNode { f_commit_tuning_record(record); } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!"; return f_get_top_k(workload, top_k); } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { ICHECK(f_get_all_tuning_records != nullptr) << "PyDatabase's GetAllTuningRecords method not implemented!"; return f_get_all_tuning_records(); } - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { if (f_query_tuning_record == nullptr) { return DatabaseNode::QueryTuningRecord(mod, target, workload_name); } else { @@ -432,8 +432,8 @@ class PyDatabaseNode : public DatabaseNode { } } - Optional QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { if (f_query_schedule == nullptr) { return DatabaseNode::QuerySchedule(mod, target, workload_name); } else { @@ -441,8 +441,8 @@ class PyDatabaseNode : public DatabaseNode { } } - Optional QueryIRModule(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QueryIRModule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { if (f_query_ir_module == nullptr) { return DatabaseNode::QueryIRModule(mod, target, workload_name); } else { @@ -455,8 +455,8 @@ class PyDatabaseNode : public DatabaseNode { return f_size(); } - static constexpr const char* _type_key = "meta_schedule.PyDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyDatabase", PyDatabaseNode, DatabaseNode); }; /*! @@ -465,11 +465,18 @@ class PyDatabaseNode : public DatabaseNode { */ class Database : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Database(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief An in-memory database. * \param mod_eq_name A string to specify the module equality testing and hashing method. */ - TVM_DLL static Database MemoryDatabase(String mod_eq_name = "structural"); + TVM_DLL static Database MemoryDatabase(ffi::String mod_eq_name = "structural"); /*! * \brief A database for injecting handcrafted schedule functions. * \param schedule_fn The function to do scheduling, which takes a TIR schedule, @@ -477,7 +484,7 @@ class Database : public runtime::ObjectRef { * \param mod_eq_name A string to specify the module equality testing and hashing method. */ TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction schedule_fn, - String mod_eq_name = "structural"); + ffi::String mod_eq_name = "structural"); /*! * \brief Create a default database that uses JSON file for tuning records. * \param path_workload The path to the workload table. @@ -485,8 +492,8 @@ class Database : public runtime::ObjectRef { * \param allow_missing Whether to create new file when the given path is not found. * \param mod_eq_name A string to specify the module equality testing and hashing method. */ - TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, - bool allow_missing, String mod_eq_name = "structural"); + TVM_DLL static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record, + bool allow_missing, ffi::String mod_eq_name = "structural"); /*! * \brief A database composed of multiple databases, allowing users to guide IR rewriting using * combined knowledge of those databases. To each query, it returns the best record among all the @@ -494,7 +501,7 @@ class Database : public runtime::ObjectRef { * \param databases The list of databases to be combined. * \return The combined database. */ - TVM_DLL static Database UnionDatabase(Array databases); + TVM_DLL static Database UnionDatabase(ffi::Array databases); /*! * \brief A database composed of multiple databases, allowing users to guide IR rewriting using * combined knowledge of those databases. To each query, it returns the record from the first @@ -502,7 +509,7 @@ class Database : public runtime::ObjectRef { * \param databases The database to be subsetted. * \return The subsetted database. */ - TVM_DLL static Database OrderedUnionDatabase(Array databases); + TVM_DLL static Database OrderedUnionDatabase(ffi::Array databases); /*! * \brief Create a database with customized methods on the python-side. * \param f_has_workload The packed function of `HasWorkload`. @@ -526,15 +533,15 @@ class Database : public runtime::ObjectRef { PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size, - String mod_eq_name = "structural"); + ffi::String mod_eq_name = "structural"); /*! \return The current Database in the scope. */ - static Optional Current(); + static ffi::Optional Current(); /*! \brief Entering the scope of the context manager */ void EnterWithScope(); /*! \brief Exiting the scope of the context manager */ void ExitWithScope(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Database, runtime::ObjectRef, DatabaseNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index 57debfee2267..646ec3c00cf0 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -42,13 +42,13 @@ namespace meta_schedule { class ExtractedTaskNode : public runtime::Object { public: /*! \brief The name of the task extracted */ - String task_name; + ffi::String task_name; /*! \brief The high-level IR */ IRModule mod; /*! \brief Target */ Target target; /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ - Array dispatched; + ffi::Array dispatched; /*! \brief Weight of the task */ int weight; @@ -62,9 +62,9 @@ class ExtractedTaskNode : public runtime::Object { .def_ro("weight", &ExtractedTaskNode::weight); } - static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.ExtractedTask", ExtractedTaskNode, + runtime::Object); }; /*! @@ -73,10 +73,10 @@ class ExtractedTaskNode : public runtime::Object { */ class ExtractedTask : public runtime::ObjectRef { public: - explicit ExtractedTask(String task_name, IRModule mod, Target target, Array dispatched, - int weight); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, - ExtractedTaskNode); + explicit ExtractedTask(ffi::String task_name, IRModule mod, Target target, + ffi::Array dispatched, int weight); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ExtractedTask, runtime::ObjectRef, + ExtractedTaskNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 88bf056ebb6f..9a339d39e7ba 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -25,8 +25,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace meta_schedule { @@ -40,20 +40,19 @@ class FeatureExtractorNode : public runtime::Object { virtual ~FeatureExtractorNode() = default; static void RegisterReflection() { - // No fields to register + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } /*! * \brief Extract features from the given measure candidate. * \param context The tuning context for feature extraction. * \param candidates The measure candidates to extract features from. - * \return The feature ndarray extracted. + * \return The feature tensor extracted. */ - virtual Array ExtractFrom(const TuneContext& context, - const Array& candidates) = 0; - - static constexpr const char* _type_key = "meta_schedule.FeatureExtractor"; - TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object); + virtual ffi::Array ExtractFrom( + const TuneContext& context, const ffi::Array& candidates) = 0; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.FeatureExtractor", FeatureExtractorNode, Object); }; /*! \brief The feature extractor with customized methods on the python-side. */ @@ -63,15 +62,15 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { * \brief Extract features from the given measure candidate. * \param context The tuning context for feature extraction. * \param candidates The measure candidates to extract features from. - * \return The feature ndarray extracted. + * \return The feature tensor extracted. */ - using FExtractFrom = ffi::TypedFunction( - const TuneContext& context, const Array& candidates)>; + using FExtractFrom = ffi::TypedFunction( + const TuneContext& context, const ffi::Array& candidates)>; /*! * \brief Get the feature extractor as string with name. * \return The string of the feature extractor. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `ExtractFrom` function. */ FExtractFrom f_extract_from; @@ -81,13 +80,14 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { static void RegisterReflection() { // `f_extract_from` is not registered // `f_as_string` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - Array ExtractFrom(const TuneContext& context, - const Array& candidates) final; - - static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); + ffi::Array ExtractFrom( + const TuneContext& context, const ffi::Array& candidates) final; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyFeatureExtractor", PyFeatureExtractorNode, + FeatureExtractorNode); }; /*! @@ -119,7 +119,7 @@ class FeatureExtractor : public runtime::ObjectRef { TVM_DLL static FeatureExtractor PyFeatureExtractor( PyFeatureExtractorNode::FExtractFrom f_extract_from, PyFeatureExtractorNode::FAsString f_as_string); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FeatureExtractor, ObjectRef, FeatureExtractorNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index d7377c3e5d1f..9e7d49a0c9d4 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -43,7 +43,8 @@ class MeasureCallbackNode : public runtime::Object { virtual ~MeasureCallbackNode() = default; static void RegisterReflection() { - // No fields to register + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } /*! @@ -54,14 +55,14 @@ class MeasureCallbackNode : public runtime::Object { * \param builder_results The builder results by building the measure candidates. * \param runner_results The runner results by running the built measure candidates. */ - virtual void Apply(const TaskScheduler& task_scheduler, // - int task_id, // - const Array& measure_candidates, // - const Array& builder_results, // - const Array& runner_results) = 0; - - static constexpr const char* _type_key = "meta_schedule.MeasureCallback"; - TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); + virtual void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const ffi::Array& measure_candidates, // + const ffi::Array& builder_results, // + const ffi::Array& runner_results) = 0; + + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.MeasureCallback", MeasureCallbackNode, Object); }; /*! \brief The measure callback with customized methods on the python-side. */ @@ -76,16 +77,16 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { * \param results The runner results by running the built measure candidates. * \return Whether the measure callback was successfully applied. */ - using FApply = ffi::TypedFunction& measure_candidates, // - const Array& builds, // - const Array& results)>; + using FApply = ffi::TypedFunction& measure_candidates, // + const ffi::Array& builds, // + const ffi::Array& results)>; /*! * \brief Get the measure callback function as string with name. * \return The string of the measure callback function. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `Apply` function. */ FApply f_apply; @@ -95,16 +96,17 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { static void RegisterReflection() { // `f_apply` is not registered // `f_as_string` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - void Apply(const TaskScheduler& task_scheduler, // - int task_id, // - const Array& measure_candidates, // - const Array& builds, // - const Array& results); - - static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); + void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const ffi::Array& measure_candidates, // + const ffi::Array& builds, // + const ffi::Array& results); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyMeasureCallback", PyMeasureCallbackNode, + MeasureCallbackNode); }; /*! @@ -137,8 +139,8 @@ class MeasureCallback : public runtime::ObjectRef { TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, PyMeasureCallbackNode::FAsString f_as_string); /*! \brief The default list of measure callbacks. */ - TVM_DLL static Array Default(); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); + TVM_DLL static ffi::Array Default(); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MeasureCallback, ObjectRef, MeasureCallbackNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h index 0aee01fff5eb..557e9a3139d2 100644 --- a/include/tvm/meta_schedule/measure_candidate.h +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -35,7 +35,7 @@ class MeasureCandidateNode : public runtime::Object { /*! \brief The schedule for measurement. */ tir::Schedule sch; /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ - Array args_info; + ffi::Array args_info; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -43,9 +43,7 @@ class MeasureCandidateNode : public runtime::Object { .def_ro("sch", &MeasureCandidateNode::sch) .def_ro("args_info", &MeasureCandidateNode::args_info); } - - static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; - TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MeasureCandidate", MeasureCandidateNode, Object); }; /*! @@ -59,8 +57,8 @@ class MeasureCandidate : public runtime::ObjectRef { * \param sch The schedule for measurement. * \param args_info The argument information, e.g., (shape, dtype) for tensors. */ - TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); + TVM_DLL MeasureCandidate(tir::Schedule sch, ffi::Array args_info); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(MeasureCandidate, ObjectRef, MeasureCandidateNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 701045b7fb3f..05489c755217 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -41,7 +41,8 @@ class MutatorNode : public runtime::Object { virtual ~MutatorNode() = default; static void RegisterReflection() { - // No fields to register + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } /*! @@ -57,8 +58,8 @@ class MutatorNode : public runtime::Object { * \param rand_state The random state for mutation. * \return None if mutator failed, otherwise return the mutated trace. */ - virtual Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) = 0; + virtual ffi::Optional Apply( + const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0; /*! * \brief Clone the mutator. @@ -66,8 +67,8 @@ class MutatorNode : public runtime::Object { */ virtual Mutator Clone() const = 0; - static constexpr const char* _type_key = "meta_schedule.Mutator"; - TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Mutator", MutatorNode, Object); }; /*! @@ -86,7 +87,7 @@ class Mutator : public runtime::ObjectRef { * \param trace The given trace for mutation. * \return None if mutator failed, otherwise return the mutated trace. */ - using FApply = ffi::TypedFunction( + using FApply = ffi::TypedFunction( const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; /*! * \brief Clone the mutator. @@ -97,7 +98,7 @@ class Mutator : public runtime::ObjectRef { * \brief Get the mutator as string with name. * \return The string of the mutator. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */ TVM_DLL static Mutator MutateTileSize(); /*! @@ -132,15 +133,15 @@ class Mutator : public runtime::ObjectRef { TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, FApply f_apply, FClone f_clone, FAsString f_as_string); /*! \brief Create default mutators for LLVM */ - TVM_DLL static Map DefaultLLVM(); + TVM_DLL static ffi::Map DefaultLLVM(); /*! \brief Create default mutators for CUDA */ - TVM_DLL static Map DefaultCUDA(); + TVM_DLL static ffi::Map DefaultCUDA(); /*! \brief Create default mutators for CUDA with TensorCore */ - TVM_DLL static Map DefaultCUDATensorCore(); + TVM_DLL static ffi::Map DefaultCUDATensorCore(); /*! \brief Create default mutators for Hexagon */ - TVM_DLL static Map DefaultHexagon(); + TVM_DLL static ffi::Map DefaultHexagon(); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mutator, ObjectRef, MutatorNode); }; /*! \brief The mutator with customized methods on the python-side. */ @@ -167,12 +168,10 @@ class PyMutatorNode : public MutatorNode { } void InitializeWithTuneContext(const TuneContext& context) final; - Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) final; + ffi::Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) final; Mutator Clone() const final; - - static constexpr const char* _type_key = "meta_schedule.PyMutator"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyMutator", PyMutatorNode, MutatorNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index c511271d20a9..948f75210701 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -40,7 +40,8 @@ class PostprocNode : public runtime::Object { virtual ~PostprocNode() = default; static void RegisterReflection() { - // No fields to register + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } /*! @@ -63,8 +64,8 @@ class PostprocNode : public runtime::Object { */ virtual Postproc Clone() const = 0; - static constexpr const char* _type_key = "meta_schedule.Postproc"; - TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Postproc", PostprocNode, Object); }; /*! @@ -93,7 +94,7 @@ class Postproc : public runtime::ObjectRef { * \brief Get the postprocessor function as string with name. * \return The string of the postprocessor function. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! * \brief Create a postprocessor with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. @@ -163,17 +164,19 @@ class Postproc : public runtime::ObjectRef { */ TVM_DLL static Postproc RewriteLayout(); /*! \brief Create default postprocessors for LLVM */ - TVM_DLL static Array DefaultLLVM(); + TVM_DLL static ffi::Array DefaultLLVM(); /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */ - TVM_DLL static Array DefaultCPUTensorization(); + TVM_DLL static ffi::Array DefaultCPUTensorization(); + /*! \brief Create default postprocessors for RISCV */ + TVM_DLL static ffi::Array DefaultRISCV(); /*! \brief Create default postprocessors for CUDA */ - TVM_DLL static Array DefaultCUDA(); + TVM_DLL static ffi::Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ - TVM_DLL static Array DefaultCUDATensorCore(); + TVM_DLL static ffi::Array DefaultCUDATensorCore(); /*! \brief Create default postprocessors for Hexagon */ - TVM_DLL static Array DefaultHexagon(); + TVM_DLL static ffi::Array DefaultHexagon(); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Postproc, ObjectRef, PostprocNode); }; /*! \brief The postprocessor with customized methods on the python-side. */ @@ -197,14 +200,14 @@ class PyPostprocNode : public PostprocNode { // `f_apply` is not registered // `f_clone` is not registered // `f_as_string` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } void InitializeWithTuneContext(const TuneContext& context) final; bool Apply(const tir::Schedule& sch) final; Postproc Clone() const final; - - static constexpr const char* _type_key = "meta_schedule.PyPostproc"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyPostproc", PyPostprocNode, PostprocNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/profiler.h b/include/tvm/meta_schedule/profiler.h index c3754e0211a1..5b82e6606b98 100644 --- a/include/tvm/meta_schedule/profiler.h +++ b/include/tvm/meta_schedule/profiler.h @@ -60,18 +60,18 @@ class ProfilerNode : public runtime::Object { ffi::Function total_timer; static void RegisterReflection() { - // `stats_sec` is not registered - // `total_timer` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - static constexpr const char* _type_key = "meta_schedule.Profiler"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.Profiler", ProfilerNode, runtime::Object); public: /*! \brief Get the internal stats of the running time */ - Map Get() const; + ffi::Map Get() const; /*! \brief Return a summary of profiling results as table format */ - String Table() const; + ffi::String Table() const; }; /*! @@ -81,20 +81,20 @@ class ProfilerNode : public runtime::Object { class Profiler : public runtime::ObjectRef { public: Profiler(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Profiler, runtime::ObjectRef, ProfilerNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Profiler, runtime::ObjectRef, ProfilerNode); /*! \brief Entering the scope of the context manager */ void EnterWithScope(); /*! \brief Exiting the scope of the context manager */ void ExitWithScope(); /*! \brief Returns the current profiler */ - static Optional Current(); + static ffi::Optional Current(); /*! * \brief Profile the time usage in the given scope in the given name. * \param name Name for the scope. * \return A scope timer for time profiling. */ - static ScopedTimer TimedScope(String name); + static ScopedTimer TimedScope(ffi::String name); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 1bfda4820f6a..a88ae5feac1c 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -35,11 +35,11 @@ namespace meta_schedule { class RunnerInputNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ - String artifact_path; + ffi::String artifact_path; /*! \brief The type of device. */ - String device_type; + ffi::String device_type; /*! \brief The argument information. */ - Array args_info; + ffi::Array args_info; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -48,10 +48,7 @@ class RunnerInputNode : public runtime::Object { .def_ro("device_type", &RunnerInputNode::device_type) .def_ro("args_info", &RunnerInputNode::args_info); } - - static constexpr const char* _type_key = "meta_schedule.RunnerInput"; - - TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RunnerInput", RunnerInputNode, runtime::Object); }; /*! @@ -66,17 +63,18 @@ class RunnerInput : public runtime::ObjectRef { * \param device_type The type of device. * \param args_info The argument information. */ - TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array args_info); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode); + TVM_DLL explicit RunnerInput(ffi::String artifact_path, ffi::String device_type, + ffi::Array args_info); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RunnerInput, runtime::ObjectRef, RunnerInputNode); }; /*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ class RunnerResultNode : public runtime::Object { public: /*! \brief The run time in seconds.*/ - Optional> run_secs; + ffi::Optional> run_secs; /*! \brief The error message, if any. */ - Optional error_msg; + ffi::Optional error_msg; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -84,10 +82,8 @@ class RunnerResultNode : public runtime::Object { .def_ro("run_secs", &RunnerResultNode::run_secs) .def_ro("error_msg", &RunnerResultNode::error_msg); } - - static constexpr const char* _type_key = "meta_schedule.RunnerResult"; - - TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RunnerResult", RunnerResultNode, + runtime::Object); }; /*! @@ -101,8 +97,9 @@ class RunnerResult : public runtime::ObjectRef { * \brief The run time in seconds. * \brief The error message, if any. */ - TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); + TVM_DLL explicit RunnerResult(ffi::Optional> run_secs, + ffi::Optional error_msg); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RunnerResult, runtime::ObjectRef, RunnerResultNode); }; /*! @@ -131,6 +128,8 @@ class RunnerFutureNode : public runtime::Object { static void RegisterReflection() { // `f_done` is not registered // `f_result` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } /*! @@ -149,9 +148,8 @@ class RunnerFutureNode : public runtime::Object { ICHECK(f_result != nullptr) << "PyRunnerFuture's Result method not implemented!"; return f_result(); } - - static constexpr const char* _type_key = "meta_schedule.RunnerFuture"; - TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RunnerFuture", RunnerFutureNode, + runtime::Object); }; /*! @@ -169,8 +167,7 @@ class RunnerFuture : public runtime::ObjectRef { * \param f_result The packed function to fetch runner output if it is ready. */ TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef, - RunnerFutureNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RunnerFuture, runtime::ObjectRef, RunnerFutureNode); }; /*! \brief The abstract runner interface. */ @@ -182,7 +179,7 @@ class RunnerNode : public runtime::Object { * \return The runner futures. * \sa RunnerFuture */ - using FRun = ffi::TypedFunction(Array)>; + using FRun = ffi::TypedFunction(ffi::Array)>; /*! \brief Default destructor */ virtual ~RunnerNode() = default; @@ -192,10 +189,15 @@ class RunnerNode : public runtime::Object { * \param runner_inputs The runner's inputs. * \return The runner futures. */ - virtual Array Run(Array runner_inputs) = 0; + virtual ffi::Array Run(ffi::Array runner_inputs) = 0; - static constexpr const char* _type_key = "meta_schedule.Runner"; - TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.Runner", RunnerNode, runtime::Object); }; /*! @@ -205,14 +207,18 @@ class RunnerNode : public runtime::Object { class Runner : public runtime::ObjectRef { public: using FRun = RunnerNode::FRun; - + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit Runner(ObjectPtr data) : ObjectRef(data) { TVM_FFI_ICHECK(data != nullptr); } /*! * \brief Create a runner with customized build method on the python-side. * \param f_run The packed function to run the built artifacts and get runner futures. * \return The runner created. */ TVM_DLL static Runner PyRunner(FRun f_run); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Runner, runtime::ObjectRef, RunnerNode); }; /*! \brief An abstract runner with customized build method on the python-side. */ @@ -223,15 +229,15 @@ class PyRunnerNode : public RunnerNode { static void RegisterReflection() { // `f_run` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - Array Run(Array runner_inputs) final { + ffi::Array Run(ffi::Array runner_inputs) final { ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; return f_run(runner_inputs); } - - static constexpr const char* _type_key = "meta_schedule.PyRunner"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyRunner", PyRunnerNode, RunnerNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h index 125d6dc11fc8..aa3df4e7d443 100644 --- a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -36,7 +36,7 @@ namespace meta_schedule { * \return A sampler that returns a random thread extent. */ std::function MakeFactorSampler(tir::Schedule sch, - Array thread_extents); + ffi::Array thread_extents); /*! * \brief Bind blockIdx.x and threadIdx.x to the given loop @@ -47,9 +47,9 @@ std::function MakeFactorSampler(tir::Schedule sch, * \param get_factor A function that returns the tiling factor. * \return The binded loops in the order of blockIdx.x, threadIdx.x, and the rest. */ -Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // - int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor = nullptr); +ffi::Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor = nullptr); /*! * \brief Bind the given block if it is not bound to blockIdx or threadIdx. diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 9011ebe0c12f..be9074acbde7 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -43,7 +43,8 @@ class ScheduleRuleNode : public runtime::Object { virtual ~ScheduleRuleNode() = default; static void RegisterReflection() { - // No fields to register + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } /*! @@ -59,7 +60,7 @@ class ScheduleRuleNode : public runtime::Object { * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - virtual Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; + virtual ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; /*! * \brief Deep clone the schedule rule. @@ -67,8 +68,8 @@ class ScheduleRuleNode : public runtime::Object { */ virtual ScheduleRule Clone() const = 0; - static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; - TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.ScheduleRule", ScheduleRuleNode, Object); }; /*! @@ -89,12 +90,12 @@ class ScheduleRule : public runtime::ObjectRef { * \return The list of schedules generated by applying the schedule rule. */ using FApply = - ffi::TypedFunction(const tir::Schedule&, const tir::BlockRV&)>; + ffi::TypedFunction(const tir::Schedule&, const tir::BlockRV&)>; /*! * \brief Get the schedule rule as string with name. * \return The string of the schedule rule. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! * \brief The function type of `Clone` method. * \return The cloned schedule rule. @@ -125,7 +126,7 @@ class ScheduleRule : public runtime::ObjectRef { bool disallow_if_then_else, // bool require_injective, // bool require_ordered, // - Optional> disallow_op); + ffi::Optional> disallow_op); /*! * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of @@ -155,13 +156,14 @@ class ScheduleRule : public runtime::ObjectRef { * ignored by default. This function should return True for a block that should be tiled. * \return The schedule rule created */ - TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // - Optional> tile_binds, // - Optional max_innermost_factor, // - Optional> vector_load_lens, // - Optional> reuse_read, // - Optional> reuse_write, - Optional filter_fn = std::nullopt); + TVM_DLL static ScheduleRule MultiLevelTiling( + ffi::String structure, // + ffi::Optional> tile_binds, // + ffi::Optional max_innermost_factor, // + ffi::Optional> vector_load_lens, // + ffi::Optional> reuse_read, // + ffi::Optional> reuse_write, + ffi::Optional filter_fn = std::nullopt); /*! * \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic. @@ -181,9 +183,12 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( - String intrin_name, String structure, Optional> tile_binds, - Optional max_innermost_factor, Optional> vector_load_lens, - Optional> reuse_read, Optional> reuse_write); + ffi::String intrin_name, ffi::String structure, + ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write); /*! * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate @@ -206,10 +211,12 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingTensorCore( - Array> intrin_groups, String structure, - Optional> tile_binds, Optional max_innermost_factor, - Optional> vector_load_lens, Optional> reuse_read, - Optional> reuse_write, bool use_software_pipeline); + ffi::Array> intrin_groups, ffi::String structure, + ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write, bool use_software_pipeline); /*! * \brief Extension of MultiLevelTiling for backends with wide vectors. @@ -223,8 +230,10 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingWideVector( - String structure, Integer vector_length_in_bits, Optional max_innermost_factor, - Optional> reuse_read, Optional> reuse_write); + ffi::String structure, Integer vector_length_in_bits, + ffi::Optional max_innermost_factor, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write); /*! * \brief Create a rule: add-rfactor to some blocks if needed @@ -235,14 +244,14 @@ class ScheduleRule : public runtime::ObjectRef { * limit \return The schedule rule created */ TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // - Optional max_innermost_factor); + ffi::Optional max_innermost_factor); /*! * \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks * correspondingly when needed * \param thread_extents Candidates of thread axis extent (values are required to be positive). * \return The schedule rule created */ - TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); + TVM_DLL static ScheduleRule CrossThreadReduction(ffi::Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The schedule rule created @@ -261,9 +270,9 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // - Array unroll_max_steps, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + ffi::Array unroll_max_steps, // bool unroll_explicit); /*! * \brief Auto bind loops around the block to BlockIdx and ThreadIdx @@ -273,7 +282,7 @@ class ScheduleRule : public runtime::ObjectRef { * when this schedule rule is created. * \return The schedule rule created */ - TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents, + TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, ffi::Array thread_extents, int max_threads_per_block = -1); /*! * \brief Create a schedule rule with customized methods on the python-side. @@ -290,19 +299,21 @@ class ScheduleRule : public runtime::ObjectRef { FAsString f_as_string); /*! \brief Create default schedule rules for LLVM */ - TVM_DLL static Array DefaultLLVM(); + TVM_DLL static ffi::Array DefaultLLVM(); /*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */ - TVM_DLL static Array DefaultX86(const String& type); + TVM_DLL static ffi::Array DefaultX86(const ffi::String& type); /*! \brief Create default schedule rules for CUDA */ - TVM_DLL static Array DefaultCUDA(); + TVM_DLL static ffi::Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ - TVM_DLL static Array DefaultCUDATensorCore(); + TVM_DLL static ffi::Array DefaultCUDATensorCore(); /*! \brief Create default schedule rules for Hexagon */ - TVM_DLL static Array DefaultHexagon(); + TVM_DLL static ffi::Array DefaultHexagon(); /*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */ - TVM_DLL static Array DefaultARM(const String& type); + TVM_DLL static ffi::Array DefaultARM(const ffi::String& type); + /*! \brief Create default schedule rules for RISCV CPU (RVV) */ + TVM_DLL static ffi::Array DefaultRISCV(int vlen); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScheduleRule, ObjectRef, ScheduleRuleNode); }; /*! \brief The schedule rule with customized methods on the python-side. */ @@ -327,14 +338,15 @@ class PyScheduleRuleNode : public ScheduleRuleNode { // `f_apply` is not registered // `f_as_string` is not registered // `f_clone` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } void InitializeWithTuneContext(const TuneContext& context) final; - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; ScheduleRule Clone() const final; - - static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyScheduleRule", PyScheduleRuleNode, + ScheduleRuleNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 9e1af10a01d6..714c43470f05 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -98,9 +98,9 @@ class SearchStrategyNode : public runtime::Object { * and reset the search strategy. */ virtual void PreTuning(int max_trials, int num_trials_per_iter, - const Array& design_spaces, - const Optional& database, - const Optional& cost_model) = 0; + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) = 0; /*! * \brief Post-tuning for the search strategy. @@ -113,15 +113,15 @@ class SearchStrategyNode : public runtime::Object { * \brief Generate measure candidates from design spaces for measurement. * \return The measure candidates generated, nullptr if finished. */ - virtual Optional> GenerateMeasureCandidates() = 0; + virtual ffi::Optional> GenerateMeasureCandidates() = 0; /*! * \brief Update the search strategy with measurement results. * \param measure_candidates The candidates to be measured. * \param results The measurement results from the runner. */ - virtual void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) = 0; + virtual void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) = 0; /*! * \brief Clone the search strategy. @@ -129,8 +129,8 @@ class SearchStrategyNode : public runtime::Object { */ virtual SearchStrategy Clone() const = 0; - static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; - TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.SearchStrategy", SearchStrategyNode, Object); }; /*! @@ -147,22 +147,23 @@ class SearchStrategy : public runtime::ObjectRef { /*! * \brief The function type of `PreTuning` method. */ - using FPreTuning = - ffi::TypedFunction&, - const Optional&, const Optional&)>; + using FPreTuning = ffi::TypedFunction&, + const ffi::Optional&, const ffi::Optional&)>; /*! \brief The function type of `PostTuning` method. */ using FPostTuning = ffi::TypedFunction; /*! * \brief The function type of `GenerateMeasureCandidates` method. * \return The measure candidates generated, nullptr if finished. */ - using FGenerateMeasureCandidates = ffi::TypedFunction>()>; + using FGenerateMeasureCandidates = + ffi::TypedFunction>()>; /*! * \brief The function type of `NotifyRunnerResults` method. * \param results The measurement results from the runner. */ - using FNotifyRunnerResults = - ffi::TypedFunction&, const Array&)>; + using FNotifyRunnerResults = ffi::TypedFunction&, + const ffi::Array&)>; /*! * \brief The function type of `Clone` method. * \return The cloned search strategy. @@ -215,7 +216,7 @@ class SearchStrategy : public runtime::ObjectRef { int genetic_max_fail_count, // double eps_greedy); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SearchStrategy, ObjectRef, SearchStrategyNode); }; /*! \brief The python side customizable class for measure candidate generation */ @@ -248,19 +249,22 @@ class PySearchStrategyNode : public SearchStrategyNode { // `f_generate_measure_candidates` is not registered // `f_notify_runner_results` is not registered // `f_clone` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } void InitializeWithTuneContext(const TuneContext& context) final; - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final; + void PreTuning(int max_trials, int num_trials_per_iter, + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final; void PostTuning() final; - Optional> GenerateMeasureCandidates() final; - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results); + ffi::Optional> GenerateMeasureCandidates() final; + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results); SearchStrategy Clone() const final; - - static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; - TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PySearchStrategy", PySearchStrategyNode, + SearchStrategyNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 7b26b56abbed..460a41e44a20 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -76,11 +76,11 @@ class SpaceGenerator; class SpaceGeneratorNode : public runtime::Object { public: /*! \brief The schedule rules. */ - Optional> sch_rules; + ffi::Optional> sch_rules; /*! \brief The postprocessors. */ - Optional> postprocs; + ffi::Optional> postprocs; /*! \brief The probability of using certain mutator. */ - Optional> mutator_probs; + ffi::Optional> mutator_probs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -105,7 +105,7 @@ class SpaceGeneratorNode : public runtime::Object { * \param mod The module used for design space generation. * \return The generated design spaces, i.e., schedules. */ - virtual Array GenerateDesignSpace(const IRModule& mod) = 0; + virtual ffi::Array GenerateDesignSpace(const IRModule& mod) = 0; /*! * \brief Clone the space generator. @@ -113,8 +113,8 @@ class SpaceGeneratorNode : public runtime::Object { */ virtual SpaceGenerator Clone() const = 0; - static constexpr const char* _type_key = "meta_schedule.SpaceGenerator"; - TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.SpaceGenerator", SpaceGeneratorNode, Object); }; /*! @@ -123,6 +123,13 @@ class SpaceGeneratorNode : public runtime::Object { */ class SpaceGenerator : public runtime::ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit SpaceGenerator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief The function type of `InitializeWithTuneContext` method. * \param context The tuning context for initialization. @@ -133,7 +140,7 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mod The module used for design space generation. * \return The generated design spaces, i.e., schedules. */ - using FGenerateDesignSpace = ffi::TypedFunction(const IRModule&)>; + using FGenerateDesignSpace = ffi::TypedFunction(const IRModule&)>; /*! * \brief The function type of `Clone` method. * \return The cloned space generator. @@ -155,8 +162,9 @@ class SpaceGenerator : public runtime::ObjectRef { * \return The design space generator created. */ TVM_DLL static SpaceGenerator PySpaceGenerator( - Optional> sch_rules, Optional> postprocs, - Optional> mutator_probs, + ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs, FInitializeWithTuneContext f_initialize_with_tune_context, FGenerateDesignSpace f_generate_design_space, FClone f_clone); /*! @@ -164,15 +172,15 @@ class SpaceGenerator : public runtime::ObjectRef { * \param schedule_fn The schedule function, which can have the following signatures: * 1) void(Schedule) * 2) Schedule(Schedule) - * 3) Array(Schedule) + * 3) ffi::Array(Schedule) * \param sch_rules The schedule rules. * \param postprocs The postprocessors. * \param mutator_probs The probability of using certain mutator. */ - TVM_DLL static SpaceGenerator ScheduleFn(ffi::Function schedule_fn, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs); + TVM_DLL static SpaceGenerator ScheduleFn( + ffi::Function schedule_fn, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs); /*! * \brief Create a design space generator that is union of multiple design space generators. * \param space_generators An array of design space generators to be unioned. @@ -181,10 +189,11 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mutator_probs The probability of using certain mutator. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs); + TVM_DLL static SpaceGenerator SpaceGeneratorUnion( + ffi::Array space_generators, + ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs); /*! * \brief Create a design space generator that generates design spaces by applying schedule * rules to blocks in post-DFS order. @@ -194,11 +203,11 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mutator_probs The probability of using certain mutator. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator PostOrderApply(ffi::Function f_block_filter, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); + TVM_DLL static SpaceGenerator PostOrderApply( + ffi::Function f_block_filter, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; /*! \brief The design space generator with customized methods on the python-side. */ @@ -218,14 +227,15 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { // `f_initialize_with_tune_context` is not registered // `f_generate_design_space` is not registered // `f_clone` is not registered + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } void InitializeWithTuneContext(const TuneContext& context) final; - Array GenerateDesignSpace(const IRModule& mod) final; + ffi::Array GenerateDesignSpace(const IRModule& mod) final; SpaceGenerator Clone() const final; - - static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; - TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PySpaceGenerator", PySpaceGeneratorNode, + SpaceGeneratorNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 9c1300d2433f..1cc56f251f10 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -40,7 +40,7 @@ namespace meta_schedule { class TaskRecordNode : public runtime::Object { public: /*! \brief The tune context of the task. */ - TuneContext ctx{nullptr}; + TuneContext ctx{ffi::UnsafeInit()}; /*! \brief The weight of the task */ double task_weight{1.0}; /*! \brief The FLOP count of the task */ @@ -54,11 +54,11 @@ class TaskRecordNode : public runtime::Object { /*! \brief The latency of each run, in milliseconds. */ std::vector latency_ms = {}; /*! \brief The measure candidates. */ - Optional> measure_candidates = std::nullopt; + ffi::Optional> measure_candidates = std::nullopt; /*! \brief The building results. */ - Optional> builder_results = std::nullopt; + ffi::Optional> builder_results = std::nullopt; /*! \brief Packed functions to fetch the runner results asynchronously. */ - Optional> runner_futures = std::nullopt; + ffi::Optional> runner_futures = std::nullopt; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -74,8 +74,8 @@ class TaskRecordNode : public runtime::Object { .def_ro("runner_futures", &TaskRecordNode::runner_futures); } - static constexpr const char* _type_key = "meta_schedule.TaskRecord"; - TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TaskRecord", TaskRecordNode, Object); }; /*! @@ -87,7 +87,7 @@ class TaskRecord : public runtime::ObjectRef { /*! \brief Constructor */ explicit TaskRecord(TuneContext task, double task_weight); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskRecord, ObjectRef, TaskRecordNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TaskRecord, ObjectRef, TaskRecordNode); }; /*! @@ -131,13 +131,13 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief The tuning task's logging function. */ ffi::Function logger; /*! \brief Records for each task */ - Array tasks_; + ffi::Array tasks_; /*! \brief The list of measure callbacks of the scheduler. */ - Array measure_callbacks_; + ffi::Array measure_callbacks_; /*! \brief The database used in tuning */ - Optional database_; + ffi::Optional database_; /*! \brief The cost model used in tuning */ - Optional cost_model_; + ffi::Optional cost_model_; /*! \brief The number of remaining tasks to be tuned. */ int remaining_tasks_; @@ -164,7 +164,7 @@ class TaskSchedulerNode : public runtime::Object { * \param task_id The task id to be joined. * \return The results from the runner. */ - virtual Array JoinRunningTask(int task_id); + virtual ffi::Array JoinRunningTask(int task_id); /*! * \brief Jointly tune a given list of tasks. * \param tasks The tasks to be tuned @@ -178,16 +178,16 @@ class TaskSchedulerNode : public runtime::Object { * \param database The database used in tuning * \param cost_model The cost model used in tuning */ - virtual void Tune(Array tasks, // - Array task_weights, // - int max_trials_global, // - int max_trials_per_task, // - int num_trials_per_iter, // - Builder builder, // - Runner runner, // - Array measure_callbacks, // - Optional database, // - Optional cost_model); + virtual void Tune(ffi::Array tasks, // + ffi::Array task_weights, // + int max_trials_global, // + int max_trials_per_task, // + int num_trials_per_iter, // + Builder builder, // + Runner runner, // + ffi::Array measure_callbacks, // + ffi::Optional database, // + ffi::Optional cost_model); /*! * \brief Terminate a task * \param task_id The id of the task to be terminated @@ -201,8 +201,8 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief Print out a human-readable format of the tuning statistics. */ void PrintTuningStatistics(); - static constexpr const char* _type_key = "meta_schedule.TaskScheduler"; - TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.TaskScheduler", TaskSchedulerNode, Object); }; class TaskScheduler; @@ -219,18 +219,18 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { * \brief The function type of `JoinRunningTask` method. * \param task_id The task id to be joined. */ - using FJoinRunningTask = ffi::TypedFunction(int)>; + using FJoinRunningTask = ffi::TypedFunction(int)>; /*! \brief The function type of `Tune` method. */ - using FTune = ffi::TypedFunction tasks, // - Array task_weights, // - int max_trials_global, // - int max_trials_per_task, // - int num_trials_per_iter, // - Builder builder, // - Runner runner, // - Array measure_callbacks, // - Optional database, // - Optional cost_model)>; + using FTune = ffi::TypedFunction tasks, // + ffi::Array task_weights, // + int max_trials_global, // + int max_trials_per_task, // + int num_trials_per_iter, // + Builder builder, // + Runner runner, // + ffi::Array measure_callbacks, // + ffi::Optional database, // + ffi::Optional cost_model)>; /*! \brief The packed function to the `NextTaskId` function. */ FNextTaskId f_next_task_id; @@ -245,14 +245,13 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { } int NextTaskId() final; - Array JoinRunningTask(int task_id) final; - void Tune(Array tasks, Array task_weights, int max_trials_global, + ffi::Array JoinRunningTask(int task_id) final; + void Tune(ffi::Array tasks, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, Optional database, - Optional cost_model) final; - - static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode); + ffi::Array measure_callbacks, ffi::Optional database, + ffi::Optional cost_model) final; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PyTaskScheduler", PyTaskSchedulerNode, + TaskSchedulerNode); }; /*! @@ -261,6 +260,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { */ class TaskScheduler : public runtime::ObjectRef { public: + explicit TaskScheduler(ObjectPtr data) : runtime::ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a task scheduler that fetches tasks in a round-robin fashion. * \param logger The tuning task's logging function. @@ -288,7 +290,7 @@ class TaskScheduler : public runtime::ObjectRef { TVM_DLL static TaskScheduler PyTaskScheduler( ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TaskScheduler, ObjectRef, TaskSchedulerNode); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 47326ac46b99..a36a946d0ae5 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -48,15 +48,15 @@ class TuneContextNode : public runtime::Object { using TRandState = support::LinearCongruentialEngine::TRandState; /*! \brief The workload to be tuned. */ - Optional mod; + ffi::Optional mod; /*! \brief The target to be tuned for. */ - Optional target; + ffi::Optional target; /*! \brief The design space generator. */ - Optional space_generator; + ffi::Optional space_generator; /*! \brief The search strategy. */ - Optional search_strategy; + ffi::Optional search_strategy; /*! \brief The name of the tuning task. */ - Optional task_name; + ffi::Optional task_name; /*! \brief The number of threads to be used. */ int num_threads; /*! \brief The random state. */ @@ -87,8 +87,8 @@ class TuneContextNode : public runtime::Object { */ TuneContext Clone() const; - static constexpr const char* _type_key = "meta_schedule.TuneContext"; - TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TuneContext", TuneContextNode, Object); }; /*! @@ -98,6 +98,13 @@ class TuneContextNode : public runtime::Object { class TuneContext : public runtime::ObjectRef { public: using TRandState = support::LinearCongruentialEngine::TRandState; + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit TuneContext(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Constructor. * \param mod The workload to be tuned. @@ -109,11 +116,12 @@ class TuneContext : public runtime::ObjectRef { * \param rand_state The random state. * \param logger The tuning task's logging function. */ - TVM_DLL explicit TuneContext(Optional mod, Optional target, - Optional space_generator, - Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, ffi::Function logger); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); + TVM_DLL explicit TuneContext(ffi::Optional mod, ffi::Optional target, + ffi::Optional space_generator, + ffi::Optional search_strategy, + ffi::Optional task_name, int num_threads, + TRandState rand_state, ffi::Function logger); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TuneContext, ObjectRef, TuneContextNode); }; } // namespace meta_schedule diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 37dc710ac161..e273fa8f5fe1 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -86,7 +86,7 @@ class AttrRegistryMapContainerMap { private: /*! \brief The name of the attr field */ - String attr_name_; + ffi::String attr_name_; /*! \brief The internal data. */ std::vector> data_; /*! \brief The constructor */ @@ -97,7 +97,7 @@ class AttrRegistryMapContainerMap { }; /*! - * \brief Map used to store meta-data. + * \brief ffi::Map used to store meta-data. * \tparam KeyType The type of the key * \tparam ValueType The type of the value stored in map. */ diff --git a/include/tvm/node/cast.h b/include/tvm/node/cast.h index ae23c9e9aa33..32d4be721656 100644 --- a/include/tvm/node/cast.h +++ b/include/tvm/node/cast.h @@ -45,19 +45,20 @@ namespace tvm { template >> inline SubRef Downcast(BaseRef ref) { + using ContainerType = typename SubRef::ContainerType; if (ref.defined()) { - if (!ref->template IsInstance()) { + if (!ref->template IsInstance()) { TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key << " failed."; } - return SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); + return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr( + ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); } else { if constexpr (ffi::is_optional_type_v || SubRef::_type_is_nullable) { - return SubRef(ffi::ObjectPtr(nullptr)); + return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); } - TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" - << SubRef::ContainerType::_type_key - << "` is not allowed. Use Downcast> instead."; + TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" << ContainerType::_type_key + << "` is not allowed. Use Downcast> instead."; TVM_FFI_UNREACHABLE(); } } diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 7c8c2bfb9214..d5716f96f6d5 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -34,7 +34,8 @@ namespace tvm { * \param fields The fields of the object. * \return The created object. */ -TVM_DLL ffi::Any CreateObject(const String& type_key, const Map& fields); +TVM_DLL ffi::Any CreateObject(const ffi::String& type_key, + const ffi::Map& fields); } // namespace tvm #endif // TVM_NODE_REFLECTION_H_ diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 05687d70d742..f3e0edab6e07 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -83,7 +83,7 @@ inline std::ostream& operator<<(std::ostream& os, const Any& n) { // NOLINT(*) } template -inline std::ostream& operator<<(std::ostream& os, const Variant& n) { // NOLINT(*) +inline std::ostream& operator<<(std::ostream& os, const ffi::Variant& n) { // NOLINT(*) ReprPrinter(os).Print(Any(n)); return os; } @@ -94,7 +94,7 @@ inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { namespace refl = ffi::reflection; switch (step->kind) { case refl::AccessKind::kAttr: { - os << '.' << step->key.cast(); + os << '.' << step->key.cast(); return os; } case refl::AccessKind::kArrayItem: { @@ -106,7 +106,7 @@ inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { return os; } case refl::AccessKind::kAttrMissing: { - os << ".key.cast() << "`>"; + os << ".key.cast() << "`>"; return os; } case refl::AccessKind::kArrayItemMissing: { @@ -125,7 +125,7 @@ inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { } inline std::ostream& operator<<(std::ostream& os, const AccessPath& path) { - Array steps = path->ToSteps(); + ffi::Array steps = path->ToSteps(); os << ""; for (const auto& step : steps) { os << step; diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index d046dbfae732..ac293c88e884 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -40,7 +40,7 @@ namespace tvm { class PrinterConfigNode : public ffi::Object { public: /*! \brief A stack that tracks the names of the binding hierarchy */ - Array binding_names = {}; + ffi::Array binding_names = {}; /*! \brief Whether or not to show metadata. */ bool show_meta = false; /*! \brief The prefix of IR nodes */ @@ -113,13 +113,13 @@ class PrinterConfigNode : public ffi::Object { bool show_all_struct_info = true; /* \brief Object path to be underlined */ - Array path_to_underline; + ffi::Array path_to_underline; /*! \brief Object path to be annotated. */ - Map path_to_annotate; + ffi::Map path_to_annotate; /*! \brief Object to be underlined. */ - Array obj_to_underline = Array(); + ffi::Array obj_to_underline = ffi::Array(); /*! \brief Object to be annotated. */ - Map obj_to_annotate = Map(); + ffi::Map obj_to_annotate = ffi::Map(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -146,33 +146,35 @@ class PrinterConfigNode : public ffi::Object { .def_ro("obj_to_annotate", &PrinterConfigNode::obj_to_annotate); } - Array GetBuiltinKeywords(); + ffi::Array GetBuiltinKeywords(); - static constexpr const char* _type_key = "script.PrinterConfig"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrinterConfigNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.PrinterConfig", PrinterConfigNode, Object); }; class PrinterConfig : public ObjectRef { public: - explicit PrinterConfig(Map config_dict = Map()); + explicit PrinterConfig( + ffi::Map config_dict = ffi::Map()); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrinterConfig, runtime::ObjectRef, - PrinterConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrinterConfig, runtime::ObjectRef, + PrinterConfigNode); }; /*! \brief Legacy behavior of ReprPrinter. */ class TVMScriptPrinter { public: /* Convert the object to TVMScript format */ - static std::string Script(const ObjectRef& node, const Optional& cfg); + static std::string Script(const ObjectRef& node, const ffi::Optional& cfg); // Allow registration to be printer. using FType = NodeFunctor; TVM_DLL static FType& vtable(); }; -#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ - std::string Script(const Optional& config = std::nullopt) const { \ - return TVMScriptPrinter::Script(GetRef(this), config.value_or(PrinterConfig())); \ +#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ + std::string Script(const ffi::Optional& config = std::nullopt) const { \ + return TVMScriptPrinter::Script(ffi::GetRef(this), \ + config.value_or(PrinterConfig())); \ } } // namespace tvm diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 12ba59118b72..4f00e1770b41 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -58,10 +58,10 @@ class BaseValueEqual { bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } - bool operator()(const Optional& lhs, const Optional& rhs) const { + bool operator()(const ffi::Optional& lhs, const ffi::Optional& rhs) const { return lhs == rhs; } - bool operator()(const Optional& lhs, const Optional& rhs) const { + bool operator()(const ffi::Optional& lhs, const ffi::Optional& rhs) const { return lhs == rhs; } bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 0aca92d0e28a..ba7cbaf88aa6 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include @@ -78,14 +78,14 @@ class BaseValueHash { uint64_t operator()(const std::string& key) const { return tvm::ffi::details::StableHashBytes(key.data(), key.length()); } - uint64_t operator()(const Optional& key) const { + uint64_t operator()(const ffi::Optional& key) const { if (key.has_value()) { return Reinterpret(*key); } else { return 0; } } - uint64_t operator()(const Optional& key) const { + uint64_t operator()(const ffi::Optional& key) const { if (key.has_value()) { return Reinterpret(*key); } else { diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 267eb1b66eeb..73d1a3dbebce 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -53,7 +53,7 @@ namespace relax { * if result is false, there is still possibility that * two shapes equals to each other during runtime. */ -TVM_DLL bool CanProveShapeEqual(const Array& lhs, const Array& rhs, +TVM_DLL bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Array& rhs, arith::Analyzer* ana); /*! @@ -155,11 +155,11 @@ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Ca * * \return the corresponding erased struct info. */ -TVM_DLL StructInfo -EraseToWellDefined(const StructInfo& info, - std::function(const tir::Var& var)> f_shape_var_map = nullptr, - std::function(const Var& var)> f_var_map = nullptr, - arith::Analyzer* ana = nullptr); +TVM_DLL StructInfo EraseToWellDefined( + const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map = nullptr, + std::function(const Var& var)> f_var_map = nullptr, + arith::Analyzer* ana = nullptr); /*! * \brief EraseToWellDefined variant with map. @@ -174,8 +174,9 @@ EraseToWellDefined(const StructInfo& info, * * \return the corresponding erased struct info. */ -TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, - Map var_map, arith::Analyzer* ana = nullptr); +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, + ffi::Map shape_var_map, + ffi::Map var_map, arith::Analyzer* ana = nullptr); /*! * \brief Fine grained result of base check. @@ -289,7 +290,7 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, * \param sinfo The struct info object to be analyzed. * \return The list of TIR variables that appear in the input struct info. */ -TVM_DLL Array TIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo); /*! * \brief Get the TIR variables that appear in the input struct info. @@ -303,7 +304,7 @@ TVM_DLL Array TIRVarsInStructInfo(const StructInfo& sinfo); * deduplicated, each TIR variable will appear at most once, and in * order of occurrence. */ -TVM_DLL Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); /*! \brief Collect expressions whose usage requires them to be non-negative * @@ -316,7 +317,7 @@ TVM_DLL Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); * * \return A list of non-negative expressions. */ -TVM_DLL Array CollectNonNegativeExpressions(const StructInfo& sinfo); +TVM_DLL ffi::Array CollectNonNegativeExpressions(const StructInfo& sinfo); /*! * \brief Get the TIR variables that defined in the input function. @@ -324,7 +325,7 @@ TVM_DLL Array CollectNonNegativeExpressions(const StructInfo& sinfo); * \param expr The relax expression (e.g. a Function) to be analyzed. * \return The list of TIR variables that are defined in the input function. */ -TVM_DLL Array DefinedSymbolicVars(const Expr& expr); +TVM_DLL ffi::Array DefinedSymbolicVars(const Expr& expr); /*! * \brief Get the TIR variables that are used but not defined in the input function. @@ -332,7 +333,7 @@ TVM_DLL Array DefinedSymbolicVars(const Expr& expr); * \param expr The relax expression (e.g. a Function) to be analyzed. * \return The list of TIR variables that are used but not defined in the input function. */ -TVM_DLL Array FreeSymbolicVars(const Expr& expr); +TVM_DLL ffi::Array FreeSymbolicVars(const Expr& expr); //----------------------------------- // General IR analysis //----------------------------------- @@ -346,7 +347,7 @@ TVM_DLL Array FreeSymbolicVars(const Expr& expr); * * \return List of bound vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array BoundVars(const Expr& expr); +TVM_DLL tvm::ffi::Array BoundVars(const Expr& expr); /*! * \brief Get free type parameters from expression expr. @@ -358,7 +359,7 @@ TVM_DLL tvm::Array BoundVars(const Expr& expr); * * \return List of free vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array FreeVars(const Expr& expr); +TVM_DLL tvm::ffi::Array FreeVars(const Expr& expr); /*! * \brief Get all variables from expression expr. @@ -367,7 +368,7 @@ TVM_DLL tvm::Array FreeVars(const Expr& expr); * * \return List of all vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array AllVars(const Expr& expr); +TVM_DLL tvm::ffi::Array AllVars(const Expr& expr); /*! * \brief Get all global variables from expression expr. @@ -379,7 +380,7 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); * * \return List of all global variables, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); +TVM_DLL tvm::ffi::Array AllGlobalVars(const Expr& expr); /*! * \brief Find all sets of recursive or mutually recursive functions in the module. @@ -404,7 +405,7 @@ TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); * If a function is simply recursive and not mutually recursive with any other, * then it will be listed as a group by itself. */ -TVM_DLL tvm::Array> DetectRecursion(const IRModule& m); +TVM_DLL tvm::ffi::Array> DetectRecursion(const IRModule& m); /*! * \brief Analyze var -> value mapping from VarBindings. @@ -412,7 +413,7 @@ TVM_DLL tvm::Array> DetectRecursion(const IRModule& m); * \param m The IRModule to check. * \return Var -> Value (Expr) */ -TVM_DLL Map AnalyzeVar2Value(const IRModule& m); +TVM_DLL ffi::Map AnalyzeVar2Value(const IRModule& m); /*! * \brief Analyze var -> value mapping from VarBindings. @@ -420,7 +421,7 @@ TVM_DLL Map AnalyzeVar2Value(const IRModule& m); * \param expr The expression to check. * \return Var -> Value (Expr) */ -TVM_DLL Map AnalyzeVar2Value(const Expr& expr); +TVM_DLL ffi::Map AnalyzeVar2Value(const Expr& expr); /*! * \brief Analyze var -> value mapping from VarBindings. @@ -428,7 +429,7 @@ TVM_DLL Map AnalyzeVar2Value(const Expr& expr); * \param dfb The dataflow block to check. * \return Var -> Value (Expr) */ -TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); +TVM_DLL ffi::Map AnalyzeVar2Value(const DataflowBlock& dfb); /*! * \brief Return a mapping from variable name to its Bindings. @@ -436,7 +437,7 @@ TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); * \param fn The function to be analyzed. * \return A mapping from variable name to its Bindings. */ -TVM_DLL Map> NameToBinding(const Function& fn); +TVM_DLL ffi::Map> NameToBinding(const Function& fn); /*! * \brief Get the use-def chain of variables inside a dataflow block. @@ -444,7 +445,7 @@ TVM_DLL Map> NameToBinding(const Function& fn); * \param dfb The dataflow block to be analyzed. * \return A map mapping variable definitions to a set of uses. */ -TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); +TVM_DLL ffi::Map> DataflowBlockUseDef(const DataflowBlock& dfb); /*! * \brief Get the use-def chain of variables inside a function. @@ -457,7 +458,7 @@ TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); * variables whose usage occurs outside of any variable binding, * typically the output body of a relax::Function or a relax::SeqExpr. */ -std::pair>, Array> FunctionUseDef(const Expr& expr); +std::pair>, ffi::Array> FunctionUseDef(const Expr& expr); /*! \brief A utility struct returned by CollectVarUsage */ @@ -466,19 +467,19 @@ struct VarUsageInfo { * * This is equivalent to the output of AnalyzeVar2Value */ - Map bound_values; + ffi::Map bound_values; /* \brief The map from variables to downstream usages of the variable * * This is equivalent to the first output of FunctionUseDef. */ - Map> downstream_usage; + ffi::Map> downstream_usage; /* \brief A list of variables produced as output * * This is equivalent to the second output of FunctionUseDef */ - Array outputs; + ffi::Array outputs; }; /*! \brief Collect variable bindings and usage @@ -541,8 +542,8 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. */ -TVM_DLL Optional FindImpureCall( - const Expr& expr, const Optional& own_name = Optional(std::nullopt)); +TVM_DLL ffi::Optional FindImpureCall( + const Expr& expr, const ffi::Optional& own_name = ffi::Optional(std::nullopt)); /*! * \brief Check if the given expression (likely a function body) contains any impure calls. @@ -555,8 +556,8 @@ TVM_DLL Optional FindImpureCall( * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. */ -TVM_DLL bool ContainsImpureCall(const Expr& expr, - const Optional& own_name = Optional(std::nullopt)); +TVM_DLL bool ContainsImpureCall( + const Expr& expr, const ffi::Optional& own_name = ffi::Optional(std::nullopt)); /*! * \brief Check if the IRModule is well formed. @@ -569,7 +570,7 @@ TVM_DLL bool ContainsImpureCall(const Expr& expr, * where `check_struct_info` might be false, so that other well-formed requirements * will be well tested and will not be blocked by not having structure info. */ -TVM_DLL bool WellFormed(Variant obj, bool check_struct_info = true); +TVM_DLL bool WellFormed(ffi::Variant obj, bool check_struct_info = true); /*! * \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks @@ -581,8 +582,8 @@ TVM_DLL bool WellFormed(Variant obj, bool check_struct_info * from the object (block or buffer) to it's index map transformation. */ -TVM_DLL Map> SuggestLayoutTransforms( - const Function& fn, Array write_buffer_transformations); +TVM_DLL ffi::Map> SuggestLayoutTransforms( + const Function& fn, ffi::Array write_buffer_transformations); /* \brief Collect variables whose value can be computed at compile-time * @@ -597,7 +598,7 @@ TVM_DLL Map> SuggestLayoutTransforms( * \return The set of variables that can be computed at compile-time, * in order of their occurrence within the function. */ -TVM_DLL Array ComputableAtCompileTime(const Function& func); +TVM_DLL ffi::Array ComputableAtCompileTime(const Function& func); } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index e6736dd2e731..09d40b4ed98e 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -32,7 +32,7 @@ namespace relax { /*! \brief Attributes used in allreduce operators */ struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { - String op_type; + ffi::String op_type; bool in_group; static void RegisterReflection() { @@ -45,9 +45,7 @@ struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { "Whether the reduction operation performs in group or globally or in group as " "default."); } - - static constexpr const char* _type_key = "relax.attrs.AllReduceAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllReduceAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllReduceAttrs", AllReduceAttrs, BaseAttrsNode); }; // struct AllReduceAttrs /*! \brief Attributes used in allgather operators */ @@ -65,9 +63,7 @@ struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter { "Whether the allgather operation performs in group or globally or in group as " "default."); } - - static constexpr const char* _type_key = "relax.attrs.AllGatherAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllGatherAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllGatherAttrs", AllGatherAttrs, BaseAttrsNode); }; // struct AllGatherAttrs /*! \brief Attributes used in scatter operators */ @@ -85,9 +81,8 @@ struct ScatterCollectiveAttrs : public tvm::AttrsNodeReflAdapter { refl::ObjectDef().def_ro("dtype", &InitAttrs::dtype, "The data type of the created tensor."); } - - static constexpr const char* _type_key = "relax.attrs.InitAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(InitAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InitAttrs", InitAttrs, BaseAttrsNode); }; // struct InitAttrs /*! \brief Attributes used in tril and triu operator */ @@ -53,9 +51,7 @@ struct TriluAttrs : public AttrsNodeReflAdapter { "k", &TriluAttrs::k, "The number of diagonals above or below the main diagonal to exclude or include."); } - - static constexpr const char* _type_key = "relax.attrs.TriluAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TriluAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TriluAttrs", TriluAttrs, BaseAttrsNode); }; // struct TriluAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h index 5f72b284d562..dd07e3b54851 100644 --- a/include/tvm/relax/attrs/datatype.h +++ b/include/tvm/relax/attrs/datatype.h @@ -37,9 +37,7 @@ struct AstypeAttrs : public AttrsNodeReflAdapter { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("dtype", &AstypeAttrs::dtype, "Target data type"); } - - static constexpr const char* _type_key = "relax.attrs.AstypeAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AstypeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AstypeAttrs", AstypeAttrs, BaseAttrsNode); }; // struct AstypeAttrs. /*! \brief Attributes used in wrap_param operator */ @@ -50,9 +48,7 @@ struct WrapParamAttrs : public AttrsNodeReflAdapter { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("dtype", &WrapParamAttrs::dtype, "Target data type"); } - - static constexpr const char* _type_key = "relax.attrs.WrapParamAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(WrapParamAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.WrapParamAttrs", WrapParamAttrs, BaseAttrsNode); }; // struct WrapParamAttrs. } // namespace relax diff --git a/include/tvm/relax/attrs/distributed.h b/include/tvm/relax/attrs/distributed.h index 08a508a9bd53..356a248ba220 100644 --- a/include/tvm/relax/attrs/distributed.h +++ b/include/tvm/relax/attrs/distributed.h @@ -44,9 +44,8 @@ struct DistributionAttrs : public AttrsNodeReflAdapter { .def_ro("placement", &DistributionAttrs::placement, "The placement of a tensor's distribution plan"); } - - static constexpr const char* _type_key = "relax.attrs.DistributionAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(DistributionAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DistributionAttrs", DistributionAttrs, + BaseAttrsNode); }; // struct DistributionAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h index 544ad1ebd1dc..b367ce58433d 100644 --- a/include/tvm/relax/attrs/image.h +++ b/include/tvm/relax/attrs/image.h @@ -31,11 +31,11 @@ namespace relax { /*! \brief Attributes used in image resize2d operator */ struct Resize2DAttrs : public AttrsNodeReflAdapter { - Array roi; - String layout; - String method; - String coordinate_transformation_mode; - String rounding_method; + ffi::Array roi; + ffi::String layout; + ffi::String method; + ffi::String coordinate_transformation_mode; + ffi::String rounding_method; double cubic_alpha; int cubic_exclude; double extrapolation_value; @@ -75,11 +75,31 @@ struct Resize2DAttrs : public AttrsNodeReflAdapter { "The dtype of the output tensor. It it is not specified, the output will have the same " "dtype as input if not specified."); } - - static constexpr const char* _type_key = "relax.attrs.Resize2DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Resize2DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs", Resize2DAttrs, BaseAttrsNode); }; // struct Resize2dAttrs +/*! \brief Attributes used in image grid_sample operator */ +struct GridSampleAttrs : public AttrsNodeReflAdapter { + ffi::String method; + ffi::String layout; + ffi::String padding_mode; + bool align_corners; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("method", &GridSampleAttrs::method, + "Interpolation method. Can be 'nearest', 'bilinear', or 'bicubic'.") + .def_ro("layout", &GridSampleAttrs::layout, + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc.") + .def_ro("padding_mode", &GridSampleAttrs::padding_mode, + "Padding mode for outside grid values. Can be 'zeros', 'border', or 'reflection'.") + .def_ro("align_corners", &GridSampleAttrs::align_corners, + "If True, the corner pixels of the input and output tensors are aligned."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GridSampleAttrs", GridSampleAttrs, BaseAttrsNode); +}; // struct GridSampleAttrs + } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index cc914449db30..0ea7c06bacc0 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -31,8 +31,8 @@ namespace relax { /*! \brief Attributes used in take operator */ struct TakeAttrs : public AttrsNodeReflAdapter { - Optional axis; - String mode; + ffi::Optional axis; + ffi::String mode; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -41,9 +41,7 @@ struct TakeAttrs : public AttrsNodeReflAdapter { .def_ro("mode", &TakeAttrs::mode, "The mode for handling out-of-bounds indices.", refl::DefaultValue("fast")); } - - static constexpr const char* _type_key = "relax.attrs.TakeAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TakeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TakeAttrs", TakeAttrs, BaseAttrsNode); }; // struct TakeAttrs /*! \brief Attributes used in strided_slice operator */ @@ -58,9 +56,8 @@ struct StridedSliceAttrs : public AttrsNodeReflAdapter { "out of bound indices will be clipped to the bound.", refl::DefaultValue(true)); } - - static constexpr const char* _type_key = "relax.attrs.StridedSliceAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(StridedSliceAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StridedSliceAttrs", StridedSliceAttrs, + BaseAttrsNode); }; // struct StridedSliceAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h index 041b9cb1bef4..f95d817f1e4d 100644 --- a/include/tvm/relax/attrs/linear_algebra.h +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -38,23 +38,19 @@ struct MatmulAttrs : public AttrsNodeReflAdapter { refl::ObjectDef().def_ro("out_dtype", &MatmulAttrs::out_dtype, "The data type of the output tensor"); } - - static constexpr const char* _type_key = "relax.attrs.MatmulAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(MatmulAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MatmulAttrs", MatmulAttrs, BaseAttrsNode); }; // struct MatmulAttrs /*! \brief Attributes used in einsum operator */ struct EinsumAttrs : public AttrsNodeReflAdapter { - String subscripts; + ffi::String subscripts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("subscripts", &EinsumAttrs::subscripts, "The einsum expression string"); } - - static constexpr const char* _type_key = "relax.attrs.EinsumAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(EinsumAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.EinsumAttrs", EinsumAttrs, BaseAttrsNode); }; // struct EinsumAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 6a7cfe0baba2..21184848e3c7 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -32,7 +32,7 @@ namespace relax { /*! \brief Attributes used in concat operators */ struct ConcatAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -40,14 +40,12 @@ struct ConcatAttrs : public AttrsNodeReflAdapter { "The axis at which the input arrays are concatenated." "Should lie in range `[-ndim, ndim)`."); } - - static constexpr const char* _type_key = "relax.attrs.ConcatAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ConcatAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ConcatAttrs", ConcatAttrs, BaseAttrsNode); }; // struct ConcatAttrs /*! \brief Attributes used in expand_dims operators */ struct ExpandDimsAttrs : public AttrsNodeReflAdapter { - Array axis; + ffi::Array axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -57,9 +55,7 @@ struct ExpandDimsAttrs : public AttrsNodeReflAdapter { "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, " "with the convention of negative indexing."); } - - static constexpr const char* _type_key = "relax.attrs.ExpandDimsAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ExpandDimsAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ExpandDimsAttrs", ExpandDimsAttrs, BaseAttrsNode); }; // struct ExpandDimsAttrs /*! \brief Attributes used in layout_transform operator */ @@ -67,20 +63,20 @@ struct LayoutTransformAttrs : public AttrsNodeReflAdapter tir::IndexMap index_map; // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. - Optional pad_value; + ffi::Optional pad_value; /*! * axis_separators between input axes when generating flattened output axes. For buffers * representing flat 1-d memory (e.g. any buffer in RAM), this should be an empty array. * For buffers representing non-flat memory, each entry in axis_separators should be the * first input axis that is part of a new flattened axis. */ - Optional> axis_separators; + ffi::Optional> axis_separators; /*! * axis_separators for input buffers. * Needed to identify if the input buffer to layout_transform * contains axis separator. */ - Optional> input_axis_separators; + ffi::Optional> input_axis_separators; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -96,23 +92,21 @@ struct LayoutTransformAttrs : public AttrsNodeReflAdapter .def_ro("input_axis_separators", &LayoutTransformAttrs::input_axis_separators, "The separators between axes to regenerate output"); } - - static constexpr const char* _type_key = "relax.attrs.LayoutTransformAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(LayoutTransformAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayoutTransformAttrs", LayoutTransformAttrs, + BaseAttrsNode); }; // struct LayoutTransformAttrs /*! \brief Attributes used in permute_dims operator */ struct PermuteDimsAttrs : public AttrsNodeReflAdapter { - Optional> axes; + ffi::Optional> axes; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro( "axes", &PermuteDimsAttrs::axes, "The target axes order, reverse order if not specified."); } - - static constexpr const char* _type_key = "relax.attrs.PermuteDimsAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PermuteDimsAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PermuteDimsAttrs", PermuteDimsAttrs, + BaseAttrsNode); }; // struct PermuteDimsAttrs /*! \brief Attributes used in split operator */ @@ -127,14 +121,12 @@ struct SplitAttrs : public AttrsNodeReflAdapter { "The input array of indices or the number of split sections.") .def_ro("axis", &SplitAttrs::axis, "The axis to be splitted"); } - - static constexpr const char* _type_key = "relax.attrs.SplitAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SplitAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SplitAttrs", SplitAttrs, BaseAttrsNode); }; // struct SplitAttrs /*! \brief Attributes used in squeeze operators */ struct SqueezeAttrs : public AttrsNodeReflAdapter { - Optional> axis; + ffi::Optional> axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -144,14 +136,12 @@ struct SqueezeAttrs : public AttrsNodeReflAdapter { "Else, the dimension in axes get squeezed." "It is an error if an axis does not has dimension 1."); } - - static constexpr const char* _type_key = "relax.attrs.SqueezeAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SqueezeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SqueezeAttrs", SqueezeAttrs, BaseAttrsNode); }; // struct SqueezeAttrs /*! \brief Attributes used in stack operators */ struct StackAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -162,15 +152,13 @@ struct StackAttrs : public AttrsNodeReflAdapter { "so it must be in range [-ndim-1, ndim] where ndim is the " "number of dimensions of the input tensors."); } - - static constexpr const char* _type_key = "relax.attrs.StackAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(StackAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StackAttrs", StackAttrs, BaseAttrsNode); }; // struct StackAttrs /*! \brief Attributes used in repeat operators */ struct RepeatAttrs : public AttrsNodeReflAdapter { int repeats; - Optional axis; + ffi::Optional axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -181,23 +169,19 @@ struct RepeatAttrs : public AttrsNodeReflAdapter { "counting from the backward. By default, use the flattened input array, and " "return a flat output array."); } - - static constexpr const char* _type_key = "relax.attrs.RepeatAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(RepeatAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RepeatAttrs", RepeatAttrs, BaseAttrsNode); }; // struct RepeatAttrs /*! \brief Attributes used in tile operators */ struct TileAttrs : public AttrsNodeReflAdapter { - Array repeats; + ffi::Array repeats; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("repeats", &TileAttrs::repeats, "The number of repetitions of data along each axis."); } - - static constexpr const char* _type_key = "relax.attrs.TileAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TileAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TileAttrs", TileAttrs, BaseAttrsNode); }; // struct TileAttrs /*! \brief Attributes used in flip operators */ @@ -210,9 +194,7 @@ struct FlipAttrs : public AttrsNodeReflAdapter { "The axis along which to flip over.", refl::DefaultValue(NullValue())); } - - static constexpr const char* _type_key = "relax.attrs.FlipAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(FlipAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, BaseAttrsNode); }; // struct FlipAttrs /*! \brief Attributes used in gather_elements operators */ @@ -225,9 +207,8 @@ struct GatherElementsAttrs : public AttrsNodeReflAdapter { "The axis along which to index.", refl::DefaultValue(0)); } - - static constexpr const char* _type_key = "relax.attrs.GatherElementsAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(GatherElementsAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherElementsAttrs", GatherElementsAttrs, + BaseAttrsNode); }; // struct GatherElementsAttrs /*! \brief Attributes used in gather_nd operators */ @@ -239,9 +220,7 @@ struct GatherNDAttrs : public AttrsNodeReflAdapter { refl::ObjectDef().def_ro("batch_dims", &GatherNDAttrs::batch_dims, "The number of batch dims.", refl::DefaultValue(0)); } - - static constexpr const char* _type_key = "relax.attrs.GatherNDAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(GatherNDAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherNDAttrs", GatherNDAttrs, BaseAttrsNode); }; // struct GatherNDAttrs /*! \brief Attributes used in index_put operator */ @@ -257,29 +236,25 @@ struct IndexPutAttrs : public AttrsNodeReflAdapter { "otherwise performs tensor[indices] = values.", refl::DefaultValue(false)); } - - static constexpr const char* _type_key = "relax.attrs.IndexPutAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(IndexPutAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.IndexPutAttrs", IndexPutAttrs, BaseAttrsNode); }; // struct IndexPutAttrs /*! \brief Attribute used in meshgrid operator */ struct MeshgridAttrs : public AttrsNodeReflAdapter { - Optional indexing; + ffi::Optional indexing; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("indexing", &MeshgridAttrs::indexing, "Specifies how the grid dimensions are ordered."); } - - static constexpr const char* _type_key = "relax.attrs.MeshgridAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(MeshgridAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MeshgridAttrs", MeshgridAttrs, BaseAttrsNode); }; /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public AttrsNodeReflAdapter { Integer axis; - String reduction; + ffi::String reduction; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -291,14 +266,13 @@ struct ScatterElementsAttrs : public AttrsNodeReflAdapter "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".", refl::DefaultValue("update")); } - - static constexpr const char* _type_key = "relax.attrs.ScatterElementsAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterElementsAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterElementsAttrs", ScatterElementsAttrs, + BaseAttrsNode); }; // struct ScatterElementsAttrs /*! \brief Attributes used in scatter_nd operators */ struct ScatterNDAttrs : public AttrsNodeReflAdapter { - String reduction; + ffi::String reduction; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -308,9 +282,7 @@ struct ScatterNDAttrs : public AttrsNodeReflAdapter { "either \"update\", \"add\", \"mul\", \"min\" or \"max\".", refl::DefaultValue("update")); } - - static constexpr const char* _type_key = "relax.attrs.ScatterNDAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterNDAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterNDAttrs", ScatterNDAttrs, BaseAttrsNode); }; // struct ScatterNDAttrs /*! \brief Attributes used in slice_scatter operator */ @@ -323,9 +295,8 @@ struct SliceScatterAttrs : public AttrsNodeReflAdapter { "the dimension to insert the slice into ", refl::DefaultValue(0)); } - - static constexpr const char* _type_key = "relax.attrs.SliceScatterAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SliceScatterAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SliceScatterAttrs", SliceScatterAttrs, + BaseAttrsNode); }; // struct SliceScatterAttrs /*! \brief Attributes used in one_hot operator */ @@ -339,9 +310,7 @@ struct OneHotAttrs : public AttrsNodeReflAdapter { .def_ro("depth", &OneHotAttrs::depth, "Depth of the one hot dimension.") .def_ro("axis", &OneHotAttrs::axis, "Axis to fill.", refl::DefaultValue(-1)); } - - static constexpr const char* _type_key = "relax.attrs.OneHotAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(OneHotAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.OneHotAttrs", OneHotAttrs, BaseAttrsNode); }; // struct OneHotAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 9f09bce6af2c..13a54a16b378 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -31,13 +31,13 @@ namespace relax { /*! \brief Attributes used in Conv1d operator */ struct Conv1DAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -70,20 +70,18 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter { .def_ro("out_dtype", &Conv1DAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - - static constexpr const char* _type_key = "relax.attrs.Conv1DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv1DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DAttrs", Conv1DAttrs, BaseAttrsNode); }; // struct Conv1dAttrs /*! \brief Attributes used in Conv2d operator */ struct Conv2DAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -118,20 +116,18 @@ struct Conv2DAttrs : public AttrsNodeReflAdapter { .def_ro("out_dtype", &Conv2DAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - - static constexpr const char* _type_key = "relax.attrs.Conv2DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv2DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DAttrs", Conv2DAttrs, BaseAttrsNode); }; // struct Conv2dAttrs /*! \brief Attributes used in Conv3d operator */ struct Conv3DAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -168,21 +164,19 @@ struct Conv3DAttrs : public AttrsNodeReflAdapter { .def_ro("out_dtype", &Conv3DAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - - static constexpr const char* _type_key = "relax.attrs.Conv3DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv3DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DAttrs", Conv3DAttrs, BaseAttrsNode); }; // struct Conv3dAttrs /*! \brief Attributes used in Conv1DTranspose operator */ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array output_padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array output_padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -218,21 +212,20 @@ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter .def_ro("out_dtype", &Conv1DTransposeAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - - static constexpr const char* _type_key = "relax.attrs.Conv1DTransposeAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv1DTransposeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DTransposeAttrs", Conv1DTransposeAttrs, + BaseAttrsNode); }; // struct Conv1DTransposeAttrs /*! \brief Attributes used in Conv2d operator */ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array output_padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array output_padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -270,21 +263,20 @@ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter .def_ro("out_dtype", &Conv2DTransposeAttrs::out_dtype, "Output data type, set to explicit type under mixed precision setting"); } - - static constexpr const char* _type_key = "relax.attrs.Conv2DTransposeAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv2DTransposeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DTransposeAttrs", Conv2DTransposeAttrs, + BaseAttrsNode); }; // struct Conv2DTransposeAttrs /*! \brief Attributes used in max_pool1d and avg_pool1d operator */ struct Pool1DAttrs : public AttrsNodeReflAdapter { - Array pool_size; - Array strides; - Array padding; - Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; - String layout; - String out_layout; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -313,21 +305,19 @@ struct Pool1DAttrs : public AttrsNodeReflAdapter { "'N', 'C', 'W' stands for batch, channel, and width" "dimensions respectively. Pooling is applied on the 'W' dimensions."); } - - static constexpr const char* _type_key = "relax.attrs.Pool1DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Pool1DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool1DAttrs", Pool1DAttrs, BaseAttrsNode); }; // struct Pool1dAttrs /*! \brief Attributes used in max_pool2d and avg_pool2d operator */ struct Pool2DAttrs : public AttrsNodeReflAdapter { - Array pool_size; - Array strides; - Array padding; - Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; - String layout; - String out_layout; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -358,21 +348,19 @@ struct Pool2DAttrs : public AttrsNodeReflAdapter { "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); } - - static constexpr const char* _type_key = "relax.attrs.Pool2DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Pool2DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool2DAttrs", Pool2DAttrs, BaseAttrsNode); }; // struct Pool2dAttrs /*! \brief Attributes used in max_pool3d and avg_pool3d operator */ struct Pool3DAttrs : public AttrsNodeReflAdapter { - Array pool_size; - Array strides; - Array padding; - Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; - String layout; - String out_layout; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -403,16 +391,14 @@ struct Pool3DAttrs : public AttrsNodeReflAdapter { "dimensions respectively. Pooling is applied on the 'D', 'H' and" "'W' dimensions."); } - - static constexpr const char* _type_key = "relax.attrs.Pool3DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Pool3DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool3DAttrs", Pool3DAttrs, BaseAttrsNode); }; // struct Pool3dAttrs /*! \brief Attributes for 1d adaptive pool operator */ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { - Optional> output_size; - String layout; - String out_layout; + ffi::Optional> output_size; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -429,16 +415,15 @@ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { "dimensions respectively. Pooling is applied on the" "'W' dimensions."); } - - static constexpr const char* _type_key = "relax.attrs.AdaptivePool1DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AdaptivePool1DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool1DAttrs", AdaptivePool1DAttrs, + BaseAttrsNode); }; // struct AdaptivePool1DAttrs /*! \brief Attributes for 2d adaptive pool operator */ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { - Optional> output_size; - String layout; - String out_layout; + ffi::Optional> output_size; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -455,16 +440,15 @@ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { "dimensions respectively. Pooling is applied on the 'H' and" "'W' dimensions."); } - - static constexpr const char* _type_key = "relax.attrs.AdaptivePool2DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AdaptivePool2DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool2DAttrs", AdaptivePool2DAttrs, + BaseAttrsNode); }; // struct AdaptivePool2DAttrs /*! \brief Attributes for 3d adaptive pool operator */ struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { - Optional> output_size; - String layout; - String out_layout; + ffi::Optional> output_size; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -481,9 +465,8 @@ struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { "dimensions respectively. Pooling is applied on 'D', 'H' and" "'W' dimensions."); } - - static constexpr const char* _type_key = "relax.attrs.AdaptivePool3DAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AdaptivePool3DAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool3DAttrs", AdaptivePool3DAttrs, + BaseAttrsNode); }; // struct AdaptivePool3DAttrs /*! \brief Attributes used in softmax operators */ @@ -495,9 +478,7 @@ struct SoftmaxAttrs : public AttrsNodeReflAdapter { refl::ObjectDef().def_ro("axis", &SoftmaxAttrs::axis, "The axis to sum over when computing softmax."); } - - static constexpr const char* _type_key = "relax.attrs.SoftmaxAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SoftmaxAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftmaxAttrs", SoftmaxAttrs, BaseAttrsNode); }; /*! \brief Attributes used in softmax operators */ @@ -509,9 +490,7 @@ struct LeakyReluAttrs : public AttrsNodeReflAdapter { refl::ObjectDef().def_ro("alpha", &LeakyReluAttrs::alpha, "The slope of the negative part."); } - - static constexpr const char* _type_key = "relax.attrs.LeakyReluAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(LeakyReluAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LeakyReluAttrs", LeakyReluAttrs, BaseAttrsNode); }; /*! \brief Attributes used in softplus operators */ @@ -527,9 +506,7 @@ struct SoftplusAttrs : public AttrsNodeReflAdapter { .def_ro("threshold", &SoftplusAttrs::threshold, "Value determining when to use linear approximation for numerical stability."); } - - static constexpr const char* _type_key = "relax.attrs.SoftplusAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SoftplusAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftplusAttrs", SoftplusAttrs, BaseAttrsNode); }; /*! \brief Attributes used in PReLU operator */ @@ -541,9 +518,7 @@ struct PReluAttrs : public AttrsNodeReflAdapter { refl::ObjectDef().def_ro("axis", &PReluAttrs::axis, "The axis along which the alpha values are applied."); } - - static constexpr const char* _type_key = "relax.attrs.PReluAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PReluAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PReluAttrs", PReluAttrs, BaseAttrsNode); }; /*! \brief Attributes used in batch_norm operator */ @@ -570,14 +545,12 @@ struct BatchNormAttrs : public AttrsNodeReflAdapter { .def_ro("training", &BatchNormAttrs::training, "Whether we are training (i.e., not in eval mode)."); } - - static constexpr const char* _type_key = "relax.attrs.BatchNormAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BatchNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BatchNormAttrs", BatchNormAttrs, BaseAttrsNode); }; // struct BatchNormAttrs /*! \brief Attributes used in layer_norm operator */ struct LayerNormAttrs : public AttrsNodeReflAdapter { - Array axes; + ffi::Array axes; double epsilon; bool center; bool scale; @@ -594,16 +567,14 @@ struct LayerNormAttrs : public AttrsNodeReflAdapter { .def_ro("scale", &LayerNormAttrs::scale, "Indicating if the gamma scale will be multiplied."); } - - static constexpr const char* _type_key = "relax.attrs.LayerNormAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(LayerNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayerNormAttrs", LayerNormAttrs, BaseAttrsNode); }; // struct LayerNormAttrs /*! \brief Attributes used in group_norm operator */ struct GroupNormAttrs : public AttrsNodeReflAdapter { int num_groups; int channel_axis; - Array axes; + ffi::Array axes; double epsilon; bool center; bool scale; @@ -625,15 +596,13 @@ struct GroupNormAttrs : public AttrsNodeReflAdapter { .def_ro("scale", &GroupNormAttrs::scale, "Indicating if the gamma scale will be multiplied."); } - - static constexpr const char* _type_key = "relax.attrs.GroupNormAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(GroupNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GroupNormAttrs", GroupNormAttrs, BaseAttrsNode); }; // struct GroupNormAttrs /*! \brief Attributes used in instance_norm operator */ struct InstanceNormAttrs : public AttrsNodeReflAdapter { int channel_axis; - Array axes; + ffi::Array axes; double epsilon; bool center; bool scale; @@ -652,14 +621,13 @@ struct InstanceNormAttrs : public AttrsNodeReflAdapter { .def_ro("scale", &InstanceNormAttrs::scale, "Indicating if the gamma scale will be multiplied."); } - - static constexpr const char* _type_key = "relax.attrs.InstanceNormAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(InstanceNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InstanceNormAttrs", InstanceNormAttrs, + BaseAttrsNode); }; // struct InstanceNormAttrs /*! \brief Attributes used in rms_norm operator */ struct RMSNormAttrs : public AttrsNodeReflAdapter { - Array axes; + ffi::Array axes; double epsilon; static void RegisterReflection() { @@ -670,14 +638,12 @@ struct RMSNormAttrs : public AttrsNodeReflAdapter { .def_ro("epsilon", &RMSNormAttrs::epsilon, "Small float added to variance to avoid dividing by zero"); } - - static constexpr const char* _type_key = "relax.attrs.RMSNormAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(RMSNormAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RMSNormAttrs", RMSNormAttrs, BaseAttrsNode); }; // struct RMSNormAttrs /*! \brief Attributes used in nll_loss operator */ struct NLLLossAttrs : public AttrsNodeReflAdapter { - String reduction; + ffi::String reduction; int ignore_index; static void RegisterReflection() { @@ -689,9 +655,7 @@ struct NLLLossAttrs : public AttrsNodeReflAdapter { refl::DefaultValue("mean")) .def_ro("ignore_index", &NLLLossAttrs::ignore_index, "The target value to ignore."); } - - static constexpr const char* _type_key = "relax.attrs.NLLLossAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(NLLLossAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NLLLossAttrs", NLLLossAttrs, BaseAttrsNode); }; // struct NLLLossAttrs /*! \brief Attributes used in dropout operator */ @@ -704,16 +668,14 @@ struct DropoutAttrs : public AttrsNodeReflAdapter { "rate", &DropoutAttrs::rate, "Fraction of the input that gets dropped out during training time"); } - - static constexpr const char* _type_key = "relax.attrs.DropoutAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(DropoutAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DropoutAttrs", DropoutAttrs, BaseAttrsNode); }; // struct DropoutAttrs /*! \brief Attributes used in Attention operator */ struct AttentionAttrs : public AttrsNodeReflAdapter { - Optional scale; - Optional causal_mask; - Optional window_size; + ffi::Optional scale; + ffi::Optional causal_mask; + ffi::Optional window_size; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -726,16 +688,14 @@ struct AttentionAttrs : public AttrsNodeReflAdapter { .def_ro("window_size", &AttentionAttrs::window_size, "The size of the window for sliding-window attention."); } - - static constexpr const char* _type_key = "relax.attrs.AttentionAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AttentionAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AttentionAttrs", AttentionAttrs, BaseAttrsNode); }; // struct AttentionAttrs /*! \brief Attributes used for the padding operator */ struct PadAttrs : public AttrsNodeReflAdapter { - Array pad_width; + ffi::Array pad_width; double pad_value = 0.0; - tvm::String pad_mode; + tvm::ffi::String pad_mode; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -751,9 +711,7 @@ struct PadAttrs : public AttrsNodeReflAdapter { "\"reflect\" pads by reflecting values with respect to the edges.", refl::DefaultValue("constant")); } - - static constexpr const char* _type_key = "relax.attrs.PadAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PadAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PadAttrs", PadAttrs, BaseAttrsNode); }; /*! \brief Attributes used for the pixel shuffle operator */ @@ -766,9 +724,8 @@ struct PixelShuffleAttrs : public AttrsNodeReflAdapter { &PixelShuffleAttrs::upscale_factor, "Scale factor for spatial upsampling."); } - - static constexpr const char* _type_key = "relax.attrs.PixelShuffleAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PixelShuffleAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PixelShuffleAttrs", PixelShuffleAttrs, + BaseAttrsNode); }; } // namespace relax diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 337f8dc4cbc2..54640901ff53 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -32,8 +32,8 @@ namespace relax { /*! \brief Attributes used in call_tir_with_grad */ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter { - String te_grad_name; - Map te_grad_kwargs; + ffi::String te_grad_name; + ffi::Map te_grad_kwargs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -44,9 +44,8 @@ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter .def_ro("te_grad_kwargs", &CallTIRWithGradAttrs::te_grad_kwargs, "The keyword arguments passed to the te gradient function."); } - - static constexpr const char* _type_key = "relax.attrs.CallTIRWithGradAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(CallTIRWithGradAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallTIRWithGradAttrs", CallTIRWithGradAttrs, + BaseAttrsNode); }; // struct CallTIRAttrs /*! \brief Attributes used in call_tir_inplace */ @@ -58,16 +57,15 @@ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { * store the `i`th output. If an element has the value -1, that means a new tensor should be * allocated for that output. */ - Array inplace_indices; + ffi::Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("inplace_indices", &CallTIRInplaceAttrs::inplace_indices); } - - static constexpr const char* _type_key = "relax.attrs.CallTIRInplaceAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(CallTIRInplaceAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallTIRInplaceAttrs", CallTIRInplaceAttrs, + BaseAttrsNode); }; // struct CallTIRInplaceAttrs /*! \brief Attributes used in call_inplace_packed */ @@ -79,16 +77,15 @@ struct CallInplacePackedAttrs : public AttrsNodeReflAdapter inplace_indices; + ffi::Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("inplace_indices", &CallInplacePackedAttrs::inplace_indices); } - - static constexpr const char* _type_key = "relax.attrs.CallInplacePackedAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(CallInplacePackedAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.CallInplacePackedAttrs", CallInplacePackedAttrs, + BaseAttrsNode); }; // struct CallInplacePackedAttrs /*! \brief Attributes used in to_vdevice */ @@ -100,26 +97,25 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { refl::ObjectDef().def_ro("dst_vdevice", &ToVDeviceAttrs::dst_vdevice, "The destination device where the data is copied to."); } - - static constexpr const char* _type_key = "relax.attrs.ToVDeviceAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ToVDeviceAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ToVDeviceAttrs", ToVDeviceAttrs, BaseAttrsNode); }; // struct ToVDeviceAttrs /*! \brief Attributes used in hint_on_device */ struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { - int32_t dev_type; - int32_t dev_id; + int32_t device_type; + int32_t index; + MemoryScope memory_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("dev_type", &HintOnDeviceAttrs::dev_type, + .def_ro("device_type", &HintOnDeviceAttrs::device_type, "The device type where the data is supposed to be executed.") - .def_ro("dev_id", &HintOnDeviceAttrs::dev_id, "The device id."); + .def_ro("index", &HintOnDeviceAttrs::index, "The device id.") + .def_ro("memory_scope", &HintOnDeviceAttrs::memory_scope, "The device memory scope."); } - - static constexpr const char* _type_key = "relax.attrs.HintOnDeviceAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(HintOnDeviceAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.HintOnDeviceAttrs", HintOnDeviceAttrs, + BaseAttrsNode); }; // struct HintOnDeviceAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h index 71343f10beb4..ffb554994f98 100644 --- a/include/tvm/relax/attrs/qdq.h +++ b/include/tvm/relax/attrs/qdq.h @@ -43,9 +43,7 @@ struct QuantizeAttrs : public AttrsNodeReflAdapter { "Default value is -1, which corresponds to the last axis.", refl::DefaultValue(-1)); } - - static constexpr const char* _type_key = "relax.attrs.QuantizeAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(QuantizeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.QuantizeAttrs", QuantizeAttrs, BaseAttrsNode); }; // QuantizeAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/sampling.h b/include/tvm/relax/attrs/sampling.h index 8144e85e1623..53fd3a140497 100644 --- a/include/tvm/relax/attrs/sampling.h +++ b/include/tvm/relax/attrs/sampling.h @@ -39,9 +39,8 @@ struct MultinomialFromUniformAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; bool keepdims; static void RegisterReflection() { @@ -44,9 +44,8 @@ struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { "with size " "one."); } - - static constexpr const char* _type_key = "relax.attrs.ArgmaxArgminAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ArgmaxArgminAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgmaxArgminAttrs", ArgmaxArgminAttrs, + BaseAttrsNode); }; // struct ArgmaxArgminAttrs /*! \brief Attributes for bucketize operator */ @@ -62,9 +61,7 @@ struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter { .def_ro("right", &BucketizeAttrs::right, "Determines the behavior for values in boundaries"); } - - static constexpr const char* _type_key = "relax.attrs.BucketizeAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BucketizeAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BucketizeAttrs", BucketizeAttrs, BaseAttrsNode); }; // struct BucketizeAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index 81705c71a261..0731c6cf4f6d 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -47,9 +47,7 @@ struct SortAttrs : public AttrsNodeReflAdapter { "If it is not specified, it defaults to the ascending order.", refl::DefaultValue(false)); } - - static constexpr const char* _type_key = "relax.attrs.SortAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SortAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SortAttrs", SortAttrs, BaseAttrsNode); }; // struct SortAttrs /*! \brief Attributes used in argsort operator */ @@ -72,9 +70,7 @@ struct ArgsortAttrs : public AttrsNodeReflAdapter { .def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.", refl::DefaultValue(NullValue())); } - - static constexpr const char* _type_key = "relax.attrs.ArgsortAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ArgsortAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs, BaseAttrsNode); }; // struct ArgsortAttrs /*! \brief Attributes used in topk operator */ @@ -82,7 +78,7 @@ struct TopKAttrs : public AttrsNodeReflAdapter { int k; int axis; bool largest; - String ret_type; + ffi::String ret_type; DataType dtype; static void RegisterReflection() { @@ -104,9 +100,7 @@ struct TopKAttrs : public AttrsNodeReflAdapter { .def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.", refl::DefaultValue(NullValue())); } - - static constexpr const char* _type_key = "relax.attrs.TopKAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TopKAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs, BaseAttrsNode); }; // struct TopKAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index c61169dc9923..433524116d3c 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for statistical operators */ struct StatisticalAttrs : public AttrsNodeReflAdapter { - Optional> axis; + ffi::Optional> axis; bool keepdims; static void RegisterReflection() { @@ -44,14 +44,13 @@ struct StatisticalAttrs : public AttrsNodeReflAdapter { "with size " "one."); } - - static constexpr const char* _type_key = "relax.attrs.StatisticalAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(StatisticalAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StatisticalAttrs", StatisticalAttrs, + BaseAttrsNode); }; // struct StatisticalAttrs /*! \brief Attributes used in scan operators like cumsum, cumprod */ struct ScanopAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; DataType dtype; Bool exclusive = Bool(false); @@ -67,9 +66,7 @@ struct ScanopAttrs : public AttrsNodeReflAdapter { .def_ro("exclusive", &ScanopAttrs::exclusive, "The first element is not included", refl::DefaultValue(Bool(false))); } - - static constexpr const char* _type_key = "relax.attrs.ScanopAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScanopAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScanopAttrs", ScanopAttrs, BaseAttrsNode); }; // struct ScanopAttrs } // namespace relax diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h new file mode 100644 index 000000000000..2fd98533b589 --- /dev/null +++ b/include/tvm/relax/attrs/vision.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/relax/attrs/vision.h + * \brief Auxiliary attributes for vision operators. + */ +#ifndef TVM_RELAX_ATTRS_VISION_H_ +#define TVM_RELAX_ATTRS_VISION_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in AllClassNonMaximumSuppression operator */ +struct AllClassNonMaximumSuppressionAttrs + : public AttrsNodeReflAdapter { + ffi::String output_format; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro( + "output_format", &AllClassNonMaximumSuppressionAttrs::output_format, + "Output format, onnx or tensorflow. Returns outputs in a way that can be easily " + "consumed by each frontend."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllClassNonMaximumSuppressionAttrs", + AllClassNonMaximumSuppressionAttrs, BaseAttrsNode); +}; // struct AllClassNonMaximumSuppressionAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_VISION_H_ diff --git a/include/tvm/relax/backend/adreno/transform.h b/include/tvm/relax/backend/adreno/transform.h new file mode 100644 index 000000000000..891a19187739 --- /dev/null +++ b/include/tvm/relax/backend/adreno/transform.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend/adreno/transform.h + * \brief Adreno GPU specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ +#define TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ + +#include +#include +namespace tvm { +namespace relax { +namespace backend { +namespace adreno { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; +using DataflowBlock = tvm::relax::DataflowBlock; +using tvm::relax::transform::CreateFunctionPass; +using tvm::transform::CreateModulePass; + +/*! + * \brief This pass is designed to annotate the memory scope information via VDevice attribute. + * This pass need operator attrbutes which in general vanish aftre legalization. + * FuseOps and FuseTIR are modified to pass on the operator specific attributes and also + * op_pattern details as part of the PrimFunc. This pass is Adreno specific and annotates each + * BindingVar with appropriate HintInDevice. RealizeVDevice pass followed by handles these hints. + * Followed by this pass we also invoke SpecializePrimFuncBasedOnCallSite which updates the + * var_buffer_map based on this new VDevice information. + */ +TVM_DLL Pass AnnotateCustomMemoryScope(Target target); + +/* + * \brief This is a texture specific pass that can optimize unnecessary to_device copies. + * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + * store into global scope avoiding unnecessary device copy. + */ +TVM_DLL Pass FoldVDeviceScopeChange(); + +} // namespace transform +} // namespace adreno +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h index bdb405a0af6e..90d5b1540ee0 100644 --- a/include/tvm/relax/binding_rewrite.h +++ b/include/tvm/relax/binding_rewrite.h @@ -46,7 +46,7 @@ class DataflowBlockRewriteNode : public Object { /*! \brief Insert a Binding statement. */ void Add(Binding binding); /*! \brief Insert an expression as VarBinding with variable name. */ - void Add(String var_name, Expr expr, bool is_dfvar = false) { + void Add(ffi::String var_name, Expr expr, bool is_dfvar = false) { auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) // : Var(var_name, GetStructInfo(expr)); Add(VarBinding(std::move(var), std::move(expr))); @@ -74,18 +74,16 @@ class DataflowBlockRewriteNode : public Object { .def_ro("dfb", &DataflowBlockRewriteNode::dfb_) .def_ro("root_fn", &DataflowBlockRewriteNode::root_fn_); } - - static constexpr const char* _type_key = "relax.DataflowBlockRewrite"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DataflowBlockRewrite", DataflowBlockRewriteNode, Object); protected: friend class DataflowBlockRewrite; - DataflowBlock dfb_; //!< The rewritten dataflow block. - Optional root_fn_; //!< The rewritten function. - const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. - Map> to_users_; //!< Map from variable to its users. - Array fn_outputs_; //!< Variables required by function outputs. + DataflowBlock dfb_; //!< The rewritten dataflow block. + ffi::Optional root_fn_; //!< The rewritten function. + const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. + ffi::Map> to_users_; //!< Map from variable to its users. + ffi::Array fn_outputs_; //!< Variables required by function outputs. private: NameSupply name_supply_; //!< Name supply for tracking and generating unique names. @@ -108,7 +106,8 @@ class DataflowBlockRewrite : public ObjectRef { return static_cast(get_mutable()); } - TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, DataflowBlockRewriteNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowBlockRewrite, ObjectRef, + DataflowBlockRewriteNode); }; } // namespace relax diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index c33d99b5f91f..2ab6b52f4a91 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -104,7 +104,7 @@ class BlockBuilderNode : public Object { * GlobalVar directly. * \return The global var bound to the added function. */ - virtual GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) = 0; + virtual GlobalVar AddFunction(const BaseFunc& func, ffi::String func_name_hint) = 0; /*! * \brief Update a Relax function or a TIR PrimFunc in the internal context module. @@ -128,7 +128,7 @@ class BlockBuilderNode : public Object { * \return The Expr bound to the input \p var. * \note For function parameters, this function returns std::nullopt. */ - virtual Optional LookupBinding(const Var& var) = 0; + virtual ffi::Optional LookupBinding(const Var& var) = 0; /*! * \brief Begin a new scope, with optional parameters that @@ -144,7 +144,7 @@ class BlockBuilderNode : public Object { * * \sa EndScope */ - virtual void BeginScope(Optional> params) = 0; + virtual void BeginScope(ffi::Optional> params) = 0; /*! * \brief Begin a new scope, which inherits visible parameters from @@ -204,7 +204,7 @@ class BlockBuilderNode : public Object { * \note This Emit function normalizes the \p expr, and * performs shape and type deductions by calling Normalize. */ - virtual Var Emit(Expr expr, String name_hint = "") = 0; + virtual Var Emit(Expr expr, ffi::String name_hint = "") = 0; /*! * \brief Emit a MatchCast. @@ -213,7 +213,7 @@ class BlockBuilderNode : public Object { * \param name_hint Name hint for the bound variable. * \return The variable bound to the MatchCast. */ - virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0; + virtual Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint = "") = 0; /*! * \brief Generate an output for the current dataflow block. @@ -221,7 +221,7 @@ class BlockBuilderNode : public Object { * \param name_hint Name hint for the bound variable. * \return The variable bound to \p output. */ - virtual Var EmitOutput(Expr output, String name_hint = "") = 0; + virtual Var EmitOutput(Expr output, ffi::String name_hint = "") = 0; /*! * \brief Emit a binding that is already normalized. @@ -257,8 +257,8 @@ class BlockBuilderNode : public Object { */ virtual arith::Analyzer* GetAnalyzer() = 0; - static constexpr const char* _type_key = "relax.BlockBuilder"; - TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("relax.BlockBuilder", BlockBuilderNode, Object); }; class BlockBuilder : public ObjectRef { @@ -274,7 +274,7 @@ class BlockBuilder : public ObjectRef { * ctx_mod so you can lookup the context functions for cross function * call analysis. */ - TVM_DLL static BlockBuilder Create(Optional ctx_mod); + TVM_DLL static BlockBuilder Create(ffi::Optional ctx_mod); /*! \brief A marker struct to disable FNormalize * @@ -315,10 +315,10 @@ class BlockBuilder : public ObjectRef { * ctx_mod so you can lookup the context functions for cross function * call analysis. */ - TVM_DLL static BlockBuilder Create(Optional ctx_mod, + TVM_DLL static BlockBuilder Create(ffi::Optional ctx_mod, DisableOperatorSpecificNormalizationForTVMScript tag); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BlockBuilder, ObjectRef, BlockBuilderNode); }; } // namespace relax diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index 80359135c200..8a834d1fcd01 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -44,11 +44,12 @@ namespace relax { * \return true if matched * \return false if unmatched */ -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings = std::nullopt); +bool MatchExpr(DFPattern pattern, Expr expr, + ffi::Optional> bindings = std::nullopt); /* \brief Similar to above, but return pairs of a matching pattern and an expression. */ -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings = std::nullopt); +ffi::Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, ffi::Optional> bindings = std::nullopt); /** * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. @@ -56,8 +57,8 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, * \param dfb The function to match. * \return Matched patterns and corresponding bound variables */ -TVM_DLL Optional> MatchGraph(const PatternContext& ctx, - const DataflowBlock& dfb); +TVM_DLL ffi::Optional> MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb); /** * \brief Rewrite a function with the given pattern and the rewriter function. @@ -70,7 +71,8 @@ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, */ TVM_DLL Function RewriteBindings( const PatternContext& ctx, - ffi::TypedFunction(Map, Map)> rewriter, Function f); + ffi::TypedFunction(ffi::Map, ffi::Map)> rewriter, + Function f); /** * \brief Rewrite a function with the given pattern and the rewriter function. @@ -96,7 +98,7 @@ TVM_DLL Function RewriteBindings( * \return The updated function, if any updates were applied. */ TVM_DLL Function RewriteCall(const DFPattern& pattern, - ffi::TypedFunction)> rewriter, + ffi::TypedFunction)> rewriter, Function func); } // namespace relax diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index c302b29864ab..1925d5ae148d 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -90,9 +90,8 @@ TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs); */ class DFPatternNode : public Object { public: - static constexpr const char* _type_key = "DFPatternNode"; static constexpr const uint32_t _type_child_slots = 21; - TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.DFPattern", DFPatternNode, Object); }; /*! @@ -113,7 +112,7 @@ class DFPattern : public ObjectRef { /*! \brief Syntatic Sugar for creating a NotPattern */ TVM_DLL NotPattern operator~() const; /*! \brief Syntatic Sugar for creating an AttrPattern */ - TVM_DLL AttrPattern HasAttr(const Map& attrs) const; + TVM_DLL AttrPattern HasAttr(const ffi::Map& attrs) const; /*! \brief Syntatic Sugar for creating a StructInfoPattern */ TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ @@ -121,7 +120,7 @@ class DFPattern : public ObjectRef { /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; /*! \brief Syntatic Sugar for creating a ShapePattern */ - TVM_DLL ShapePattern HasShape(const Array& shape) const; + TVM_DLL ShapePattern HasShape(const ffi::Array& shape) const; /*! \brief Syntatic Sugar for creating a ShapePattern */ TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const; /*! \brief Syntatic Sugar for duplicating the current pattern */ @@ -130,7 +129,7 @@ class DFPattern : public ObjectRef { /*! \brief Implicit conversion from DFPattern to PatternSeq */ TVM_DLL operator PatternSeq() const; - TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DFPattern, ObjectRef, DFPatternNode); }; /*! \brief Constraint of a DFPattern edge (producer -> consumer) in graph-level matching */ @@ -165,7 +164,7 @@ struct PairCons { class DFConstraintNode : public Object { public: /*! \brief Return the patterns on which the constraint depends */ - virtual Array GetDependentPatterns() const = 0; + virtual ffi::Array GetDependentPatterns() const = 0; /*! \brief Convert the constraint to a PrimExpr * @@ -195,16 +194,15 @@ class DFConstraintNode : public Object { * sufficient for the constraint to be satisfied. */ virtual std::tuple AsPrimExpr( - std::function(const DFPatternNode*)> match_state) const = 0; + std::function(const DFPatternNode*)> match_state) const = 0; - static constexpr const char* _type_key = "DFConstraintNode"; static constexpr const uint32_t _type_child_slots = 1; - TVM_DECLARE_BASE_OBJECT_INFO(DFConstraintNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.DFConstraint", DFConstraintNode, Object); }; class DFConstraint : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(DFConstraint, ObjectRef, DFConstraintNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DFConstraint, ObjectRef, DFConstraintNode); }; /*! @@ -213,16 +211,14 @@ class DFConstraint : public ObjectRef { */ class PatternSeqNode final : public Object { public: - tvm::Array patterns; /*!< The sequence of DFPatterns */ + tvm::ffi::Array patterns; /*!< The sequence of DFPatterns */ std::vector pair_constraints; /*!< Constraints between the previous and next patterns */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("patterns", &PatternSeqNode::patterns); } - - static constexpr const char* _type_key = "relax.dpl.PatternSeq"; - TVM_DECLARE_BASE_OBJECT_INFO(PatternSeqNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.PatternSeq", PatternSeqNode, Object); }; /*! @@ -232,7 +228,7 @@ class PatternSeqNode final : public Object { class PatternSeq final : public ObjectRef { public: TVM_DLL explicit PatternSeq(DFPattern init_pattern); - TVM_DLL explicit PatternSeq(tvm::Array patterns, bool only_used_by = false); + TVM_DLL explicit PatternSeq(tvm::ffi::Array patterns, bool only_used_by = false); PatternSeq UsedBy(PatternSeq other, int index = -1) const; PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const; @@ -244,7 +240,7 @@ class PatternSeq final : public ObjectRef { friend PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); friend PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); - TVM_DEFINE_OBJECT_REF_METHODS(PatternSeq, ObjectRef, PatternSeqNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PatternSeq, ObjectRef, PatternSeqNode); }; /*! @@ -269,9 +265,7 @@ class PatternContextNode : public Object { // Non-edge constraints std::vector validation_constraints; - - static constexpr const char* _type_key = "relax.dpl.PatternContext"; - TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.PatternContext", PatternContextNode, Object); }; /*! @@ -280,6 +274,7 @@ class PatternContextNode : public Object { */ class PatternContext : public ObjectRef { public: + explicit PatternContext(ffi::UnsafeInit tag) : ObjectRef(tag) {} TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} TVM_DLL explicit PatternContext(bool incremental = false); @@ -329,7 +324,7 @@ class PatternContext : public ObjectRef { } /*! \brief Get the constraint context object on the top of the stack */ - TVM_DLL static Optional Current(); + TVM_DLL static ffi::Optional Current(); /*! \brief The RAII-like entry of a constraint context scope */ TVM_DLL void EnterWithScope() const; @@ -352,9 +347,7 @@ class ExprPatternNode : public DFPatternNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("expr", &ExprPatternNode::expr); } - - static constexpr const char* _type_key = "relax.dpl.ExprPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ExprPattern", ExprPatternNode, DFPatternNode); }; /*! @@ -364,7 +357,7 @@ class ExprPatternNode : public DFPatternNode { class ExprPattern : public DFPattern { public: TVM_DLL explicit ExprPattern(Expr expr); - TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExprPattern, DFPattern, ExprPatternNode); }; /*! @@ -374,17 +367,16 @@ class ExprPattern : public DFPattern { */ class VarPatternNode : public DFPatternNode { public: - String name; - const String& name_hint() const { return name; } + ffi::String name; + const ffi::String& name_hint() const { return name; } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("name", &VarPatternNode::name); } - static constexpr const char* _type_key = "relax.dpl.VarPattern"; static constexpr const uint32_t _type_child_slots = 1; - TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.VarPattern", VarPatternNode, DFPatternNode); }; /*! @@ -398,8 +390,8 @@ class VarPattern : public DFPattern { * * \param name_hint Variable name to match. Any if empty (""). */ - TVM_DLL VarPattern(String name_hint); - TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); + TVM_DLL VarPattern(ffi::String name_hint); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VarPattern, DFPattern, VarPatternNode); }; /*! @@ -412,9 +404,8 @@ class DataflowVarPatternNode : public VarPatternNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "relax.dpl.DataflowVarPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, VarPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.DataflowVarPattern", DataflowVarPatternNode, + VarPatternNode); }; /*! @@ -424,8 +415,8 @@ class DataflowVarPatternNode : public VarPatternNode { class DataflowVarPattern : public DFPattern { public: /*! \sa VarPattern::VarPattern */ - TVM_DLL DataflowVarPattern(String name_hint); - TVM_DEFINE_OBJECT_REF_METHODS(DataflowVarPattern, DFPattern, DataflowVarPatternNode); + TVM_DLL DataflowVarPattern(ffi::String name_hint); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowVarPattern, DFPattern, DataflowVarPatternNode); }; /*! @@ -434,8 +425,8 @@ class DataflowVarPattern : public DFPattern { */ class GlobalVarPatternNode : public VarPatternNode { public: - static constexpr const char* _type_key = "relax.dpl.GlobalVarPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.GlobalVarPattern", GlobalVarPatternNode, + DFPatternNode); }; /*! @@ -444,8 +435,8 @@ class GlobalVarPatternNode : public VarPatternNode { */ class GlobalVarPattern : public DFPattern { public: - TVM_DLL GlobalVarPattern(String name_hint); - TVM_DEFINE_OBJECT_REF_METHODS(GlobalVarPattern, DFPattern, GlobalVarPatternNode); + TVM_DLL GlobalVarPattern(ffi::String name_hint); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GlobalVarPattern, DFPattern, GlobalVarPatternNode); }; /*! @@ -458,9 +449,8 @@ class ConstantPatternNode : public DFPatternNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "relax.dpl.ConstantPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ConstantPattern", ConstantPatternNode, + DFPatternNode); }; /*! @@ -469,7 +459,7 @@ class ConstantPatternNode : public DFPatternNode { */ class ConstantPattern : public DFPattern { public: - TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ConstantPattern, DFPattern, ConstantPatternNode); }; /*! @@ -483,8 +473,8 @@ class CallPatternNode : public DFPatternNode { * - relax::Op which corresponds to the primitive operators. * - user defined functions (Function, GlobalVar, Var). */ - DFPattern op; /*!< The operator (function) being invoked */ - tvm::Array args; /*!< The arguments of the function call */ + DFPattern op; /*!< The operator (function) being invoked */ + tvm::ffi::Array args; /*!< The arguments of the function call */ /*! * \note If varg_default_wildcard is true. Given args of [pA, pB], when matching a call whose * arguments are [A, B, ...], the pattern will still match despite N(args) < N(call.args). That @@ -501,15 +491,13 @@ class CallPatternNode : public DFPatternNode { .def_ro("op", &CallPatternNode::op) .def_ro("args", &CallPatternNode::args); } - - static constexpr const char* _type_key = "relax.dpl.CallPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.CallPattern", CallPatternNode, DFPatternNode); }; class CallPattern : public DFPattern { public: - TVM_DLL CallPattern(DFPattern op, Array args, bool varg_default_wildcard = false); - TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); + TVM_DLL CallPattern(DFPattern op, ffi::Array args, bool varg_default_wildcard = false); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CallPattern, DFPattern, CallPatternNode); }; /*! @@ -519,15 +507,13 @@ class CallPattern : public DFPattern { */ class PrimArrPatternNode : public DFPatternNode { public: - Array fields; /*!< The array to match */ + ffi::Array fields; /*!< The array to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("fields", &PrimArrPatternNode::fields); } - - static constexpr const char* _type_key = "relax.dpl.PrimArrPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrimArrPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.PrimArrPattern", PrimArrPatternNode, DFPatternNode); }; /*! @@ -536,8 +522,8 @@ class PrimArrPatternNode : public DFPatternNode { */ class PrimArrPattern : public DFPattern { public: - TVM_DLL PrimArrPattern(Array arr); - TVM_DEFINE_OBJECT_REF_METHODS(PrimArrPattern, DFPattern, PrimArrPatternNode); + TVM_DLL PrimArrPattern(ffi::Array arr); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimArrPattern, DFPattern, PrimArrPatternNode); }; /*! @@ -547,7 +533,7 @@ class PrimArrPattern : public DFPattern { */ class FunctionPatternNode : public DFPatternNode { public: - tvm::Array params; /*!< The parameters of the function */ + tvm::ffi::Array params; /*!< The parameters of the function */ /*! * \note Note that in Relax, the function body is a SeqExpr which contains * 1) SeqExprNode::blocks, which is a list of blocks of statements; and 2) @@ -562,9 +548,8 @@ class FunctionPatternNode : public DFPatternNode { .def_ro("params", &FunctionPatternNode::params) .def_ro("body", &FunctionPatternNode::body); } - - static constexpr const char* _type_key = "relax.dpl.FunctionPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.FunctionPattern", FunctionPatternNode, + DFPatternNode); }; /*! @@ -578,9 +563,9 @@ class FunctionPattern : public DFPattern { * \param params The parameters of the function. * \param body The body of the function. */ - TVM_DLL FunctionPattern(tvm::Array params, DFPattern body); + TVM_DLL FunctionPattern(tvm::ffi::Array params, DFPattern body); - TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FunctionPattern, DFPattern, FunctionPatternNode); }; /*! @@ -589,15 +574,13 @@ class FunctionPattern : public DFPattern { */ class TuplePatternNode : public DFPatternNode { public: - tvm::Array fields; /*!< The fields of the tuple */ + tvm::ffi::Array fields; /*!< The fields of the tuple */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("fields", &TuplePatternNode::fields); } - - static constexpr const char* _type_key = "relax.dpl.TuplePattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.TuplePattern", TuplePatternNode, DFPatternNode); }; /*! @@ -606,8 +589,8 @@ class TuplePatternNode : public DFPatternNode { */ class TuplePattern : public DFPattern { public: - TVM_DLL explicit TuplePattern(tvm::Array fields); - TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); + TVM_DLL explicit TuplePattern(tvm::ffi::Array fields); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TuplePattern, DFPattern, TuplePatternNode); }; /*! @@ -616,16 +599,15 @@ class TuplePattern : public DFPattern { */ class UnorderedTuplePatternNode : public DFPatternNode { public: - tvm::Array fields; /*!< The fields of the tuple */ + tvm::ffi::Array fields; /*!< The fields of the tuple */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("fields", &UnorderedTuplePatternNode::fields); } - - static constexpr const char* _type_key = "relax.dpl.UnorderedTuplePattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(UnorderedTuplePatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.UnorderedTuplePattern", UnorderedTuplePatternNode, + DFPatternNode); }; /*! @@ -634,8 +616,9 @@ class UnorderedTuplePatternNode : public DFPatternNode { */ class UnorderedTuplePattern : public DFPattern { public: - TVM_DLL explicit UnorderedTuplePattern(tvm::Array fields); - TVM_DEFINE_OBJECT_REF_METHODS(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode); + TVM_DLL explicit UnorderedTuplePattern(tvm::ffi::Array fields); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(UnorderedTuplePattern, DFPattern, + UnorderedTuplePatternNode); }; /*! @@ -654,9 +637,8 @@ class TupleGetItemPatternNode : public DFPatternNode { .def_ro("tuple", &TupleGetItemPatternNode::tuple) .def_ro("index", &TupleGetItemPatternNode::index); } - - static constexpr const char* _type_key = "relax.dpl.TupleGetItemPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.TupleGetItemPattern", TupleGetItemPatternNode, + DFPatternNode); }; /*! @@ -666,7 +648,8 @@ class TupleGetItemPatternNode : public DFPatternNode { class TupleGetItemPattern : public DFPattern { public: TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); - TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleGetItemPattern, DFPattern, + TupleGetItemPatternNode); }; /*! @@ -684,9 +667,7 @@ class AndPatternNode : public DFPatternNode { .def_ro("left", &AndPatternNode::left) .def_ro("right", &AndPatternNode::right); } - - static constexpr const char* _type_key = "relax.dpl.AndPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(AndPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.AndPattern", AndPatternNode, DFPatternNode); }; /*! @@ -696,7 +677,7 @@ class AndPatternNode : public DFPatternNode { class AndPattern : public DFPattern { public: TVM_DLL AndPattern(DFPattern lhs, DFPattern rhs); - TVM_DEFINE_OBJECT_REF_METHODS(AndPattern, DFPattern, AndPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AndPattern, DFPattern, AndPatternNode); }; /*! @@ -714,9 +695,7 @@ class OrPatternNode : public DFPatternNode { .def_ro("left", &OrPatternNode::left) .def_ro("right", &OrPatternNode::right); } - - static constexpr const char* _type_key = "relax.dpl.OrPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(OrPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.OrPattern", OrPatternNode, DFPatternNode); }; /*! @@ -726,7 +705,7 @@ class OrPatternNode : public DFPatternNode { class OrPattern : public DFPattern { public: TVM_DLL OrPattern(DFPattern left, DFPattern right); - TVM_DEFINE_OBJECT_REF_METHODS(OrPattern, DFPattern, OrPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(OrPattern, DFPattern, OrPatternNode); }; /*! @@ -741,9 +720,7 @@ class NotPatternNode : public DFPatternNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("reject", &NotPatternNode::reject); } - - static constexpr const char* _type_key = "relax.dpl.NotPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(NotPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.NotPattern", NotPatternNode, DFPatternNode); }; /*! @@ -753,7 +730,7 @@ class NotPatternNode : public DFPatternNode { class NotPattern : public DFPattern { public: TVM_DLL NotPattern(DFPattern reject); - TVM_DEFINE_OBJECT_REF_METHODS(NotPattern, DFPattern, NotPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NotPattern, DFPattern, NotPatternNode); }; /*! @@ -766,9 +743,8 @@ class WildcardPatternNode : public DFPatternNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "relax.dpl.WildcardPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.WildcardPattern", WildcardPatternNode, + DFPatternNode); }; /*! @@ -778,13 +754,17 @@ class WildcardPatternNode : public DFPatternNode { class WildcardPattern : public DFPattern { public: WildcardPattern(); + explicit WildcardPattern(ObjectPtr data) : DFPattern(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } // Declaring WildcardPattern declared as non-nullable avoids the // default zero-parameter constructor for ObjectRef with `data_ = // nullptr`. This allows a zero-parameter constructor to be // declared here, to create a valid wildcard instance. - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WildcardPattern, DFPattern, WildcardPatternNode); }; /*! @@ -802,15 +782,14 @@ class StructInfoPatternNode : public DFPatternNode { .def_ro("pattern", &StructInfoPatternNode::pattern) .def_ro("struct_info", &StructInfoPatternNode::struct_info); } - - static constexpr const char* _type_key = "relax.dpl.StructInfoPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(StructInfoPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.StructInfoPattern", StructInfoPatternNode, + DFPatternNode); }; class StructInfoPattern : public DFPattern { public: TVM_DLL StructInfoPattern(DFPattern pattern, StructInfo struct_info); - TVM_DEFINE_OBJECT_REF_METHODS(StructInfoPattern, DFPattern, StructInfoPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfoPattern, DFPattern, StructInfoPatternNode); }; /*! @@ -819,8 +798,8 @@ class StructInfoPattern : public DFPattern { */ class ShapePatternNode : public DFPatternNode { public: - DFPattern pattern; /*!< The root pattern to match */ - Array shape; /*!< The shape to match */ + DFPattern pattern; /*!< The root pattern to match */ + ffi::Array shape; /*!< The shape to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -828,9 +807,7 @@ class ShapePatternNode : public DFPatternNode { .def_ro("pattern", &ShapePatternNode::pattern) .def_ro("shape", &ShapePatternNode::shape); } - - static constexpr const char* _type_key = "relax.dpl.ShapePattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ShapePattern", ShapePatternNode, DFPatternNode); }; /*! @@ -839,8 +816,8 @@ class ShapePatternNode : public DFPatternNode { */ class ShapePattern : public DFPattern { public: - TVM_DLL ShapePattern(DFPattern pattern, Array type); - TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); + TVM_DLL ShapePattern(DFPattern pattern, ffi::Array type); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ShapePattern, DFPattern, ShapePatternNode); }; /*! @@ -849,20 +826,19 @@ class ShapePattern : public DFPattern { */ class SameShapeConstraintNode : public DFConstraintNode { public: - Array args; /*!< The patterns with matching shapes */ + ffi::Array args; /*!< The patterns with matching shapes */ - Array GetDependentPatterns() const override { return args; } + ffi::Array GetDependentPatterns() const override { return args; } std::tuple AsPrimExpr( - std::function(const DFPatternNode*)> match_state) const override; + std::function(const DFPatternNode*)> match_state) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("args", &SameShapeConstraintNode::args); } - - static constexpr const char* _type_key = "relax.dpl.SameShapeConstraint"; - TVM_DECLARE_FINAL_OBJECT_INFO(SameShapeConstraintNode, DFConstraintNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.SameShapeConstraint", SameShapeConstraintNode, + DFConstraintNode); }; /*! @@ -871,8 +847,9 @@ class SameShapeConstraintNode : public DFConstraintNode { */ class SameShapeConstraint : public DFConstraint { public: - TVM_DLL SameShapeConstraint(Array args); - TVM_DEFINE_OBJECT_REF_METHODS(SameShapeConstraint, DFConstraint, SameShapeConstraintNode); + TVM_DLL SameShapeConstraint(ffi::Array args); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SameShapeConstraint, DFConstraint, + SameShapeConstraintNode); }; /*! @@ -890,9 +867,8 @@ class DataTypePatternNode : public DFPatternNode { .def_ro("pattern", &DataTypePatternNode::pattern) .def_ro("dtype", &DataTypePatternNode::dtype); } - - static constexpr const char* _type_key = "relax.dpl.DataTypePattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.DataTypePattern", DataTypePatternNode, + DFPatternNode); }; /*! @@ -902,7 +878,7 @@ class DataTypePatternNode : public DFPatternNode { class DataTypePattern : public DFPattern { public: TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); - TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypePattern, DFPattern, DataTypePatternNode); }; /*! @@ -920,9 +896,7 @@ class AttrPatternNode : public DFPatternNode { .def_ro("pattern", &AttrPatternNode::pattern) .def_ro("attrs", &AttrPatternNode::attrs); } - - static constexpr const char* _type_key = "relax.dpl.AttrPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.AttrPattern", AttrPatternNode, DFPatternNode); }; /*! @@ -932,7 +906,7 @@ class AttrPatternNode : public DFPatternNode { class AttrPattern : public DFPattern { public: TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs); - TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrPattern, DFPattern, AttrPatternNode); }; /*! @@ -942,19 +916,18 @@ class AttrPattern : public DFPattern { */ class ExternFuncPatternNode : public DFPatternNode { public: - String global_symbol_; /*!< The global symbol name of the external function */ + ffi::String global_symbol_; /*!< The global symbol name of the external function */ /*! \brief The external function name */ - const String& global_symbol() const { return global_symbol_; } + const ffi::String& global_symbol() const { return global_symbol_; } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("global_symbol", &ExternFuncPatternNode::global_symbol_); } - - static constexpr const char* _type_key = "relax.dpl.ExternFuncPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncPatternNode, DFPatternNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.ExternFuncPattern", ExternFuncPatternNode, + DFPatternNode); }; /*! @@ -963,12 +936,12 @@ class ExternFuncPatternNode : public DFPatternNode { */ class ExternFuncPattern : public DFPattern { public: - TVM_DLL ExternFuncPattern(String global_symbol); - TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode); + TVM_DLL ExternFuncPattern(ffi::String global_symbol); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternFuncPattern, DFPattern, ExternFuncPatternNode); }; /*! \brief Syntatic Sugar for creating a VarPattern with a name */ -VarPattern IsVar(const String& name); +VarPattern IsVar(const ffi::String& name); /*! \brief Syntatic Sugar for creating a ConstantPattern */ ConstantPattern IsConst(); /*! \brief Syntatic Sugar for creating a WildcardPattern */ @@ -976,26 +949,27 @@ WildcardPattern Wildcard(); /*! \brief Syntatic Sugar for creating a ExprPattern */ ExprPattern IsExpr(const Expr& expr); /*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */ -ExprPattern IsOp(const String& op_name); +ExprPattern IsOp(const ffi::String& op_name); /*! \brief Syntatic Sugar for call_tir (return a tensor) */ // Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo -CallPattern IsCallTIR(const String& name, Optional args = std::nullopt); +CallPattern IsCallTIR(const ffi::String& name, ffi::Optional args = std::nullopt); /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ -CallPattern IsCallTIR(const String& name, TuplePattern var_args); +CallPattern IsCallTIR(const ffi::String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */ -CallPattern IsCallDPSPacked(const String& name, Optional args = std::nullopt); +CallPattern IsCallDPSPacked(const ffi::String& name, + ffi::Optional args = std::nullopt); /*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */ -CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args); +CallPattern IsCallDPSPacked(const ffi::String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ -DFPattern IsTuple(const Array& fields, bool unordered = false); +DFPattern IsTuple(const ffi::Array& fields, bool unordered = false); /*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1); /*! \brief Implementation of the templated CallPattern syntax sugar */ template CallPattern DFPattern::operator()(Args&&... args) const { - return CallPattern(GetRef(this->get()), - Array({std::forward(args)...})); + return CallPattern(ffi::GetRef(this->get()), + ffi::Array({std::forward(args)...})); } } // namespace relax diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index 565aaa0835f5..ddb618e06b1f 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -58,7 +58,8 @@ class BufferAxisHash { * \param analyzer The analyzer * \return The iter var whose extent to be changed */ -Var GetShardingVarFromIndex(PrimExpr index, Map var_range, arith::Analyzer* analyzer); +Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, + arith::Analyzer* analyzer); /*! * \brief Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they @@ -69,7 +70,7 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { static std::vector> GetTIRVarAxisGraph(const PrimFunc& prim_func) { BufferAxisGraphExtractor extractor; extractor(prim_func->body); - Map inverse_buffer_map; + ffi::Map inverse_buffer_map; for (const auto& pr : prim_func->buffer_map) { inverse_buffer_map.Set(pr.second, pr.first); } @@ -162,14 +163,14 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { arith::Analyzer analyzer; for (const auto& access_pr : buffer_access_indices_) { Buffer buffer = access_pr.first; - Array indices = access_pr.second; + ffi::Array indices = access_pr.second; for (int i = 0; i < static_cast(indices.size()); i++) { for (const auto& another_access_pr : buffer_access_indices_) { if (another_access_pr.first.same_as(buffer)) { continue; } Buffer another_buffer = another_access_pr.first; - Array another_indices = another_access_pr.second; + ffi::Array another_indices = another_access_pr.second; for (int j = 0; j < static_cast(another_indices.size()); j++) { if (Match(indices[i], buffer->shape[i], another_indices[j], another_buffer->shape[j], &analyzer)) { @@ -192,9 +193,9 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { buffer_axis_graph_[axis2].push_back(axis1); } - std::vector>> buffer_access_indices_; + std::vector>> buffer_access_indices_; std::unordered_map, BufferAxisHash> buffer_axis_graph_; - Map iter_var_range_; + ffi::Map iter_var_range_; std::string func_name; }; } // namespace tir @@ -439,7 +440,7 @@ class AxisGroupGraph { } } ICHECK(specs.size() == 1) << "multiple possible sharding for axis: (" - << GetRef(axis.tensor) << ", " << axis.dim << ")"; + << ffi::GetRef(axis.tensor) << ", " << axis.dim << ")"; } } diff --git a/include/tvm/relax/distributed/global_info.h b/include/tvm/relax/distributed/global_info.h index 5e0afc0dcaa7..2bb8d8772b06 100644 --- a/include/tvm/relax/distributed/global_info.h +++ b/include/tvm/relax/distributed/global_info.h @@ -40,10 +40,10 @@ class DeviceMeshNode : public GlobalInfoNode { ffi::Shape shape; /*! \brief device ids in the mesh*/ - Array device_ids; + ffi::Array device_ids; /*! \brief Optionally use range to represent device_ids*/ - Optional device_range; + ffi::Optional device_range; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -52,9 +52,7 @@ class DeviceMeshNode : public GlobalInfoNode { .def_ro("device_ids", &DeviceMeshNode::device_ids) .def_ro("device_range", &DeviceMeshNode::device_range); } - - static constexpr const char* _type_key = "relax.distributed.DeviceMesh"; - TVM_DECLARE_FINAL_OBJECT_INFO(DeviceMeshNode, GlobalInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.distributed.DeviceMesh", DeviceMeshNode, GlobalInfoNode); }; /*! @@ -63,9 +61,9 @@ class DeviceMeshNode : public GlobalInfoNode { */ class DeviceMesh : public GlobalInfo { public: - TVM_DLL DeviceMesh(ffi::Shape shape, Array device_ids); + TVM_DLL DeviceMesh(ffi::Shape shape, ffi::Array device_ids); TVM_DLL DeviceMesh(ffi::Shape shape, Range device_range); - TVM_DEFINE_OBJECT_REF_METHODS(DeviceMesh, GlobalInfo, DeviceMeshNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DeviceMesh, GlobalInfo, DeviceMeshNode); }; } // namespace distributed diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/struct_info.h index cd4c2e7daef2..9ca3b1513828 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/struct_info.h @@ -51,9 +51,8 @@ class PlacementSpecNode : public Object { .def_ro("kind", &PlacementSpecNode::kind); } - static constexpr const char* _type_key = "relax.distributed.PlacementSpec"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_DECLARE_BASE_OBJECT_INFO(PlacementSpecNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("relax.distributed.PlacementSpec", PlacementSpecNode, Object); }; /*! @@ -66,7 +65,7 @@ class PlacementSpec : public ObjectRef { TVM_DLL static PlacementSpec Replica(); - TVM_DEFINE_OBJECT_REF_METHODS(PlacementSpec, ObjectRef, PlacementSpecNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PlacementSpec, ObjectRef, PlacementSpecNode); }; class ShardingNode : public PlacementSpecNode { @@ -79,16 +78,16 @@ class ShardingNode : public PlacementSpecNode { refl::ObjectDef().def_ro("sharding_dim", &ShardingNode::sharding_dim); } - TVM_DECLARE_FINAL_OBJECT_INFO(ShardingNode, PlacementSpecNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.distributed.Sharding", ShardingNode, PlacementSpecNode); }; /*! \brief Describes how data is distributed in each dimension of the device mesh*/ class PlacementNode : public Object { public: /*! \brief specs for each dim of device mesh.*/ - Array dim_specs; + ffi::Array dim_specs; - String ToString() const; + ffi::String ToString() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -96,8 +95,7 @@ class PlacementNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - static constexpr const char* _type_key = "relax.distributed.Placement"; - TVM_DECLARE_FINAL_OBJECT_INFO(PlacementNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.distributed.Placement", PlacementNode, Object); }; /*! @@ -106,10 +104,10 @@ class PlacementNode : public Object { */ class Placement : public ObjectRef { public: - TVM_DLL explicit Placement(Array dim_specs); + TVM_DLL explicit Placement(ffi::Array dim_specs); /*! \brief replica dim is printed as "R" and sharding dim is printed as "S[i]".]*/ - static Placement FromText(String text_repr); - TVM_DEFINE_OBJECT_REF_METHODS(Placement, ObjectRef, PlacementNode); + static Placement FromText(ffi::String text_repr); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Placement, ObjectRef, PlacementNode); }; /*! @@ -137,9 +135,8 @@ class DTensorStructInfoNode : public StructInfoNode { .def_ro("placement", &DTensorStructInfoNode::placement) .def_ro("tensor_sinfo", &DTensorStructInfoNode::tensor_sinfo); } - - static constexpr const char* _type_key = "relax.DTensorStructInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(DTensorStructInfoNode, StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DTensorStructInfo", DTensorStructInfoNode, + StructInfoNode); }; /*! @@ -158,7 +155,7 @@ class DTensorStructInfo : public StructInfo { TVM_DLL DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(DTensorStructInfo, StructInfo, DTensorStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DTensorStructInfo, StructInfo, DTensorStructInfoNode); }; } // namespace distributed diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index dd0539cb9666..4fd0fd66bb90 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -62,7 +62,7 @@ class ExecBuilderNode : public Object { * \param init_register_size Initial setting of register file size. */ void EmitFunction(const std::string& func, int64_t num_inputs, - Optional> param_names, + ffi::Optional> param_names, vm::VMFuncInfo::FuncKind kind = vm::VMFuncInfo::FuncKind::kVMFunc, int64_t init_register_size = 0); /*! @@ -142,8 +142,8 @@ class ExecBuilderNode : public Object { refl::ObjectDef(); } - static constexpr const char* _type_key = "relax.ExecBuilder"; - TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ExecBuilder", ExecBuilderNode, Object); private: /*! @@ -174,7 +174,7 @@ class ExecBuilderNode : public Object { class ExecBuilder : public ObjectRef { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecBuilder, ObjectRef, ExecBuilderNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExecBuilder, ObjectRef, ExecBuilderNode); }; } // namespace relax diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 22cda9e06635..9b5a3176f413 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -53,7 +53,7 @@ class IdNode : public Object { * this only acts as a hint to the user, * and is not used for equality. */ - String name_hint; + ffi::String name_hint; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -62,9 +62,7 @@ class IdNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; - static constexpr const char* _type_key = "relax.Id"; - - TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.Id", IdNode, Object); }; class Id : public ObjectRef { @@ -73,9 +71,9 @@ class Id : public ObjectRef { * \brief The constructor * \param name_hint The name of the variable. */ - TVM_DLL explicit Id(String name_hint); + TVM_DLL explicit Id(ffi::String name_hint); - TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Id, ObjectRef, IdNode); }; /*! @@ -122,10 +120,9 @@ class StructInfoNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "ir.StructInfo"; static constexpr const uint32_t _type_child_slots = 7; - TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("ir.StructInfo", StructInfoNode, Object); }; /*! @@ -134,7 +131,7 @@ class StructInfoNode : public Object { */ class StructInfo : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StructInfo, ObjectRef, StructInfoNode); }; /*! @@ -152,18 +149,22 @@ class CallNode : public ExprNode { Expr op; /*! \brief The arguments(inputs) of the call */ - tvm::Array args; + tvm::ffi::Array args; /*! \brief The additional attributes */ Attrs attrs; /*! * \brief The structure info arguments of a CallNode. - * sinfo_args is designed to be non-empty only for intrinsic op (e.g., + * sinfo_args is by default designed to be non-empty only for intrinsic op (e.g., * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main * usage of structure info inference. + * + * Regular ops also at times may have sinfo_args defined to specialize partial + * or complete structure info. Like VDevice customization with mixed input memory_scopes. + * The customized pass can set this info and operator specific inference will respect it. */ - Array sinfo_args; + ffi::Array sinfo_args; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -173,9 +174,7 @@ class CallNode : public ExprNode { .def_ro("attrs", &CallNode::attrs) .def_ro("sinfo_args", &CallNode::sinfo_args); } - - static constexpr const char* _type_key = "relax.expr.Call"; - TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Call", CallNode, ExprNode); }; class Call : public Expr { @@ -188,10 +187,10 @@ class Call : public Expr { * \param sinfo_args The structure info arguments passed to a function. * \param span The source span of the expression. */ - TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), - Array sinfo_args = Array(), Span span = Span()); + TVM_DLL Call(Expr op, ffi::Array args, Attrs attrs = Attrs(), + ffi::Array sinfo_args = ffi::Array(), Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, Expr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; @@ -200,25 +199,24 @@ class Call : public Expr { * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -Call WithFields(Call call, Optional opt_op = Optional(), - Optional> opt_args = Optional>(), - Optional opt_attrs = Optional(), - Optional> opt_sinfo_args = Optional>(), - Optional opt_span = Optional()); +Call WithFields( + Call call, ffi::Optional opt_op = ffi::Optional(), + ffi::Optional> opt_args = ffi::Optional>(), + ffi::Optional opt_attrs = ffi::Optional(), + ffi::Optional> opt_sinfo_args = ffi::Optional>(), + ffi::Optional opt_span = ffi::Optional()); /*! \brief Tuple container */ class TupleNode : public ExprNode { public: /*! \brief the fields of the tuple */ - tvm::Array fields; + tvm::ffi::Array fields; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("fields", &TupleNode::fields); } - - static constexpr const char* _type_key = "relax.expr.Tuple"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Tuple", TupleNode, ExprNode); }; class Tuple : public Expr { @@ -228,15 +226,15 @@ class Tuple : public Expr { * \param fields The fields of a tuple. * \param span The source span of the expression. */ - TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + TVM_DLL explicit Tuple(tvm::ffi::Array fields, Span span = Span()); /*! * \brief Utility constructor to handle conversion to relax::Expr * * If the calling scope already has an array of a specific type of - * relax expression (e.g. `Array`), it must be converted + * relax expression (e.g. `ffi::Array`), it must be converted * into an array of base type. This constructor handles the - * conversion to the base `Array`. + * conversion to the base `ffi::Array`. * * \tparam RelaxExpr The type of relax expression passed in as an argument. * @@ -245,10 +243,10 @@ class Tuple : public Expr { * \param span The source span of the expression. */ template >> - TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()) + TVM_DLL explicit Tuple(tvm::ffi::Array fields, Span span = Span()) : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {} - TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tuple, Expr, TupleNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); }; @@ -257,8 +255,9 @@ class Tuple : public Expr { * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), - Optional opt_span = Optional()); +Tuple WithFields(Tuple tuple, + ffi::Optional> opt_fields = ffi::Optional>(), + ffi::Optional opt_span = ffi::Optional()); /*! \brief Get index-th field out of a tuple. */ class TupleGetItemNode : public ExprNode { @@ -274,9 +273,7 @@ class TupleGetItemNode : public ExprNode { .def_ro("tuple_value", &TupleGetItemNode::tuple) .def_ro("index", &TupleGetItemNode::index); } - - static constexpr const char* _type_key = "relax.expr.TupleGetItem"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.TupleGetItem", TupleGetItemNode, ExprNode); }; class TupleGetItem : public Expr { @@ -289,7 +286,7 @@ class TupleGetItem : public Expr { */ TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleGetItem, Expr, TupleGetItemNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode); }; @@ -298,9 +295,10 @@ class TupleGetItem : public Expr { * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), - Optional opt_index = Optional(), - Optional opt_span = Optional()); +TupleGetItem WithFields(TupleGetItem tuple_get_item, + ffi::Optional opt_tuple = ffi::Optional(), + ffi::Optional opt_index = ffi::Optional(), + ffi::Optional opt_span = ffi::Optional()); /*! * \brief Base type of all (non-function) leaf Exprs. @@ -308,9 +306,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = */ class LeafExprNode : public ExprNode { public: - static constexpr const char* _type_key = "relax.expr.LeafExpr"; static constexpr const uint32_t _type_child_slots = 7; - TVM_DECLARE_BASE_OBJECT_INFO(LeafExprNode, ExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.LeafExpr", LeafExprNode, ExprNode); }; /*! @@ -319,7 +316,7 @@ class LeafExprNode : public ExprNode { */ class LeafExpr : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(LeafExpr, Expr, LeafExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LeafExpr, Expr, LeafExprNode); }; /*! \brief A shape expression which allows users to construct a shape containing PrimExpr. @@ -327,21 +324,19 @@ class LeafExpr : public Expr { class ShapeExprNode : public LeafExprNode { public: /*! The values of the shape expression. */ - Array values; + ffi::Array values; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("values", &ShapeExprNode::values); } - - static constexpr const char* _type_key = "relax.expr.ShapeExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ShapeExpr", ShapeExprNode, LeafExprNode); }; class ShapeExpr : public LeafExpr { public: - TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode); + TVM_DLL explicit ShapeExpr(ffi::Array values, Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ShapeExpr, LeafExpr, ShapeExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); }; @@ -353,7 +348,7 @@ class VarNode : public LeafExprNode { Id vid; /*! \return The name hint of the variable */ - const String& name_hint() const { return vid->name_hint; } + const ffi::String& name_hint() const { return vid->name_hint; } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -379,19 +374,19 @@ class VarNode : public LeafExprNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; - static constexpr const char* _type_key = "relax.expr.Var"; static constexpr const uint32_t _type_child_slots = 1; - TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Var", VarNode, LeafExprNode); }; class Var : public LeafExpr { public: - TVM_DLL explicit Var(String name_hint, Optional struct_info_annotation, + TVM_DLL explicit Var(ffi::String name_hint, ffi::Optional struct_info_annotation, Span span = Span()) : Var(Id(name_hint), struct_info_annotation, span) {} - TVM_DLL explicit Var(Id vid, Optional struct_info_annotation, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); + TVM_DLL explicit Var(Id vid, ffi::Optional struct_info_annotation, + Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Var, LeafExpr, VarNode); VarNode* CopyOnWrite(); }; @@ -407,20 +402,19 @@ class DataflowVarNode : public VarNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; - static constexpr const char* _type_key = "relax.expr.DataflowVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowVar", DataflowVarNode, VarNode); }; class DataflowVar : public Var { public: - TVM_DLL explicit DataflowVar(String name_hint, Optional struct_info_annotation, - Span span = Span()) + TVM_DLL explicit DataflowVar(ffi::String name_hint, + ffi::Optional struct_info_annotation, Span span = Span()) : DataflowVar(Id(name_hint), struct_info_annotation, span) {} - TVM_DLL explicit DataflowVar(Id vid, Optional struct_info_annotation, + TVM_DLL explicit DataflowVar(Id vid, ffi::Optional struct_info_annotation, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowVar, Var, DataflowVarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode); }; @@ -432,7 +426,7 @@ class DataflowVar : public Var { class ConstantNode : public LeafExprNode { public: /*! \brief The data of the tensor */ - runtime::NDArray data; + runtime::Tensor data; /*! \return The corresponding tensor type of the data */ TensorType tensor_type() const; @@ -444,9 +438,7 @@ class ConstantNode : public LeafExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("data", &ConstantNode::data); } - - static constexpr const char* _type_key = "relax.expr.Constant"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Constant", ConstantNode, LeafExprNode); }; class Constant : public LeafExpr { @@ -458,11 +450,11 @@ class Constant : public LeafExpr { * If not specified, infer it from data. * \param span The source span of the expression. */ - TVM_DLL explicit Constant(runtime::NDArray data, - Optional struct_info_annotation = std::nullopt, + TVM_DLL explicit Constant(runtime::Tensor data, + ffi::Optional struct_info_annotation = std::nullopt, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Constant, LeafExpr, ConstantNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); }; @@ -480,9 +472,7 @@ class PrimValueNode : public LeafExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &PrimValueNode::value); } - - static constexpr const char* _type_key = "relax.expr.PrimValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.PrimValue", PrimValueNode, LeafExprNode); }; /*! @@ -506,7 +496,7 @@ class PrimValue : public LeafExpr { */ TVM_DLL static PrimValue Int64(int64_t value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(PrimValue, LeafExpr, PrimValueNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimValue, LeafExpr, PrimValueNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode); }; @@ -516,15 +506,13 @@ class PrimValue : public LeafExpr { class StringImmNode : public LeafExprNode { public: /*! \brief The data value. */ - String value; + ffi::String value; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &StringImmNode::value); } - - static constexpr const char* _type_key = "relax.expr.StringImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.StringImm", StringImmNode, LeafExprNode); }; /*! @@ -538,9 +526,9 @@ class StringImm : public LeafExpr { * \param value The value input. * \param span The source span of the expression. */ - TVM_DLL explicit StringImm(String value, Span span = Span()); + TVM_DLL explicit StringImm(ffi::String value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, LeafExpr, StringImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); }; @@ -556,9 +544,7 @@ class DataTypeImmNode : public LeafExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &DataTypeImmNode::value); } - - static constexpr const char* _type_key = "relax.expr.DataTypeImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataTypeImm", DataTypeImmNode, LeafExprNode); }; /*! @@ -574,7 +560,7 @@ class DataTypeImm : public LeafExpr { */ TVM_DLL explicit DataTypeImm(DataType value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(DataTypeImm, LeafExpr, DataTypeImmNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypeImm, LeafExpr, DataTypeImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode); }; @@ -592,10 +578,9 @@ class BindingNode : public Object { .def_ro("var", &BindingNode::var, refl::AttachFieldFlag::SEqHashDef()); } - static constexpr const char* _type_key = "relax.expr.Binding"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.Binding", BindingNode, Object); }; class Binding : public ObjectRef { @@ -603,7 +588,8 @@ class Binding : public ObjectRef { Binding() = default; public: - explicit Binding(ObjectPtr n) : ObjectRef(n) {} + explicit Binding(ObjectPtr n) : ObjectRef(n) {} + explicit Binding(ffi::UnsafeInit tag) : ObjectRef(tag) {} TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding); const BindingNode* operator->() const { return static_cast(data_.get()); } const BindingNode* get() const { return operator->(); } @@ -630,9 +616,7 @@ class MatchCastNode : public BindingNode { .def_ro("value", &MatchCastNode::value) .def_ro("struct_info", &MatchCastNode::struct_info, refl::AttachFieldFlag::SEqHashDef()); } - - static constexpr const char* _type_key = "relax.expr.MatchCast"; - TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.MatchCast", MatchCastNode, BindingNode); }; /*! @@ -643,7 +627,7 @@ class MatchCast : public Binding { public: TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchCast, Binding, MatchCastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode); }; @@ -665,22 +649,19 @@ class VarBindingNode : public BindingNode { ffi::TypedFunction equal) const; uint64_t SHash(uint64_t init_hash, ffi::TypedFunction hash) const; - - static constexpr const char* _type_key = "relax.expr.VarBinding"; - - TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.VarBinding", VarBindingNode, BindingNode); }; class VarBinding : public Binding { public: TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VarBinding, Binding, VarBindingNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode); }; class BindingBlockNode : public Object { public: - Array bindings; + ffi::Array bindings; mutable Span span; static void RegisterReflection() { @@ -692,15 +673,13 @@ class BindingBlockNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "relax.expr.BindingBlock"; - - TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("relax.expr.BindingBlock", BindingBlockNode, Object); }; class BindingBlock : public ObjectRef { public: - TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); + TVM_DLL explicit BindingBlock(ffi::Array bindings, Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BindingBlock, ObjectRef, BindingBlockNode); BindingBlockNode* CopyOnWrite(); }; @@ -711,16 +690,14 @@ class DataflowBlockNode : public BindingBlockNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "relax.expr.DataflowBlock"; - - TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.DataflowBlock", DataflowBlockNode, + BindingBlockNode); }; class DataflowBlock : public BindingBlock { public: - TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); + TVM_DLL explicit DataflowBlock(ffi::Array bindings, Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowBlock, BindingBlock, DataflowBlockNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode); }; @@ -730,7 +707,7 @@ class DataflowBlock : public BindingBlock { */ class SeqExprNode : public ExprNode { public: - Array blocks; + ffi::Array blocks; Expr body; static void RegisterReflection() { @@ -739,10 +716,7 @@ class SeqExprNode : public ExprNode { .def_ro("blocks", &SeqExprNode::blocks) .def_ro("body", &SeqExprNode::body); } - - static constexpr const char* _type_key = "relax.expr.SeqExpr"; - - TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.SeqExpr", SeqExprNode, ExprNode); }; class SeqExpr : public Expr { @@ -760,8 +734,8 @@ class SeqExpr : public Expr { */ TVM_DLL SeqExpr(Expr body); // NOLINT(*) - TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); + TVM_DLL explicit SeqExpr(ffi::Array blocks, Expr body, Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqExpr, Expr, SeqExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); }; @@ -794,8 +768,7 @@ class IfNode : public ExprNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; - static constexpr const char* _type_key = "relax.expr.If"; - TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.If", IfNode, ExprNode); }; class If : public Expr { @@ -819,7 +792,7 @@ class If : public Expr { */ TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(If, Expr, IfNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); }; @@ -828,16 +801,16 @@ class If : public Expr { * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -If WithFields(If if_expr, Optional opt_cond = Optional(), - Optional opt_true_branch = Optional(), - Optional opt_false_branch = Optional(), - Optional opt_span = Optional()); +If WithFields(If if_expr, ffi::Optional opt_cond = ffi::Optional(), + ffi::Optional opt_true_branch = ffi::Optional(), + ffi::Optional opt_false_branch = ffi::Optional(), + ffi::Optional opt_span = ffi::Optional()); /*! \brief A Relax function. */ class FunctionNode : public BaseFuncNode { public: /*! \brief The parameters to the function. */ - Array params; + ffi::Array params; /*! \brief The body of the function. */ SeqExpr body; /*! \brief The return type of the function. */ @@ -855,8 +828,7 @@ class FunctionNode : public BaseFuncNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindDAGNode; - static constexpr const char* _type_key = "relax.expr.Function"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.Function", FunctionNode, BaseFuncNode); }; class Function : public BaseFunc { @@ -882,18 +854,19 @@ class Function : public BaseFunc { * * \param span The source span of the expression. */ - TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, - bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); + TVM_DLL explicit Function(ffi::Array params, Expr body, + ffi::Optional ret_struct_info, bool is_pure = true, + DictAttrs attrs = DictAttrs(), Span span = Span()); /*! * \brief Mimics the constructor but without body Expr. * \note ret_struct_info is required, since it can not deduced by the body. */ - TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, + TVM_DLL static Function CreateEmpty(ffi::Array params, StructInfo ret_struct_info, bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); }; @@ -932,23 +905,21 @@ constexpr const char* kNumInput = "num_input"; class ExternFuncNode : public BaseFuncNode { public: /*! \brief The name of global symbol. */ - String global_symbol; + ffi::String global_symbol; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("global_symbol", &ExternFuncNode::global_symbol); } - - static constexpr const char* _type_key = "relax.expr.ExternFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.ExternFunc", ExternFuncNode, BaseFuncNode); }; class ExternFunc : public BaseFunc { public: - TVM_DLL ExternFunc(String global_symbol, Span span = Span()); - TVM_DLL ExternFunc(String global_symbol, StructInfo struct_info, Span span = Span()); + TVM_DLL ExternFunc(ffi::String global_symbol, Span span = Span()); + TVM_DLL ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternFunc, BaseFunc, ExternFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); }; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 7634bc34a26f..afacb81e4072 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -379,7 +379,7 @@ class ExprMutatorBase : public ExprFunctor { */ bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) { if (const StructInfoNode* sinfo = struct_info.as()) { - return this->VisitExprDepStructInfoField(GetRef(sinfo)).same_as(struct_info); + return this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)).same_as(struct_info); } else { return true; } @@ -421,7 +421,7 @@ class ExprMutator : public ExprMutatorBase { public: using ExprMutatorBase::VisitExpr_; - ExprMutator(Optional mod = std::nullopt) { builder_ = BlockBuilder::Create(mod); } + ExprMutator(ffi::Optional mod = std::nullopt) { builder_ = BlockBuilder::Create(mod); } Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const DataflowVarNode* op) override; @@ -502,7 +502,8 @@ class ExprMutator : public ExprMutatorBase { * * \note The body_expr must be an SeqExpr in the normal form. */ - Expr VisitWithNewScope(const Expr& body_expr, Optional> params = std::nullopt); + Expr VisitWithNewScope(const Expr& body_expr, + ffi::Optional> params = std::nullopt); /*! * \brief Rewrite the expr with a new scope, used in the branches of If. @@ -526,7 +527,7 @@ class ExprMutator : public ExprMutatorBase { * \return The value bound to the input \p var. * \note For function parameters, this function returns std::nullopt. */ - Optional LookupBinding(const Var& var); + ffi::Optional LookupBinding(const Var& var); /*! * \brief Post-order rewrite a node and normalize. diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 8620ad80bda7..77f001630f75 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -140,20 +140,20 @@ class NestedMsg { data_ = std::move(other); return *this; } - // Array> handling - NestedMsg(Array, void> other) // NOLINT(*) + // ffi::Array> handling + NestedMsg(ffi::Array, void> other) // NOLINT(*) : data_(other) {} - NestedMsg& operator=(Array, void> other) { + NestedMsg& operator=(ffi::Array, void> other) { data_ = std::move(other); return *this; } // initializer list handling NestedMsg(std::initializer_list> other) // NOLINT(*) - : NestedMsg(Array, void>(other)) {} + : NestedMsg(ffi::Array, void>(other)) {} NestedMsg& operator=(std::initializer_list> other) { - return operator=(Array, void>(other)); + return operator=(ffi::Array, void>(other)); } // delete the int constructor @@ -190,8 +190,9 @@ class NestedMsg { * \return a corresponding nested array. * \note This checks if the underlying data type is array. */ - Array, void> NestedArray() const { - return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck, void>>(data_); + ffi::Array, void> NestedArray() const { + return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck, void>>( + data_); } private: @@ -238,8 +239,8 @@ bool Equal(const NestedMsg& lhs, const NestedMsg& rhs, FType fequal) { return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue()); } else { if (!rhs.IsNested()) return false; - Array> arr_lhs = lhs.NestedArray(); - Array> arr_rhs = rhs.NestedArray(); + ffi::Array> arr_lhs = lhs.NestedArray(); + ffi::Array> arr_rhs = rhs.NestedArray(); if (arr_lhs.size() != arr_rhs.size()) return false; for (size_t i = 0; i < arr_lhs.size(); ++i) { if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false; @@ -264,7 +265,7 @@ bool Equal(const NestedMsg& lhs, const NestedMsg& rhs, FType fequal) { template NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { if (auto* tuple = expr.as()) { - Array> res; + ffi::Array> res; res.reserve(tuple->fields.size()); for (Expr x : tuple->fields) { res.push_back(MapToNestedMsg(x, fmapleaf)); @@ -291,7 +292,7 @@ NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { template NestedMsg MapToNestedMsg(StructInfo sinfo, FType fmapleaf) { if (auto* tuple = sinfo.as()) { - Array> res; + ffi::Array> res; res.reserve(tuple->fields.size()); for (StructInfo x : tuple->fields) { res.push_back(MapToNestedMsg(x, fmapleaf)); @@ -320,7 +321,7 @@ template NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { auto sinfo = GetStructInfo(expr); if (auto* tuple = sinfo.as()) { - Array> res; + ffi::Array> res; res.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { Expr field; @@ -346,9 +347,9 @@ NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { * * \param msg The input nested message. * \param fmapleaf The mapping function for each leaf with signature - * `TargetType fmapleaf(Optional)`. + * `TargetType fmapleaf(ffi::Optional)`. * \param fcombine The function for combining all childs of a node into TargetType with signature - * `TargetType fmapleaf(Array)`. + * `TargetType fmapleaf(ffi::Array)`. * \tparam TargetType the target type to map nested msg to. * \tparam T the content type of nested msg. * \tparam FMapLeaf The leaf mapping function type. @@ -362,8 +363,8 @@ TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { return fmapleaf(msg.LeafValue()); } else { ICHECK(msg.IsNested()); - Array> arr = msg.NestedArray(); - Array subexpr; + ffi::Array> arr = msg.NestedArray(); + ffi::Array subexpr; subexpr.reserve(arr.size()); for (size_t i = 0; i < arr.size(); ++i) { subexpr.push_back(NestedMsgTo(arr[i], fmapleaf, fcombine)); @@ -380,14 +381,14 @@ TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { * then recursively combines the results as tuple expr. * * \param msg The input nested message. - * \param fmapleaf The mapping function for each leaf with signature `Expr fmapleaf(Optional)`. - * \tparam T the content type of nested msg. - * \tparam FType The mapping function type. + * \param fmapleaf The mapping function for each leaf with signature `Expr + * fmapleaf(ffi::Optional)`. \tparam T the content type of nested msg. \tparam FType The mapping + * function type. */ template Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { - return NestedMsgTo(msg, fmapleaf, [](Array arr) { - Optional simplified_tuple; + return NestedMsgTo(msg, fmapleaf, [](ffi::Array arr) { + ffi::Optional simplified_tuple; bool simplified_flag = false; if (arr.size() >= 1) { simplified_flag = true; @@ -436,11 +437,11 @@ NestedMsg CombineNestedMsg(NestedMsg lhs, NestedMsg rhs, FType fcombine } else { ICHECK(lhs.IsNested()); ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested"; - Array> arr_lhs = lhs.NestedArray(); - Array> arr_rhs = rhs.NestedArray(); + ffi::Array> arr_lhs = lhs.NestedArray(); + ffi::Array> arr_rhs = rhs.NestedArray(); ICHECK_EQ(arr_lhs.size(), arr_rhs.size()) << "Cannot combine two nested array with different sizes"; - Array> res; + ffi::Array> res; res.reserve(arr_lhs.size()); for (size_t i = 0; i < arr_lhs.size(); ++i) { res.push_back(CombineNestedMsg(arr_lhs[i], arr_rhs[i], fcombine)); @@ -465,8 +466,8 @@ NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { return fmapleaf(msg.LeafValue()); } else { ICHECK(msg.IsNested()); - Array> arr = msg.NestedArray(); - Array> res; + ffi::Array> arr = msg.NestedArray(); + ffi::Array> res; res.reserve(arr.size()); for (int i = 0; i < static_cast(arr.size()); ++i) { res.push_back(MapNestedMsg(arr[i], fmapleaf)); @@ -492,7 +493,7 @@ template void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { if (auto* tuple = expr.as()) { ICHECK(msg.IsNested()) << "Expected nested to match tuple"; - Array> arr = msg.NestedArray(); + ffi::Array> arr = msg.NestedArray(); ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size"; for (size_t i = 0; i < arr.size(); ++i) { DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf); @@ -511,7 +512,7 @@ void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { * * \param expr The input expression to be transform.  * \param msgs The input messages to guide the transformation. - * \param ftransleaf with signature ftransleaf(Expr, Array>)->Expr + * \param ftransleaf with signature ftransleaf(Expr, ffi::Array>)->Expr * \tparam T the content type of nested msg * \tparam N the number of messages * \tparam FType The visit function type. @@ -520,13 +521,13 @@ template Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftransleaf) { StructInfo sinfo = GetStructInfo(expr); if (const auto* tuple = sinfo.as()) { - std::array>, N> msg_arrays; + std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; msg_arrays[i] = msgs[i].NestedArray(); } bool same = true; - Array fields; + ffi::Array fields; fields.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { Expr field; @@ -560,7 +561,7 @@ Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftran * * \param sinfo The input sinfo to be transform.  * \param msgs The input messages to guide the transformation. - * \param ftransleaf with signature ftransleaf(StructInfo, Array>)->StructInfo + * \param ftransleaf with signature ftransleaf(StructInfo, ffi::Array>)->StructInfo * \tparam T the content type of nested msg * \tparam N the number of messages * \tparam FType The visit function type. @@ -569,13 +570,13 @@ template StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs, FType ftransleaf) { if (const auto* tuple = sinfo.as()) { - std::array>, N> msg_arrays; + std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; msg_arrays[i] = msgs[i].NestedArray(); } bool same = true; - Array fields; + ffi::Array fields; fields.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { StructInfo field = tuple->fields[i]; @@ -638,7 +639,7 @@ struct TypeTraits> : public TypeTraitsBase { } TVM_FFI_INLINE static relax::NestedMsg MoveFromAnyAfterCheck(TVMFFIAny* src) { - return relax::NestedMsg(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src))); + return relax::NestedMsg(details::AnyUnsafe::MoveTVMFFIAnyToAny(src)); } static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { @@ -654,7 +655,7 @@ struct TypeTraits> : public TypeTraitsBase { } if (src->type_index == TypeIndex::kTVMFFIArray) { const ArrayObj* n = reinterpret_cast(src->v_obj); - Array> result; + ffi::Array> result; result.reserve(n->size()); for (size_t i = 0; i < n->size(); i++) { const Any& any_v = (*n)[i]; @@ -672,6 +673,14 @@ struct TypeTraits> : public TypeTraitsBase { TVM_FFI_INLINE static std::string TypeStr() { return "NestedMsg<" + details::Type2Str::v() + ">"; } + + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":"NestedMsg","args":[)"; + oss << details::TypeSchema::v(); + oss << "]}"; + return oss.str(); + } }; } // namespace ffi } // namespace tvm diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index bd9c59da3acb..2e686035b20c 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -65,7 +65,7 @@ using FInferStructInfo = ffi::TypedFunction( +using FPrimalGradient = ffi::TypedFunction( const Var& orig_var, const Call& orig_call, const Var& output_grad, const BlockBuilder& ctx)>; } // namespace relax diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index a897f031a289..f08d737fdca5 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -39,9 +41,7 @@ class ObjectStructInfoNode : public StructInfoNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "relax.ObjectStructInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectStructInfo", ObjectStructInfoNode, StructInfoNode); }; /*! @@ -52,7 +52,7 @@ class ObjectStructInfo : public StructInfo { public: TVM_DLL ObjectStructInfo(Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectStructInfo, StructInfo, ObjectStructInfoNode); }; /*! @@ -61,7 +61,7 @@ class ObjectStructInfo : public StructInfo { class PrimStructInfoNode : public StructInfoNode { public: /*! \brief Underlying primitive value, if known */ - Optional value; + ffi::Optional value; /*! \brief Underlying data type of the primitive value */ DataType dtype; @@ -72,9 +72,7 @@ class PrimStructInfoNode : public StructInfoNode { .def_ro("value", &PrimStructInfoNode::value) .def_ro("dtype", &PrimStructInfoNode::dtype); } - - static constexpr const char* _type_key = "relax.PrimStructInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PrimStructInfo", PrimStructInfoNode, StructInfoNode); }; /*! @@ -89,7 +87,7 @@ class PrimStructInfo : public StructInfo { /* Construct a PrimStructInfo with a known value */ TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimStructInfo, StructInfo, PrimStructInfoNode); }; /*! @@ -98,7 +96,7 @@ class PrimStructInfo : public StructInfo { class ShapeStructInfoNode : public StructInfoNode { public: /*! \brief optionally stores the symbolic value patterns of the shape */ - Optional> values; + ffi::Optional> values; /*! * \brief The number of dimension of the shape, can be unknown. * \sa kUnknownNDim @@ -114,9 +112,7 @@ class ShapeStructInfoNode : public StructInfoNode { .def_ro("values", &ShapeStructInfoNode::values) .def_ro("ndim", &ShapeStructInfoNode::ndim); } - - static constexpr const char* _type_key = "relax.ShapeStructInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ShapeStructInfo", ShapeStructInfoNode, StructInfoNode); }; /*! @@ -130,7 +126,7 @@ class ShapeStructInfo : public StructInfo { * \param values The symbolic shape values * \param span The span of the AST. */ - TVM_DLL ShapeStructInfo(Array values, Span span = Span()); + TVM_DLL ShapeStructInfo(ffi::Array values, Span span = Span()); /*! * \brief Construction with known unknown symbolic shape patterns. * \param ndim Number of dimensions -- can be kUnknownNDim @@ -138,7 +134,7 @@ class ShapeStructInfo : public StructInfo { */ TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeStructInfo, StructInfo, ShapeStructInfoNode); }; /*! @@ -150,11 +146,11 @@ class TensorStructInfoNode : public StructInfoNode { * \brief optionally store the shape expression of the tensor. * \note shape must be normalized: it can only be std::nullopt or ShapeExpr or Var. */ - Optional shape; + ffi::Optional shape; /*! \brief The virtual device, indicates where the tensor * is expected to be executed. */ - Optional vdevice; + ffi::Optional vdevice; /*! \brief The content data type, use void to denote the dtype is unknown. */ DataType dtype; /*! @@ -170,7 +166,7 @@ class TensorStructInfoNode : public StructInfoNode { bool IsUnknownDtype() const { return dtype.is_void(); } /*! \return Shape if it is known. */ - Optional> GetShape() const { + ffi::Optional> GetShape() const { if (!shape.defined()) return {}; ShapeStructInfo shape_sinfo = Downcast(this->shape.value()->struct_info_); return shape_sinfo->values; @@ -184,9 +180,7 @@ class TensorStructInfoNode : public StructInfoNode { .def_ro("vdevice", &TensorStructInfoNode::vdevice) .def_ro("ndim", &TensorStructInfoNode::ndim); } - - static constexpr const char* _type_key = "relax.TensorStructInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TensorStructInfo", TensorStructInfoNode, StructInfoNode); }; /*! @@ -204,8 +198,8 @@ class TensorStructInfo : public StructInfo { * * \note shape must already be normalized. */ - TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Optional vdevice = std::nullopt, - Span span = Span()); + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, + ffi::Optional vdevice = std::nullopt, Span span = Span()); /*! * \brief Construction with an unknown shape expression. @@ -214,10 +208,10 @@ class TensorStructInfo : public StructInfo { * \param vdevice The virtual device. * \param span The span of the AST. */ - TVM_DLL TensorStructInfo(DataType dtype, int ndim, Optional vdevice = std::nullopt, + TVM_DLL TensorStructInfo(DataType dtype, int ndim, ffi::Optional vdevice = std::nullopt, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorStructInfo, StructInfo, TensorStructInfoNode); }; /*! @@ -226,15 +220,13 @@ class TensorStructInfo : public StructInfo { class TupleStructInfoNode : public StructInfoNode { public: /*! \brief The struct info of tuple fields. */ - Array fields; + ffi::Array fields; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("fields", &TupleStructInfoNode::fields); } - - static constexpr const char* _type_key = "relax.TupleStructInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TupleStructInfo", TupleStructInfoNode, StructInfoNode); }; /*! @@ -248,9 +240,9 @@ class TupleStructInfo : public StructInfo { * \param fields Struct info of tuple fields. * \param span The span of the AST. */ - TVM_DLL TupleStructInfo(Array fields, Span span = Span()); + TVM_DLL TupleStructInfo(ffi::Array fields, Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TupleStructInfo, StructInfo, TupleStructInfoNode); }; /*! @@ -274,7 +266,7 @@ class FuncStructInfoNode : public StructInfoNode { * \note When params is std::nullopt means the function can take arbitrary number of arguments. * We define such functions as Opaque function. */ - Optional> params; + ffi::Optional> params; /*! * \brief The struct info of the function's return value. */ @@ -284,7 +276,7 @@ class FuncStructInfoNode : public StructInfoNode { * \note When derive_func is not empty, then params should be std::nullopt, * ret should be ObjectStructInfo() */ - Optional derive_func; + ffi::Optional derive_func; /*! * \brief Whether the function is pure. * \note This parameter should be set to true only if the function is pure on all inputs. @@ -306,9 +298,7 @@ class FuncStructInfoNode : public StructInfoNode { .def_ro("derive_func", &FuncStructInfoNode::derive_func) .def_ro("purity", &FuncStructInfoNode::purity); } - - static constexpr const char* _type_key = "relax.FuncStructInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.FuncStructInfo", FuncStructInfoNode, StructInfoNode); }; /*! @@ -317,6 +307,10 @@ class FuncStructInfoNode : public StructInfoNode { */ class FuncStructInfo : public StructInfo { public: + explicit FuncStructInfo(ObjectPtr data) : StructInfo(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } /*! * \brief Constructor from parameter struct info and return value struct info. * \param params The struct info of function parameters. @@ -327,7 +321,7 @@ class FuncStructInfo : public StructInfo { * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from * params. If you are unsure, you can always erase ret to static. */ - TVM_DLL FuncStructInfo(Array params, StructInfo ret, bool purity = true, + TVM_DLL FuncStructInfo(ffi::Array params, StructInfo ret, bool purity = true, Span span = Span()); /*! @@ -358,7 +352,7 @@ class FuncStructInfo : public StructInfo { TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false, Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FuncStructInfo, StructInfo, FuncStructInfoNode); }; /*! @@ -369,10 +363,10 @@ class FuncStructInfo : public StructInfo { * \tparam T the underlying structure info type */ template -inline Optional MatchStructInfo(const Expr& expr) { +inline ffi::Optional MatchStructInfo(const Expr& expr) { using TNode = typename T::ContainerType; if (const TNode* ptr = expr->struct_info_.as()) { - return GetRef(ptr); + return ffi::GetRef(ptr); } else { return std::nullopt; } @@ -401,7 +395,7 @@ inline const T* GetStructInfoAs(const Expr& expr) { inline StructInfo GetStructInfo(const Expr& expr) { auto* ptr = expr->struct_info_.as(); ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; - return GetRef(ptr); + return ffi::GetRef(ptr); } /*! diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h index 1397bafc36ff..6bd36560a6ac 100644 --- a/include/tvm/relax/tir_pattern.h +++ b/include/tvm/relax/tir_pattern.h @@ -41,9 +41,9 @@ class MatchResultNode : public Object { /*! The matched tir pattern*/ TIRPattern pattern; /*! \brief The evaluated values of symbolic vars. */ - Array symbol_values; + ffi::Array symbol_values; /*! \brief The matched buffers of input and output. */ - Array matched_buffers; + ffi::Array matched_buffers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -52,9 +52,7 @@ class MatchResultNode : public Object { .def_ro("symbol_values", &MatchResultNode::symbol_values) .def_ro("matched_buffers", &MatchResultNode::matched_buffers); } - - static constexpr const char* _type_key = "relax.MatchResult"; - TVM_DECLARE_FINAL_OBJECT_INFO(MatchResultNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.MatchResult", MatchResultNode, Object); }; /*! @@ -68,13 +66,13 @@ class MatchResult : public ObjectRef { * \param symbol_values The evaluated values of symbolic vars. * \param matched_buffers The matched buffers of input and output. */ - TVM_DLL explicit MatchResult(TIRPattern pattern, Array symbol_values, - Array matched_buffers); + TVM_DLL explicit MatchResult(TIRPattern pattern, ffi::Array symbol_values, + ffi::Array matched_buffers); - TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchResult, ObjectRef, MatchResultNode); }; -using FCodegen = ffi::TypedFunction(Array match_results)>; +using FCodegen = ffi::TypedFunction(ffi::Array match_results)>; } // namespace relax } // namespace tvm #endif // TVM_RELAX_TIR_PATTERN_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 1567294a4b38..786dfdcdf98c 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -54,8 +54,8 @@ using tvm::transform::CreateModulePass; * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass(std::function pass_func, - int opt_level, String name, tvm::Array required, - bool traceable = false); + int opt_level, ffi::String name, + tvm::ffi::Array required, bool traceable = false); /*! * \brief Create a dataflowblock pass. @@ -70,7 +70,7 @@ TVM_DLL Pass CreateFunctionPass(std::function pass_func, int opt_level, - String name, tvm::Array required, bool traceable = false); + ffi::String name, tvm::ffi::Array required, bool traceable = false); /*! * \brief Perform lambda lifting to lift functions from nested into global. @@ -125,18 +125,19 @@ TVM_DLL Pass RewriteDataflowReshape(); * The pass will reuse allocated memory to its best effort, in order to * reduce the total amount of allocated memory size. * - * The pass "supports" dynamic shape in the way of TIR variable upper bound - * annotation. We can optionally annotate the attribute "tir_var_upper_bound" - * to Relax functions. The attribute value is a dict from strings to integers, - * denoting the name of TIR variables to the upper bound values of the TIR vars. - * Note: The annotated upper bound attribute only applies to TIR vars in the + * The pass "supports" dynamic shape in the way of TIR variable bound + * annotations. We can optionally annotate the attributes "tir_var_upper_bound" + * and "tir_var_lower_bound" to Relax functions. The attribute values are dicts + * from strings to integers, denoting the name of TIR variables to the bound + * values of the TIR vars. + * Note: The annotated bound attributes only apply to TIR vars in the * function signature for clarity. * * For example, we can annotate a Relax function with - * `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`. - * It means the maximum value of variable that names "n" in the function - * signature will have upper bound 1024. And we will use 1024 as its value - * during memory planning. + * `R.func_attr({"tir_var_lower_bound": {"n": 1}, "tir_var_upper_bound": {"n": 1024}})`. + * It means the variable that names "n" in the function signature will have + * range [1, 1024]. And we will use these bounds during memory planning. + * If lower bound is not specified, it defaults to 0. * * \return The pass. */ @@ -196,7 +197,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false); * * \return The Pass. */ -TVM_DLL Pass BindParams(String func_name, Map params); +TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map params); /*! * \brief Bind symbolic vars to constant shape values. @@ -213,8 +214,8 @@ TVM_DLL Pass BindParams(String func_name, Map params); * * \return The Pass. */ -TVM_DLL Pass BindSymbolicVars(Map, PrimExpr> binding_map, - Optional func_name = std::nullopt); +TVM_DLL Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, + ffi::Optional func_name = std::nullopt); /*! * \brief Fold constant expressions within dataflow blocks. @@ -244,11 +245,14 @@ TVM_DLL Pass FoldConstant(); * * \param cmap The customized operator legalization function map. The customized function * will override the default one. + * \param skip_ops The list operator names which need to be skipped from legalization * \param enable_warning A boolean value indicating if to print warnings for TIR functions not * showing up in the database. * \return The Pass. */ -TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_warning = false); +TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, + ffi::Optional> skip_ops, + bool enable_warning = false); /*! * \brief Propagate virtual device information. @@ -303,7 +307,8 @@ TVM_DLL Pass SplitLayoutRewritePreproc(); * * \return The Pass. */ -TVM_DLL Pass LiftTransformParams(Variant> shared_transform = Bool(false)); +TVM_DLL Pass +LiftTransformParams(ffi::Variant> shared_transform = Bool(false)); /*! * \brief Update virtual device. @@ -364,7 +369,7 @@ class FusionPatternNode : public Object { * \brief The name of pattern. It becomes the value of the kComposite attribute * of a fused function after successful matching */ - String name; + ffi::String name; /*! * \brief The dataflow pattern that will be used to match expression in the DataflowBlock. @@ -376,7 +381,7 @@ class FusionPatternNode : public Object { * \brief The map which is used to extract important expressions from the pattern match * result. All DFPattern in this map should be part of the `pattern`. */ - Map annotation_patterns; + ffi::Map annotation_patterns; /*! * \brief The function to determine whether the match result is accepted. This can be @@ -385,15 +390,15 @@ class FusionPatternNode : public Object { * It should have signature * bool(const PatternCheckContext& context) */ - Optional check; + ffi::Optional check; /*! * \brief The function to get attributes for fused function * * It should have signature - * Map(const Map& context) + * ffi::Map(const ffi::Map& context) */ - Optional attrs_getter; + ffi::Optional attrs_getter; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -404,20 +409,19 @@ class FusionPatternNode : public Object { .def_ro("check", &FusionPatternNode::check) .def_ro("attrs_getter", &FusionPatternNode::attrs_getter); } - - static constexpr const char* _type_key = "relax.transform.FusionPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, Object); }; class FusionPattern : public ObjectRef { public: - FusionPattern(String name, DFPattern pattern, Map annotation_patterns, - Optional check, Optional attrs_getter); + FusionPattern(ffi::String name, DFPattern pattern, + ffi::Map annotation_patterns, + ffi::Optional check, ffi::Optional attrs_getter); - FusionPattern(String name, DFPattern pattern) + FusionPattern(ffi::String name, DFPattern pattern) : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {} - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FusionPattern, ObjectRef, FusionPatternNode); }; /*! @@ -434,25 +438,25 @@ class PatternCheckContextNode : public Object { * \brief A map which contains all expressions matched by the sub patterns in * FusionPattern::annotation_patterns. */ - Map annotated_expr; + ffi::Map annotated_expr; /*! * \brief Map from variable to its value. It contains variables from bindings that * is being fused by FuseOpsByPattern. */ - Map matched_bindings; + ffi::Map matched_bindings; /*! * \brief A map mapping variable definitions to a set of uses. It has all variables * used in the function. */ - Map> var_usages; + ffi::Map> var_usages; /*! * \brief Map from value to its bound variable. It doesn't have variables after the * matched expression. */ - Map value_to_bound_var; + ffi::Map value_to_bound_var; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -463,19 +467,19 @@ class PatternCheckContextNode : public Object { .def_ro("var_usages", &PatternCheckContextNode::var_usages) .def_ro("value_to_bound_var", &PatternCheckContextNode::value_to_bound_var); } - - static constexpr const char* _type_key = "relax.transform.PatternCheckContext"; - TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode, + Object); }; class PatternCheckContext : public ObjectRef { public: - PatternCheckContext(Expr matched_expr, Map annotated_expr, - Map matched_bindings, Map> var_usages, - Map value_to_bound_var); + PatternCheckContext(Expr matched_expr, ffi::Map annotated_expr, + ffi::Map matched_bindings, + ffi::Map> var_usages, + ffi::Map value_to_bound_var); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef, - PatternCheckContextNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PatternCheckContext, ObjectRef, + PatternCheckContextNode); }; /*! @@ -503,7 +507,8 @@ class PatternCheckContext : public ObjectRef { * * \note ConvertToDataflow may need to be called first to provide dataflow blocks. */ -TVM_DLL Pass Gradient(String func_name, Optional> require_grads = std::nullopt, +TVM_DLL Pass Gradient(ffi::String func_name, + ffi::Optional> require_grads = std::nullopt, int target_index = 0); /*! @@ -526,9 +531,9 @@ TVM_DLL Pass Gradient(String func_name, Optional> require_grads = std * * \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants = true, - bool annotate_codegen = false, - const tvm::Array& entry_function_names = {}); +TVM_DLL Pass FuseOpsByPattern(const tvm::ffi::Array& patterns, + bool bind_constants = true, bool annotate_codegen = false, + const tvm::ffi::Array& entry_function_names = {}); /*! * \brief Group one or multiple composite functions created by FuseOpsByPattern into a new @@ -553,8 +558,9 @@ TVM_DLL Pass FuseTIR(); * \param entry_functions list of entry functions * \return The Pass. */ -TVM_DLL Pass RunCodegen(Optional>> target_options, - Array entry_functions); +TVM_DLL Pass +RunCodegen(ffi::Optional>> target_options, + ffi::Array entry_functions); /*! * \brief Decompose composite operators during inference. For example, The result of batch norm (a @@ -564,7 +570,7 @@ TVM_DLL Pass RunCodegen(Optional>> target_opti * \param func_name The name of the specified function. If not specified, the pass will run in * all functions. */ -TVM_DLL Pass DecomposeOpsForInference(Optional func_name); +TVM_DLL Pass DecomposeOpsForInference(ffi::Optional func_name); /*! * \brief Decompose composite operators during training. For example, The result of batch norm (a @@ -574,7 +580,7 @@ TVM_DLL Pass DecomposeOpsForInference(Optional func_name); * \param func_name The name of the specified function. If not specified, the pass will run in * all functions. */ -TVM_DLL Pass DecomposeOpsForTraining(Optional func_name); +TVM_DLL Pass DecomposeOpsForTraining(ffi::Optional func_name); /*! * \brief Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in \p @@ -590,10 +596,12 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional func_name); * \param input_axis_separators Map from kOperatorName attr to axis_separator for input buffer * \return The Pass. */ -TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, - const Map>& op_buffer_transforms, - const Map>>>& axis_separators, - const Map>>>& input_axis_separators); +TVM_DLL Pass AlterOpImpl( + const ffi::Map& op_impl_map, + const ffi::Map>& op_buffer_transforms, + const ffi::Map>>>& axis_separators, + const ffi::Map>>>& + input_axis_separators); /*! * \brief Layout conversion pass. @@ -601,7 +609,7 @@ TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, * \return The Pass. * \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass ConvertLayout(Map> desired_layouts); +TVM_DLL Pass ConvertLayout(ffi::Map> desired_layouts); /*! * \brief A pass that converts consecutive dataflow operations @@ -628,7 +636,7 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); * * \return The Pass. */ -TVM_DLL Pass DeadCodeElimination(Array entry_functions = {}); +TVM_DLL Pass DeadCodeElimination(ffi::Array entry_functions = {}); /*! * \brief Pass that changes calls to operators that can be done in-place @@ -651,8 +659,9 @@ TVM_DLL Pass DataflowUseInplaceCalls(); * * \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype, - Optional> fp16_input_names = std::nullopt); +TVM_DLL Pass +ToMixedPrecision(const DataType& out_dtype, + ffi::Optional> fp16_input_names = std::nullopt); /*! * \brief Rewrite a Relax module for executing with CUDA graph. This pass identifies @@ -673,6 +682,13 @@ TVM_DLL Pass RewriteCUDAGraph(); */ TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark); +/*! + * \brief This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + * Primarily used to update the VDevice information if any changes occured from the caller. + * This pass recreates the buffers and updates the map. + */ +TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 18fd16af4d2b..8eaaf7bddc48 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -48,9 +48,7 @@ class ShapeTypeNode : public TypeNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("ndim", &ShapeTypeNode::ndim); } - - static constexpr const char* _type_key = "relax.ShapeType"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ShapeType", ShapeTypeNode, TypeNode); }; class ShapeType : public Type { @@ -58,7 +56,7 @@ class ShapeType : public Type { // TODO(relax-team): remove the default value later. TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeType, Type, ShapeTypeNode); }; /*! @@ -86,9 +84,7 @@ class TensorTypeNode : public TypeNode { inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } inline bool IsUnknownDtype() const { return dtype.is_void(); } - - static constexpr const char* _type_key = "relax.DynTensorType"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DynTensorType", TensorTypeNode, TypeNode); }; /*! @@ -110,7 +106,7 @@ class TensorType : public Type { */ TVM_DLL static TensorType CreateUnknownNDim(DataType dtype, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorType, Type, TensorTypeNode); }; using TensorTypeNode = TensorTypeNode; @@ -122,16 +118,14 @@ class ObjectTypeNode : public TypeNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "relax.ObjectType"; - TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ObjectType", ObjectTypeNode, TypeNode); }; class ObjectType : public Type { public: TVM_DLL ObjectType(Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectType, Type, ObjectTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ObjectType, Type, ObjectTypeNode); }; class PackedFuncTypeNode : public TypeNode { @@ -140,16 +134,14 @@ class PackedFuncTypeNode : public TypeNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "relax.PackedFuncType"; - TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncTypeNode, TypeNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.PackedFuncType", PackedFuncTypeNode, TypeNode); }; class PackedFuncType : public Type { public: TVM_DLL PackedFuncType(Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PackedFuncType, Type, PackedFuncTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PackedFuncType, Type, PackedFuncTypeNode); }; } // namespace relax diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index e48c1856f9fe..70ecbe4855ac 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -47,15 +47,15 @@ namespace relax { * * \return The updated expression. */ -TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds, - const tvm::Map& symbolic_var_map = {}); +TVM_DLL Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, + const tvm::ffi::Map& symbolic_var_map = {}); /*! * \brief Bind the symbolic variables to a StructInfo. This is a helper function usually called by * other pass functions to help optimizations. */ TVM_DLL StructInfo Bind(const StructInfo& sinfo, - const tvm::Map& symbolic_var_map); + const tvm::ffi::Map& symbolic_var_map); /*! * \brief Infer a binding map for symbolic variables @@ -74,8 +74,8 @@ TVM_DLL StructInfo Bind(const StructInfo& sinfo, * * \return A map of TIR variables to TIR expressions */ -TVM_DLL tvm::Map InferSymbolicVarMap( - const tvm::Map& binds, arith::Analyzer* analyzer); +TVM_DLL tvm::ffi::Map InferSymbolicVarMap( + const tvm::ffi::Map& binds, arith::Analyzer* analyzer); /*! * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean diff --git a/include/tvm/runtime/base.h b/include/tvm/runtime/base.h index c704decb63e9..d838966aec13 100644 --- a/include/tvm/runtime/base.h +++ b/include/tvm/runtime/base.h @@ -29,7 +29,7 @@ #include // TVM version -#define TVM_VERSION "0.21.dev0" +#define TVM_VERSION "0.23.dev0" // define extra macros for TVM DLL exprt #ifdef __EMSCRIPTEN__ diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index 4e6c2f53641a..a2eefb5b7d14 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -47,15 +47,6 @@ extern "C" { TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* out); -/*! - * \brief Backend function to register system-wide library symbol. - * - * \param name The name of the symbol - * \param ptr The symbol address. - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_DLL int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* ptr); - /*! * \brief Backend function to allocate temporal workspace. * diff --git a/include/tvm/runtime/contrib/papi.h b/include/tvm/runtime/contrib/papi.h index 93c1aa274bfd..551f66726473 100644 --- a/include/tvm/runtime/contrib/papi.h +++ b/include/tvm/runtime/contrib/papi.h @@ -38,7 +38,8 @@ namespace profiling { * collected on that device. You can find the names of available metrics by * running `papi_native_avail`. */ -TVM_DLL MetricCollector CreatePAPIMetricCollector(Map> metrics); +TVM_DLL MetricCollector +CreatePAPIMetricCollector(ffi::Map> metrics); } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index e24768bde2f8..a4a01b1223f6 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -60,6 +60,7 @@ class DataType { kFloat = kDLFloat, kHandle = kDLOpaqueHandle, kBFloat = kDLBfloat, + kBool = kDLBool, kFloat8_e3m4 = kDLFloat8_e3m4, kFloat8_e4m3 = kDLFloat8_e4m3, kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz, @@ -71,7 +72,8 @@ class DataType { kFloat6_e2m3fn = kDLFloat6_e2m3fn, kFloat6_e3m2fn = kDLFloat6_e3m2fn, kFloat4_e2m1fn = kDLFloat4_e2m1fn, - kCustomBegin = 129 + kCustomBegin = 129, + kTensorFloat32 = 130 }; /*! \brief default constructor */ DataType() { data_ = DataType::Void(); } @@ -108,6 +110,9 @@ class DataType { if (code == kFloat4_e2m1fn) { ICHECK_EQ(bits, 4); } + if (code == kTensorFloat32) { + ICHECK_EQ(bits, 32); + } } /*! \return The type code. */ int code() const { return static_cast(data_.code); } @@ -137,12 +142,16 @@ class DataType { } /*! \return whether type is a scalar type. */ bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } - /*! \return whether type is a scalar type. */ - bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } + /*! \return whether type is a bool type. */ + bool is_bool() const { return code() == DataType::kBool; } + /*! \return whether type can be used in a predicate expression. */ + bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() == 1); } /*! \return whether type is a float type. */ bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a bfloat type. */ bool is_bfloat() const { return code() == DataType::kBFloat; } + /*! \return whether type is a tfloat type. */ + bool is_tfloat() const { return code() == DataType::kTensorFloat32; } /*! \return whether type is any 8-bit custom Float8 variant. */ bool is_float8() const { return bits() == 8 && @@ -182,6 +191,8 @@ class DataType { bool is_float6_e3m2fn() const { return bits() == 6 && code() == DataType::kFloat6_e3m2fn; } /*! \return whether type is Float4E2M1FN. */ bool is_float4_e2m1fn() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; } + /*! \return whether type is a tfloat32 type. */ + bool is_tfloat32() const { return bits() == 32 && code() == DataType::kTensorFloat32; } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ @@ -204,9 +215,11 @@ class DataType { /*! \return whether type is a vector type. */ bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ - bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; } + bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && is_bool(); } /*! \return whether type is a Void type. */ - bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } + bool is_void() const { + return code() == DataType::kHandle && bits() == 0 && static_cast(data_.lanes) == 0; + } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. @@ -372,6 +385,14 @@ class DataType { * \return The constructed data type. */ static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); } + + /*! + * \brief Construct a tensorfloat32 datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType TensorFloat32(int lanes = 1) { return DataType(kTensorFloat32, 32, lanes); } + /*! * \brief Construct a bool type. * \param lanes The number of lanes. @@ -379,7 +400,7 @@ class DataType { * \return The constructed data type. */ static DataType Bool(int lanes = 1, bool is_scalable = false) { - return DataType::UInt(1, lanes, is_scalable); + return DataType(kDLBool, 8, lanes, is_scalable); } /*! * \brief Construct a handle type. @@ -465,6 +486,7 @@ struct TypeTraits : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const runtime::DataType& src, TVMFFIAny* result) { // clear padding part to ensure the equality check can always check the v_uint64 part result->v_uint64 = 0; + result->zero_padding = 0; result->type_index = TypeIndex::kTVMFFIDataType; result->v_dtype = src; } @@ -472,6 +494,7 @@ struct TypeTraits : public TypeTraitsBase { TVM_FFI_INLINE static void MoveToAny(runtime::DataType src, TVMFFIAny* result) { // clear padding part to ensure the equality check can always check the v_uint64 part result->v_uint64 = 0; + result->zero_padding = 0; result->type_index = TypeIndex::kTVMFFIDataType; result->v_dtype = src; } @@ -493,6 +516,10 @@ struct TypeTraits : public TypeTraitsBase { } TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } + + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) + R"("})"; + } }; } // namespace ffi diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index bc0faf2413e5..ae119e52652b 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -21,7 +21,7 @@ #include #include -#include +#include #include @@ -62,15 +62,15 @@ inline std::string ReduceKind2String(ReduceKind kind) { * \param device The default device used to initialize the RelaxVM * \return The RelaxVM as a runtime Module */ -TVM_DLL ffi::Module LoadVMModule(std::string path, Optional device); +TVM_DLL ffi::Module LoadVMModule(std::string path, ffi::Optional device); /*! - * \brief Create an uninitialized empty NDArray - * \param shape The shape of the NDArray - * \param dtype The dtype of the NDArray - * \param device The device the NDArray is created on. If None, use the thread local default device - * \return The NDArray created + * \brief Create an uninitialized empty Tensor + * \param shape The shape of the Tensor + * \param dtype The dtype of the Tensor + * \param device The device the Tensor is created on. If None, use the thread local default device + * \return The Tensor created */ -TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional device); +TVM_DLL Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, ffi::Optional device); /*! * \brief Perform an allreduce operation using the underlying communication library * \param send The array send to perform allreduce on @@ -78,21 +78,21 @@ TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional send, bool in_group, NDArray recv); +TVM_DLL void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv); /*! * \brief Perform a gather operation to worker-0. * \param send The sending buffer, which must not be None. @@ -108,36 +108,36 @@ TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, NDArray r * \param recv For worker-0, it must be provided, and otherwise, the buffer must be None. The * receiving buffer will be divided into equal parts and receive from each worker accordingly. */ -TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional recv); +TVM_DLL void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv); /*! * \brief Receive a buffer from worker-0. No-op if the current worker is worker-0. * \param buffer The buffer to be received */ -TVM_DLL void RecvFromWorker0(NDArray buffer); +TVM_DLL void RecvFromWorker0(Tensor buffer); /*! * \brief Send a buffer to the corresponding worker in the next group. * An error is thrown if the worker is already in the last group. * \param buffer The sending buffer. */ -TVM_DLL void SendToNextGroup(NDArray buffer); +TVM_DLL void SendToNextGroup(Tensor buffer); /*! * \brief Receive a buffer from the corresponding worker in the previous group. * An error is thrown if the worker is already in the first group. * \param buffer The receiving buffer. */ -TVM_DLL void RecvFromPrevGroup(NDArray buffer); +TVM_DLL void RecvFromPrevGroup(Tensor buffer); /*! * \brief Send a buffer to the target receiver worker (globally across all groups). * \param buffer The sending buffer. * \param receiver_id The global receiver worker id. */ -TVM_DLL void SendToWorker(NDArray buffer, int receiver_id); +TVM_DLL void SendToWorker(Tensor buffer, int receiver_id); /*! * \brief Receive a buffer from the target sender worker (globally across all groups). * \param buffer The receiving buffer. * \param sender_id The global sender worker id. */ -TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id); +TVM_DLL void RecvFromWorker(Tensor buffer, int sender_id); /*! \brief Get the local worker id */ TVM_DLL int WorkerId(); /*! diff --git a/include/tvm/runtime/disco/cuda_ipc_memory.h b/include/tvm/runtime/disco/cuda_ipc_memory.h index a77e06ccaef5..a6bfbd866b06 100644 --- a/include/tvm/runtime/disco/cuda_ipc_memory.h +++ b/include/tvm/runtime/disco/cuda_ipc_memory.h @@ -70,8 +70,8 @@ class CUDAIPCMemoryObj : public Object { /*! \brief The integer buffer flag for all-reduce. */ int barrier_flag; - static constexpr const char* _type_key = "tvm.runtime.disco.cuda_ipc_memory"; - TVM_DECLARE_BASE_OBJECT_INFO(CUDAIPCMemoryObj, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("tvm.runtime.disco.cuda_ipc_memory", CUDAIPCMemoryObj, Object); }; /*! @@ -90,7 +90,7 @@ class CUDAIPCMemory : public ObjectRef { */ TVM_DLL static CUDAIPCMemory GetIPCMemoryFromDevicePtr(void* ptr); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAIPCMemory, ObjectRef, CUDAIPCMemoryObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CUDAIPCMemory, ObjectRef, CUDAIPCMemoryObj); }; } // namespace cuda_ipc diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 078c061b7b82..464efb59c01b 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -79,7 +79,7 @@ class DiscoWorker { /*! \brief The default device to allocate data if not specified */ Device default_device; /*! \brief The name of the underlying collective communication library. */ - String ccl; + ffi::String ccl; /*! * \brief The data shared between worker-0 and the controler. It's a nullptr if * the worker is not worker-0. diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 4fe0e72e79c1..283d75740c4f 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -46,7 +46,7 @@ * It is assumed that the controler can synchronize with and access the registers of worker-0. * The Disco session provides multiple APIs to interact specifically with the worker-0. * To shared data with other workers, a common paradigm in Disco is to copy data from the - * controler-side NDArray to the worker-0, and then copy it to other workers using primitives on + * controler-side Tensor to the worker-0, and then copy it to other workers using primitives on * the data plane, for example, `broadcast` and `send`. * * **Control plane.** The controler broadcasts commands to all the workers as control signals. @@ -74,8 +74,8 @@ #include #include -#include #include +#include #include #include @@ -143,16 +143,16 @@ class DRefObj : public Object { */ inline ffi::Any DebugGetFromRemote(int worker_id); /*! - * \brief Copy from the NDArray provided to a remote worker. + * \brief Copy from the Tensor provided to a remote worker. * \param worker_id The id of the worker to be copied to. - * \param source The NDArray to be copied. + * \param source The Tensor to be copied. */ inline void DebugCopyFrom(int worker_id, ffi::AnyView source); - static constexpr const char* _type_key = "runtime.disco.DRef"; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef; static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(DRefObj, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC("runtime.disco.DRef", DRefObj, Object); /*! \brief The id of the register */ int64_t reg_id; @@ -170,7 +170,8 @@ class DRefObj : public Object { */ class DRef : public ObjectRef { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj); + explicit DRef(ObjectPtr data) : ObjectRef(data) { TVM_FFI_ICHECK(data != nullptr); } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DRef, ObjectRef, DRefObj); }; /*! @@ -189,7 +190,7 @@ class SessionObj : public Object { * - std::string; * - DRef. * Examples of unsupported types: - * - NDArray, DLTensor; + * - Tensor, DLTensor; * - TVM Objects, including ffi::Function, Module and String; * \param func The function to be called. * \param args The variadic arguments. @@ -209,17 +210,17 @@ class SessionObj : public Object { /*! \brief Get a global functions on workers. */ TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0; /*! - * \brief Copy an NDArray from worker-0 to the controler-side NDArray + * \brief Copy an Tensor from worker-0 to the controler-side Tensor * \param host_array The array to be copied to worker-0 - * \param remote_array The NDArray on worker-0 + * \param remote_array The Tensor on worker-0 */ - TVM_DLL virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0; + TVM_DLL virtual void CopyFromWorker0(const Tensor& host_array, const DRef& remote_array) = 0; /*! - * \brief Copy the controler-side NDArray to worker-0 + * \brief Copy the controler-side Tensor to worker-0 * \param host_array The array to be copied to worker-0 - * \param remote_array The NDArray on worker-0 + * \param remote_array The Tensor on worker-0 */ - TVM_DLL virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0; + TVM_DLL virtual void CopyToWorker0(const Tensor& host_array, const DRef& remote_array) = 0; /*! * \brief Synchrnoize the controler with a worker, and it will wait until worker finishes * executing this instruction. @@ -235,7 +236,7 @@ class SessionObj : public Object { * \param ccl The name of the communication backend, e.g., nccl, rccl, mpi. * \param device_ids The device ids of the workers. */ - TVM_DLL virtual void InitCCL(String ccl, IntTuple device_ids) = 0; + TVM_DLL virtual void InitCCL(ffi::String ccl, IntTuple device_ids) = 0; /*! * \brief Get the value of a register from a remote worker. * \param reg_id The id of the register to be fetched. @@ -254,8 +255,9 @@ class SessionObj : public Object { struct FFI; friend struct SessionObj::FFI; friend class DRefObj; - static constexpr const char* _type_key = "runtime.disco.Session"; - TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object); + + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("runtime.disco.Session", SessionObj, Object); protected: /*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */ @@ -287,9 +289,9 @@ class Session : public ObjectRef { * worker-0 does not exist in the process pool. */ TVM_DLL static Session ProcessSession(int num_workers, int num_groups, - String process_pool_creator, String entrypoint); + ffi::String process_pool_creator, ffi::String entrypoint); - TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Session, ObjectRef, SessionObj); }; /*! @@ -319,7 +321,7 @@ class WorkerZeroData { * \brief The host-side arrays to passed to worker-0 for special uses, for example, * copy-to-worker0 and copy-from-worker0 */ - std::queue host_arrays; + std::queue host_arrays; /*! \brief The mutex that guards `host_arrays` */ std::mutex queue_mutex_; }; diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index da715848e09a..f39a07b3d968 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -206,7 +206,7 @@ class InternalError : public Error { */ InternalError(std::string file, int lineno, std::string message) : Error(DetectKind(message), DetectMessage(message), - TVMFFITraceback(file.c_str(), lineno, "")) {} + TVMFFIBacktrace(file.c_str(), lineno, "", 0)) {} private: // try to detect the kind of error from the message when the error type diff --git a/include/tvm/runtime/memory/memory_manager.h b/include/tvm/runtime/memory/memory_manager.h index f103c6f30ac8..8d2de7791af0 100644 --- a/include/tvm/runtime/memory/memory_manager.h +++ b/include/tvm/runtime/memory/memory_manager.h @@ -25,8 +25,8 @@ #define TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_ #include -#include #include +#include #include #include @@ -59,15 +59,15 @@ class Allocator { public: explicit Allocator(AllocatorType type) : type_(type) {} virtual ~Allocator() = default; - /*! \brief Allocate an empty NDArray using from the allocator. - * \param shape The shape of the NDArray. - * \param dtype The datatype of the NDArray. + /*! \brief Allocate an empty Tensor using from the allocator. + * \param shape The shape of the Tensor. + * \param dtype The datatype of the Tensor. * \param dev The device where the array is allocated. * \param mem_scope The device memory scope hint. - * \return The empty NDArray. + * \return The empty Tensor. */ - TVM_DLL NDArray Empty(ffi::Shape shape, DLDataType dtype, Device dev, - Optional mem_scope = std::nullopt); + TVM_DLL Tensor Empty(ffi::Shape shape, DLDataType dtype, Device dev, + ffi::Optional mem_scope = std::nullopt); /*! \brief Return the allocator type. */ inline AllocatorType type() const { return type_; } /*! \brief Allocate a buffer given a size, alignment and type. @@ -163,12 +163,12 @@ class StorageObj : public Object { /*! \brief The allocator where the storage buffer is allocated from. */ Allocator* allocator = nullptr; - /*! \brief Allocate an NDArray from a given piece of storage. */ - TVM_DLL NDArray AllocNDArray(int64_t offset, ffi::Shape shape, DLDataType dtype); + /*! \brief Allocate an Tensor from a given piece of storage. */ + TVM_DLL Tensor AllocTensor(int64_t offset, ffi::Shape shape, DLDataType dtype); - /*! \brief Allocate an NDArray with memory scope from a given piece of storage. */ - TVM_DLL NDArray AllocNDArrayScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, - String scope = "global"); + /*! \brief Allocate an Tensor with memory scope from a given piece of storage. */ + TVM_DLL Tensor AllocTensorScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, + ffi::String scope = "global"); ~StorageObj() { if (allocator) { @@ -176,8 +176,8 @@ class StorageObj : public Object { } } - static constexpr const char* _type_key = "vm.Storage"; - TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("vm.Storage", StorageObj, Object); }; /*! \brief reference to storage. */ @@ -185,7 +185,7 @@ class Storage : public ObjectRef { public: TVM_DLL explicit Storage(Buffer buffer, Allocator* allocator); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Storage, ObjectRef, StorageObj); }; } // namespace memory diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index f805ec988d37..1e0e7039448b 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -45,7 +45,7 @@ namespace runtime { * \param target The target module name. * \return Whether runtime is enabled. */ -TVM_DLL bool RuntimeEnabled(const String& target); +TVM_DLL bool RuntimeEnabled(const ffi::String& target); /*! \brief namespace for constant symbols */ namespace symbol { @@ -105,11 +105,11 @@ struct ModuleVTableEntryHelper { } // namespace runtime } // namespace tvm -#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ - const char* kind() const final { return TypeKey; } \ - ::tvm::ffi::Optional<::tvm::ffi::Function> GetFunction(const String& _name) override { \ - using SelfPtr = std::remove_cv_t; \ - ::tvm::ffi::ObjectPtr<::tvm::ffi::Object> _self = \ +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* kind() const final { return TypeKey; } \ + ::tvm::ffi::Optional<::tvm::ffi::Function> GetFunction(const ffi::String& _name) override { \ + using SelfPtr = std::remove_cv_t; \ + ::tvm::ffi::ObjectPtr<::tvm::ffi::Object> _self = \ ::tvm::ffi::GetObjectPtr<::tvm::ffi::Object>(this); #define TVM_MODULE_VTABLE_END() \ return std::nullopt; \ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 302b161b6fd7..d60b5712c78d 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -52,8 +52,8 @@ enum TypeIndex : int32_t { // Frontends can take benefit of these constants. /*! \brief runtime::Module. */ kRuntimeModule = TVMFFITypeIndex::kTVMFFIModule, - /*! \brief runtime::NDArray. */ - kRuntimeNDArray = TVMFFITypeIndex::kTVMFFINDArray, + /*! \brief runtime::Tensor. */ + kRuntimeTensor = TVMFFITypeIndex::kTVMFFITensor, /*! \brief runtime::Shape. */ kRuntimeShape = TVMFFITypeIndex::kTVMFFIShape, // Extra builtin static index here @@ -106,18 +106,18 @@ static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - static_assert(ObjectName::_type_final, \ - "TVM's CopyOnWrite may only be used for " \ - "Object types that are declared as final, " \ - "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \ - ObjectName* CopyOnWrite() { \ - ICHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + static_assert(ObjectName::_type_final, \ + "TVM's CopyOnWrite may only be used for " \ + "Object types that are declared as final, " \ + "using the TVM_FFI_DECLARE_OBJECT_INFO_FINAL macro."); \ + ObjectName* CopyOnWrite() { \ + ICHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = ::tvm::ffi::make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ } /* @@ -126,23 +126,14 @@ static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ - ObjectName) \ - explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ +#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE_WITHOUT_DEFAULT_CONSTRUCTOR( \ + TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ using ContainerType = ObjectName; -#define TVM_DECLARE_BASE_OBJECT_INFO TVM_FFI_DECLARE_BASE_OBJECT_INFO -#define TVM_DECLARE_FINAL_OBJECT_INFO TVM_FFI_DECLARE_FINAL_OBJECT_INFO -#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS - -#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS -#define TVM_DEFINE_OBJECT_REF_METHODS TVM_FFI_DEFINE_OBJECT_REF_METHODS -#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS \ - TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS - #define TVM_STR_CONCAT_(__x, __y) __x##__y #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 9f25b6775c13..c04310d9db20 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -30,8 +30,8 @@ #include #include #include -#include #include +#include #include #include @@ -75,8 +75,8 @@ class TimerNode : public Object { virtual ~TimerNode() {} - static constexpr const char* _type_key = "runtime.TimerNode"; - TVM_DECLARE_BASE_OBJECT_INFO(TimerNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("runtime.TimerNode", TimerNode, Object); }; /*! \brief Timer for a specific device. @@ -126,7 +126,7 @@ class Timer : public ObjectRef { * virtual ~CPUTimerNode() {} * * static constexpr const char* _type_key = "runtime.CPUTimerNode"; - * TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode); + * TVM_FFI_DECLARE_OBJECT_INFO_FINAL(CPUTimerNode, TimerNode); * * private: * std::chrono::high_resolution_clock::time_point start_; @@ -134,17 +134,17 @@ class Timer : public ObjectRef { * }; * * - * TVM_FFI_STATIC_INIT_BLOCK({ + * TVM_FFI_STATIC_INIT_BLOCK() { * namespace refl = tvm::ffi::reflection; * refl::GlobalDef().def("profiling.timer.cpu", [](Device dev) { - * return Timer(make_object()); + * return Timer(ffi::make_object()); * }); - * }); + * } * \endcode */ static TVM_DLL Timer Start(Device dev); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Timer, ObjectRef, TimerNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Timer, ObjectRef, TimerNode); }; /*! @@ -166,16 +166,14 @@ struct DeviceWrapperNode : public Object { /*! Constructor */ explicit DeviceWrapperNode(Device device) : device(device) {} - - static constexpr const char* _type_key = "runtime.profiling.DeviceWrapper"; - TVM_DECLARE_BASE_OBJECT_INFO(DeviceWrapperNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("runtime.profiling.DeviceWrapper", DeviceWrapperNode, Object); }; /*! \brief Wrapper for `Device`. */ class DeviceWrapper : public ObjectRef { public: - explicit DeviceWrapper(Device dev) { data_ = make_object(dev); } - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DeviceWrapper, ObjectRef, DeviceWrapperNode); + explicit DeviceWrapper(Device dev) { data_ = ffi::make_object(dev); } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DeviceWrapper, ObjectRef, DeviceWrapperNode); }; /*! \brief Data collected from a profiling run. Includes per-call metrics and per-device metrics. @@ -189,7 +187,7 @@ class ReportNode : public Object { * and "Duration (us)". Values are one of `String`, `PercentNode`, * `DurationNode`, or `CountNode`. */ - Array> calls; + ffi::Array> calls; /*! \brief Metrics collected for the entire run of the model on a per-device basis. * * `device_metrics` is indexed by device name then metric. @@ -197,17 +195,17 @@ class ReportNode : public Object { * These metrics may be larger than the sum of the same metric in `calls` * because these metrics include the overhead of the executor. */ - Map> device_metrics; + ffi::Map> device_metrics; /*! Configuration used for this profiling run. Includes number of threads, executor. * * Values must be an object type that can be used with device_metrics. */ - Map configuration; + ffi::Map configuration; /*! \brief Output `calls` in CSV format. * * Note that this does not include `device_metrics`, it only includes per-call metrics. */ - String AsCSV() const; + ffi::String AsCSV() const; /*! \brief Create a human readable table of profiling metrics. * * \param aggregate Whether or not to join multiple calls to the @@ -222,7 +220,7 @@ class ReportNode : public Object { * the Count, Duation, and Percent columns. * */ - String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const; + ffi::String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const; /*! \brief Convert this report to JSON. * * Output JSON will be of this format: @@ -255,10 +253,8 @@ class ReportNode : public Object { * } * \endcode */ - String AsJSON() const; - - static constexpr const char* _type_key = "runtime.profiling.Report"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReportNode, Object); + ffi::String AsJSON() const; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.profiling.Report", ReportNode, Object); }; class Report : public ObjectRef { @@ -268,16 +264,16 @@ class Report : public ObjectRef { * \param device_metrics Per-device metrics for overall execution. * \param configuration Configuration data specific to this profiling run. */ - explicit Report(Array> calls, - Map> device_metrics, - Map configuration); + explicit Report(ffi::Array> calls, + ffi::Map> device_metrics, + ffi::Map configuration); /*! Deserialize a Report from a JSON object. Needed for sending the report over RPC. * \param json Serialized json report from `ReportNode::AsJSON`. * \returns A Report. */ - static Report FromJSON(String json); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Report, ObjectRef, ReportNode); + static Report FromJSON(ffi::String json); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Report, ObjectRef, ReportNode); }; /*! \brief Interface for user defined profiling metric collection. @@ -304,7 +300,7 @@ class MetricCollectorNode : public Object { * expensive precomputation should happen here. * \param devs The list of devices this collector will be run on. */ - virtual void Init(Array devs) = 0; + virtual void Init(ffi::Array devs) = 0; /*! \brief Start colling metrics for a function call. * \param dev The device the call will be run on. * \returns An object used to maintain state of the metric collection. This @@ -317,18 +313,18 @@ class MetricCollectorNode : public Object { * \returns A set of metric names and the associated values. Values must be * one of DurationNode, PercentNode, CountNode, or String. */ - virtual Map Stop(ffi::ObjectRef obj) = 0; + virtual ffi::Map Stop(ffi::ObjectRef obj) = 0; virtual ~MetricCollectorNode() {} - static constexpr const char* _type_key = "runtime.profiling.MetricCollector"; - TVM_DECLARE_BASE_OBJECT_INFO(MetricCollectorNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("runtime.profiling.MetricCollector", MetricCollectorNode, Object); }; /*! \brief Wrapper for `MetricCollectorNode`. */ class MetricCollector : public ObjectRef { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetricCollector, ObjectRef, MetricCollectorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MetricCollector, ObjectRef, MetricCollectorNode); }; /*! Information about a single function or operator call. */ @@ -336,7 +332,7 @@ struct CallFrame { /*! Device on which the call was made */ Device dev; /*! Name of the function or op */ - String name; + ffi::String name; /*! Runtime of the function or op */ Timer timer; /*! Extra performance metrics */ @@ -382,7 +378,7 @@ class Profiler { * \param configuration Additional configuration data to add to the outputted profiling report. */ explicit Profiler(std::vector devs, std::vector metric_collectors, - std::unordered_map configuration = {}); + std::unordered_map configuration = {}); /*! \brief Start the profiler. * * This function should only be called once per object. @@ -403,7 +399,7 @@ class Profiler { * `StopCall`. Function calls are stopped in LIFO order, so calls to * `StartCall` and `StopCall` must be nested properly. */ - void StartCall(String name, Device dev, + void StartCall(ffi::String name, Device dev, std::unordered_map extra_metrics = {}); /*! \brief Stop the last `StartCall`. * \param extra_metrics Optional additional profiling information to add to @@ -427,7 +423,7 @@ class Profiler { std::vector calls_; std::stack in_flight_; std::vector collectors_; - std::unordered_map configuration_; + std::unordered_map configuration_; }; /* \brief A duration in time. */ @@ -440,9 +436,7 @@ class DurationNode : public Object { * \param a The duration in microseconds. */ explicit DurationNode(double a) : microseconds(a) {} - - static constexpr const char* _type_key = "runtime.profiling.Duration"; - TVM_DECLARE_FINAL_OBJECT_INFO(DurationNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.profiling.Duration", DurationNode, Object); }; /* A percentage of something */ @@ -455,9 +449,7 @@ class PercentNode : public Object { * \param a The percentage out of 100. */ explicit PercentNode(double a) : percent(a) {} - - static constexpr const char* _type_key = "runtime.profiling.Percent"; - TVM_DECLARE_FINAL_OBJECT_INFO(PercentNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.profiling.Percent", PercentNode, Object); }; /* A count of something */ @@ -470,9 +462,7 @@ class CountNode : public Object { * \param a The count. */ explicit CountNode(int64_t a) : value(a) {} - - static constexpr const char* _type_key = "runtime.profiling.Count"; - TVM_DECLARE_FINAL_OBJECT_INFO(CountNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.profiling.Count", CountNode, Object); }; /* \brief A ratio of two things. */ @@ -485,28 +475,26 @@ class RatioNode : public Object { * \param a The ratio. */ explicit RatioNode(double a) : ratio(a) {} - - static constexpr const char* _type_key = "runtime.profiling.Ratio"; - TVM_DECLARE_FINAL_OBJECT_INFO(RatioNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.profiling.Ratio", RatioNode, Object); }; -/*! \brief String representation of an array of NDArray shapes - * \param shapes Array of NDArrays to get the shapes of. +/*! \brief ffi::String representation of an array of Tensor shapes + * \param shapes Array of Tensors to get the shapes of. * \return A textual representation of the shapes. For example: `float32[2], int64[1, 2]`. */ -String ShapeString(const std::vector& shapes); -/*! \brief String representation of shape encoded as an NDArray - * \param shape NDArray containing the shape. +ffi::String ShapeString(const std::vector& shapes); +/*! \brief ffi::String representation of shape encoded as an Tensor + * \param shape Tensor containing the shape. * \param dtype The dtype of the shape. * \return A textual representation of the shape. For example: `float32[2]`. */ -String ShapeString(NDArray shape, DLDataType dtype); -/*! \brief String representation of a shape encoded as a vector +ffi::String ShapeString(Tensor shape, DLDataType dtype); +/*! \brief ffi::String representation of a shape encoded as a vector * \param shape Shape as a vector of integers. * \param dtype The dtype of the shape. * \return A textual representation of the shape. For example: `float32[2]`. */ -String ShapeString(const std::vector& shape, DLDataType dtype); +ffi::String ShapeString(const std::vector& shape, DLDataType dtype); /*! \brief Collect performance information of a function execution. Usually * used with a compiled PrimFunc (via tvm.compile). @@ -536,11 +524,12 @@ String ShapeString(const std::vector& shape, DLDataType dtype); * \param collectors List of different * ways to collect metrics. See MetricCollector. * \returns A ffi::Function which takes the same arguments as the `mod[func_name]` - * and returns performance metrics as a `Map` where + * and returns performance metrics as a `ffi::Map` where * values can be `CountNode`, `DurationNode`, `PercentNode`. */ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, - int device_id, int warmup_iters, Array collectors); + int device_id, int warmup_iters, + ffi::Array collectors); /*! * \brief Wrap a timer function to measure the time cost of a given packed function. diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index 2cfd1de44dde..c8e9d3c435f0 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include namespace dmlc { namespace serializer { diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/tensor.h similarity index 74% rename from include/tvm/runtime/ndarray.h rename to include/tvm/runtime/tensor.h index 6eebe49ff135..615cfd8cccfe 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/tensor.h @@ -18,14 +18,14 @@ */ /*! - * \file tvm/runtime/ndarray.h - * \brief A device-independent managed NDArray abstraction. + * \file tvm/runtime/tensor.h + * \brief A device-independent managed Tensor abstraction. */ -#ifndef TVM_RUNTIME_NDARRAY_H_ -#define TVM_RUNTIME_NDARRAY_H_ +#ifndef TVM_RUNTIME_TENSOR_H_ +#define TVM_RUNTIME_TENSOR_H_ -#include #include +#include #include #include #include @@ -47,32 +47,34 @@ using ffi::IsAligned; using ffi::IsContiguous; /*! - * \brief Managed NDArray. + * \brief Managed Tensor. * The array is backed by reference counted blocks. */ -class NDArray : public tvm::ffi::NDArray { +class Tensor : public tvm::ffi::Tensor { public: - using Container = ffi::NDArrayObj; - NDArray() = default; + using Container = ffi::TensorObj; + Tensor() = default; /*! * \brief constructor. * \param data ObjectPtr to the data container. */ - explicit NDArray(ObjectPtr data) : tvm::ffi::NDArray(data) {} - NDArray(ffi::NDArray&& other) : tvm::ffi::NDArray(std::move(other)) {} // NOLINT(*) - NDArray(const ffi::NDArray& other) : tvm::ffi::NDArray(other) {} // NOLINT(*) + explicit Tensor(ObjectPtr data) : tvm::ffi::Tensor(data) {} + explicit Tensor(ffi::UnsafeInit tag) : tvm::ffi::Tensor(tag) {} + Tensor(ffi::Tensor&& other) : tvm::ffi::Tensor(std::move(other)) {} // NOLINT(*) + Tensor(const ffi::Tensor& other) : tvm::ffi::Tensor(other) {} // NOLINT(*) - ffi::Shape Shape() const { return this->shape(); } + ffi::ShapeView Shape() const { return this->shape(); } runtime::DataType DataType() const { return runtime::DataType(this->dtype()); } // DLPack handling - static NDArray FromDLPack(DLManagedTensor* tensor) { - return tvm::ffi::NDArray::FromDLPack(tensor, kAllocAlignment, true); + static Tensor FromDLPack(DLManagedTensor* tensor) { + return tvm::ffi::Tensor::FromDLPack(tensor, kAllocAlignment, true); } - static NDArray FromDLPackVersioned(DLManagedTensorVersioned* tensor) { - return tvm::ffi::NDArray::FromDLPackVersioned(tensor, kAllocAlignment, true); + static Tensor FromDLPackVersioned(DLManagedTensorVersioned* tensor) { + return tvm::ffi::Tensor::FromDLPackVersioned(tensor, kAllocAlignment, true); } + inline const DLTensor* operator->() const { return this->get(); } /*! * \brief Copy data content from another array. * \param other The source array to be copied from. @@ -80,12 +82,12 @@ class NDArray : public tvm::ffi::NDArray { * TVMSynchronize is necessary. */ inline void CopyFrom(const DLTensor* other); - inline void CopyFrom(const NDArray& other); + inline void CopyFrom(const Tensor& other); /*! * \brief Copy data content from a byte buffer. * \param data The source bytes to be copied from. * \param nbytes The size of the buffer in bytes - * Must be equal to the size of the NDArray. + * Must be equal to the size of the Tensor. * \note The copy always triggers a TVMSynchronize. */ TVM_DLL void CopyFromBytes(const void* data, size_t nbytes); @@ -96,12 +98,12 @@ class NDArray : public tvm::ffi::NDArray { * TVMSynchronize is necessary. */ inline void CopyTo(DLTensor* other) const; - inline void CopyTo(const NDArray& other) const; + inline void CopyTo(const Tensor& other) const; /*! * \brief Copy data content into another array. * \param data The source bytes to be copied from. * \param nbytes The size of the data buffer. - * Must be equal to the size of the NDArray. + * Must be equal to the size of the Tensor. * \note The copy always triggers a TVMSynchronize. */ TVM_DLL void CopyToBytes(void* data, size_t nbytes) const; @@ -112,27 +114,28 @@ class NDArray : public tvm::ffi::NDArray { * \return The array under another device. * \note The copy always triggers a TVMSynchronize. */ - TVM_DLL NDArray CopyTo(const Device& dev, Optional mem_scope = std::nullopt) const; + TVM_DLL Tensor CopyTo(const Device& dev, + ffi::Optional mem_scope = std::nullopt) const; /*! - * \brief Load NDArray from stream + * \brief Load Tensor from stream * \param stream The input data stream * \return Whether load is successful */ inline bool Load(dmlc::Stream* stream); /*! - * \brief Save NDArray to stream + * \brief Save Tensor to stream * \param stream The output data stream */ inline void Save(dmlc::Stream* stream) const; /*! - * \brief Create a NDArray that shares the data memory with the current one. + * \brief Create a Tensor that shares the data memory with the current one. * * \param shape The shape of the new array. * * \param dtype The data type of the new array. * - * \param relative_byte_offset The offset of the output NDArray, + * \param relative_byte_offset The offset of the output Tensor, * relative to the current byte offset. * * By default, the offset of the view is the same as the offset @@ -145,18 +148,18 @@ class NDArray : public tvm::ffi::NDArray { * outside the bounds of the current array, this function will * raise an exception. */ - TVM_DLL NDArray CreateView(ffi::Shape shape, DLDataType dtype, - uint64_t relative_byte_offset = 0) const; + TVM_DLL Tensor CreateView(ffi::Shape shape, DLDataType dtype, + uint64_t relative_byte_offset = 0) const; /*! - * \brief Create an empty NDArray. + * \brief Create an empty Tensor. * \param shape The shape of the new array. * \param dtype The data type of the new array. * \param dev The device of the array. * \param mem_scope The memory scope of the array. * \return The created Array */ - TVM_DLL static NDArray Empty(ffi::Shape shape, DLDataType dtype, Device dev, - Optional mem_scope = std::nullopt); + TVM_DLL static Tensor Empty(ffi::Shape shape, DLDataType dtype, Device dev, + ffi::Optional mem_scope = std::nullopt); /*! * \brief Function to copy data from one array to another. * \param from The source array. @@ -175,6 +178,16 @@ class NDArray : public tvm::ffi::NDArray { */ TVM_DLL static void CopyToBytes(const DLTensor* from, void* to, size_t nbytes, TVMStreamHandle stream = nullptr); + + /*! + * \brief Function to copy data from one array to a byte buffer. + * \param from The source array. + * \param to The target byte buffer. + * \param nbytes The size of the data buffer. + * \param stream The stream used in copy. + */ + TVM_DLL static void CopyFromBytes(const DLTensor* to, void* from, size_t nbytes, + TVMStreamHandle stream = nullptr); }; /*! @@ -184,33 +197,33 @@ class NDArray : public tvm::ffi::NDArray { */ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); -inline void NDArray::CopyFrom(const DLTensor* other) { +inline void Tensor::CopyFrom(const DLTensor* other) { ICHECK(data_ != nullptr); CopyFromTo(other, get_mutable()); } -inline void NDArray::CopyFrom(const NDArray& other) { +inline void Tensor::CopyFrom(const Tensor& other) { ICHECK(data_ != nullptr); ICHECK(other.data_ != nullptr); CopyFromTo(other.get_mutable(), get_mutable()); } -inline void NDArray::CopyTo(DLTensor* other) const { +inline void Tensor::CopyTo(DLTensor* other) const { ICHECK(data_ != nullptr); CopyFromTo(get_mutable(), other); } -inline void NDArray::CopyTo(const NDArray& other) const { +inline void Tensor::CopyTo(const Tensor& other) const { ICHECK(data_ != nullptr); ICHECK(other.data_ != nullptr); CopyFromTo(get_mutable(), other.get_mutable()); } -/*! \brief Magic number for NDArray file */ -constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; +/*! \brief Magic number for Tensor file */ +constexpr uint64_t kTVMTensorMagic = 0xDD5E40F096B4A13F; inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { - uint64_t header = kTVMNDArrayMagic, reserved = 0; + uint64_t header = kTVMTensorMagic, reserved = 0; strm->Write(header); strm->Write(reserved); // Always save data as CPU context @@ -239,12 +252,12 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { strm->Write(data_byte_size); if (DMLC_IO_NO_ENDIAN_SWAP && tensor->device.device_type == kDLCPU && - tensor->strides == nullptr && tensor->byte_offset == 0) { + ffi::IsContiguous(*tensor) && tensor->byte_offset == 0) { // quick path strm->Write(tensor->data, data_byte_size); } else { std::vector bytes(data_byte_size); - NDArray::CopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size); + Tensor::CopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size); if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); } @@ -253,13 +266,13 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { return true; } -inline void NDArray::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); } +inline void Tensor::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); } -inline bool NDArray::Load(dmlc::Stream* strm) { +inline bool Tensor::Load(dmlc::Stream* strm) { uint64_t header, reserved; ICHECK(strm->Read(&header)) << "Invalid DLTensor file format"; ICHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; - ICHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format"; + ICHECK(header == kTVMTensorMagic) << "Invalid DLTensor file format"; Device dev; int ndim; DLDataType dtype; @@ -271,7 +284,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) { if (ndim != 0) { ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } - NDArray ret = NDArray::Empty(ffi::Shape(shape), dtype, dev); + Tensor ret = Tensor::Empty(ffi::Shape(shape), dtype, dev); int64_t num_elems = 1; int elem_bytes = (ret->dtype.bits + 7) / 8; for (int i = 0; i < ret->ndim; ++i) { @@ -328,4 +341,4 @@ struct equal_to { }; } // namespace std -#endif // TVM_RUNTIME_NDARRAY_H_ +#endif // TVM_RUNTIME_TENSOR_H_ diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 6dfc2b0c50be..37488ff31f52 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -113,12 +113,12 @@ class VMExecutable : public ffi::ModuleObj { * \brief Print the instructions as text format. * \return The text format of the instructions. */ - String AsText() const; + ffi::String AsText() const; /*! * \brief Print the instructions as python program. * \return The python program of the instructions, represented by a string. */ - String AsPython() const; + ffi::String AsPython() const; /*! * \brief Write the VMExecutable to the binary stream in serialized form. * \return The binary bytes that save the executable to. @@ -135,19 +135,19 @@ class VMExecutable : public ffi::ModuleObj { * \param file_name The name of the file to write the serialized data to. * \param format The target format of the saved file. */ - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; /*! \brief Create a Relax virtual machine and load `this` as the executable. */ ffi::Module VMLoadExecutable() const; /*! \brief Create a Relax virtual machine with profiler and load `this` as the executable. */ ffi::Module VMProfilerLoadExecutable() const; /*! \brief Check if the VMExecutable contains a specific function. */ - bool HasFunction(const String& name) const; + bool HasFunction(const ffi::String& name) const; /*! * \brief Load VMExecutable from the file. * \param file_name The path of the file that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ - static ffi::Module LoadFromFile(const String& file_name); + static ffi::Module LoadFromFile(const ffi::String& file_name); /*! \brief The virtual machine's function table. */ std::vector func_table; diff --git a/include/tvm/runtime/vm/ndarray_cache_support.h b/include/tvm/runtime/vm/tensor_cache_support.h similarity index 68% rename from include/tvm/runtime/vm/ndarray_cache_support.h rename to include/tvm/runtime/vm/tensor_cache_support.h index 3ab08df04389..c489064792e7 100644 --- a/include/tvm/runtime/vm/ndarray_cache_support.h +++ b/include/tvm/runtime/vm/tensor_cache_support.h @@ -16,12 +16,12 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_RUNTIME_VM_NDARRAY_CACHE_SUPPORT_H_ -#define TVM_RUNTIME_VM_NDARRAY_CACHE_SUPPORT_H_ +#ifndef TVM_RUNTIME_VM_TENSOR_CACHE_SUPPORT_H_ +#define TVM_RUNTIME_VM_TENSOR_CACHE_SUPPORT_H_ #include #include -#include +#include #include #include @@ -32,10 +32,10 @@ namespace runtime { namespace vm { /*! - * \brief Metadata for NDArray cache, which by default, is named as "ndarray-cache.json". + * \brief Metadata for Tensor cache, which by default, is named as "tensor-cache.json". */ -struct NDArrayCacheMetadata { - /*! \brief Each shard of NDArray cache, which by default, is named as "params_shard_x.bin". */ +struct TensorCacheMetadata { + /*! \brief Each shard of Tensor cache, which by default, is named as "params_shard_x.bin". */ struct FileRecord { /*! \brief Metadata of each parameter */ struct ParamRecord { @@ -46,8 +46,8 @@ struct NDArrayCacheMetadata { * \param staging_buffer The buffer to be used to avoid extra OpenCL copies. Pass in a nullptr * in other cases */ - TVM_DLL NDArray Load(Device device, const std::string* raw_data, - Optional* staging_buffer = nullptr) const; + TVM_DLL Tensor Load(Device device, const std::string* raw_data, + ffi::Optional* staging_buffer = nullptr) const; /*! \brief Name of the parameter */ std::string name; @@ -64,10 +64,10 @@ struct NDArrayCacheMetadata { }; /*! \brief Load a FileRecord into memory */ - TVM_DLL Array Load(Device device, // - const std::string& path_prefix, // - std::string* raw_data_buffer, // - Optional* staging_buffer = nullptr) const; + TVM_DLL ffi::Array Load(Device device, // + const std::string& path_prefix, // + std::string* raw_data_buffer, // + ffi::Optional* staging_buffer = nullptr) const; /*! \brief Relative path to the bin file */ std::string data_path; @@ -78,19 +78,19 @@ struct NDArrayCacheMetadata { /*! \brief The parameters in the file */ std::vector records; }; - /*! \brief The files in the NDArray cache */ + /*! \brief The files in the Tensor cache */ std::vector records; - /*! \brief The path to the `ndarray-cache.json` file */ + /*! \brief The path to the `tensor-cache.json` file */ std::string path; /*! \brief Load the metadata from a specific directory */ - TVM_DLL static NDArrayCacheMetadata Load(const std::string& path); + TVM_DLL static TensorCacheMetadata Load(const std::string& path); /*! \brief Load the metadata from a given JSON string */ - static NDArrayCacheMetadata LoadFromStr(const std::string& json_str, const std::string& path); + static TensorCacheMetadata LoadFromStr(const std::string& json_str, const std::string& path); }; } // namespace vm } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_VM_NDARRAY_CACHE_SUPPORT_H_ +#endif // TVM_RUNTIME_VM_TENSOR_CACHE_SUPPORT_H_ diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 3a0b7418b946..335d77f1966d 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -68,7 +68,7 @@ class VMClosureObj : public Object { * \brief The function name. The function could be any * function object that is compatible to the VM runtime. */ - String func_name; + ffi::String func_name; /*! * \brief The implementation of the Closure. @@ -77,16 +77,14 @@ class VMClosureObj : public Object { * the same arguments as the normal function call. */ ffi::Function impl; - - static constexpr const char* _type_key = "relax.vm.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.vm.Closure", VMClosureObj, Object); }; /*! \brief reference to closure. */ class VMClosure : public ObjectRef { public: - VMClosure(String func_name, ffi::Function impl); - TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, ObjectRef, VMClosureObj); + VMClosure(ffi::String func_name, ffi::Function impl); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VMClosure, ObjectRef, VMClosureObj); /*! * \brief Create another ffi::Function with last arguments already bound to last_args. @@ -109,14 +107,13 @@ class VMClosure : public ObjectRef { */ class VMExtensionNode : public Object { protected: - static constexpr const char* _type_key = "runtime.VMExtension"; - TVM_DECLARE_BASE_OBJECT_INFO(VMExtensionNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("runtime.VMExtension", VMExtensionNode, Object); }; /*! \brief Managed reference to VM extension. */ class VMExtension : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(VMExtension, ObjectRef, VMExtensionNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VMExtension, ObjectRef, VMExtensionNode); }; /*! @@ -149,7 +146,7 @@ class VirtualMachine : public ffi::ModuleObj { * \param func_name The name of the function. * \return The closure */ - virtual VMClosure GetClosure(const String& func_name) = 0; + virtual VMClosure GetClosure(const ffi::String& func_name) = 0; /*! * \brief Invoke closure or packed function using ffi::Function convention. * \param closure_or_packedfunc A VM closure or a packed_func. diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 0c9e54eaf113..8c5209982b10 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -73,8 +73,9 @@ class IRBuilderFrameNode : public runtime::Object { // `callbacks` is not registered as it's not visited. } - static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame"; - TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.IRBuilderFrame", IRBuilderFrameNode, + runtime::Object); public: /*! \brief Default destructor. */ @@ -102,11 +103,12 @@ class IRBuilderFrameNode : public runtime::Object { */ class IRBuilderFrame : public runtime::ObjectRef { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IRBuilderFrame, ObjectRef, IRBuilderFrameNode); protected: /*! \brief Disallow direct construction of this object. */ IRBuilderFrame() = default; + explicit IRBuilderFrame(ObjectPtr data) : ObjectRef(data) {} public: /*! @@ -157,9 +159,9 @@ class IRBuilderFrame : public runtime::ObjectRef { class IRBuilderNode : public runtime::Object { public: /*! \brief A stack of context frames in the IRBuilder */ - Array frames; + ffi::Array frames; /*! \brief The outcome of IR construction */ - Optional result; + ffi::Optional result; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -168,8 +170,8 @@ class IRBuilderNode : public runtime::Object { .def_ro("result", &IRBuilderNode::result); } - static constexpr const char* _type_key = "script.ir_builder.IRBuilder"; - TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.IRBuilder", IRBuilderNode, runtime::Object); public: /*! @@ -178,7 +180,7 @@ class IRBuilderNode : public runtime::Object { * \return The frame if found, otherwise std::nullopt. */ template - inline Optional FindFrame() const; + inline ffi::Optional FindFrame() const; /*! * \brief Get the frame on top of the stack `this->frames` if its type is `TFrame`. * \tparam TFrame The assumed type of the last frame on stack. @@ -186,7 +188,7 @@ class IRBuilderNode : public runtime::Object { * Otherwise std::nullopt. */ template - inline Optional GetLastFrame() const; + inline ffi::Optional GetLastFrame() const; /*! * \brief Get the IR being constructed. * \tparam TObjectRef The type of the IR being constructed. @@ -204,7 +206,7 @@ class IRBuilder : public runtime::ObjectRef { public: /*! \brief Creates an IRBuilder. */ IRBuilder(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IRBuilder, ObjectRef, IRBuilderNode); public: /*! @@ -249,7 +251,7 @@ class IRBuilder : public runtime::ObjectRef { * \param obj The object to name. */ template - inline static TObjectRef Name(String name, TObjectRef obj); + inline static TObjectRef Name(ffi::String name, TObjectRef obj); }; ////////////////////////////// Details ////////////////////////////// @@ -258,32 +260,32 @@ namespace details { class Namer { public: - using FType = NodeFunctor; + using FType = NodeFunctor; static FType& vtable(); - static void Name(ObjectRef node, String name); + static void Name(ObjectRef node, ffi::String name); }; } // namespace details template -inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) { +inline TObjectRef IRBuilder::Name(ffi::String name, TObjectRef obj) { details::Namer::Name(obj, name); return Downcast(obj); } template -inline Optional IRBuilderNode::FindFrame() const { +inline ffi::Optional IRBuilderNode::FindFrame() const { using TFrameNode = typename TFrame::ContainerType; for (auto it = frames.rbegin(); it != frames.rend(); ++it) { if (const TFrameNode* p = (*it).template as()) { - return GetRef(p); + return ffi::GetRef(p); } } return std::nullopt; } template -inline Optional IRBuilderNode::GetLastFrame() const { +inline ffi::Optional IRBuilderNode::GetLastFrame() const { using TFrameNode = typename TFrame::ContainerType; if (!frames.empty() && frames.back()->IsInstance()) { return Downcast(frames.back()); @@ -297,7 +299,7 @@ inline TObjectRef IRBuilderNode::Get() const { CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet"; const auto* n = result.as(); CHECK(n != nullptr) << "TypeError: IRBuilder result is not of type: " << TObject::_type_key; - return GetRef(n); + return ffi::GetRef(n); } } // namespace ir_builder diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index b009338cf0d4..53efc9df7f2b 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -41,16 +41,16 @@ namespace ir { class IRModuleFrameNode : public IRBuilderFrameNode { public: /*! \brief A map from string names to global variables that ensures global uniqueness. */ - Map global_var_map; + ffi::Map global_var_map; /*! * \brief A map from GlobalVar to all global functions. * \note Only defined functions are in the map, while declared functions are not included. */ - Map functions; + ffi::Map functions; /*! \brief IRModule's attributes. */ - Map attrs; + ffi::Map attrs; /*! \brief IRModule's global_infos */ - Map> global_infos; + ffi::Map> global_infos; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -60,9 +60,8 @@ class IRModuleFrameNode : public IRBuilderFrameNode { .def_ro("attrs", &IRModuleFrameNode::attrs) .def_ro("global_infos", &IRModuleFrameNode::global_infos); } - - static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.IRModuleFrame", IRModuleFrameNode, + IRBuilderFrameNode); public: void ExitWithScope() final; @@ -75,8 +74,10 @@ class IRModuleFrameNode : public IRBuilderFrameNode { */ class IRModuleFrame : public IRBuilderFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame, - IRModuleFrameNode); + explicit IRModuleFrame(ObjectPtr data) : IRBuilderFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IRModuleFrame, IRBuilderFrame, IRModuleFrameNode); }; } // namespace ir diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index 49bdcf60e6fb..9fe3d7e1ac65 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -45,14 +45,14 @@ TVM_DLL IRModuleFrame IRModule(); * (i.e. func params and func return type/shape). * \return The corresponding GlobalVar. */ -TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); +TVM_DLL GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signature); /*! * \brief Define the function which is declared before. * \param func_name The function unique name. * \param func The given function implementation */ -TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); +TVM_DLL void DefFunction(const ffi::String& func_name, const BaseFunc& func); } // namespace ir } // namespace ir_builder diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index f729d19a14dd..5d6bcc8a2c2f 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -26,6 +26,8 @@ #include #include +#include + namespace tvm { namespace script { namespace ir_builder { @@ -38,14 +40,17 @@ class RelaxFrameNode : public IRBuilderFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.ir_builder.relax.RelaxFrame"; - TVM_DECLARE_BASE_OBJECT_INFO(RelaxFrameNode, IRBuilderFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.relax.RelaxFrame", RelaxFrameNode, + IRBuilderFrameNode); }; class RelaxFrame : public IRBuilderFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); + explicit RelaxFrame(ObjectPtr data) : IRBuilderFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RelaxFrame, IRBuilderFrame, RelaxFrameNode); protected: RelaxFrame() = default; @@ -57,9 +62,9 @@ class RelaxFrame : public IRBuilderFrame { class SeqExprFrameNode : public RelaxFrameNode { public: /*! \brief The binding blocks inside the frame. */ - Array binding_blocks; + ffi::Array binding_blocks; /*! \brief The frame output expr. `std::nullopt` when undefined. */ - Optional output; + ffi::Optional output; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -67,9 +72,8 @@ class SeqExprFrameNode : public RelaxFrameNode { .def_ro("binding_blocks", &SeqExprFrameNode::binding_blocks) .def_ro("output", &SeqExprFrameNode::output); } - - static constexpr const char* _type_key = "script.ir_builder.relax.SeqExprFrame"; - TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.relax.SeqExprFrame", SeqExprFrameNode, + RelaxFrameNode); public: void EnterWithScope() override; @@ -78,7 +82,10 @@ class SeqExprFrameNode : public RelaxFrameNode { class SeqExprFrame : public RelaxFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); + explicit SeqExprFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SeqExprFrame, RelaxFrame, SeqExprFrameNode); }; /*! \brief The ir_builder frame for the relax function. */ @@ -89,9 +96,9 @@ class FunctionFrameNode : public SeqExprFrameNode { * \note The name will not be specified in constructor, so it is "Optional", * However, we must specify the name by `R.func_name` before exit this frame. */ - Optional name; + ffi::Optional name; /*! \brief The function params. */ - Array params; + ffi::Array params; /*! * \brief The function return struct info. * \note Usually the function return type can be deduced by the function body. @@ -101,13 +108,13 @@ class FunctionFrameNode : public SeqExprFrameNode { * if we ret_struct_info is base of body.struct_info. If not, we will * take the specified `ret_struct_info`. */ - Optional ret_struct_info; + ffi::Optional ret_struct_info; /*! \brief Whether the function is annotated as pure */ - Optional is_pure; + ffi::Optional is_pure; /*! \brief Whether the function is annotated as private */ - Optional is_private; + ffi::Optional is_private; /*! \brief The function attributes. */ - Map attrs; + ffi::Map attrs; /*! \brief The block builder to create Relax function. */ tvm::relax::BlockBuilder block_builder; @@ -123,9 +130,8 @@ class FunctionFrameNode : public SeqExprFrameNode { .def_ro("output", &FunctionFrameNode::output); // `block_builder` is not registered as it's not visited. } - - static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.FunctionFrame", FunctionFrameNode, + SeqExprFrameNode); public: void EnterWithScope() final; @@ -134,7 +140,10 @@ class FunctionFrameNode : public SeqExprFrameNode { class FunctionFrame : public SeqExprFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); + explicit FunctionFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FunctionFrame, SeqExprFrame, FunctionFrameNode); }; /*! \brief The ir_builder frame for relax binding blocks. */ @@ -143,7 +152,7 @@ class BlockFrameNode : public RelaxFrameNode { /*! \brief The flag that indicates whether the block is a dataflow block. */ bool is_dataflow; /*! \brief The variables emitted in this block. */ - Array emitted_vars; + ffi::Array emitted_vars; /*! * \brief A boolean indicating if the dataflow block is ended of construction. * If it is true, any new binding trying to be emitted into this block will cause an error. @@ -154,7 +163,7 @@ class BlockFrameNode : public RelaxFrameNode { * \brief The output vars of the dataflow block. * \note Only used for a dataflow block. */ - Array output_vars; + ffi::Array output_vars; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -164,9 +173,8 @@ class BlockFrameNode : public RelaxFrameNode { .def_ro("output_vars", &BlockFrameNode::output_vars); // `block_ended` is not registered as it's not visited. } - - static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.BlockFrame", BlockFrameNode, + RelaxFrameNode); public: void EnterWithScope() final; @@ -175,7 +183,10 @@ class BlockFrameNode : public RelaxFrameNode { class BlockFrame : public RelaxFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); + explicit BlockFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockFrame, RelaxFrame, BlockFrameNode); }; /*! @@ -188,13 +199,13 @@ class IfFrameNode : public RelaxFrameNode { /*! \brief The condition of the if statement. */ tvm::relax::Expr condition; /*! \brief The Bindings in the true branch. */ - Optional then_expr; + ffi::Optional then_expr; /*! \brief The Bindings in the false branch. */ - Optional else_expr; + ffi::Optional else_expr; /*! \brief The Binding var. */ tvm::relax::Var var; /*! \brief The binding var name. */ - String var_name; + ffi::String var_name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -205,9 +216,7 @@ class IfFrameNode : public RelaxFrameNode { .def_ro("var", &IfFrameNode::var) .def_ro("var_name", &IfFrameNode::var_name); } - - static constexpr const char* _type_key = "script.ir_builder.relax.IfFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, RelaxFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.IfFrame", IfFrameNode, RelaxFrameNode); public: /*! @@ -229,7 +238,10 @@ class IfFrameNode : public RelaxFrameNode { */ class IfFrame : public RelaxFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); + explicit IfFrame(ObjectPtr data) : RelaxFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IfFrame, RelaxFrame, IfFrameNode); }; /*! @@ -243,9 +255,8 @@ class ThenFrameNode : public SeqExprFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.ir_builder.relax.ThenFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, SeqExprFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.ThenFrame", ThenFrameNode, + SeqExprFrameNode); public: /*! @@ -267,7 +278,10 @@ class ThenFrameNode : public SeqExprFrameNode { */ class ThenFrame : public SeqExprFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); + explicit ThenFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ThenFrame, SeqExprFrame, ThenFrameNode); }; /*! @@ -281,9 +295,8 @@ class ElseFrameNode : public SeqExprFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.ir_builder.relax.ElseFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, SeqExprFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.ElseFrame", ElseFrameNode, + SeqExprFrameNode); public: /*! @@ -305,7 +318,10 @@ class ElseFrameNode : public SeqExprFrameNode { */ class ElseFrame : public SeqExprFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); + explicit ElseFrame(ObjectPtr data) : SeqExprFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ElseFrame, SeqExprFrame, ElseFrameNode); }; } // namespace relax diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 49bc1a2851d3..80b70daffd0b 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -45,19 +45,19 @@ TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private); * \param struct_info The struct_info of the parameter. * \return The created function parameter var. */ -TVM_DLL tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info); +TVM_DLL tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struct_info); /*! * \brief Specify the name of the last function frame. * \param name The function name. */ -TVM_DLL void FuncName(const String& name); +TVM_DLL void FuncName(const ffi::String& name); /*! * \brief Specify the attrs of the last function frame. * \param attrs The function attrs. */ -TVM_DLL void FuncAttrs(Map attrs); +TVM_DLL void FuncAttrs(ffi::Map attrs); /*! * \brief Specify the return struct info of the last function frame. @@ -89,7 +89,7 @@ TVM_DLL BlockFrame Dataflow(); * \brief Expose the dataflow block output variables as global ones * \param vars The output variables of a dataflow block */ -TVM_DLL void DataflowBlockOutput(const Array& vars); +TVM_DLL void DataflowBlockOutput(const ffi::Array& vars); ////////////////////////////// Bindings //////////////////////////////// @@ -101,7 +101,7 @@ TVM_DLL void DataflowBlockOutput(const Array& vars); */ TVM_DLL tvm::relax::Var Emit( const tvm::relax::Expr& value, - const Optional& annotate_struct_info = std::nullopt); + const ffi::Optional& annotate_struct_info = std::nullopt); /*! * \brief Emit a match_cast binding to the last binding block frame. diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 1e205edc43f3..4be475c09419 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -23,6 +23,8 @@ #include #include +#include + namespace tvm { namespace script { namespace ir_builder { @@ -36,15 +38,13 @@ namespace tir { class TIRFrameNode : public IRBuilderFrameNode { public: /*! \brief The Stmt within in this frame. */ - Array stmts; + ffi::Array stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("stmts", &TIRFrameNode::stmts); } - - static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame"; - TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.tir.TIRFrame", TIRFrameNode, IRBuilderFrameNode); }; /*! @@ -54,10 +54,11 @@ class TIRFrameNode : public IRBuilderFrameNode { */ class TIRFrame : public IRBuilderFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TIRFrame, IRBuilderFrame, TIRFrameNode); protected: TIRFrame() = default; + explicit TIRFrame(ObjectPtr data) : IRBuilderFrame(data) {} }; /*! @@ -68,21 +69,21 @@ class TIRFrame : public IRBuilderFrame { class PrimFuncFrameNode : public TIRFrameNode { public: /*! \brief The name of the block. */ - Optional name; + ffi::Optional name; /*! \brief Function parameters. */ - Array args; + ffi::Array args; /*! \brief Whether the PrimFunc is annotated as private. */ bool is_private; /*! \brief The return type of the function. */ - Optional ret_type; + ffi::Optional ret_type; /*! \brief Maps some parameters to specific Buffer data structures. */ - Map buffer_map; + ffi::Map buffer_map; /*! \brief Additional attributes storing the meta-data */ - Map attrs; + ffi::Map attrs; /*! \brief The variable map bound to thread env. */ - Map env_threads; + ffi::Map env_threads; /*! \brief The buffer allocated in root block. */ - Array root_alloc_buffers; + ffi::Array root_alloc_buffers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -96,9 +97,8 @@ class PrimFuncFrameNode : public TIRFrameNode { .def_ro("env_threads", &PrimFuncFrameNode::env_threads) .def_ro("root_alloc_buffers", &PrimFuncFrameNode::root_alloc_buffers); } - - static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.PrimFuncFrame", PrimFuncFrameNode, + TIRFrameNode); public: /*! @@ -115,7 +115,11 @@ class PrimFuncFrameNode : public TIRFrameNode { */ class PrimFuncFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); + explicit PrimFuncFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; /*! @@ -126,28 +130,28 @@ class PrimFuncFrame : public TIRFrame { class BlockFrameNode : public TIRFrameNode { public: /*! \brief The name of the block. */ - String name; + ffi::String name; /*! \brief The variables of the block. */ - Array iter_vars; + ffi::Array iter_vars; /*! \brief The read buffer regions of the block. */ - Optional> reads; + ffi::Optional> reads; /*! \brief The write buffer regions of the block. */ - Optional> writes; + ffi::Optional> writes; /*! \brief The init statement of the bolck. */ - Optional init; + ffi::Optional init; /*! \brief The buffer allocated in the block. */ - Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The match buffer regions. */ - Array match_buffers; + ffi::Array match_buffers; /*! \brief The annotation of the block. */ - Optional> annotations; + ffi::Optional> annotations; /*! \brief The corresponding values of the iter vars. */ - Array iter_values; + ffi::Array iter_values; /*! * \brief The predicate of the block realization, the block will only be executed when the * predicate is true. */ - Optional predicate; + ffi::Optional predicate; /*! \brief The flag whether to construct BlockRealize or Block. */ bool no_realize; @@ -166,9 +170,8 @@ class BlockFrameNode : public TIRFrameNode { .def_ro("predicate", &BlockFrameNode::predicate) .def_ro("no_realize", &BlockFrameNode::no_realize); } - - static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BlockFrame", BlockFrameNode, + TIRFrameNode); public: /*! @@ -186,7 +189,11 @@ class BlockFrameNode : public TIRFrameNode { class BlockFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); + explicit BlockFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockFrame, TIRFrame, BlockFrameNode); }; /*! @@ -200,9 +207,8 @@ class BlockInitFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.ir_builder.tir.BlockInitFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.BlockInitFrame", BlockInitFrameNode, + TIRFrameNode); public: /*! @@ -224,7 +230,11 @@ class BlockInitFrameNode : public TIRFrameNode { */ class BlockInitFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode); + explicit BlockInitFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockInitFrame, TIRFrame, BlockInitFrameNode); }; /*! @@ -242,11 +252,14 @@ class ForFrameNode : public TIRFrameNode { * \return A stmt, the loop nest */ using FMakeForLoop = ffi::TypedFunction loop_vars, Array loop_extents, tvm::tir::Stmt loop_body)>; + ffi::Array loop_vars, ffi::Array loop_extents, + ffi::Array> loop_steps, tvm::tir::Stmt loop_body)>; /*! \brief The loop variable. */ - Array vars; + ffi::Array vars; /*! \brief The domains of iteration. */ - Array doms; + ffi::Array doms; + /*! \brief The optional steps of iteration. */ + ffi::Array> steps; /*! \brief The for loop generating function. */ FMakeForLoop f_make_for_loop; @@ -257,9 +270,7 @@ class ForFrameNode : public TIRFrameNode { .def_ro("doms", &ForFrameNode::doms); // `f_make_for_loop` is not registered as it's not visited. } - - static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.ForFrame", ForFrameNode, TIRFrameNode); public: /*! @@ -276,7 +287,11 @@ class ForFrameNode : public TIRFrameNode { */ class ForFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); + explicit ForFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ForFrame, TIRFrame, ForFrameNode); }; /*! @@ -298,9 +313,8 @@ class AssertFrameNode : public TIRFrameNode { .def_ro("condition", &AssertFrameNode::condition) .def_ro("message", &AssertFrameNode::message); } - - static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.AssertFrame", AssertFrameNode, + TIRFrameNode); public: /*! @@ -317,7 +331,11 @@ class AssertFrameNode : public TIRFrameNode { */ class AssertFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode); + explicit AssertFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AssertFrame, TIRFrame, AssertFrameNode); }; /*! @@ -332,15 +350,14 @@ class LetFrameNode : public TIRFrameNode { /*! \brief The value we bind var to */ PrimExpr value; + static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("var", &LetFrameNode::var) - .def_ro("value", &LetFrameNode::value); + .def_rw("value", &LetFrameNode::value); } - - static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LetFrame", LetFrameNode, TIRFrameNode); public: /*! @@ -357,7 +374,11 @@ class LetFrameNode : public TIRFrameNode { */ class LetFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); + explicit LetFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LetFrame, TIRFrame, LetFrameNode); }; /*! @@ -369,7 +390,7 @@ class LaunchThreadFrameNode : public TIRFrameNode { /*! \brief The extent of environment thread. */ PrimExpr extent; /*! \brief The attribute key, could be either virtual_thread or thread_extent. */ - String attr_key; + ffi::String attr_key; /*! \brief The iteration variable. */ tvm::tir::IterVar iter_var; @@ -380,9 +401,8 @@ class LaunchThreadFrameNode : public TIRFrameNode { .def_ro("attr_key", &LaunchThreadFrameNode::attr_key) .def_ro("iter_var", &LaunchThreadFrameNode::iter_var); } - - static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LaunchThreadFrame", + LaunchThreadFrameNode, TIRFrameNode); public: /*! @@ -399,8 +419,11 @@ class LaunchThreadFrameNode : public TIRFrameNode { */ class LaunchThreadFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame, - LaunchThreadFrameNode); + explicit LaunchThreadFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LaunchThreadFrame, TIRFrame, LaunchThreadFrameNode); }; /*! @@ -413,7 +436,7 @@ class RealizeFrameNode : public TIRFrameNode { /*! \brief The region of buffer access. */ tvm::tir::BufferRegion buffer_slice; /*! \brief The storage scope associated with this realization. */ - String storage_scope; + ffi::String storage_scope; /*! \brief The condition expression. */ PrimExpr condition; @@ -424,9 +447,8 @@ class RealizeFrameNode : public TIRFrameNode { .def_ro("storage_scope", &RealizeFrameNode::storage_scope) .def_ro("condition", &RealizeFrameNode::condition); } - - static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.RealizeFrame", RealizeFrameNode, + TIRFrameNode); public: /*! @@ -443,7 +465,11 @@ class RealizeFrameNode : public TIRFrameNode { */ class RealizeFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); + explicit RealizeFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RealizeFrame, TIRFrame, RealizeFrameNode); }; /*! @@ -454,15 +480,15 @@ class RealizeFrame : public TIRFrame { class AllocateFrameNode : public TIRFrameNode { public: /*! \brief The extents of the allocate. */ - Array extents; + ffi::Array extents; /*! \brief The data type of the buffer. */ DataType dtype; /*! \brief The storage scope. */ - String storage_scope; + ffi::String storage_scope; /*! \brief The condition. */ PrimExpr condition; /*! \brief Additional annotation hints. */ - Map annotations; + ffi::Map annotations; /*! \brief The buffer var. */ tvm::tir::Var buffer_var; @@ -476,9 +502,8 @@ class AllocateFrameNode : public TIRFrameNode { .def_ro("annotations", &AllocateFrameNode::annotations) .def_ro("buffer_var", &AllocateFrameNode::buffer_var); } - - static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.AllocateFrame", AllocateFrameNode, + TIRFrameNode); public: /*! @@ -495,7 +520,11 @@ class AllocateFrameNode : public TIRFrameNode { */ class AllocateFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode); + explicit AllocateFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AllocateFrame, TIRFrame, AllocateFrameNode); }; /*! @@ -508,13 +537,13 @@ class AllocateConstFrameNode : public TIRFrameNode { /*! \brief The data type of the buffer. */ DataType dtype; /*! \brief The extents of the allocate. */ - Array extents; + ffi::Array extents; /*! \brief The data associated with the constant. */ - tvm::runtime::NDArray data; + tvm::runtime::Tensor data; /*! \brief The buffer var */ tvm::tir::Var buffer_var; /*! \brief Additional annotations about the allocation. */ - Map annotations; + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -525,9 +554,8 @@ class AllocateConstFrameNode : public TIRFrameNode { .def_ro("buffer_var", &AllocateConstFrameNode::buffer_var) .def_ro("annotations", &AllocateConstFrameNode::annotations); } - - static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.AllocateConstFrame", + AllocateConstFrameNode, TIRFrameNode); public: /*! @@ -544,8 +572,13 @@ class AllocateConstFrameNode : public TIRFrameNode { */ class AllocateConstFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame, - AllocateConstFrameNode); + explicit AllocateConstFrame(ObjectPtr data) + : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AllocateConstFrame, TIRFrame, + AllocateConstFrameNode); }; /*! * \brief A frame that represents attribute node. @@ -557,7 +590,7 @@ class AttrFrameNode : public TIRFrameNode { /*! \brief The node to annotate the attribute. */ Any node; /*! \brief Attribute type key. */ - String attr_key; + ffi::String attr_key; /*! \brief The value of the attribute. */ PrimExpr value; @@ -568,9 +601,7 @@ class AttrFrameNode : public TIRFrameNode { .def_ro("attr_key", &AttrFrameNode::attr_key) .def_ro("value", &AttrFrameNode::value); } - - static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.AttrFrame", AttrFrameNode, TIRFrameNode); public: /*! @@ -587,7 +618,11 @@ class AttrFrameNode : public TIRFrameNode { */ class AttrFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode); + explicit AttrFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AttrFrame, TIRFrame, AttrFrameNode); }; /*! @@ -604,9 +639,8 @@ class WhileFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("condition", &WhileFrameNode::condition); } - - static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.WhileFrame", WhileFrameNode, + TIRFrameNode); public: /*! @@ -623,7 +657,11 @@ class WhileFrameNode : public TIRFrameNode { */ class WhileFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode); + explicit WhileFrame(ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { + TVM_FFI_ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WhileFrame, TIRFrame, WhileFrameNode); }; /*! @@ -636,9 +674,9 @@ class IfFrameNode : public TIRFrameNode { /*! \brief The condition of the if statement. */ PrimExpr condition; /*! \brief The statements in the true branch. */ - Optional> then_stmts; + ffi::Optional> then_stmts; /*! \brief The stetements in the false branch. */ - Optional> else_stmts; + ffi::Optional> else_stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -647,9 +685,7 @@ class IfFrameNode : public TIRFrameNode { .def_ro("then_stmts", &IfFrameNode::then_stmts) .def_ro("else_stmts", &IfFrameNode::else_stmts); } - - static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.IfFrame", IfFrameNode, TIRFrameNode); public: /*! @@ -666,7 +702,10 @@ class IfFrameNode : public TIRFrameNode { */ class IfFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode); + explicit IfFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IfFrame, TIRFrame, IfFrameNode); }; /*! @@ -680,9 +719,7 @@ class ThenFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.ThenFrame", ThenFrameNode, TIRFrameNode); public: /*! @@ -704,7 +741,10 @@ class ThenFrameNode : public TIRFrameNode { */ class ThenFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode); + explicit ThenFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ThenFrame, TIRFrame, ThenFrameNode); }; /*! @@ -718,9 +758,7 @@ class ElseFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.ElseFrame", ElseFrameNode, TIRFrameNode); public: /*! @@ -742,7 +780,11 @@ class ElseFrameNode : public TIRFrameNode { */ class ElseFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode); + explicit ElseFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ElseFrame, TIRFrame, ElseFrameNode); }; class DeclBufferFrameNode : public TIRFrameNode { @@ -758,9 +800,8 @@ class DeclBufferFrameNode : public TIRFrameNode { .def_ro("buffer", &DeclBufferFrameNode::buffer) .def_ro("allocated", &DeclBufferFrameNode::allocated); } - - static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.DeclBufferFrame", DeclBufferFrameNode, + TIRFrameNode); public: void ExitWithScope() final; @@ -768,7 +809,10 @@ class DeclBufferFrameNode : public TIRFrameNode { class DeclBufferFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode); + explicit DeclBufferFrame(ObjectPtr data) : TIRFrame(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DeclBufferFrame, TIRFrame, DeclBufferFrameNode); }; } // namespace tir diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 30b5bb3382f4..174d0b9c63c6 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -28,7 +28,7 @@ namespace script { namespace ir_builder { namespace tir { -using tvm::runtime::NDArray; +using tvm::runtime::Tensor; using tvm::tir::Buffer; using tvm::tir::Var; @@ -47,10 +47,11 @@ using tvm::tir::Var; * \param axis_separators The separators between input axes when generating flattened output axes. * \return The declared buffer. */ -Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, - Optional> strides, Optional elem_offset, - String storage_scope, int align, int offset_factor, String buffer_type, - Optional> axis_separators); +Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, int align, + int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators); /*! * \brief The primitive function statement. @@ -64,7 +65,7 @@ PrimFuncFrame PrimFunc(bool is_private); * \param var The variable argument. * \return The variable. */ -Var Arg(String name, Var var); +Var Arg(ffi::String name, Var var); /*! * \brief The PrimFunc buffer arguments adding function. @@ -72,19 +73,19 @@ Var Arg(String name, Var var); * \param buffer The buffer argument. * \return The buffer. */ -Buffer Arg(String name, Buffer buffer); +Buffer Arg(ffi::String name, Buffer buffer); /*! * \brief The PrimFunc naming statement. * \param name The name of the PrimFunc. */ -void FuncName(String name); +void FuncName(ffi::String name); /*! * \brief The PrimFunc annotation statement. * \param attrs The annotations of the PrimFunc. */ -void FuncAttrs(Map attrs); +void FuncAttrs(ffi::Map attrs); /*! * \brief The PrimFunc return type statement. @@ -108,11 +109,12 @@ Type FuncRet(Type ret_type); * \param axis_separators The separators between input axes when generating flattened output axes. * \return The matched buffer. */ -Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = DataType::Float(32), - Optional data = std::nullopt, Array strides = {}, - PrimExpr elem_offset = PrimExpr(), String storage_scope = "global", - int align = -1, int offset_factor = 0, String buffer_type = "default", - Optional> axis_separators = std::nullopt); +Buffer MatchBuffer(ObjectRef param, ffi::Array shape, + DataType dtype = DataType::Float(32), ffi::Optional data = std::nullopt, + ffi::Array strides = {}, PrimExpr elem_offset = PrimExpr(), + ffi::String storage_scope = "global", int align = -1, int offset_factor = 0, + ffi::String buffer_type = "default", + ffi::Optional> axis_separators = std::nullopt); /*! * \brief The block declaration statement. @@ -120,7 +122,7 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = Data * \param no_realize The flag whether to construct BlockRealize or Block. * \return The BlockFrame. */ -BlockFrame Block(String name, bool no_realize = false); +BlockFrame Block(ffi::String name, bool no_realize = false); /*! * \brief The block initialization statement. @@ -138,19 +140,19 @@ void Where(PrimExpr predicate); * \brief The block buffer region reading statement. * \param buffer_slices The array of buffer regions to read. */ -void Reads(Array buffer_slices); +void Reads(ffi::Array buffer_slices); /*! * \brief The block buffer region writing statement. * \param buffer_slices The array of buffer regions to write. */ -void Writes(Array buffer_slices); +void Writes(ffi::Array buffer_slices); /*! * \brief The block annotation statement. * \param attrs The annotation of the block. */ -void BlockAttrs(Map attrs); +void BlockAttrs(ffi::Map attrs); /*! * \brief The buffer allocation function. @@ -166,11 +168,11 @@ void BlockAttrs(Map attrs); * \param axis_separators The separators between input axes when generating flattened output axes. * \return The allocated buffer. */ -Buffer AllocBuffer(Array shape, DataType dtype = DataType::Float(32), - Optional data = std::nullopt, Array strides = {}, - PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1, - int offset_factor = 0, String buffer_type = "default", - Optional> axis_separators = std::nullopt); +Buffer AllocBuffer(ffi::Array shape, DataType dtype = DataType::Float(32), + ffi::Optional data = std::nullopt, ffi::Array strides = {}, + PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "", + int align = -1, int offset_factor = 0, ffi::String buffer_type = "default", + ffi::Optional> axis_separators = std::nullopt); namespace axis { /*! @@ -216,7 +218,8 @@ Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data types of the iteration variables. * \return The iteration variables. */ -Array Remap(String kinds, Array bindings, DataType dtype = DataType::Int(32)); +ffi::Array Remap(ffi::String kinds, ffi::Array bindings, + DataType dtype = DataType::Int(32)); } // namespace axis @@ -225,37 +228,45 @@ Array Remap(String kinds, Array bindings, DataType dtype = DataTy * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Serial(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The parallel For statement. * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Parallel(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The vectorized For statement. * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Vectorized(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The unrolled For statement. * \param start The minimum value of iteration. * \param stop The maximum value of iteration. * \param annotations The optional annotations of the For statement. + * \param step The optional step value of iteration. * \return The ForFrame. */ ForFrame Unroll(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt, + ffi::Optional step = std::nullopt); /*! * \brief The thread-binding For statement. * \param start The minimum value of iteration. @@ -264,14 +275,14 @@ ForFrame Unroll(PrimExpr start, PrimExpr stop, * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, - Optional> annotations = std::nullopt); +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, + ffi::Optional> annotations = std::nullopt); /*! * \brief The grid For statement. * \param extents The extents of the iteration. * \return The ForFrame. */ -ForFrame Grid(Array extents); +ForFrame Grid(ffi::Array extents); /*! * \brief The assertion statement. @@ -279,7 +290,7 @@ ForFrame Grid(Array extents); * \param message The error message when the assertion fails. * \return The AssertFrame. */ -AssertFrame Assert(PrimExpr condition, String message); +AssertFrame Assert(PrimExpr condition, ffi::String message); /*! * \brief The let binding. @@ -290,8 +301,8 @@ AssertFrame Assert(PrimExpr condition, String message); * \param var The variable to be bound. If not specified, a new variable will be created. * \return The created LetFrame. */ -LetFrame LetStmt(PrimExpr value, Optional type_annotation = std::nullopt, - Optional var = std::nullopt); +LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation = std::nullopt, + ffi::Optional var = std::nullopt); /*! * \brief The realization. @@ -300,7 +311,8 @@ LetFrame LetStmt(PrimExpr value, Optional type_annotation = std::nullopt, * \param condition The condition expression. * \return The result RealizeFrame. */ -RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String storage_scope, + PrimExpr condition); /*! * \brief The allocate node. @@ -311,9 +323,9 @@ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, * \param annotations Additional annotation hints. * \return The created AllocateFrame. */ -AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", - Optional condition = std::nullopt, - Optional> annotations = std::nullopt); +AllocateFrame Allocate(ffi::Array extents, DataType dtype, ffi::String storage_scope = "", + ffi::Optional condition = std::nullopt, + ffi::Optional> annotations = std::nullopt); /*! * \brief The allocate constant node. @@ -323,8 +335,9 @@ AllocateFrame Allocate(Array extents, DataType dtype, String storage_s * \param annotations Additional annotation hints. * \return The created AllocateConstFrame. */ -AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array extents, - Optional> annotations = std::nullopt); +AllocateConstFrame AllocateConst( + Tensor data, DataType dtype, ffi::Array extents, + ffi::Optional> annotations = std::nullopt); /*! * \brief Create an attribute. @@ -333,7 +346,7 @@ AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array e * \param value The value of the attribute. * \return The result AttrFrame. */ -AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value); +AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value); /*! * \brief Create a while loop. @@ -376,11 +389,11 @@ ElseFrame Else(); * \param axis_separators The separators between input axes when generating flattened output axes. * \return The declared buffer. */ -DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, - Optional data, Optional> strides, - Optional elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type, - Optional> axis_separators); +DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, + int align, int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators); /*! * \brief Launch a thread. @@ -396,7 +409,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); * \param extent The extent of environment thread. * \return The result LaunchThreadFrame. */ -LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent); +LaunchThreadFrame LaunchThread(ffi::String thread_tag, PrimExpr extent); /*! * \brief Bind a var to thread env. @@ -404,7 +417,7 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent); * \param dtype The data type of the variable. * \return The result variable which gets bound to the thread env. */ -Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); +Var EnvThread(ffi::String thread_tag, DataType dtype = DataType::Int(32)); /*! * \brief Store data in a buffer. @@ -414,8 +427,8 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * stored. The number lanes of the mask must be equal to the number of lanes in value. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate); +void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate); /*! * \brief Evaluate the input expression. @@ -441,7 +454,7 @@ void Evaluate(PrimExpr value); * \return The pointer. */ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), - String storage_scope = "global", bool is_size_var = false, + ffi::String storage_scope = "global", bool is_size_var = false, bool is_unknown_type = false) { Type type_annotation{nullptr}; if (is_unknown_type && storage_scope == "global") { @@ -454,12 +467,13 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), inline Var TensormapHandle() { return tvm::tir::Var("", PointerType(TensorMapType())); } -#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ - inline PrimExpr FuncName(Optional expr = std::nullopt, bool is_size_var = false) { \ - DataType dtype = DType; \ - return expr.defined() \ - ? tvm::cast(dtype, expr.value()) \ - : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(ffi::Optional expr = std::nullopt, \ + bool is_size_var = false) { \ + DataType dtype = DType; \ + return expr.defined() \ + ? tvm::cast(dtype, expr.value()) \ + : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ } #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ @@ -474,6 +488,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, FDType(Size, 2)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \ @@ -493,6 +508,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, FDType(2)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \ @@ -513,6 +529,8 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::Float TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::Float4E2M1FN); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(TensorFloat32, DataType::TensorFloat32); + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index b045ee00315b..9ce980d268df 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -42,7 +42,7 @@ class Doc; * \param doc Doc to be converted * \param cfg The configuration of the printer */ -String DocToPythonScript(Doc doc, const PrinterConfig& cfg); +ffi::String DocToPythonScript(Doc doc, const PrinterConfig& cfg); /*! * \brief The base class of all Doc. @@ -64,17 +64,16 @@ class DocNode : public Object { * this Doc is generated, in order to position the diagnostic * message. */ - mutable Array source_paths; + mutable ffi::Array source_paths; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_rw("source_paths", &DocNode::source_paths); } - static constexpr const char* _type_key = "script.printer.Doc"; static constexpr bool _type_mutable = true; - TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("script.printer.Doc", DocNode, Object); public: virtual ~DocNode() = default; @@ -88,9 +87,10 @@ class DocNode : public Object { class Doc : public ObjectRef { protected: Doc() = default; + explicit Doc(ObjectPtr data) : ObjectRef(data) {} public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Doc, ObjectRef, DocNode); }; class ExprDoc; @@ -106,19 +106,19 @@ class ExprDocNode : public DocNode { * \brief Create a doc representing attribute access on the current ExprDoc * \param attr The attribute to access. */ - ExprDoc Attr(String attr) const; + ExprDoc Attr(ffi::String attr) const; /*! * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. */ - ExprDoc operator[](Array indices) const; + ExprDoc operator[](ffi::Array indices) const; /*! * \brief Create a doc representing calling the current ExprDoc * \param args The positional arguments of the function call. */ - ExprDoc Call(Array args) const; + ExprDoc Call(ffi::Array args) const; /*! * \brief Create a doc representing attribute access on the current ExprDoc @@ -126,18 +126,15 @@ class ExprDocNode : public DocNode { * \param kwargs_keys Keys of keywords arguments of the function call. * \param kwargs_values Values of keywords arguments of the function call. */ - ExprDoc Call(Array args, // - Array kwargs_keys, // - Array kwargs_values) const; + ExprDoc Call(ffi::Array args, // + ffi::Array kwargs_keys, // + ffi::Array kwargs_values) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.printer.ExprDoc"; - - TVM_DECLARE_BASE_OBJECT_INFO(ExprDocNode, DocNode); + TVM_FFI_DECLARE_OBJECT_INFO("script.printer.ExprDoc", ExprDocNode, DocNode); }; /*! @@ -154,9 +151,11 @@ class ExprDoc : public Doc { * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. */ - ExprDoc operator[](Array indices) const; + ExprDoc operator[](ffi::Array indices) const; - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); + explicit ExprDoc(ObjectPtr data) : Doc(data) { TVM_FFI_ICHECK(data != nullptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ExprDoc, Doc, ExprDocNode); }; /*! @@ -174,16 +173,13 @@ class StmtDocNode : public DocNode { * line as the statement, or the line above, or inside the statement * if it spans over multiple lines. * */ - mutable Optional comment{std::nullopt}; + mutable ffi::Optional comment{std::nullopt}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_rw("comment", &StmtDocNode::comment); } - - static constexpr const char* _type_key = "script.printer.StmtDoc"; - - TVM_DECLARE_BASE_OBJECT_INFO(StmtDocNode, DocNode); + TVM_FFI_DECLARE_OBJECT_INFO("script.printer.StmtDoc", StmtDocNode, DocNode); }; /*! @@ -196,7 +192,7 @@ class StmtDoc : public Doc { StmtDoc() = default; public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtDoc, Doc, StmtDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StmtDoc, Doc, StmtDocNode); }; /*! @@ -208,16 +204,13 @@ class StmtDoc : public Doc { class StmtBlockDocNode : public DocNode { public: /*! \brief The list of statements. */ - Array stmts; + ffi::Array stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("stmts", &StmtBlockDocNode::stmts); } - - static constexpr const char* _type_key = "script.printer.StmtBlockDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(StmtBlockDocNode, DocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.StmtBlockDoc", StmtBlockDocNode, DocNode); }; /*! @@ -230,8 +223,8 @@ class StmtBlockDoc : public Doc { * \brief Constructor of StmtBlockDoc. * \param stmts The list of statements. */ - explicit StmtBlockDoc(Array stmts); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtBlockDoc, Doc, StmtBlockDocNode); + explicit StmtBlockDoc(ffi::Array stmts); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StmtBlockDoc, Doc, StmtBlockDocNode); }; /*! @@ -256,10 +249,7 @@ class LiteralDocNode : public ExprDocNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &LiteralDocNode::value); } - - static constexpr const char* _type_key = "script.printer.LiteralDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(LiteralDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.LiteralDoc", LiteralDocNode, ExprDocNode); }; /*! @@ -269,20 +259,22 @@ class LiteralDocNode : public ExprDocNode { */ class LiteralDoc : public ExprDoc { protected: - explicit LiteralDoc(ffi::Any value, const Optional& object_path); + explicit LiteralDoc(ffi::Any value, const ffi::Optional& object_path); public: /*! * \brief Create a LiteralDoc to represent None/null/empty value. * \param p The object path */ - static LiteralDoc None(const Optional& p) { return LiteralDoc(ffi::Any(nullptr), p); } + static LiteralDoc None(const ffi::Optional& p) { + return LiteralDoc(ffi::Any(nullptr), p); + } /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. * \param p The object path */ - static LiteralDoc Int(int64_t v, const Optional& p) { + static LiteralDoc Int(int64_t v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Int(64), v), p); } /*! @@ -290,7 +282,7 @@ class LiteralDoc : public ExprDoc { * \param v The boolean value. * \param p The object path */ - static LiteralDoc Boolean(bool v, const Optional& p) { + static LiteralDoc Boolean(bool v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Bool(), v), p); } /*! @@ -298,7 +290,7 @@ class LiteralDoc : public ExprDoc { * \param v The float value. * \param p The object path */ - static LiteralDoc Float(double v, const Optional& p) { + static LiteralDoc Float(double v, const ffi::Optional& p) { return LiteralDoc(FloatImm(DataType::Float(64), v), p); } /*! @@ -306,13 +298,15 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc Str(const String& v, const Optional& p) { return LiteralDoc(v, p); } + static LiteralDoc Str(const ffi::String& v, const ffi::Optional& p) { + return LiteralDoc(v, p); + } /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. * \param p The object path */ - static LiteralDoc DataType(const runtime::DataType& v, const Optional& p) { + static LiteralDoc DataType(const runtime::DataType& v, const ffi::Optional& p) { std::string dtype = v.is_void() ? "void" : runtime::DLDataTypeToString(v); return LiteralDoc::Str(dtype, p); } @@ -321,13 +315,13 @@ class LiteralDoc : public ExprDoc { * \param v The device. * \param p The object path */ - static LiteralDoc Device(const DLDevice& v, const Optional& p) { + static LiteralDoc Device(const DLDevice& v, const ffi::Optional& p) { std::ostringstream os; runtime::operator<<(os, v); return LiteralDoc::Str(os.str(), p); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LiteralDoc, ExprDoc, LiteralDocNode); }; /*! @@ -338,16 +332,13 @@ class LiteralDoc : public ExprDoc { class IdDocNode : public ExprDocNode { public: /*! \brief The name of the identifier */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("name", &IdDocNode::name); } - - static constexpr const char* _type_key = "script.printer.IdDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(IdDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.IdDoc", IdDocNode, ExprDocNode); }; /*! @@ -361,9 +352,9 @@ class IdDoc : public ExprDoc { * \brief Constructor of IdDoc. * \param name The name of identifier. */ - explicit IdDoc(String name); + explicit IdDoc(ffi::String name); explicit IdDoc(std::nullptr_t) : ExprDoc(nullptr) {} - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IdDoc, ExprDoc, IdDocNode); }; /*! @@ -374,9 +365,9 @@ class IdDoc : public ExprDoc { class AttrAccessDocNode : public ExprDocNode { public: /*! \brief The target expression to be accessed */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; /*! \brief The attribute to be accessed */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -384,10 +375,7 @@ class AttrAccessDocNode : public ExprDocNode { .def_ro("value", &AttrAccessDocNode::value) .def_ro("name", &AttrAccessDocNode::name); } - - static constexpr const char* _type_key = "script.printer.AttrAccessDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(AttrAccessDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.AttrAccessDoc", AttrAccessDocNode, ExprDocNode); }; /*! @@ -402,8 +390,8 @@ class AttrAccessDoc : public ExprDoc { * \param value The target expression of attribute access. * \param name The name of attribute to access. */ - explicit AttrAccessDoc(ExprDoc value, String name); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttrAccessDoc, ExprDoc, AttrAccessDocNode); + explicit AttrAccessDoc(ExprDoc value, ffi::String name); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AttrAccessDoc, ExprDoc, AttrAccessDocNode); }; /*! @@ -414,7 +402,7 @@ class AttrAccessDoc : public ExprDoc { class IndexDocNode : public ExprDocNode { public: /*! \brief The container value to be accessed */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; /*! * \brief The indices to access * @@ -422,7 +410,7 @@ class IndexDocNode : public ExprDocNode { * - ExprDoc (single point access like a[1, 2]) * - SliceDoc (slice access like a[1:5, 2]) */ - Array indices; // Each element is union of: Slice / ExprDoc + ffi::Array indices; // Each element is union of: Slice / ExprDoc static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -430,10 +418,7 @@ class IndexDocNode : public ExprDocNode { .def_ro("value", &IndexDocNode::value) .def_ro("indices", &IndexDocNode::indices); } - - static constexpr const char* _type_key = "script.printer.IndexDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(IndexDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.IndexDoc", IndexDocNode, ExprDocNode); }; /*! @@ -448,8 +433,8 @@ class IndexDoc : public ExprDoc { * \param value The target expression of index access. * \param indices The indices to access. */ - explicit IndexDoc(ExprDoc value, Array indices); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IndexDoc, ExprDoc, IndexDocNode); + explicit IndexDoc(ExprDoc value, ffi::Array indices); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IndexDoc, ExprDoc, IndexDocNode); }; /*! @@ -460,18 +445,18 @@ class IndexDoc : public ExprDoc { class CallDocNode : public ExprDocNode { public: /*! \brief The callee of this function call */ - ExprDoc callee{nullptr}; + ExprDoc callee{ffi::UnsafeInit()}; /*! \brief The positional arguments */ - Array args; + ffi::Array args; /*! \brief The keys of keyword arguments */ - Array kwargs_keys; + ffi::Array kwargs_keys; /*! * \brief The values of keyword arguments. * * The i-th element is the value of the i-th key in `kwargs_keys`. * It must have the same length as `kwargs_keys`. */ - Array kwargs_values; + ffi::Array kwargs_values; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -481,10 +466,7 @@ class CallDocNode : public ExprDocNode { .def_ro("kwargs_keys", &CallDocNode::kwargs_keys) .def_ro("kwargs_values", &CallDocNode::kwargs_values); } - - static constexpr const char* _type_key = "script.printer.CallDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(CallDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.CallDoc", CallDocNode, ExprDocNode); }; /*! @@ -501,9 +483,9 @@ class CallDoc : public ExprDoc { * \param kwargs_keys Keys of keyword arguments. * \param kwargs_values Values of keyword arguments, must have the same length as `kwargs_keys. */ - CallDoc(ExprDoc callee, Array args, Array kwargs_keys, - Array kwargs_values); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CallDoc, ExprDoc, CallDocNode); + CallDoc(ExprDoc callee, ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(CallDoc, ExprDoc, CallDocNode); }; /*! @@ -557,7 +539,7 @@ class OperationDocNode : public ExprDocNode { /*! \brief The kind of operation (operator) */ Kind kind; /*! \brief Operands of this expression */ - Array operands; + ffi::Array operands; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -565,10 +547,7 @@ class OperationDocNode : public ExprDocNode { .def_ro("kind", &OperationDocNode::kind) .def_ro("operands", &OperationDocNode::operands); } - - static constexpr const char* _type_key = "script.printer.OperationDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(OperationDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.OperationDoc", OperationDocNode, ExprDocNode); }; /*! @@ -583,8 +562,8 @@ class OperationDoc : public ExprDoc { * \param kind The kind of operation. * \param operands Operands of this expression. */ - explicit OperationDoc(OperationDocNode::Kind kind, Array operands); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(OperationDoc, ExprDoc, OperationDocNode); + explicit OperationDoc(OperationDocNode::Kind kind, ffi::Array operands); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(OperationDoc, ExprDoc, OperationDocNode); }; /*! @@ -598,9 +577,9 @@ class OperationDoc : public ExprDoc { class LambdaDocNode : public ExprDocNode { public: /*! \brief The arguments of this anonymous function */ - Array args; + ffi::Array args; /*! \brief The body of this anonymous function */ - ExprDoc body{nullptr}; + ExprDoc body{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -608,10 +587,7 @@ class LambdaDocNode : public ExprDocNode { .def_ro("args", &LambdaDocNode::args) .def_ro("body", &LambdaDocNode::body); } - - static constexpr const char* _type_key = "script.printer.LambdaDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.LambdaDoc", LambdaDocNode, ExprDocNode); }; /*! @@ -626,8 +602,8 @@ class LambdaDoc : public ExprDoc { * \param args Arguments of this function. * \param body Body expression of this function. */ - explicit LambdaDoc(Array args, ExprDoc body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, ExprDoc, LambdaDocNode); + explicit LambdaDoc(ffi::Array args, ExprDoc body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LambdaDoc, ExprDoc, LambdaDocNode); }; /*! @@ -638,16 +614,13 @@ class LambdaDoc : public ExprDoc { class TupleDocNode : public ExprDocNode { public: /*! \brief Elements of tuple */ - Array elements; + ffi::Array elements; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("elements", &TupleDocNode::elements); } - - static constexpr const char* _type_key = "script.printer.TupleDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(TupleDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.TupleDoc", TupleDocNode, ExprDocNode); }; /*! @@ -660,13 +633,13 @@ class TupleDoc : public ExprDoc { /*! * \brief Create an empty TupleDoc */ - TupleDoc() : TupleDoc(ffi::make_object()) {} + TupleDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of TupleDoc * \param elements Elements of tuple. */ - explicit TupleDoc(Array elements); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleDoc, ExprDoc, TupleDocNode); + explicit TupleDoc(ffi::Array elements); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TupleDoc, ExprDoc, TupleDocNode); }; /*! @@ -677,16 +650,13 @@ class TupleDoc : public ExprDoc { class ListDocNode : public ExprDocNode { public: /*! \brief Elements of list */ - Array elements; + ffi::Array elements; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("elements", &ListDocNode::elements); } - - static constexpr const char* _type_key = "script.printer.ListDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ListDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.ListDoc", ListDocNode, ExprDocNode); }; /*! @@ -699,13 +669,13 @@ class ListDoc : public ExprDoc { /*! * \brief Create an empty ListDoc */ - ListDoc() : ListDoc(ffi::make_object()) {} + ListDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of ListDoc * \param elements Elements of list. */ - explicit ListDoc(Array elements); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ListDoc, ExprDoc, ListDocNode); + explicit ListDoc(ffi::Array elements); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ListDoc, ExprDoc, ListDocNode); }; /*! @@ -716,14 +686,14 @@ class ListDoc : public ExprDoc { class DictDocNode : public ExprDocNode { public: /*! \brief keys of dictionary */ - Array keys; + ffi::Array keys; /*! * \brief Values of dictionary * * The i-th element is the value of the i-th element of `keys`. * It must have the same length as `keys`. */ - Array values; + ffi::Array values; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -731,10 +701,7 @@ class DictDocNode : public ExprDocNode { .def_ro("keys", &DictDocNode::keys) .def_ro("values", &DictDocNode::values); } - - static constexpr const char* _type_key = "script.printer.DictDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(DictDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.DictDoc", DictDocNode, ExprDocNode); }; /*! @@ -747,14 +714,14 @@ class DictDoc : public ExprDoc { /*! * \brief Create an empty dictionary */ - DictDoc() : DictDoc(ffi::make_object()) {} + DictDoc() : ExprDoc(ffi::make_object()) {} /*! * \brief Constructor of DictDoc * \param keys Keys of dictionary. * \param values Values of dictionary, must have same length as `keys`. */ - explicit DictDoc(Array keys, Array values); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DictDoc, ExprDoc, DictDocNode); + explicit DictDoc(ffi::Array keys, ffi::Array values); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DictDoc, ExprDoc, DictDocNode); }; /*! @@ -767,11 +734,11 @@ class DictDoc : public ExprDoc { class SliceDocNode : public DocNode { public: /*! \brief The start of slice */ - Optional start; + ffi::Optional start; /*! \brief The exclusive end of slice */ - Optional stop; + ffi::Optional stop; /*! \brief The step of slice */ - Optional step; + ffi::Optional step; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -780,10 +747,7 @@ class SliceDocNode : public DocNode { .def_ro("stop", &SliceDocNode::stop) .def_ro("step", &SliceDocNode::step); } - - static constexpr const char* _type_key = "script.printer.SliceDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(SliceDocNode, DocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.SliceDoc", SliceDocNode, DocNode); }; /*! @@ -799,8 +763,9 @@ class SliceDoc : public Doc { * \param stop The exclusive end of slice. * \param step The step of slice. */ - explicit SliceDoc(Optional start, Optional stop, Optional step); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); + explicit SliceDoc(ffi::Optional start, ffi::Optional stop, + ffi::Optional step); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SliceDoc, Doc, SliceDocNode); }; /*! @@ -811,15 +776,15 @@ class SliceDoc : public Doc { class AssignDocNode : public StmtDocNode { public: /*! \brief The left hand side of the assignment */ - ExprDoc lhs{nullptr}; + ExprDoc lhs{ffi::UnsafeInit()}; /*! * \brief The right hand side of the assignment. * * If null, this doc represents declaration, e.g. `A: T.Buffer((1,2))` * */ - Optional rhs; + ffi::Optional rhs; /*! \brief The type annotation of this assignment. */ - Optional annotation; + ffi::Optional annotation; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -828,10 +793,7 @@ class AssignDocNode : public StmtDocNode { .def_ro("rhs", &AssignDocNode::rhs) .def_ro("annotation", &AssignDocNode::annotation); } - - static constexpr const char* _type_key = "script.printer.AssignDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(AssignDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.AssignDoc", AssignDocNode, StmtDocNode); }; /*! @@ -847,8 +809,8 @@ class AssignDoc : public StmtDoc { * \param rhs The right hand side of the assignment. * \param annotation The type annotation of this assignment. */ - explicit AssignDoc(ExprDoc lhs, Optional rhs, Optional annotation); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssignDoc, StmtDoc, AssignDocNode); + explicit AssignDoc(ExprDoc lhs, ffi::Optional rhs, ffi::Optional annotation); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AssignDoc, StmtDoc, AssignDocNode); }; /*! @@ -859,11 +821,11 @@ class AssignDoc : public StmtDoc { class IfDocNode : public StmtDocNode { public: /*! \brief The predicate of the if-then-else statement. */ - ExprDoc predicate{nullptr}; + ExprDoc predicate{ffi::UnsafeInit()}; /*! \brief The then branch of the if-then-else statement. */ - Array then_branch; + ffi::Array then_branch; /*! \brief The else branch of the if-then-else statement. */ - Array else_branch; + ffi::Array else_branch; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -872,10 +834,7 @@ class IfDocNode : public StmtDocNode { .def_ro("then_branch", &IfDocNode::then_branch) .def_ro("else_branch", &IfDocNode::else_branch); } - - static constexpr const char* _type_key = "script.printer.IfDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(IfDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.IfDoc", IfDocNode, StmtDocNode); }; /*! @@ -891,8 +850,9 @@ class IfDoc : public StmtDoc { * \param then_branch The then branch of the if-then-else statement. * \param else_branch The else branch of the if-then-else statement. */ - explicit IfDoc(ExprDoc predicate, Array then_branch, Array else_branch); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IfDoc, StmtDoc, IfDocNode); + explicit IfDoc(ExprDoc predicate, ffi::Array then_branch, + ffi::Array else_branch); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IfDoc, StmtDoc, IfDocNode); }; /*! @@ -903,9 +863,9 @@ class IfDoc : public StmtDoc { class WhileDocNode : public StmtDocNode { public: /*! \brief The predicate of the while statement. */ - ExprDoc predicate{nullptr}; + ExprDoc predicate{ffi::UnsafeInit()}; /*! \brief The body of the while statement. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -913,10 +873,7 @@ class WhileDocNode : public StmtDocNode { .def_ro("predicate", &WhileDocNode::predicate) .def_ro("body", &WhileDocNode::body); } - - static constexpr const char* _type_key = "script.printer.WhileDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(WhileDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.WhileDoc", WhileDocNode, StmtDocNode); }; /*! @@ -931,8 +888,8 @@ class WhileDoc : public StmtDoc { * \param predicate The predicate of the while statement. * \param body The body of the while statement. */ - explicit WhileDoc(ExprDoc predicate, Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WhileDoc, StmtDoc, WhileDocNode); + explicit WhileDoc(ExprDoc predicate, ffi::Array body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WhileDoc, StmtDoc, WhileDocNode); }; /*! @@ -947,11 +904,11 @@ class WhileDoc : public StmtDoc { class ForDocNode : public StmtDocNode { public: /*! \brief The left hand side of the assignment of iterating variable. */ - ExprDoc lhs{nullptr}; + ExprDoc lhs{ffi::UnsafeInit()}; /*! \brief The right hand side of the assignment of iterating variable. */ - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; /*! \brief The body of the for statement. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -960,10 +917,7 @@ class ForDocNode : public StmtDocNode { .def_ro("rhs", &ForDocNode::rhs) .def_ro("body", &ForDocNode::body); } - - static constexpr const char* _type_key = "script.printer.ForDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ForDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.ForDoc", ForDocNode, StmtDocNode); }; /*! @@ -979,8 +933,8 @@ class ForDoc : public StmtDoc { * \param rhs The right hand side of the assignment of iterating variable. * \param body The body of the for statement. */ - explicit ForDoc(ExprDoc lhs, ExprDoc rhs, Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ForDoc, StmtDoc, ForDocNode); + explicit ForDoc(ExprDoc lhs, ExprDoc rhs, ffi::Array body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ForDoc, StmtDoc, ForDocNode); }; /*! @@ -996,11 +950,11 @@ class ForDoc : public StmtDoc { class ScopeDocNode : public StmtDocNode { public: /*! \brief The name of the scoped variable. */ - Optional lhs{std::nullopt}; + ffi::Optional lhs{std::nullopt}; /*! \brief The value of the scoped variable. */ - ExprDoc rhs{nullptr}; + ExprDoc rhs{ffi::UnsafeInit()}; /*! \brief The body of the scope doc. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1009,10 +963,7 @@ class ScopeDocNode : public StmtDocNode { .def_ro("rhs", &ScopeDocNode::rhs) .def_ro("body", &ScopeDocNode::body); } - - static constexpr const char* _type_key = "script.printer.ScopeDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ScopeDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.ScopeDoc", ScopeDocNode, StmtDocNode); }; /*! @@ -1028,16 +979,16 @@ class ScopeDoc : public StmtDoc { * \param rhs The value of the scoped variable. * \param body The body of the scope doc. */ - explicit ScopeDoc(Optional lhs, ExprDoc rhs, Array body); + explicit ScopeDoc(ffi::Optional lhs, ExprDoc rhs, ffi::Array body); /*! * \brief Constructor of ScopeDoc. * \param rhs The value of the scoped variable. * \param body The body of the scope doc. */ - explicit ScopeDoc(ExprDoc rhs, Array body); + explicit ScopeDoc(ExprDoc rhs, ffi::Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ScopeDoc, StmtDoc, ScopeDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ScopeDoc, StmtDoc, ScopeDocNode); }; /*! @@ -1048,16 +999,13 @@ class ScopeDoc : public StmtDoc { class ExprStmtDocNode : public StmtDocNode { public: /*! \brief The expression represented by this doc. */ - ExprDoc expr{nullptr}; + ExprDoc expr{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("expr", &ExprStmtDocNode::expr); } - - static constexpr const char* _type_key = "script.printer.ExprStmtDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ExprStmtDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.ExprStmtDoc", ExprStmtDocNode, StmtDocNode); }; /*! @@ -1072,7 +1020,7 @@ class ExprStmtDoc : public StmtDoc { * \param expr The expression represented by this doc. */ explicit ExprStmtDoc(ExprDoc expr); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprStmtDoc, StmtDoc, ExprStmtDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ExprStmtDoc, StmtDoc, ExprStmtDocNode); }; /*! @@ -1083,9 +1031,9 @@ class ExprStmtDoc : public StmtDoc { class AssertDocNode : public StmtDocNode { public: /*! \brief The expression to test. */ - ExprDoc test{nullptr}; + ExprDoc test{ffi::UnsafeInit()}; /*! \brief The optional error message when assertion failed. */ - Optional msg{std::nullopt}; + ffi::Optional msg{std::nullopt}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1093,10 +1041,7 @@ class AssertDocNode : public StmtDocNode { .def_ro("test", &AssertDocNode::test) .def_ro("msg", &AssertDocNode::msg); } - - static constexpr const char* _type_key = "script.printer.AssertDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(AssertDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.AssertDoc", AssertDocNode, StmtDocNode); }; /*! @@ -1111,8 +1056,8 @@ class AssertDoc : public StmtDoc { * \param test The expression to test. * \param msg The optional error message when assertion failed. */ - explicit AssertDoc(ExprDoc test, Optional msg = std::nullopt); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssertDoc, StmtDoc, AssertDocNode); + explicit AssertDoc(ExprDoc test, ffi::Optional msg = std::nullopt); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AssertDoc, StmtDoc, AssertDocNode); }; /*! @@ -1123,16 +1068,13 @@ class AssertDoc : public StmtDoc { class ReturnDocNode : public StmtDocNode { public: /*! \brief The value to return. */ - ExprDoc value{nullptr}; + ExprDoc value{ffi::UnsafeInit()}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &ReturnDocNode::value); } - - static constexpr const char* _type_key = "script.printer.ReturnDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ReturnDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.ReturnDoc", ReturnDocNode, StmtDocNode); }; /*! @@ -1147,7 +1089,7 @@ class ReturnDoc : public StmtDoc { * \param value The value to return. */ explicit ReturnDoc(ExprDoc value); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReturnDoc, StmtDoc, ReturnDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ReturnDoc, StmtDoc, ReturnDocNode); }; /*! @@ -1158,7 +1100,7 @@ class ReturnDoc : public StmtDoc { class FunctionDocNode : public StmtDocNode { public: /*! \brief The name of function. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of function. * @@ -1166,13 +1108,13 @@ class FunctionDocNode : public StmtDocNode { * `annotation` means argument type, * and `rhs` means default value. */ - Array args; + ffi::Array args; /*! \brief Decorators of function. */ - Array decorators; + ffi::Array decorators; /*! \brief The return type of function. */ - Optional return_type{std::nullopt}; + ffi::Optional return_type{std::nullopt}; /*! \brief The body of function. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1183,10 +1125,7 @@ class FunctionDocNode : public StmtDocNode { .def_ro("return_type", &FunctionDocNode::return_type) .def_ro("body", &FunctionDocNode::body); } - - static constexpr const char* _type_key = "script.printer.FunctionDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.FunctionDoc", FunctionDocNode, StmtDocNode); }; /*! @@ -1204,9 +1143,9 @@ class FunctionDoc : public StmtDoc { * \param return_type The return type of function. * \param body The body of function. */ - explicit FunctionDoc(IdDoc name, Array args, Array decorators, - Optional return_type, Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode); + explicit FunctionDoc(IdDoc name, ffi::Array args, ffi::Array decorators, + ffi::Optional return_type, ffi::Array body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FunctionDoc, StmtDoc, FunctionDocNode); }; /*! @@ -1217,11 +1156,11 @@ class FunctionDoc : public StmtDoc { class ClassDocNode : public StmtDocNode { public: /*! \brief The name of class. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! \brief Decorators of class. */ - Array decorators; + ffi::Array decorators; /*! \brief The body of class. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1230,10 +1169,7 @@ class ClassDocNode : public StmtDocNode { .def_ro("decorators", &ClassDocNode::decorators) .def_ro("body", &ClassDocNode::body); } - - static constexpr const char* _type_key = "script.printer.ClassDoc"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ClassDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.ClassDoc", ClassDocNode, StmtDocNode); }; /*! @@ -1249,8 +1185,8 @@ class ClassDoc : public StmtDoc { * \param decorators The decorator of class. * \param body The body of class. */ - explicit ClassDoc(IdDoc name, Array decorators, Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode); + explicit ClassDoc(IdDoc name, ffi::Array decorators, ffi::Array body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ClassDoc, StmtDoc, ClassDocNode); }; /*! @@ -1264,9 +1200,7 @@ class CommentDocNode : public StmtDocNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.printer.CommentDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(CommentDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.CommentDoc", CommentDocNode, StmtDocNode); }; /*! @@ -1276,8 +1210,8 @@ class CommentDocNode : public StmtDocNode { */ class CommentDoc : public StmtDoc { public: - explicit CommentDoc(String comment); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CommentDoc, StmtDoc, CommentDocNode); + explicit CommentDoc(ffi::String comment); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(CommentDoc, StmtDoc, CommentDocNode); }; /*! @@ -1291,9 +1225,7 @@ class DocStringDocNode : public StmtDocNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "script.printer.DocStringDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(DocStringDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.DocStringDoc", DocStringDocNode, StmtDocNode); }; /*! @@ -1303,8 +1235,8 @@ class DocStringDocNode : public StmtDocNode { */ class DocStringDoc : public StmtDoc { public: - explicit DocStringDoc(String docs); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DocStringDoc, StmtDoc, DocStringDocNode); + explicit DocStringDoc(ffi::String docs); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DocStringDoc, StmtDoc, DocStringDocNode); }; } // namespace printer diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index dd7eaff7cc69..b5d50d89019b 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -50,7 +50,7 @@ class IRDocsifierNode; class FrameNode : public Object { public: /*! The docs generated in the frame */ - Array stmts; + ffi::Array stmts; /*! The corresponding IRDocsifier */ IRDocsifierNode* d; /*! The callbacks that are going to be invoked when the frame exits */ @@ -61,9 +61,8 @@ class FrameNode : public Object { refl::ObjectDef().def_ro("stmts", &FrameNode::stmts); } - static constexpr const char* _type_key = "script.printer.Frame"; - - TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("script.printer.Frame", FrameNode, Object); public: virtual ~FrameNode() = default; @@ -82,7 +81,7 @@ class FrameNode : public Object { * \param d The docsifier. * \param token The token to be added. */ - void AddDispatchToken(const IRDocsifier& d, const String& token); + void AddDispatchToken(const IRDocsifier& d, const ffi::String& token); /*! * \brief Method that's called when Frame enters the scope. */ @@ -109,7 +108,7 @@ class Frame : public ObjectRef { /*! \brief Method that's called when Frame exits the scope. */ void ExitWithScope() { get()->ExitWithScope(); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Frame, ObjectRef, FrameNode); }; //////////////////////// IRDocsifier //////////////////////// @@ -129,30 +128,30 @@ class IRDocsifierNode : public Object { /*! \brief The creator */ DocCreator creator; /*! \brief The name of the variable */ - Optional name; + ffi::Optional name; }; /*! \brief The configuration of the printer */ - PrinterConfig cfg{nullptr}; + PrinterConfig cfg{ffi::UnsafeInit()}; /*! * \brief The stack of frames. * \sa FrameNode */ - Array frames; + ffi::Array frames; /*! * \brief The stack of dispatch tokens. * * The dispatch token on the top decides which dispatch function to use * when converting IR node object to Doc. */ - Array dispatch_tokens; + ffi::Array dispatch_tokens; /*! \brief Mapping from a var to its info */ std::unordered_map obj2info; /*! \brief Metadata printing */ - std::unordered_map> metadata; + std::unordered_map> metadata; /*! \brief GlobalInfo printing */ - std::unordered_map> global_infos; + std::unordered_map> global_infos; /*! \brief The variable names used already */ - std::unordered_set defined_names; + std::unordered_set defined_names; /*! \brief Common prefixes of variable usages */ std::unordered_map> common_prefix; /*! \brief The IR usages for headers printing */ @@ -165,9 +164,8 @@ class IRDocsifierNode : public Object { .def_ro("dispatch_tokens", &IRDocsifierNode::dispatch_tokens); } - static constexpr const char* _type_key = "script.printer.IRDocsifier"; - - TVM_DECLARE_FINAL_OBJECT_INFO(IRDocsifierNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.IRDocsifier", IRDocsifierNode, Object); public: /*! @@ -181,7 +179,7 @@ class IRDocsifierNode : public Object { * This function will rename the variable to avoid name conflict with other variables * in the table. */ - IdDoc Define(const ObjectRef& obj, const Frame& frame, const String& name_hint); + IdDoc Define(const ObjectRef& obj, const Frame& frame, const ffi::String& name_hint); /*! * \brief Define variable by doc factory. @@ -207,14 +205,14 @@ class IRDocsifierNode : public Object { * * \return The doc for variable, if it exists in the table. Otherwise it returns std::nullopt. */ - Optional GetVarDoc(const ObjectRef& obj) const; + ffi::Optional GetVarDoc(const ObjectRef& obj) const; /*! \brief Add a TVM object to the metadata section*/ ExprDoc AddMetadata(const ffi::Any& obj); /*! \brief Add a GlobalInfo to the global_infos map. * \param name The name of key of global_infos. * \param ginfo The GlobalInfo to be added. */ - void AddGlobalInfo(const String& name, const GlobalInfo& ginfo); + void AddGlobalInfo(const ffi::String& name, const GlobalInfo& ginfo); /*! * \brief Check if a variable exists in the table. * \param obj The variable object. @@ -252,14 +250,14 @@ class IRDocsifier : public ObjectRef { /*! \brief The registration table for IRDocsifier. */ TVM_DLL static FType& vtable(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IRDocsifier, ObjectRef, IRDocsifierNode); }; //////////////////////// Implementation //////////////////////// inline void FrameNode::EnterWithScope() { if (d != nullptr) { - d->frames.push_back(GetRef(this)); + d->frames.push_back(ffi::GetRef(this)); } } @@ -295,7 +293,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ac } for (const auto& pair : cfg->path_to_annotate) { AccessPath p = pair.first; - String attn = pair.second; + ffi::String attn = pair.second; if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { if (const auto* stmt = d.as()) { if (stmt->comment.has_value()) { @@ -340,7 +338,8 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) con default: { if (auto opt_obj = value.as()) { ObjectRef obj = opt_obj.value(); - Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef(this)); + Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, + ffi::GetRef(this)); d->source_paths.push_back(path); AddDocDecoration(d, obj, path, cfg); return Downcast(d); @@ -352,7 +351,7 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) con } } -inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) { +inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const ffi::String& token) { d->dispatch_tokens.push_back(token); this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); }); } diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index e4be2d31aa57..4500a7d8607b 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -61,7 +61,7 @@ class IRDocsifierFunctor { * dispatch function for TObjectRef with the default dispatch token (empty string). */ template - R operator()(const String& token, TObjectRef obj, Args... args) const { + R operator()(const ffi::String& token, TObjectRef obj, Args... args) const { uint32_t type_index = obj.defined() ? obj->type_index() : 0; const ffi::Function* pf = nullptr; if ((pf = LookupDispatchTable(token, type_index)) != nullptr) { @@ -91,7 +91,7 @@ class IRDocsifierFunctor { * This takes a type-erased packed function as input. It should be used * through FFI boundary, for example, registering dispatch function from Python. */ - TSelf& set_dispatch(String token, uint32_t type_index, ffi::Function f) { + TSelf& set_dispatch(ffi::String token, uint32_t type_index, ffi::Function f) { std::vector* table = &dispatch_table_[token]; if (table->size() <= type_index) { table->resize(type_index + 1, nullptr); @@ -120,7 +120,7 @@ class IRDocsifierFunctor { */ template ::value>> - TSelf& set_dispatch(String token, TCallable f) { + TSelf& set_dispatch(ffi::String token, TCallable f) { return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(), ffi::TypedFunction(f)); } @@ -140,7 +140,7 @@ class IRDocsifierFunctor { * This is useful when dispatch function comes from other language's runtime, and * those function should be removed before that language runtime shuts down. */ - void remove_dispatch(String token, uint32_t type_index) { + void remove_dispatch(ffi::String token, uint32_t type_index) { std::vector* table = &dispatch_table_[token]; if (table->size() <= type_index) { return; @@ -155,7 +155,7 @@ class IRDocsifierFunctor { * \param type_index The TVM object type index. * \return Returns the functor if the lookup succeeds, nullptr otherwise. */ - const ffi::Function* LookupDispatchTable(const String& token, uint32_t type_index) const { + const ffi::Function* LookupDispatchTable(const ffi::String& token, uint32_t type_index) const { auto it = dispatch_table_.find(token); if (it == dispatch_table_.end()) { return nullptr; diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index 26111caa079a..59a13ae572ab 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -37,9 +37,9 @@ namespace tvm { class TargetTagNode : public Object { public: /*! \brief Name of the target */ - String name; + ffi::String name; /*! \brief Config map to generate the target */ - Map config; + ffi::Map config; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -47,16 +47,13 @@ class TargetTagNode : public Object { .def_ro("name", &TargetTagNode::name) .def_ro("config", &TargetTagNode::config); } - - static constexpr const char* _type_key = "target.TargetTag"; - - TVM_DECLARE_FINAL_OBJECT_INFO(TargetTagNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.TargetTag", TargetTagNode, Object); private: /*! \brief Return the index stored in attr registry */ uint32_t AttrRegistryIndex() const { return index_; } /*! \brief Return the name stored in attr registry */ - String AttrRegistryName() const { return name; } + ffi::String AttrRegistryName() const { return name; } /*! \brief Index used for internal lookup of attribute registry */ uint32_t index_; @@ -78,12 +75,12 @@ class TargetTag : public ObjectRef { * \param target_tag_name Name of the target tag * \return The Target requested */ - TVM_DLL static Optional Get(const String& target_tag_name); + TVM_DLL static ffi::Optional Get(const ffi::String& target_tag_name); /*! * \brief List all names of the existing target tags * \return A dictionary that maps tag name to the concrete target it corresponds to */ - TVM_DLL static Map ListTags(); + TVM_DLL static ffi::Map ListTags(); /*! * \brief Add a tag into the registry * \param name Name of the tag @@ -91,9 +88,9 @@ class TargetTag : public ObjectRef { * \param override Allow overriding existing tags * \return Target created with the tag */ - TVM_DLL static Target AddTag(String name, Map config, bool override); + TVM_DLL static Target AddTag(ffi::String name, ffi::Map config, bool override); - TVM_DEFINE_OBJECT_REF_METHODS(TargetTag, ObjectRef, TargetTagNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TargetTag, ObjectRef, TargetTagNode); private: /*! \brief Mutable access to the container class */ @@ -107,13 +104,13 @@ class TargetTagRegEntry { * \brief Set the config dict corresponding to the target tag * \param config The config dict for target creation */ - inline TargetTagRegEntry& set_config(Map config); + inline TargetTagRegEntry& set_config(ffi::Map config); /*! * \brief Add a key-value pair to the config dict * \param key The attribute name * \param value The attribute value */ - inline TargetTagRegEntry& with_config(String key, Any value); + inline TargetTagRegEntry& with_config(ffi::String key, Any value); /*! \brief Set name of the TargetTag to be the same as registry if it is empty */ inline TargetTagRegEntry& set_name(); /*! @@ -121,14 +118,14 @@ class TargetTagRegEntry { * \param target_tag_name The name of the TargetTag. * \return the corresponding entry. */ - TVM_DLL static TargetTagRegEntry& RegisterOrGet(const String& target_tag_name); + TVM_DLL static TargetTagRegEntry& RegisterOrGet(const ffi::String& target_tag_name); private: TargetTag tag_; - String name; + ffi::String name; /*! \brief private constructor */ - explicit TargetTagRegEntry(uint32_t reg_index) : tag_(make_object()) { + explicit TargetTagRegEntry(uint32_t reg_index) : tag_(ffi::make_object()) { tag_->index_ = reg_index; } template @@ -136,12 +133,12 @@ class TargetTagRegEntry { friend class TargetTag; }; -inline TargetTagRegEntry& TargetTagRegEntry::set_config(Map config) { +inline TargetTagRegEntry& TargetTagRegEntry::set_config(ffi::Map config) { tag_->config = std::move(config); return *this; } -inline TargetTagRegEntry& TargetTagRegEntry::with_config(String key, ffi::Any value) { +inline TargetTagRegEntry& TargetTagRegEntry::with_config(ffi::String key, ffi::Any value) { tag_->config.Set(key, value); return *this; } diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 678d36aeceda..78d4d102f431 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -51,15 +51,15 @@ class TargetNode : public Object { /*! \brief The kind of the target device */ TargetKind kind; /*! \brief Target host information, must be Target type */ - Optional host; + ffi::Optional host; /*! \brief Tag of the target, can be empty */ - String tag; + ffi::String tag; /*! \brief Keys for this target */ - Array keys; + ffi::Array keys; /*! \brief Collection of attributes */ - Map attrs; + ffi::Map attrs; /*! \brief Target features */ - Map features; + ffi::Map features; /*! * \brief The raw string representation of the target @@ -68,9 +68,9 @@ class TargetNode : public Object { */ TVM_DLL const std::string& str() const; /*! \return Export target to JSON-like configuration */ - TVM_DLL Map Export() const; - /*! \return The Optional typed target host of the TargetNode */ - TVM_DLL Optional GetHost() const; + TVM_DLL ffi::Map Export() const; + /*! \return The ffi::Optional typed target host of the TargetNode */ + TVM_DLL ffi::Optional GetHost() const; /*! \return The device type for this target */ TVM_DLL int GetTargetDeviceType() const; @@ -91,7 +91,7 @@ class TargetNode : public Object { * TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently * code depends on str() and << being the same. */ - String ToDebugString() const; + ffi::String ToDebugString() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -112,12 +112,12 @@ class TargetNode : public Object { * \return An optional, std::nullopt if not found, otherwise the value found */ template - Optional GetAttr( + ffi::Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { auto it = attrs.find(attr_key); if (it != attrs.end()) { - return Downcast>((*it).second); + return Downcast>((*it).second); } else { return default_value; } @@ -130,8 +130,8 @@ class TargetNode : public Object { * \return An optional, std::nullopt if not found, otherwise the value found */ template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! @@ -154,8 +154,9 @@ class TargetNode : public Object { * \endcode */ template - Optional GetFeature(const std::string& feature_key, - Optional default_value = std::nullopt) const { + ffi::Optional GetFeature( + const std::string& feature_key, + ffi::Optional default_value = std::nullopt) const { if (auto feature = features.Get(feature_key)) { return Downcast(feature.value()); } else { @@ -164,8 +165,9 @@ class TargetNode : public Object { } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetFeature(const std::string& attr_key, TObjectRef default_value) const { - return GetFeature(attr_key, Optional(default_value)); + ffi::Optional GetFeature(const std::string& attr_key, + TObjectRef default_value) const { + return GetFeature(attr_key, ffi::Optional(default_value)); } /*! \brief Get the keys for this target as a vector of string */ @@ -173,9 +175,8 @@ class TargetNode : public Object { /*! \brief Get the keys for this target as an unordered_set of string */ TVM_DLL std::unordered_set GetLibs() const; - static constexpr const char* _type_key = "target.Target"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.Target", TargetNode, Object); private: /*! \brief Internal string repr. */ @@ -196,12 +197,12 @@ class Target : public ObjectRef { * \brief Construct a Target given a string * \param tag_or_config_or_target_str the string to parse for target */ - TVM_DLL explicit Target(const String& tag_or_config_or_target_str); + TVM_DLL explicit Target(const ffi::String& tag_or_config_or_target_str); /*! * \brief Construct a Target using a JSON-like configuration * \param config The JSON-like configuration for target */ - TVM_DLL explicit Target(const Map& config); + TVM_DLL explicit Target(const ffi::Map& config); /*! * \brief Get the current target context from thread local storage. * \param allow_not_defined If the context stack is empty and this is set to true, an @@ -217,7 +218,7 @@ class Target : public ObjectRef { * \param host The Target typed object for target host */ TVM_DLL explicit Target(Target target, Target host); - TVM_DEFINE_OBJECT_REF_METHODS(Target, ObjectRef, TargetNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Target, ObjectRef, TargetNode); /*! * \brief Create a new Target object with given target (w.o host) and target host. * \param target The current Target typed object target, with or without host field. @@ -230,8 +231,8 @@ class Target : public ObjectRef { Target WithoutHost() const; private: - Target(TargetKind kind, Optional host, String tag, Array keys, - Map attrs); + Target(TargetKind kind, ffi::Optional host, ffi::String tag, + ffi::Array keys, ffi::Map attrs); // enable with syntax. friend class TargetInternal; diff --git a/include/tvm/target/target_info.h b/include/tvm/target/target_info.h index e1b4a1c7cd7d..c4e12ac532f8 100644 --- a/include/tvm/target/target_info.h +++ b/include/tvm/target/target_info.h @@ -57,16 +57,13 @@ class MemoryInfoNode : public Object { .def_ro("max_simd_bits", &MemoryInfoNode::max_simd_bits) .def_ro("head_address", &MemoryInfoNode::head_address); } - - static constexpr const char* _type_key = "target.MemoryInfo"; - - TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.MemoryInfo", MemoryInfoNode, Object); }; /*! \brief Defines memory info */ class MemoryInfo : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(MemoryInfo, ObjectRef, MemoryInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MemoryInfo, ObjectRef, MemoryInfoNode); }; /*! diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index d89148964bcd..7722211b3e61 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -41,7 +41,7 @@ class Target; /*! * \brief Map containing parsed features of a specific Target */ -using TargetFeatures = Map; +using TargetFeatures = ffi::Map; /*! * \brief TargetParser to apply on instantiation of a given TargetKind @@ -50,7 +50,7 @@ using TargetFeatures = Map; * * \return The transformed Target JSON object. */ -using TargetJSON = Map; +using TargetJSON = ffi::Map; using FTVMTargetParser = ffi::TypedFunction; namespace detail { @@ -67,11 +67,11 @@ class TargetKindAttrMap; class TargetKindNode : public Object { public: /*! \brief Name of the target kind */ - String name; + ffi::String name; /*! \brief Device type of target kind */ int default_device_type; /*! \brief Default keys of the target */ - Array default_keys; + ffi::Array default_keys; /*! \brief Function used to preprocess on target creation */ ffi::Function preprocessor; /*! \brief Function used to parse a JSON target during creation */ @@ -88,25 +88,24 @@ class TargetKindNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance; - static constexpr const char* _type_key = "target.TargetKind"; - TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.TargetKind", TargetKindNode, Object); private: /*! \brief Return the index stored in attr registry */ uint32_t AttrRegistryIndex() const { return index_; } /*! \brief Return the name stored in attr registry */ - String AttrRegistryName() const { return name; } + ffi::String AttrRegistryName() const { return name; } /*! \brief Stores the required type_key and type_index of a specific attr of a target */ struct ValueTypeInfo { - String type_key; + ffi::String type_key; int32_t type_index; std::unique_ptr key; std::unique_ptr val; }; /*! \brief A hash table that stores the type information of each attr of the target key */ - std::unordered_map key2vtype_; + std::unordered_map key2vtype_; /*! \brief A hash table that stores the default value of each attr of the target key */ - std::unordered_map key2default_; + std::unordered_map key2default_; /*! \brief Index used for internal lookup of attribute registry */ uint32_t index_; @@ -127,29 +126,32 @@ class TargetKindNode : public Object { class TargetKind : public ObjectRef { public: TargetKind() = default; + explicit TargetKind(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! \brief Get the attribute map given the attribute name */ template - static inline TargetKindAttrMap GetAttrMap(const String& attr_name); + static inline TargetKindAttrMap GetAttrMap(const ffi::String& attr_name); /*! * \brief Retrieve the TargetKind given its name * \param target_kind_name Name of the target kind * \return The TargetKind requested */ - TVM_DLL static Optional Get(const String& target_kind_name); + TVM_DLL static ffi::Optional Get(const ffi::String& target_kind_name); /*! \brief Mutable access to the container class */ TargetKindNode* operator->() { return static_cast(data_.get()); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TargetKind, ObjectRef, TargetKindNode); private: TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer( - const String& attr_name); + const ffi::String& attr_name); friend class TargetKindRegEntry; friend class TargetInternal; }; /*! - * \brief Map used to store meta-information about TargetKind + * \brief ffi::Map used to store meta-information about TargetKind * \tparam ValueType The type of the value stored in map */ template @@ -188,7 +190,7 @@ class TargetKindRegEntry { * \tparam ValueType The type of the value to be set. */ template - inline TargetKindRegEntry& set_attr(const String& attr_name, const ValueType& value, + inline TargetKindRegEntry& set_attr(const ffi::String& attr_name, const ValueType& value, int plevel = 10); /*! * \brief Set DLPack's device_type the target @@ -199,7 +201,7 @@ class TargetKindRegEntry { * \brief Set DLPack's device_type the target * \param keys The default keys */ - inline TargetKindRegEntry& set_default_keys(std::vector keys); + inline TargetKindRegEntry& set_default_keys(std::vector keys); /*! * \brief Set the pre-processing function applied upon target creation * \tparam FLambda Type of the function @@ -218,7 +220,7 @@ class TargetKindRegEntry { * \tparam ValueType The value type to be registered */ template - inline TargetKindRegEntry& add_attr_option(const String& key); + inline TargetKindRegEntry& add_attr_option(const ffi::String& key); /*! * \brief Register a valid configuration option and its ValueType for validation * \param key The configuration key @@ -226,33 +228,33 @@ class TargetKindRegEntry { * \tparam ValueType The value type to be registered */ template - inline TargetKindRegEntry& add_attr_option(const String& key, ffi::Any default_value); + inline TargetKindRegEntry& add_attr_option(const ffi::String& key, ffi::Any default_value); /*! \brief Set name of the TargetKind to be the same as registry if it is empty */ inline TargetKindRegEntry& set_name(); /*! * \brief List all the entry names in the registry. * \return The entry names. */ - TVM_DLL static Array ListTargetKinds(); + TVM_DLL static ffi::Array ListTargetKinds(); /*! * \brief Get all supported option names and types for a given Target kind. * \return Map of option name to type */ - TVM_DLL static Map ListTargetKindOptions(const TargetKind& kind); + TVM_DLL static ffi::Map ListTargetKindOptions(const TargetKind& kind); /*! * \brief Register or get a new entry. * \param target_kind_name The name of the TargetKind. * \return the corresponding entry. */ - TVM_DLL static TargetKindRegEntry& RegisterOrGet(const String& target_kind_name); + TVM_DLL static TargetKindRegEntry& RegisterOrGet(const ffi::String& target_kind_name); private: TargetKind kind_; - String name; + ffi::String name; /*! \brief private constructor */ - explicit TargetKindRegEntry(uint32_t reg_index) : kind_(make_object()) { + explicit TargetKindRegEntry(uint32_t reg_index) : kind_(ffi::make_object()) { kind_->index_ = reg_index; } /*! @@ -261,7 +263,7 @@ class TargetKindRegEntry { * \param value The value to be set * \param plevel The priority level */ - TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel); + TVM_DLL void UpdateAttr(const ffi::String& key, ffi::Any value, int plevel); template friend class AttrRegistry; friend class TargetKind; @@ -278,8 +280,9 @@ struct is_specialized, Container> : std::true_type { using type = std::true_type; }; -template ::type, - typename IsMap = typename is_specialized::type> +template ::type, + typename IsMap = typename is_specialized::type> struct ValueTypeInfoMaker {}; template @@ -295,7 +298,7 @@ struct ValueTypeInfoMaker { info.type_index = tindex; info.type_key = runtime::Object::TypeIndex2Key(tindex); return info; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { // special handle string since it can be backed by multiple types. info.type_index = ffi::TypeIndex::kTVMFFIStr; info.type_key = ffi::TypeTraits::TypeStr(); @@ -346,12 +349,12 @@ struct ValueTypeInfoMaker { } // namespace detail template -inline TargetKindAttrMap TargetKind::GetAttrMap(const String& attr_name) { +inline TargetKindAttrMap TargetKind::GetAttrMap(const ffi::String& attr_name) { return TargetKindAttrMap(GetAttrMapContainer(attr_name)); } template -inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const String& attr_name, +inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const ffi::String& attr_name, const ValueType& value, int plevel) { ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; ffi::Any rv; @@ -365,7 +368,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_default_device_type(int devic return *this; } -inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector keys) { +inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector keys) { kind_->default_keys = keys; return *this; } @@ -383,7 +386,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParse } template -inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) { +inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const ffi::String& key) { ICHECK(!kind_->key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key << "' has been set once"; kind_->key2vtype_[key] = detail::ValueTypeInfoMaker()(); @@ -391,7 +394,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key } template -inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key, +inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const ffi::String& key, Any default_value) { add_attr_option(key); kind_->key2default_[key] = default_value; @@ -420,8 +423,8 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { * TVM_REGISTER_TARGET_KIND("llvm") * .set_attr("TPreCodegenPass", a-pre-codegen-pass) * .add_attr_option("system_lib") - * .add_attr_option("mtriple") - * .add_attr_option("mattr"); + * .add_attr_option("mtriple") + * .add_attr_option("mattr"); * * \endcode */ @@ -430,11 +433,11 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName) \ .set_name() \ .set_default_device_type(DeviceType) \ - .add_attr_option>("keys") \ - .add_attr_option("tag") \ - .add_attr_option("device") \ - .add_attr_option("model") \ - .add_attr_option>("libs") \ + .add_attr_option>("keys") \ + .add_attr_option("tag") \ + .add_attr_option("device") \ + .add_attr_option("model") \ + .add_attr_option>("libs") \ .add_attr_option("host") \ .add_attr_option("from_device") \ .add_attr_option("target_device_type") diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index aabd3a2ecaf2..ebe5eb39f580 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -39,10 +39,10 @@ namespace tvm { * Abstract label for an area of memory. * * Currently uninterpreted and arbitrary. Likely to be replaced by a structured representation - * of a memory pool in the future. Please try to use this alias instead of String to aid future + * of a memory pool in the future. Please try to use this alias instead of ffi::String to aid future * code migration. */ -using MemoryScope = String; +using MemoryScope = ffi::String; // NOTE: cannot use enum as they are out of bound of the original enum // and results in an undefined behavior @@ -257,9 +257,7 @@ class VirtualDeviceNode : public AttrsNodeReflAdapter { "The area of memory w.r.t. the virtual device where data is stored.", refl::DefaultValue("")); } - - static constexpr const char* _type_key = "target.VirtualDevice"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(VirtualDeviceNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("target.VirtualDevice", VirtualDeviceNode, BaseAttrsNode); friend class VirtualDevice; }; @@ -333,7 +331,7 @@ class VirtualDevice : public ObjectRef { * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such * join exists, ie there's disagreement on at least one constrained field. */ - static Optional Join(const VirtualDevice& lhs, const VirtualDevice& rhs); + static ffi::Optional Join(const VirtualDevice& lhs, const VirtualDevice& rhs); /*! * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any @@ -341,7 +339,7 @@ class VirtualDevice : public ObjectRef { */ static VirtualDevice Default(const VirtualDevice& lhs, const VirtualDevice& rhs); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VirtualDevice, ObjectRef, VirtualDeviceNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(VirtualDevice, ObjectRef, VirtualDeviceNode); friend class VirtualDeviceCache; // Private implementation helper. }; diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 6c1ea6195f5e..17de92c8be36 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -60,7 +60,7 @@ class TVM_DLL OperationNode : public Object { /*! \brief optional tag of the operation */ std::string tag; /*! \brief additional attributes of the operation*/ - Map attrs; + ffi::Map attrs; // virtual destructor. virtual ~OperationNode() {} /*! \return number of outputs */ @@ -76,12 +76,12 @@ class TVM_DLL OperationNode : public Object { * \param i The output index. * \return shape of i-th output. */ - virtual Array output_shape(size_t i) const = 0; + virtual ffi::Array output_shape(size_t i) const = 0; /*! * \brief List all the input Tensors. * \return List of input tensors. */ - virtual Array InputTensors() const = 0; + virtual ffi::Array InputTensors() const = 0; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -90,10 +90,7 @@ class TVM_DLL OperationNode : public Object { .def_ro("tag", &OperationNode::tag) .def_ro("attrs", &OperationNode::attrs); } - - static constexpr const char* _type_key = "te.Operation"; - - TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("te.Operation", OperationNode, Object); }; /*! @@ -102,14 +99,14 @@ class TVM_DLL OperationNode : public Object { class PlaceholderOpNode : public OperationNode { public: /*! \brief The shape of the input */ - Array shape; + ffi::Array shape; /*! \brief The data type of the input. */ DataType dtype; // override behavior. int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; - Array InputTensors() const final; + ffi::Array output_shape(size_t i) const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -117,10 +114,7 @@ class PlaceholderOpNode : public OperationNode { .def_ro("shape", &PlaceholderOpNode::shape) .def_ro("dtype", &PlaceholderOpNode::dtype); } - - static constexpr const char* _type_key = "te.PlaceholderOp"; - - TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); + TVM_FFI_DECLARE_OBJECT_INFO("te.PlaceholderOp", PlaceholderOpNode, OperationNode); }; /*! @@ -129,9 +123,9 @@ class PlaceholderOpNode : public OperationNode { */ class PlaceholderOp : public Operation { public: - TVM_DLL PlaceholderOp(std::string name, Array shape, DataType dtype); + TVM_DLL PlaceholderOp(std::string name, ffi::Array shape, DataType dtype); - TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PlaceholderOp, Operation, PlaceholderOpNode); }; /*! @@ -141,11 +135,11 @@ class PlaceholderOp : public Operation { class TVM_DLL BaseComputeOpNode : public OperationNode { public: /*! \brief IterVar on each axis */ - Array axis; + ffi::Array axis; /*! \brief IterVar on each reduction axis, if the body is a Reduce */ - Array reduce_axis; + ffi::Array reduce_axis; // override functions - Array output_shape(size_t idx) const final; + ffi::Array output_shape(size_t idx) const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -153,10 +147,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { .def_ro("axis", &BaseComputeOpNode::axis) .def_ro("reduce_axis", &BaseComputeOpNode::reduce_axis); } - - static constexpr const char* _type_key = "te.BaseComputeOp"; - - TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); + TVM_FFI_DECLARE_OBJECT_INFO("te.BaseComputeOp", BaseComputeOpNode, OperationNode); }; /*! @@ -165,22 +156,19 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { class TVM_DLL ComputeOpNode : public BaseComputeOpNode { public: /*! \brief the compute expression */ - Array body; + ffi::Array body; /*! \brief constructor */ ComputeOpNode() {} // override functions int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array InputTensors() const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("body", &ComputeOpNode::body); } - - static constexpr const char* _type_key = "te.ComputeOp"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.ComputeOp", ComputeOpNode, BaseComputeOpNode); }; /*! @@ -189,10 +177,10 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { */ class ComputeOp : public Operation { public: - TVM_DLL ComputeOp(std::string name, std::string tag, Map attrs, - Array axis, Array body); + TVM_DLL ComputeOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array axis, ffi::Array body); - TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ComputeOp, Operation, ComputeOpNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode); }; @@ -204,16 +192,16 @@ class ScanOpNode : public OperationNode { /*! \brief IterVar to scan over */ IterVar scan_axis; /*! \brief the initialization tensors */ - Array init; + ffi::Array init; /*! \brief the update function represented by tensor */ - Array update; + ffi::Array update; /*! \brief The placeholder to refer as states in update. */ - Array state_placeholder; + ffi::Array state_placeholder; /*! * \brief the inputs to the scan, these are optionally provided * But they can be helpful to provide hints to speedup get of scan body. */ - Array inputs; + ffi::Array inputs; /*! * \brief Spatial axis to indicate spatial dimension of each output. * They corresponds to flattened spatial axis of the outputs. @@ -223,14 +211,14 @@ class ScanOpNode : public OperationNode { * They do not corresponds to splittable iterations, thus the name comes * with underscore. */ - Array spatial_axis_; + ffi::Array spatial_axis_; /*! \brief constructor */ ScanOpNode() {} // override behavior. int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; - Array InputTensors() const final; + ffi::Array output_shape(size_t i) const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -242,10 +230,7 @@ class ScanOpNode : public OperationNode { .def_ro("inputs", &ScanOpNode::inputs) .def_ro("spatial_axis_", &ScanOpNode::spatial_axis_); } - - static constexpr const char* _type_key = "te.ScanOp"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.ScanOp", ScanOpNode, OperationNode); }; /*! @@ -254,11 +239,12 @@ class ScanOpNode : public OperationNode { */ class ScanOp : public Operation { public: - TVM_DLL ScanOp(std::string name, std::string tag, Optional> attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array input); + TVM_DLL ScanOp(std::string name, std::string tag, + ffi::Optional> attrs, IterVar axis, + ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array input); - TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScanOp, Operation, ScanOpNode); }; /*! @@ -267,11 +253,11 @@ class ScanOp : public Operation { class ExternOpNode : public OperationNode { public: /*! \brief The input tensors */ - Array inputs; + ffi::Array inputs; /*! \brief Symbolic placeholder representation of inputs */ - Array input_placeholders; + ffi::Array input_placeholders; /*! \brief Symbolic placeholder representation of outputs */ - Array output_placeholders; + ffi::Array output_placeholders; /*! \brief the statement that generates the computation. */ Stmt body; @@ -280,8 +266,8 @@ class ExternOpNode : public OperationNode { // override functions int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; - Array InputTensors() const final; + ffi::Array output_shape(size_t i) const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -291,10 +277,7 @@ class ExternOpNode : public OperationNode { .def_ro("output_placeholders", &ExternOpNode::output_placeholders) .def_ro("body", &ExternOpNode::body); } - - static constexpr const char* _type_key = "te.ExternOp"; - - TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.ExternOp", ExternOpNode, OperationNode); }; /*! @@ -303,11 +286,11 @@ class ExternOpNode : public OperationNode { */ class ExternOp : public Operation { public: - TVM_DLL ExternOp(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body); + TVM_DLL ExternOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array inputs, ffi::Array input_placeholders, + ffi::Array output_placeholders, Stmt body); - TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternOp, Operation, ExternOpNode); }; /*! @@ -334,10 +317,10 @@ TVM_DLL IterVar thread_axis(Range dom, std::string tag); TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); /*! \brief The compute function to specify the input source of a Tensor */ -using FCompute = std::function& i)>; +using FCompute = std::function& i)>; /*! \brief The compute function to specify the inputs source of Tensors */ -using FBatchCompute = std::function(const Array& i)>; +using FBatchCompute = std::function(const ffi::Array& i)>; /*! * \brief create a place holder tensor. @@ -345,7 +328,7 @@ using FBatchCompute = std::function(const Array& i)>; * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Float(32), +TVM_DLL Tensor placeholder(ffi::Array shape, DataType dtype = DataType::Float(32), std::string name = "placeholder"); /*! @@ -357,8 +340,8 @@ TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Flo * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", - std::string tag = "", Map attrs = {}); +TVM_DLL Tensor compute(ffi::Array shape, FCompute fcompute, std::string name = "tensor", + std::string tag = "", ffi::Map attrs = {}); /*! * \brief Construct a new tensor by computing over shape, @@ -369,9 +352,9 @@ TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string nam * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array compute(Array shape, FBatchCompute fcompute, - std::string name = "tensor", std::string tag = "", - Map attrs = {}); +TVM_DLL ffi::Array compute(ffi::Array shape, FBatchCompute fcompute, + std::string name = "tensor", std::string tag = "", + ffi::Map attrs = {}); /*! * \brief Construct new tensors by scan. @@ -385,34 +368,35 @@ TVM_DLL Array compute(Array shape, FBatchCompute fcompute, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array scan(Array init, Array update, - Array state_placeholder, Array inputs = Array(), - std::string name = "scan", std::string tag = "", - Map attrs = {}); +TVM_DLL ffi::Array scan(ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, + ffi::Array inputs = ffi::Array(), + std::string name = "scan", std::string tag = "", + ffi::Map attrs = {}); // same as compute, specialized for different fcompute function -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0], i[1]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0], i[1], i[2]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2], i[3]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0], i[1], i[2], i[3]); }; return compute(shape, fc, name, tag, attrs); } diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index f45a96df63d8..501b5b062b52 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -50,6 +50,7 @@ class Operation : public ObjectRef { /*! \brief default constructor */ Operation() {} explicit Operation(ObjectPtr n) : ObjectRef(n) {} + explicit Operation(ffi::UnsafeInit tag) : ObjectRef(tag) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -69,7 +70,7 @@ class Operation : public ObjectRef { class TensorNode : public DataProducerNode { public: /*! \brief The shape of the tensor */ - Array shape; + ffi::Array shape; /*! \brief data type in the content of the tensor */ DataType dtype; /*! \brief the source operation, can be None */ @@ -79,18 +80,17 @@ class TensorNode : public DataProducerNode { static void RegisterReflection(); - Array GetShape() const final { return shape; } + ffi::Array GetShape() const final { return shape; } DataType GetDataType() const final { return dtype; } TVM_DLL PrimExpr ToPrimExpr() const final; - TVM_DLL String GetNameHint() const final; + TVM_DLL ffi::String GetNameHint() const final; - static constexpr const char* _type_key = "te.Tensor"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.Tensor", TensorNode, DataProducerNode); }; /*! @@ -105,10 +105,10 @@ class Tensor : public DataProducer { * \param support_negative_indices Whether to normalize indices in the case of negative indices. * \return the result expression representing tensor read. */ - inline PrimExpr IndexTensor(Array indices, bool support_negative_indices) const; + inline PrimExpr IndexTensor(ffi::Array indices, bool support_negative_indices) const; public: - TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); + TVM_DLL Tensor(ffi::Array shape, DataType dtype, Operation op, int value_index); /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. @@ -130,7 +130,7 @@ class Tensor : public DataProducer { */ template inline PrimExpr operator()(Args&&... args) const { - Array indices{std::forward(args)...}; + ffi::Array indices{std::forward(args)...}; return operator()(indices); } /*! @@ -138,13 +138,13 @@ class Tensor : public DataProducer { * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(ffi::Array indices) const; /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(ffi::Array indices) const; /*! * \brief Take elements from the tensor with support for negative indices. * \param args The indices @@ -152,7 +152,7 @@ class Tensor : public DataProducer { */ template TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const { - Array indices{std::forward(args)...}; + ffi::Array indices{std::forward(args)...}; return IndexWithNegativeIndices(indices); } /*! @@ -160,13 +160,13 @@ class Tensor : public DataProducer { * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + TVM_DLL PrimExpr IndexWithNegativeIndices(ffi::Array indices) const; /*! * \brief Take elements from the tensor with support for negative indices. * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + TVM_DLL PrimExpr IndexWithNegativeIndices(ffi::Array indices) const; /*! * \brief data structure to represent a slice that fixes first k coordinates. @@ -205,7 +205,7 @@ class Tensor : public DataProducer { */ inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } - TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, DataProducer, TensorNode); }; // Implementations of inline functions diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index a21112b7d6f6..0f4b6afd62fb 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -99,14 +99,14 @@ TVM_DLL double EstimateTIRFlops(const IRModule& mod); * \param defs The vars that is defined. * \return Array of undefined vars. */ -TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); +TVM_DLL ffi::Array UndefinedVars(const Stmt& stmt, const ffi::Array& defs); /*! * \brief Find undefined vars in the expression. * \param expr The expression to be checked. * \return Array of undefined vars. */ -TVM_DLL Array UndefinedVars(const PrimExpr& expr); +TVM_DLL ffi::Array UndefinedVars(const PrimExpr& expr); /*! * \brief Find undefined vars in the expression. @@ -114,7 +114,7 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); * \param defs The vars that is defined. * \return Array of undefined vars. */ -TVM_DLL Array UndefinedVars(const PrimExpr& expr, const Array& defs); +TVM_DLL ffi::Array UndefinedVars(const PrimExpr& expr, const ffi::Array& defs); /*! * \brief Analyze the side effect of an expression @@ -195,7 +195,7 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); * \return valid Whether it is a valid GPU code * */ -TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); +TVM_DLL bool VerifyGPUCode(const PrimFunc& func, ffi::Map constraints); /** * @brief Utility function to get the list of lowering passes to be applied to calculate the @@ -203,7 +203,7 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constrain * * @return returns list of passes */ -TVM_DLL Array GetVTCMCompactionPasses(); +TVM_DLL ffi::Array GetVTCMCompactionPasses(); /*! * \brief Verifies that the VTCM usage for all prim_funcs in the given IRModule @@ -233,8 +233,8 @@ TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit); * - second: write regions * - third: opaque regions */ -TVM_DLL Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL ffi::Array> GetBlockAccessRegion( + const Block& block, const ffi::Map& buffer_var_map); /*! * \brief Auto detect the block read/write region according to its body stmt. An opaque access will @@ -244,8 +244,8 @@ TVM_DLL Array> GetBlockAccessRegion(const Block& block, * It is a map from buffer var to the buffer * \return An array only consisting of the read regions and write regions of the input block */ -TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL ffi::Array> GetBlockReadWriteRegion( + const Block& block, const ffi::Map& buffer_var_map); /*! \brief Helper struct for return value of IdentifyMemCpy * @@ -298,7 +298,8 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, * \return Allocated memory size per scope in bytes inside the PrimFunc returned as a Map with * key "main" and a Map of allocated sizes as values. */ -TVM_DLL tvm::Map> CalculateAllocatedBytes(const PrimFunc& func); +TVM_DLL tvm::ffi::Map> CalculateAllocatedBytes( + const PrimFunc& func); /*! * \brief Calculate the allocated memory per scope in bytes for each function inside the module @@ -306,7 +307,8 @@ TVM_DLL tvm::Map> CalculateAllocatedBytes(cons * \return Allocated memory size per scope in bytes for each function in the IRModule returned as a Map with function names as keys and a Map of allocated sizes as values. */ -TVM_DLL tvm::Map> CalculateAllocatedBytes(const IRModule& mod); +TVM_DLL tvm::ffi::Map> CalculateAllocatedBytes( + const IRModule& mod); /*! * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level @@ -316,7 +318,7 @@ TVM_DLL tvm::Map> CalculateAllocatedBytes(cons * \return The Map from buffer to the LCA of all access to it. The lca is function root if the * return stmt is std::nullopt. */ -TVM_DLL Map> DetectBufferAccessLCA(const PrimFunc& func); +TVM_DLL ffi::Map> DetectBufferAccessLCA(const PrimFunc& func); /*! * \brief Verify if the given TIR is well-formed. The verification includes: @@ -410,7 +412,7 @@ TVM_DLL Pass VerifyMemory(); * \returns The pass. * \sa tvm::tir::VerifyGPUCode */ -TVM_DLL Pass VerifyGPUCode(Map constraints); +TVM_DLL Pass VerifyGPUCode(ffi::Map constraints); /*! * \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit @@ -421,7 +423,7 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); * \returns The pass. * \sa tvm::tir::CalculateAllocatedBytes */ -TVM_DLL Pass VerifyVTCMLimit(Optional target = std::nullopt); +TVM_DLL Pass VerifyVTCMLimit(ffi::Optional target = std::nullopt); /*! * \brief Statically check TIR code for out of bounds array access. diff --git a/include/tvm/tir/block_dependence_info.h b/include/tvm/tir/block_dependence_info.h index 7b00894ea805..b1fd8998645a 100644 --- a/include/tvm/tir/block_dependence_info.h +++ b/include/tvm/tir/block_dependence_info.h @@ -65,9 +65,7 @@ class BlockDependenceInfoNode : public Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "tir.BlockDependenceInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockDependenceInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockDependenceInfo", BlockDependenceInfoNode, Object); /*! * \brief Get the BlockScope corresponding to the sref of scope root block @@ -78,7 +76,7 @@ class BlockDependenceInfoNode : public Object { auto it = sref2scope.find(scope_root); CHECK(it != sref2scope.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" - << GetRef(scope_root->stmt); + << ffi::GetRef(scope_root->stmt); return it->second; } }; @@ -97,8 +95,8 @@ class BlockDependenceInfo : public ObjectRef { */ TVM_DLL BlockDependenceInfo(IRModule mod); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockDependenceInfo, ObjectRef, - BlockDependenceInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockDependenceInfo, ObjectRef, + BlockDependenceInfoNode); }; } // namespace tir diff --git a/include/tvm/tir/block_scope.h b/include/tvm/tir/block_scope.h index 9ea77d7b9b46..f1120c7837ff 100644 --- a/include/tvm/tir/block_scope.h +++ b/include/tvm/tir/block_scope.h @@ -72,8 +72,8 @@ class StmtSRefNode : public Object { refl::ObjectDef().def_ro("seq_index", &StmtSRefNode::seq_index); } - static constexpr const char* _type_key = "tir.StmtSRef"; - TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.StmtSRef", StmtSRefNode, Object); /*! \brief Reset the object inplace to the invalid state */ void Reset() { @@ -114,10 +114,7 @@ class StmtSRef : public ObjectRef { */ TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index); - /*! \return The mutable pointer to the StmtSRefNode */ - StmtSRefNode* get() const { return static_cast(data_.get()); } - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StmtSRef, ObjectRef, StmtSRefNode); public: /*! @@ -226,9 +223,7 @@ class DependencyNode : public Object { .def_ro("dst", &DependencyNode::dst) .def_ro("kind", &DependencyNode::kind); } - - static constexpr const char* _type_key = "tir.Dependency"; - TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Dependency", DependencyNode, Object); }; /*! @@ -239,7 +234,7 @@ class Dependency : public ObjectRef { public: /*! \brief Constructor */ TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Dependency, ObjectRef, DependencyNode); }; /*! @@ -262,18 +257,17 @@ class BlockScopeNode : public Object { * \note We intentionally didn't use tvm::Map as the data structure, because we need the values * inside to be mutable so that they could be further maintained properly during transformations. */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> src2deps; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> src2deps; /*! \brief Lookup table for the `dst` of dependencies */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dst2deps; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dst2deps; /*! \brief The mapping from the buffer to the blocks who write it */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; static void RegisterReflection() { - // No fields to register as they are not visited + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - - static constexpr const char* _type_key = "tir.BlockScope"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockScope", BlockScopeNode, Object); public: /******** Dependency ********/ @@ -282,13 +276,13 @@ class BlockScopeNode : public Object { * \param src The queried block * \return The dependencies */ - TVM_DLL Array GetDepsBySrc(const StmtSRef& src) const; + TVM_DLL ffi::Array GetDepsBySrc(const StmtSRef& src) const; /*! * \brief Get all dependencies whose `dst` equals `dst` * \param dst The queried block * \return The dependencies */ - TVM_DLL Array GetDepsByDst(const StmtSRef& dst) const; + TVM_DLL ffi::Array GetDepsByDst(const StmtSRef& dst) const; }; /*! @@ -297,6 +291,13 @@ class BlockScopeNode : public Object { */ class BlockScope : public ObjectRef { public: + /*! + * \brief Constructor from ObjectPtr. + * \param data The object pointer. + */ + explicit BlockScope(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! \brief The constructor creating an empty block scope with on dependency information */ TVM_DLL BlockScope(); /*! @@ -305,9 +306,9 @@ class BlockScope : public ObjectRef { * \param child_block_srefs The srefs to the leaf blocks * \note We assume the leaf blocks are given in pre-DFS order */ - TVM_DLL explicit BlockScope(const Array& child_block_srefs); + TVM_DLL explicit BlockScope(const ffi::Array& child_block_srefs); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockScope, ObjectRef, BlockScopeNode); }; } // namespace tir diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 3cc988f49e38..1075693bb541 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -75,7 +75,7 @@ class BufferNode : public Object { * BufferLoad/BufferStore nodes, and used by the low-level code * generators. */ - Array shape; + ffi::Array shape; /*! * \brief Separators between input axes when generating flattened output axes * @@ -84,17 +84,17 @@ class BufferNode : public Object { * non-flat memory, each entry in axis_separators should be the * first input axis that is part of a new flattened axis. */ - Array axis_separators; + ffi::Array axis_separators; /*! * \brief The strides of each dimension * This can be an empty array, indicating array is contiguous */ - Array strides; + ffi::Array strides; /*! \brief The offset in terms of number of dtype elements (including lanes) */ PrimExpr elem_offset; // Meta data /*! \brief optional name of the buffer */ - String name; + ffi::String name; /*! \brief Alignment requirement of data pointer in bytes. */ int data_alignment; /*! @@ -140,12 +140,11 @@ class BufferNode : public Object { * without adjusting for number of lanes. (e.g. The number of * float16x4 elements in a buffer of type float16x4.) */ - Array ElemOffset(Array index) const; + ffi::Array ElemOffset(ffi::Array index) const; - static constexpr const char* _type_key = "tir.Buffer"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Buffer", BufferNode, Object); TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); }; @@ -158,9 +157,10 @@ class Buffer : public ObjectRef { public: // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Array axis_separators = {}, Span span = Span()); + TVM_DLL Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array strides, + PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, + BufferType buffer_type, ffi::Array axis_separators = {}, + Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -176,7 +176,7 @@ class Buffer : public ObjectRef { * If stride is not needed in the slice, it won't be presented * \return the result buffer. */ - TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; + TVM_DLL Buffer MakeSlice(ffi::Array begins, ffi::Array extents) const; /*! * \brief Get access ptr to the entire buffer. * \param access_mask The access mask @@ -187,7 +187,7 @@ class Buffer : public ObjectRef { */ TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0), - Optional input_extent = std::nullopt) const; + ffi::Optional input_extent = std::nullopt) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index @@ -195,8 +195,8 @@ class Buffer : public ObjectRef { * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * loaded. The number lanes of the mask must be equal to the number of lanes in being loaded. */ - TVM_DLL PrimExpr vload(Array begin, DataType dtype, - Optional predicate = std::nullopt) const; + TVM_DLL PrimExpr vload(ffi::Array begin, DataType dtype, + ffi::Optional predicate = std::nullopt) const; /*! * \brief Create a Stmt that does a vector store at begin index. * \param begin The beginning index @@ -204,8 +204,8 @@ class Buffer : public ObjectRef { * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * stored. The number lanes of the mask must be equal to the number of lanes in value. */ - TVM_DLL Stmt vstore(Array begin, PrimExpr value, - Optional predicate = std::nullopt) const; + TVM_DLL Stmt vstore(ffi::Array begin, PrimExpr value, + ffi::Optional predicate = std::nullopt) const; /*! * \brief Get a flattened version of the buffer @@ -218,14 +218,14 @@ class Buffer : public ObjectRef { * without adjusting for number of lanes. (e.g. The number of * float16x4 elements in a buffer of type float16x4.) */ - Array OffsetOf(Array index) const; + ffi::Array OffsetOf(ffi::Array index) const; /*! * \brief Return the storage scope associated with this buffer. */ - TVM_DLL String scope() const; + TVM_DLL ffi::String scope() const; - TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Buffer, ObjectRef, BufferNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); }; @@ -240,9 +240,9 @@ class Buffer : public ObjectRef { * \return The created buffer. * \sa Buffer for complete constructor. */ -TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", - Optional> axis_separators = std::nullopt, +TVM_DLL Buffer decl_buffer(ffi::Array shape, DataType dtype = DataType::Float(32), + ffi::String name = "buffer", ffi::String storage_scope = "", + ffi::Optional> axis_separators = std::nullopt, Span span = Span()); /*! @@ -265,7 +265,7 @@ class DataProducerNode : public PrimExprConvertibleNode { * \brief Get the shape of the result. * \return The shape. */ - virtual Array GetShape() const = 0; + virtual ffi::Array GetShape() const = 0; /*! * \brief Get the data type of the result. * \return The data type. @@ -275,10 +275,8 @@ class DataProducerNode : public PrimExprConvertibleNode { * \brief Get the name hint of the data producer. * \return The data type. */ - virtual String GetNameHint() const = 0; - - static constexpr const char* _type_key = "tir.DataProducer"; - TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, PrimExprConvertibleNode); + virtual ffi::String GetNameHint() const = 0; + TVM_FFI_DECLARE_OBJECT_INFO("tir.DataProducer", DataProducerNode, PrimExprConvertibleNode); }; /*! @@ -287,7 +285,7 @@ class DataProducerNode : public PrimExprConvertibleNode { */ class DataProducer : public PrimExprConvertible { public: - TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, PrimExprConvertible, DataProducerNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataProducer, PrimExprConvertible, DataProducerNode); }; /*! @@ -303,7 +301,7 @@ class DataProducer : public PrimExprConvertible { * \param compact If the statement has already bound to a compact buffer. * \param memory_scope memory scope of the buffer */ -TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, +TVM_DLL tir::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope = ""); diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index d3573c925daf..e7b8cac9be15 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -49,6 +49,14 @@ TVM_DLL const Op& ret(); * \brief Return from a GPU thread. */ TVM_DLL const Op& thread_return(); +/*! + * \brief Loop continue. + */ +TVM_DLL const Op& continue_loop(); +/*! + * \brief Loop break. + */ +TVM_DLL const Op& break_loop(); /*! * \brief Reinterpret the value using the target type. */ @@ -145,6 +153,11 @@ TVM_DLL const Op& isnullptr(); */ TVM_DLL const Op& isnan(); +/*! + * \brief Check if value is finite + */ +TVM_DLL const Op& isfinite(); + /*! * \brief Popcount */ @@ -298,7 +311,7 @@ TVM_DLL const Op& tvm_struct_set(); /*! * \brief See pseudo code - * Type lookup_param(String param_name) { + * Type lookup_param(ffi::String param_name) { * return __tvm_param__param_name; * } */ @@ -337,7 +350,7 @@ TVM_DLL const Op& tvm_stack_alloca(); TVM_DLL const Op& tvm_stack_make_shape(); /*! - * \brief Allocate a NDArray(DLTensor) on stack, return the handle. + * \brief Allocate a Tensor(DLTensor) on stack, return the handle. * * Type tvm_stack_make_array(Expr data, * Expr shape, diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 1395c2b6817b..4f2a4452b89f 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -99,14 +99,14 @@ class LayoutAxis { class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ - String name; + ffi::String name; /*! \brief specify each axis of the layout, * in which the variable name is the name of the axis. * The IterVar's extent indicates the size of the axis, * it is a variable for a primal axis, but a constant for a subordinate axis. * Empty for scalar's layout. */ - Array axes; + ffi::Array axes; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -114,9 +114,7 @@ class LayoutNode : public Object { .def_ro("name", &LayoutNode::name) .def_ro("axes", &LayoutNode::axes); } - - static constexpr const char* _type_key = "tir.Layout"; - TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Layout", LayoutNode, Object); }; /*! @@ -125,10 +123,10 @@ class LayoutNode : public Object { */ class Layout : public ObjectRef { public: - explicit Layout(const Array& axes); + explicit Layout(const ffi::Array& axes); /*! \brief construct from a string */ - Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) + Layout(const tvm::ffi::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) /*! \brief construct from a string */ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) @@ -291,7 +289,7 @@ class Layout : public ObjectRef { return os; } - TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode); }; // Internal node container BijectiveLayout @@ -300,13 +298,13 @@ class BijectiveLayoutNode : public Object { /*! \brief Describes how source axes can be mapped to the destination axes, * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n */ - Array index_forward_rule; + ffi::Array index_forward_rule; /*! \brief Describes how destination axes can be mapped to the source axes */ - Array index_backward_rule; + ffi::Array index_backward_rule; /*! \brief Describes how source shapes can be mapped to the destination shapes */ - Array shape_forward_rule; + ffi::Array shape_forward_rule; /*! \brief Describes how destination shapes can be mapped to the source shapes */ - Array shape_backward_rule; + ffi::Array shape_backward_rule; /*! \brief The source layout */ Layout src_layout; @@ -323,9 +321,7 @@ class BijectiveLayoutNode : public Object { .def_ro("shape_forward_rule", &BijectiveLayoutNode::shape_forward_rule) .def_ro("shape_backward_rule", &BijectiveLayoutNode::shape_backward_rule); } - - static constexpr const char* _type_key = "tir.BijectiveLayout"; - TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BijectiveLayout", BijectiveLayoutNode, Object); }; /*! @@ -344,15 +340,15 @@ class BijectiveLayout : public ObjectRef { TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout); // Given the source shape, infer the destination shape. - TVM_DLL Array ForwardShape(const Array& shape) const; + TVM_DLL ffi::Array ForwardShape(const ffi::Array& shape) const; // Given the destination shape, recover the source shape. - TVM_DLL Array BackwardShape(const Array& dst_shape) const; + TVM_DLL ffi::Array BackwardShape(const ffi::Array& dst_shape) const; // Given the destination indices, infer the destination indices. - TVM_DLL Array ForwardIndex(const Array& index) const; + TVM_DLL ffi::Array ForwardIndex(const ffi::Array& index) const; // Given the destination indices, recover the source indices. - TVM_DLL Array BackwardIndex(const Array& dst_index) const; + TVM_DLL ffi::Array BackwardIndex(const ffi::Array& dst_index) const; - TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BijectiveLayout, ObjectRef, BijectiveLayoutNode); }; } // namespace tir diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index a9185e97af69..88398cf06f06 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -106,7 +106,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; - Array VisitIndices(Array indices); + ffi::Array VisitIndices(ffi::Array indices); Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; @@ -124,7 +124,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Buffer VisitBuffer(const Buffer& buffer); Buffer GetRemappedBuffer(const Buffer& buffer); - Map VisitBlockAnnotations(const Map& annotations); + ffi::Map VisitBlockAnnotations( + const ffi::Map& annotations); BufferRegion VisitBufferRegion(const BufferRegion& region); IterVar VisitIterVar(const IterVar& iter_var); // indicator of index expr to rewrite @@ -132,7 +133,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { // indicator of condition bool is_condition_{false}; - Map buffer_remap_; + ffi::Map buffer_remap_; }; /*! diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 1b419b569311..b615ab503522 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -49,19 +49,17 @@ namespace tir { using IntImmNode = tvm::IntImmNode; using FloatImmNode = tvm::FloatImmNode; -/*! \brief String constants, only used in asserts. */ +/*! \brief ffi::String constants, only used in asserts. */ class StringImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ - String value; + ffi::String value; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &StringImmNode::value); } - - static constexpr const char* _type_key = "tir.StringImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.StringImm", StringImmNode, PrimExprNode); }; /*! @@ -70,8 +68,8 @@ class StringImmNode : public PrimExprNode { */ class StringImm : public PrimExpr { public: - TVM_DLL StringImm(String value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); + TVM_DLL StringImm(ffi::String value, Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, PrimExpr, StringImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); }; @@ -88,9 +86,7 @@ class CastNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &CastNode::value); } - - static constexpr const char* _type_key = "tir.Cast"; - TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Cast", CastNode, PrimExprNode); }; /*! @@ -100,7 +96,7 @@ class CastNode : public PrimExprNode { class Cast : public PrimExpr { public: TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode); }; @@ -121,7 +117,9 @@ class BinaryOpNode : public PrimExprNode { refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); + static const constexpr int _type_child_slots [[maybe_unused]] = 0; + static const constexpr bool _type_final [[maybe_unused]] = true; + TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode); }; /*! \brief a + b */ @@ -137,7 +135,7 @@ class AddNode : public BinaryOpNode { class Add : public PrimExpr { public: TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Add, PrimExpr, AddNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode); }; @@ -155,7 +153,7 @@ class Sub : public PrimExpr { public: TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sub, PrimExpr, SubNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode); }; @@ -172,7 +170,7 @@ class MulNode : public BinaryOpNode { class Mul : public PrimExpr { public: TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mul, PrimExpr, MulNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode); }; @@ -192,7 +190,7 @@ class DivNode : public BinaryOpNode { class Div : public PrimExpr { public: TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Div, PrimExpr, DivNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode); }; @@ -212,7 +210,7 @@ class ModNode : public BinaryOpNode { class Mod : public PrimExpr { public: TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mod, PrimExpr, ModNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode); }; @@ -229,7 +227,7 @@ class FloorDivNode : public BinaryOpNode { class FloorDiv : public PrimExpr { public: TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorDiv, PrimExpr, FloorDivNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode); }; @@ -246,7 +244,7 @@ class FloorModNode : public BinaryOpNode { class FloorMod : public PrimExpr { public: TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorMod, PrimExpr, FloorModNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode); }; @@ -263,7 +261,7 @@ class MinNode : public BinaryOpNode { class Min : public PrimExpr { public: TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Min, PrimExpr, MinNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode); }; @@ -280,7 +278,7 @@ class MaxNode : public BinaryOpNode { class Max : public PrimExpr { public: TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Max, PrimExpr, MaxNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode); }; @@ -301,7 +299,9 @@ class CmpOpNode : public PrimExprNode { refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); + static const constexpr int _type_child_slots [[maybe_unused]] = 0; + static const constexpr bool _type_final [[maybe_unused]] = true; + TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode); }; /*! \brief a == b */ @@ -317,7 +317,7 @@ class EQNode : public CmpOpNode { class EQ : public PrimExpr { public: TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EQ, PrimExpr, EQNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode); }; @@ -334,7 +334,7 @@ class NENode : public CmpOpNode { class NE : public PrimExpr { public: TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NE, PrimExpr, NENode); TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode); }; @@ -351,7 +351,7 @@ class LTNode : public CmpOpNode { class LT : public PrimExpr { public: TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LT, PrimExpr, LTNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode); }; @@ -368,7 +368,7 @@ struct LENode : public CmpOpNode { class LE : public PrimExpr { public: TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LE, PrimExpr, LENode); TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode); }; @@ -385,7 +385,7 @@ class GTNode : public CmpOpNode { class GT : public PrimExpr { public: TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GT, PrimExpr, GTNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode); }; @@ -402,7 +402,7 @@ class GENode : public CmpOpNode { class GE : public PrimExpr { public: TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GE, PrimExpr, GENode); TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode); }; @@ -418,9 +418,7 @@ class AndNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b); } - - static constexpr const char* _type_key = "tir.And"; - TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.And", AndNode, PrimExprNode); }; /*! @@ -430,7 +428,7 @@ class AndNode : public PrimExprNode { class And : public PrimExpr { public: TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(And, PrimExpr, AndNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode); }; @@ -446,9 +444,7 @@ class OrNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b); } - - static constexpr const char* _type_key = "tir.Or"; - TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Or", OrNode, PrimExprNode); }; /*! @@ -458,7 +454,7 @@ class OrNode : public PrimExprNode { class Or : public PrimExpr { public: TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Or, PrimExpr, OrNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode); }; @@ -472,9 +468,7 @@ class NotNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &NotNode::a); } - - static constexpr const char* _type_key = "tir.Not"; - TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Not", NotNode, PrimExprNode); }; /*! @@ -484,7 +478,7 @@ class NotNode : public PrimExprNode { class Not : public PrimExpr { public: TVM_DLL Not(PrimExpr a, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Not, PrimExpr, NotNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode); }; @@ -511,9 +505,7 @@ class SelectNode : public PrimExprNode { .def_ro("true_value", &SelectNode::true_value) .def_ro("false_value", &SelectNode::false_value); } - - static constexpr const char* _type_key = "tir.Select"; - TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Select", SelectNode, PrimExprNode); }; /*! @@ -524,7 +516,7 @@ class Select : public PrimExpr { public: TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Select, PrimExpr, SelectNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode); }; @@ -543,9 +535,9 @@ class BufferLoadNode : public PrimExprNode { /*! \brief The buffer variable. */ Buffer buffer; /*! \brief The indices location to be loaded. */ - Array indices; + ffi::Array indices; /*! \brief The predicate mask for loading values. */ - Optional predicate; + ffi::Optional predicate; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -554,9 +546,7 @@ class BufferLoadNode : public PrimExprNode { .def_ro("indices", &BufferLoadNode::indices) .def_ro("predicate", &BufferLoadNode::predicate); } - - static constexpr const char* _type_key = "tir.BufferLoad"; - TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferLoad", BufferLoadNode, PrimExprNode); private: /*! \brief Set the dtype based on the buffer/indices @@ -581,9 +571,9 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, - Optional predicate = std::nullopt, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); + TVM_DLL explicit BufferLoad(Buffer buffer, ffi::Array indices, + ffi::Optional predicate = std::nullopt, Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; @@ -601,7 +591,7 @@ class ProducerLoadNode : public PrimExprNode { /*! \brief The buffer producer. */ DataProducer producer; /*! \brief The location arguments. */ - Array indices; + ffi::Array indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -609,9 +599,7 @@ class ProducerLoadNode : public PrimExprNode { .def_ro("producer", &ProducerLoadNode::producer) .def_ro("indices", &ProducerLoadNode::indices); } - - static constexpr const char* _type_key = "tir.ProducerLoad"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.ProducerLoad", ProducerLoadNode, PrimExprNode); }; /*! @@ -620,9 +608,10 @@ class ProducerLoadNode : public PrimExprNode { */ class ProducerLoad : public PrimExpr { public: - TVM_DLL explicit ProducerLoad(DataProducer producer, Array indices, Span span = Span()); + TVM_DLL explicit ProducerLoad(DataProducer producer, ffi::Array indices, + Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ProducerLoad, PrimExpr, ProducerLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode); }; @@ -651,9 +640,7 @@ class RampNode : public PrimExprNode { .def_ro("stride", &RampNode::stride) .def_ro("lanes", &RampNode::lanes); } - - static constexpr const char* _type_key = "tir.Ramp"; - TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Ramp", RampNode, PrimExprNode); }; /*! @@ -663,7 +650,7 @@ class RampNode : public PrimExprNode { class Ramp : public PrimExpr { public: TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Ramp, PrimExpr, RampNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode); }; @@ -681,9 +668,7 @@ class BroadcastNode : public PrimExprNode { .def_ro("value", &BroadcastNode::value) .def_ro("lanes", &BroadcastNode::lanes); } - - static constexpr const char* _type_key = "tir.Broadcast"; - TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Broadcast", BroadcastNode, PrimExprNode); }; /*! @@ -693,7 +678,7 @@ class BroadcastNode : public PrimExprNode { class Broadcast : public PrimExpr { public: TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, PrimExpr, BroadcastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode); }; @@ -716,9 +701,7 @@ class LetNode : public PrimExprNode { .def_ro("value", &LetNode::value) .def_ro("body", &LetNode::body); } - - static constexpr const char* _type_key = "tir.Let"; - TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Let", LetNode, PrimExprNode); }; /*! @@ -728,7 +711,7 @@ class LetNode : public PrimExprNode { class Let : public PrimExpr { public: TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Let, PrimExpr, LetNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode); }; @@ -746,15 +729,25 @@ class CallNode : public PrimExprNode { RelaxExpr op; /*! \brief The arguments. */ - Array args; + ffi::Array args; + + /*! + * \brief Additional annotations about the call. + * + * These annotations can be used to pass additional metadata + * to lowering passes. For tile operators, this can include + * coalesced_width, disable_tma, eviction_policy, etc. + */ + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args); + refl::ObjectDef() + .def_ro("op", &CallNode::op) + .def_ro("args", &CallNode::args) + .def_ro("annotations", &CallNode::annotations); } - - static constexpr const char* _type_key = "tir.Call"; - TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Call", CallNode, PrimExprNode); }; /*! @@ -763,8 +756,10 @@ class CallNode : public PrimExprNode { */ class Call : public PrimExpr { public: - TVM_DLL Call(DataType dtype, RelaxExpr op, Array args, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); + TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, + ffi::Map annotations = {}, + Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; @@ -776,9 +771,9 @@ class Call : public PrimExpr { class ShuffleNode : public PrimExprNode { public: /*! \brief the input vectors. */ - Array vectors; + ffi::Array vectors; /*! \brief The indices of each element. */ - Array indices; + ffi::Array indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -786,9 +781,7 @@ class ShuffleNode : public PrimExprNode { .def_ro("vectors", &ShuffleNode::vectors) .def_ro("indices", &ShuffleNode::indices); } - - static constexpr const char* _type_key = "tir.Shuffle"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Shuffle", ShuffleNode, PrimExprNode); }; /*! @@ -797,11 +790,11 @@ class ShuffleNode : public PrimExprNode { */ class Shuffle : public PrimExpr { public: - TVM_DLL Shuffle(Array vectors, Array indices, Span span = Span()); - TVM_DLL static PrimExpr Concat(Array vectors, Span span = Span()); + TVM_DLL Shuffle(ffi::Array vectors, ffi::Array indices, Span span = Span()); + TVM_DLL static PrimExpr Concat(ffi::Array vectors, Span span = Span()); TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Shuffle, PrimExpr, ShuffleNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode); }; @@ -813,19 +806,19 @@ class Shuffle : public PrimExpr { class CommReducerNode : public Object { public: /*! \brief The left argument of reducer */ - Array lhs; + ffi::Array lhs; /*! \brief The right argument of reducer */ - Array rhs; + ffi::Array rhs; /*! \brief The result of reducer */ - Array result; + ffi::Array result; /*! * \brief The identity element of reducer, which leaves other * elements unchanged when combined with it, with respect to * the binary operation of this reducer uses. */ - Array identity_element; + ffi::Array identity_element; /*! \brief Function call operator to combine a and b */ - Array operator()(Array a, Array b) const; + ffi::Array operator()(ffi::Array a, ffi::Array b) const; /*! * \brief Span that points to the original source code. * Reserved debug information. @@ -842,9 +835,8 @@ class CommReducerNode : public Object { .def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore()); } - static constexpr const char* _type_key = "tir.CommReducer"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.CommReducer", CommReducerNode, Object); }; /*! @@ -853,10 +845,10 @@ class CommReducerNode : public Object { */ class CommReducer : public ObjectRef { public: - TVM_DLL CommReducer(Array lhs, Array rhs, Array result, - Array identity_element, Span span = Span()); + TVM_DLL CommReducer(ffi::Array lhs, ffi::Array rhs, ffi::Array result, + ffi::Array identity_element, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CommReducer, ObjectRef, CommReducerNode); }; /*! \brief Reduction operator */ @@ -865,11 +857,11 @@ class ReduceNode : public PrimExprNode { /*! \brief The commutative combiner */ CommReducer combiner; /*! \brief The source operand */ - Array source; + ffi::Array source; /*! \brief The init operand */ - Array init; + ffi::Array init; /*! \brief The reduction axis */ - Array axis; + ffi::Array axis; /*! * \brief Predicate on the reduction * Only add the body to reduction if condition is true. @@ -888,9 +880,7 @@ class ReduceNode : public PrimExprNode { .def_ro("condition", &ReduceNode::condition) .def_ro("value_index", &ReduceNode::value_index); } - - static constexpr const char* _type_key = "tir.Reduce"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Reduce", ReduceNode, PrimExprNode); }; /*! @@ -899,10 +889,11 @@ class ReduceNode : public PrimExprNode { */ class Reduce : public PrimExpr { public: - TVM_DLL Reduce(CommReducer combiner, Array src, Array rdom, PrimExpr condition, - int value_index, Array init, Span span = Span()); + TVM_DLL Reduce(CommReducer combiner, ffi::Array src, ffi::Array rdom, + PrimExpr condition, int value_index, ffi::Array init, + Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Reduce, PrimExpr, ReduceNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode); }; @@ -915,7 +906,7 @@ class Reduce : public PrimExpr { * \tparam V the value of the Map. */ template -inline std::unordered_map as_unordered_map(const Map& dmap) { +inline std::unordered_map as_unordered_map(const ffi::Map& dmap) { std::unordered_map ret; for (auto kv : dmap) { ret[kv.first] = kv.second; @@ -931,8 +922,8 @@ inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static tvm::tir::StringImm ConvertFallbackValue(String value) { + : public ObjectRefWithFallbackTraitsBase { + TVM_FFI_INLINE static tvm::tir::StringImm ConvertFallbackValue(ffi::String value) { return tvm::tir::StringImm(value); } }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 6ea50e9ae0f0..97701d16b097 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include @@ -48,7 +48,7 @@ namespace tir { class PrimFuncNode : public BaseFuncNode { public: /*! \brief Function parameters */ - Array params; + ffi::Array params; /*! \brief The return type of the function. */ Type ret_type; /*! @@ -96,7 +96,7 @@ class PrimFuncNode : public BaseFuncNode { * all usage in the body of the function is done through a * flattened alias of the buffer. */ - Map buffer_map; + ffi::Map buffer_map; /*! \brief The body of the function */ tir::Stmt body; @@ -119,9 +119,7 @@ class PrimFuncNode : public BaseFuncNode { TVM_DLL FuncType func_type_annotation() const; TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - - static constexpr const char* _type_key = "tir.PrimFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.PrimFunc", PrimFuncNode, BaseFuncNode); }; /*! @@ -148,11 +146,11 @@ class PrimFunc : public BaseFunc { * * \param span The location of this object in the source code. */ - TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = Map(), + TVM_DLL PrimFunc(ffi::Array params, Stmt body, Type ret_type = VoidType(), + ffi::Map buffer_map = ffi::Map(), DictAttrs attrs = DictAttrs(), Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); }; @@ -172,9 +170,7 @@ class TensorIntrinNode : public Object { .def_ro("desc", &TensorIntrinNode::desc) .def_ro("impl", &TensorIntrinNode::impl); } - - static constexpr const char* _type_key = "tir.TensorIntrin"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.TensorIntrin", TensorIntrinNode, Object); }; /*! @@ -198,7 +194,7 @@ class TensorIntrin : public ObjectRef { * \throws This method throws an exception if the TensorIntrin with the specified name already * exists. */ - TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false); + TVM_DLL static void Register(ffi::String name, TensorIntrin intrin, bool override = false); /*! * \brief Look up TensorIntrin by name. Raises an exception if not found. @@ -209,9 +205,9 @@ class TensorIntrin : public ObjectRef { * \throws This method throws an exception if the TensorIntrin does not exist and allow_missing is * false. */ - TVM_DLL static Optional Get(String name, bool allow_missing = false); + TVM_DLL static ffi::Optional Get(ffi::String name, bool allow_missing = false); - TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorIntrin, ObjectRef, TensorIntrinNode); }; /*! @@ -252,7 +248,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map>& param_map); +PrimFunc Specialize(PrimFunc func, const ffi::Map>& param_map); /*! * \brief PrimFunc specific attribute names. @@ -264,7 +260,7 @@ namespace attr { /*! * \brief List of thread IterVar that a DeviceLaunch function corresponds to. * - * Type: Array + * Type: ffi::Array * * We call a device kernel launch function f using the following convention: * diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 518d7602f562..6866431ee487 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -56,7 +56,7 @@ class IndexMapNode : public Object { * If initial_indices is empty, then final_indices should also be * empty, and no mapping is applied. */ - Array initial_indices; + ffi::Array initial_indices; /*! * \brief Expressions defining the indices after remapping. @@ -68,7 +68,7 @@ class IndexMapNode : public Object { * If final_indices is empty, then initial_indices should also be * empty, and the map is an identity function. */ - Array final_indices; + ffi::Array final_indices; /*! * \brief The inverse index map. @@ -80,7 +80,7 @@ class IndexMapNode : public Object { * * \note ObjectRef is used here instead of IndexMap to avoid circular reference. */ - Optional inverse_index_map; + ffi::Optional inverse_index_map; /*! * \brief Default constructor @@ -102,7 +102,8 @@ class IndexMapNode : public Object { * \returns The indices in the output space. Contains one value for * each expression in `final_indices`. */ - Array MapIndices(const Array& indices, arith::Analyzer* analyzer) const; + ffi::Array MapIndices(const ffi::Array& indices, + arith::Analyzer* analyzer) const; /*! \brief Map a memory range to the output space * @@ -120,7 +121,7 @@ class IndexMapNode : public Object { * \returns The ranges in the output space. Contains one value for * each expression in `final_indices`. */ - Array MapRanges(const Array& ranges, arith::Analyzer* analyzer) const; + ffi::Array MapRanges(const ffi::Array& ranges, arith::Analyzer* analyzer) const; /*! \brief Map a buffer shape to the output space * @@ -133,23 +134,23 @@ class IndexMapNode : public Object { * \returns The buffer shape in the output space. Contains one * value for each expression in `final_indices`. */ - Array MapShape(const Array& shape, arith::Analyzer* analyzer) const; + ffi::Array MapShape(const ffi::Array& shape, arith::Analyzer* analyzer) const; - /* \brief Map an NDArray according to this index map + /* \brief Map an Tensor according to this index map * - * \param arr_src The NDArray whose layout is transformed by this index map. + * \param arr_src The Tensor whose layout is transformed by this index map. * - * \returns The transformed NDArray. + * \returns The transformed Tensor. */ - runtime::NDArray MapNDArray(runtime::NDArray arr_src) const; + runtime::Tensor MapTensor(runtime::Tensor arr_src) const; /*! * \brief Convert to string representation in Python. * \param f_name_map Optional function to specify the stringified name of the variables. * \return The stringified lambda expression in Python. */ - String ToPythonString( - const std::function(const Var& var)>& f_name_map = nullptr) const; + ffi::String ToPythonString( + const std::function(const Var& var)>& f_name_map = nullptr) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -162,8 +163,7 @@ class IndexMapNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "tir.IndexMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IndexMap", IndexMapNode, Object); }; class IndexMap : public ObjectRef { @@ -174,8 +174,8 @@ class IndexMap : public ObjectRef { * \param final_indices Expressions defining the indices after remapping. * \param inverse_index_map The optional pre-defined inverse index map */ - IndexMap(Array initial_indices, Array final_indices, - Optional inverse_index_map = std::nullopt); + IndexMap(ffi::Array initial_indices, ffi::Array final_indices, + ffi::Optional inverse_index_map = std::nullopt); /*! * \brief Create an index map from a packed function @@ -184,8 +184,8 @@ class IndexMap : public ObjectRef { * \param inverse_index_map The optional pre-defined inverse index map * \return The created index map */ - static IndexMap FromFunc(int ndim, ffi::TypedFunction(Array)> func, - Optional inverse_index_map = std::nullopt); + static IndexMap FromFunc(int ndim, ffi::TypedFunction(ffi::Array)> func, + ffi::Optional inverse_index_map = std::nullopt); /*! \brief Generate the inverse mapping. * @@ -195,7 +195,7 @@ class IndexMap : public ObjectRef { * If the user has supplied an `inverse_index_map`, that map is * assumed to be correct and bijective, and is returned. */ - IndexMap Inverse(Array initial_ranges, arith::Analyzer* analyzer) const; + IndexMap Inverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const; /*! \brief Rename the variables in the index map and ensure the names are unique. * @@ -206,7 +206,7 @@ class IndexMap : public ObjectRef { * \return The renamed index map. */ IndexMap RenameVariables( - const std::function(const Var& var)>& f_name_map = nullptr) const; + const std::function(const Var& var)>& f_name_map = nullptr) const; /*! \brief Generate the inverse mapping. * @@ -217,10 +217,10 @@ class IndexMap : public ObjectRef { * \return The inverted index map, along with the predicate for * which the inverse maps to a valid range. */ - std::pair NonSurjectiveInverse(Array initial_ranges, + std::pair NonSurjectiveInverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const; - TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IndexMap, ObjectRef, IndexMapNode); }; /*! \brief Substitute variables in an index map. @@ -229,7 +229,7 @@ class IndexMap : public ObjectRef { * \param f_subst The substitution function */ IndexMap Substitute(const IndexMap& index_map, - std::function(const Var& var)> f_subst); + std::function(const Var& var)> f_subst); } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 3dda3f7c63c5..005e8f5532ee 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -99,6 +99,20 @@ TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); */ TVM_DLL PrimExpr thread_return(Span span = Span()); +/*! + * \brief Continue current loop. + * \param span The location of this operation in the source. + * \return The continue loop expression. + */ +TVM_DLL PrimExpr continue_loop(Span span = Span()); + +/*! + * \brief Break current loop. + * \param span The location of this operation in the source. + * \return The break loop expression. + */ +TVM_DLL PrimExpr break_loop(Span span = Span()); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. @@ -566,7 +580,7 @@ TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span()); * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr sum(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr sum(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -576,7 +590,7 @@ TVM_DLL PrimExpr sum(PrimExpr source, Array axis, Array * \param init The value with which to initialize the output. * \param span The location of this operation in the source. */ -TVM_DLL PrimExpr all(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr all(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -587,7 +601,7 @@ TVM_DLL PrimExpr all(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr any(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr any(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -598,7 +612,7 @@ TVM_DLL PrimExpr any(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr max(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr max(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -609,7 +623,7 @@ TVM_DLL PrimExpr max(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr min(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr min(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -620,8 +634,8 @@ TVM_DLL PrimExpr min(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr prod(PrimExpr source, Array axis, Array init = {}, - Span span = Span()); +TVM_DLL PrimExpr prod(PrimExpr source, ffi::Array axis, + ffi::Array init = {}, Span span = Span()); /*! * \brief Calculate floor(x) @@ -708,17 +722,17 @@ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - if (x.dtype().is_bfloat16()) { \ - DataType bf16_dtype = x.dtype(); \ - DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ - PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ - PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ - return tir::Cast(bf16_dtype, {result_fp32}, span); \ - } else { \ - return tir::Call(x.dtype(), op, {x}, span); \ - } \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType bf16_dtype = x.dtype(); \ + DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ + PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ + PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, {}, span); \ + return tir::Cast(bf16_dtype, {result_fp32}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, {}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp); @@ -750,7 +764,7 @@ TVM_DECLARE_INTRIN_UNARY(clz); #define TVM_DECLARE_INTRIN_BINARY(OpName) \ inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x, y}, span); \ + return tir::Call(x.dtype(), op, {x, y}, {}, span); \ } TVM_DECLARE_INTRIN_BINARY(atan2); @@ -802,7 +816,7 @@ inline PrimExpr make_zero(DataType t, Span span = Span()); * \return The result expression. */ inline PrimExpr const_true(int lanes = 1, Span span = Span()) { - return make_const(DataType::UInt(1, lanes), 1); + return make_const(DataType::Bool(lanes), 1); } /*! * \brief Make a constant false expression. @@ -811,7 +825,7 @@ inline PrimExpr const_true(int lanes = 1, Span span = Span()) { * \return The result expression. */ inline PrimExpr const_false(int lanes = 1, Span span = Span()) { - return make_const(DataType::UInt(1, lanes), 0); + return make_const(DataType::Bool(lanes), 0); } /*! * \brief Get x as constant int expression. @@ -883,7 +897,7 @@ inline bool is_const_number(const PrimExpr& x); * \tparam FReduce The type of the reduction. */ template -inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array& values, +inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const ffi::Array& values, Span span = Span()) { for (PrimExpr val : values) { init_value = freduce(init_value, val, span); @@ -943,7 +957,7 @@ inline bool is_no_op(const tir::Stmt& stmt) { template inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) { - if (t.is_int()) return IntImm(t, static_cast(value), span); + if (t.is_int() || t.is_bool()) return IntImm(t, static_cast(value), span); if (t.is_uint()) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 883477dd645e..c87ccd741a5e 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -39,7 +39,7 @@ namespace tir { /*! * \brief Global symbol of the op after lowering. */ -using TGlobalSymbol = String; +using TGlobalSymbol = ffi::String; /*! * \brief Whether the op is overloaded for vector form. @@ -59,7 +59,7 @@ using FLegalize = ffi::TypedFunction; /*! * \brief The operator's name in TVMScript printer */ -using TScriptPrinterName = String; +using TScriptPrinterName = ffi::String; /*! * \brief Specifies that TVMScript printer prints the dtype as the first/last argument. diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index 146d3e8ec9bb..b6e283f400fb 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -42,8 +42,9 @@ class Schedule; * \param decision Decisions made on the instruction * \return The functor returns an array of output random variables */ -using FInstructionApply = ffi::TypedFunction( - Schedule sch, const Array& inputs, const Array& attrs, const Any& decision)>; +using FInstructionApply = + ffi::TypedFunction(Schedule sch, const ffi::Array& inputs, + const ffi::Array& attrs, const Any& decision)>; /*! * \brief Type of the functor that converts the instruction to a statement in python syntax @@ -54,8 +55,8 @@ using FInstructionApply = ffi::TypedFunction( * \return A string representing the python api call */ using FInstructionAsPython = - ffi::TypedFunction& inputs, const Array& attrs, - const Any& decision, const Array& outputs)>; + ffi::TypedFunction& inputs, const ffi::Array& attrs, + const Any& decision, const ffi::Array& outputs)>; /*! * \brief Type of the functor that serialize its attributes to JSON @@ -63,7 +64,7 @@ using FInstructionAsPython = * \return An array, serialized attributes * \note This functor is nullable */ -using FInstructionAttrsAsJSON = ffi::TypedFunction attrs)>; +using FInstructionAttrsAsJSON = ffi::TypedFunction attrs)>; /*! * \brief Type of the functor that deserialize its attributes from JSON @@ -71,7 +72,7 @@ using FInstructionAttrsAsJSON = ffi::TypedFunction attrs)>; * \return An array, deserialized attributes * \note This functor is nullable */ -using FInstructionAttrsFromJSON = ffi::TypedFunction(ObjectRef json_attrs)>; +using FInstructionAttrsFromJSON = ffi::TypedFunction(ObjectRef json_attrs)>; /*! * \brief Kind of an instruction, e.g. Split, Reorder, etc. @@ -88,7 +89,7 @@ using FInstructionAttrsFromJSON = ffi::TypedFunction(ObjectRef json_a class InstructionKindNode : public runtime::Object { public: /*! \brief The name of a kind of instructions */ - String name; + ffi::String name; /*! * \brief Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule * state. For example, the instruction `GetBlock` is pure because it changes @@ -120,9 +121,7 @@ class InstructionKindNode : public runtime::Object { /*! \brief Checks if the instruction kind is EnterPostproc */ bool IsPostproc() const; - - static constexpr const char* _type_key = "tir.InstructionKind"; - TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.InstructionKind", InstructionKindNode, runtime::Object); }; /*! @@ -136,8 +135,9 @@ class InstructionKind : public runtime::ObjectRef { * \param name The registered name of the InstructionKind * \return The InstructionKind retrieved */ - static InstructionKind Get(const String& name); - TVM_DEFINE_OBJECT_REF_METHODS(InstructionKind, runtime::ObjectRef, InstructionKindNode); + static InstructionKind Get(const ffi::String& name); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(InstructionKind, runtime::ObjectRef, + InstructionKindNode); }; /*! \brief Schedule instructions each corresponds to a schedule primitive */ @@ -156,20 +156,20 @@ class InstructionNode : public runtime::Object { * - String * - null pointer */ - Array inputs; + ffi::Array inputs; /*! * \brief The attributes of the instruction. Similar to attributes of an operator, * attributes of an instruction are arbitrary constant metadata required by the instructions. * For example, the name of the block to be retrieved in `GetBlock`. */ - Array attrs; + ffi::Array attrs; /*! \brief The output random variables of the instruction, and the type of each element can be one * of the following: * - BlockRV * - LoopRV * - ExprRV, atomic variables only, won't be constants or composite PrimExpr */ - Array outputs; + ffi::Array outputs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -179,9 +179,7 @@ class InstructionNode : public runtime::Object { .def_ro("attrs", &InstructionNode::attrs) .def_ro("outputs", &InstructionNode::outputs); } - - static constexpr const char* _type_key = "tir.Instruction"; - TVM_DECLARE_FINAL_OBJECT_INFO(InstructionNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Instruction", InstructionNode, runtime::Object); }; /*! @@ -197,10 +195,10 @@ class Instruction : public runtime::ObjectRef { * \param attrs The attributes of the instruction * \param outputs The output random variables of the instruction */ - explicit Instruction(InstructionKind kind, Array inputs, Array attrs, - Array outputs); + explicit Instruction(InstructionKind kind, ffi::Array inputs, ffi::Array attrs, + ffi::Array outputs); - TVM_DEFINE_OBJECT_REF_METHODS(Instruction, runtime::ObjectRef, InstructionNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Instruction, runtime::ObjectRef, InstructionNode); }; /*! @@ -235,7 +233,7 @@ class Instruction : public runtime::ObjectRef { /*! \brief An entry in the registry of InstructionKind */ class InstructionKindRegEntry { public: - static InstructionKindRegEntry& RegisterOrGet(const String& name); + static InstructionKindRegEntry& RegisterOrGet(const ffi::String& name); InstructionKindRegEntry& set_name() { get_mutable()->name = this->name; @@ -276,7 +274,7 @@ class InstructionKindRegEntry { } /*! \brief The name of the registry entry */ - String name; + ffi::String name; /*! \brief The instruction kind */ InstructionKind inst_kind_; template diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9fbb9981e55c..12b2e66429ce 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -51,11 +51,10 @@ enum class BufferIndexType : int32_t { class BlockRVNode : public runtime::Object { public: static void RegisterReflection() { - // No fields to register as they are not visited + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - - static constexpr const char* _type_key = "tir.BlockRV"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockRV", BlockRVNode, runtime::Object); }; /*! @@ -66,7 +65,7 @@ class BlockRV : public runtime::ObjectRef { public: /*! \brief Constructor */ TVM_DLL BlockRV(); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockRV, runtime::ObjectRef, BlockRVNode); }; /**************** Random variable: LoopRV ****************/ @@ -75,11 +74,10 @@ class BlockRV : public runtime::ObjectRef { class LoopRVNode : public runtime::Object { public: static void RegisterReflection() { - // No fields to register as they are not visited + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - - static constexpr const char* _type_key = "tir.LoopRV"; - TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LoopRV", LoopRVNode, runtime::Object); }; /*! @@ -90,7 +88,7 @@ class LoopRV : public runtime::ObjectRef { public: /*! \brief Constructor */ TVM_DLL LoopRV(); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LoopRV, runtime::ObjectRef, LoopRVNode); }; /**************** Random variable: ExprRV ****************/ @@ -111,8 +109,8 @@ class ScheduleNode : public runtime::Object { public: virtual ~ScheduleNode() = default; - static constexpr const char* _type_key = "tir.Schedule"; - TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Schedule", ScheduleNode, runtime::Object); public: /*! \brief Get the IRModule associated with this schedule. */ @@ -120,9 +118,9 @@ class ScheduleNode : public runtime::Object { /*! \return The internal state of scheduling */ virtual ScheduleState state() const = 0; /*! \return The internally maintained trace of scheduling program execution */ - virtual Optional trace() const = 0; + virtual ffi::Optional trace() const = 0; /*! \return The GlobalVar of the func that the schedule is currently working on */ - virtual Optional func_working_on() const = 0; + virtual ffi::Optional func_working_on() const = 0; /*! * \brief Instruct the schedule to work on a function in the IRModule. * @@ -137,7 +135,7 @@ class ScheduleNode : public runtime::Object { * * \sa GetBlock */ - virtual void WorkOn(const String& func_name) = 0; + virtual void WorkOn(const ffi::String& func_name) = 0; /*! * \brief Returns a copy of the schedule, including both its state and its symbol table, * guaranteeing that @@ -230,8 +228,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The random variable sampled from candidates */ - virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = std::nullopt) = 0; + virtual ExprRV SampleCategorical(const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional decision = std::nullopt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop * \param loop_rv The loop to be tiled @@ -240,8 +239,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return A list of length `n`, the random perfect tile sizes sampled */ - virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = std::nullopt) = 0; + virtual ffi::Array SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision = std::nullopt) = 0; /*! * \brief Sample the factors to a partitioned tile for a specific loop * @@ -257,9 +257,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return A list of length `n`, the random partitioned tile sizes sampled */ - virtual Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, - int innerpart_factor, - Optional> decision = std::nullopt) = 0; + virtual ffi::Array SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision = std::nullopt) = 0; /*! * \brief Sample a compute-at location of the given block * \param block_rv The block whose compute-at location is to be sampled @@ -267,7 +267,7 @@ class ScheduleNode : public runtime::Object { * \return The sampled loop where the input block is to be computed at */ virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = std::nullopt) = 0; + ffi::Optional decision = std::nullopt) = 0; /******** Schedule: Get blocks & loops ********/ /*! @@ -284,40 +284,40 @@ class ScheduleNode : public runtime::Object { * * \sa WorkOn */ - virtual BlockRV GetBlock(const String& name, - const Optional& func_name = std::nullopt) = 0; + virtual BlockRV GetBlock(const ffi::String& name, + const ffi::Optional& func_name = std::nullopt) = 0; /*! * \brief Get the parent loops of the block in its scope, from outer to inner * \param block_rv The query block * \return A list of loops above the given block in its scope, from outer to inner */ - virtual Array GetLoops(const BlockRV& block_rv) = 0; + virtual ffi::Array GetLoops(const BlockRV& block_rv) = 0; /*! * \brief Get the leaf blocks of a specific scope * \param block_rv The block where the scope is rooted * \return A list of child blocks */ - virtual Array GetChildBlocks(const BlockRV& block_rv) = 0; + virtual ffi::Array GetChildBlocks(const BlockRV& block_rv) = 0; /*! * \brief Get the leaf blocks of under a specific loop * \param loop_rv The loop under which collecting is conducted * \return A list of child blocks */ - virtual Array GetChildBlocks(const LoopRV& loop_rv) = 0; + virtual ffi::Array GetChildBlocks(const LoopRV& loop_rv) = 0; /*! * \brief Get the producer of a specific block, under the same block scope * \param block_rv The block in the query * \return A list of blocks, the producers of the given block under the same scope of the given * block */ - virtual Array GetProducers(const BlockRV& block_rv) = 0; + virtual ffi::Array GetProducers(const BlockRV& block_rv) = 0; /*! * \brief Get the consumers of a specific block, under the same block scope * \param block_rv The block to be queried * \return A list of blocks, the consumers of the given block under the same scope of the given * block */ - virtual Array GetConsumers(const BlockRV& block_rv) = 0; + virtual ffi::Array GetConsumers(const BlockRV& block_rv) = 0; /*! * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written @@ -326,7 +326,7 @@ class ScheduleNode : public runtime::Object { * \return A list of all blocks that write to some output buffer * block */ - virtual Array GetOutputBlocks(const BlockRV& scope_block_rv) = 0; + virtual ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Merge a list of loops into one. The loops under their LCA requires: @@ -337,7 +337,7 @@ class ScheduleNode : public runtime::Object { * \param loop_rvs The loops to be merged * \return The new loop after merge */ - virtual LoopRV Merge(const Array& loop_rvs) = 0; + virtual LoopRV Merge(const ffi::Array& loop_rvs) = 0; /*! * \brief Fuse a list of consecutive loops into one. It requires: * 1) The loops can't have annotations or thread bindings. @@ -348,7 +348,7 @@ class ScheduleNode : public runtime::Object { * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new loop after fusion */ - virtual LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters = true) = 0; + virtual LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters = true) = 0; /*! * \brief Split a loop into a list of consecutive loops. It requires: * 1) The loop can't have annotation or thread binding. @@ -361,9 +361,10 @@ class ScheduleNode : public runtime::Object { * schedule writer knows are divisible by the loop bound. Warning: enabling this feature may * result in incorrect code generation if not used carefully. \return The new loops after split. */ - virtual Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters = true, - bool disable_predication = false) = 0; + virtual ffi::Array Split(const LoopRV& loop_rv, + const ffi::Array>& factors, + bool preserve_unit_iters = true, + bool disable_predication = false) = 0; /*! * \brief Partition the loops into sequence of multiple loops * 1) The loop can't have annotation or thread binding. @@ -373,8 +374,9 @@ class ScheduleNode : public runtime::Object { * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new loops after partition */ - virtual Array LoopPartition(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters = true) = 0; + virtual ffi::Array LoopPartition(const LoopRV& loop_rv, + const ffi::Array>& factors, + bool preserve_unit_iters = true) = 0; /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. * It requires: @@ -387,13 +389,14 @@ class ScheduleNode : public runtime::Object { * 4) No duplicated loops are allowed in the arguments. * \param ordered_loop_rvs The loops in the new order */ - virtual void Reorder(const Array& ordered_loop_rvs) = 0; + virtual void Reorder(const ffi::Array& ordered_loop_rvs) = 0; /*! * \brief Reorder the itervars inside a block. * \param block_rv The block to be transformed. * \param new_order The new itervar order. */ - virtual void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) = 0; + virtual void ReorderBlockIterVar(const BlockRV& block_rv, + const ffi::Array new_order) = 0; /*! * \brief Create a new unit loop on top of the specific block. * \param block_rv The block above which the new loop is created @@ -438,7 +441,7 @@ class ScheduleNode : public runtime::Object { * \param loop_rv The loop to be bound to the thread axis * \param thread_axis The thread axis to be bound to the loop */ - virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0; + virtual void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) = 0; /*! * \brief Unroll the input loop. It requires nothing * \param loop_rv The loop to be unrolled @@ -456,8 +459,8 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, - const Array consumer_blocks = {}) = 0; + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -469,8 +472,8 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks = {}) = 0; + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) = 0; /*! * \brief Create a block that reads a buffer region into a read cache. It requires: * 1) There is at most one block who writes the buffer in the scope. @@ -484,7 +487,7 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) = 0; + const ffi::String& storage_scope, const IndexMap& index_map) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -498,7 +501,8 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) = 0; + const ffi::String& storage_scope, + const IndexMap& index_map) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the target block both read & write the target buffer. @@ -507,8 +511,8 @@ class ScheduleNode : public runtime::Object { * \param storage_scope The target storage scope * \return The cache stage blocks, cache read block together with cache write block. */ - virtual Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) = 0; + virtual ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) = 0; /*! * \brief Create a block to cache precomputed index for later use. * if there is no index computation, keep unchanged. @@ -517,8 +521,8 @@ class ScheduleNode : public runtime::Object { * \param cse_thresh The repeat threshold that determines a common sub expr * \return The cache stage blocks. */ - virtual Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, - int cse_thresh) = 0; + virtual ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, + int cse_thresh) = 0; /*! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. * The layout of the cache will be the same as by the iterators of the block that reads/writes the @@ -531,12 +535,12 @@ class ScheduleNode : public runtime::Object { * \return The reindex stage block. */ virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) = 0; + BufferIndexType buffer_index_type, bool skip_simplify = false) = 0; /******** Schedule: Data movement ********/ virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) = 0; + const ffi::String& storage_scope) = 0; virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) = 0; + const ffi::String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -604,6 +608,13 @@ class ScheduleNode : public runtime::Object { * \param block The block to be inlined to its producer */ virtual void ReverseComputeInline(const BlockRV& block) = 0; + /*! + * \brief Fuse an epilogue block into a reduction block + * \param reduction_block The reduction block (e.g., matmul) + * \param epilogue_block The epilogue block to be fused (e.g., bias add) + */ + virtual void FuseReductionEpilogue(const BlockRV& reduction_block, + const BlockRV& epilogue_block) = 0; /******** Schedule: Reduction ********/ /*! * \brief Decompose a reduction block into two separate blocks. @@ -661,7 +672,8 @@ class ScheduleNode : public runtime::Object { * \param buffer_index The index of the buffer in block's write region * \param storage_scope The storage scope to be set */ - virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; + virtual void SetScope(const BlockRV& block_rv, int buffer_index, + const ffi::String& storage_scope) = 0; /*! * \brief Set the data type of a buffer, where the buffer is specified by a block and a * write-index @@ -671,7 +683,8 @@ class ScheduleNode : public runtime::Object { * \param buffer_index the index of the buffer in block's write region * \param dtype The data type to be set */ - virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0; + virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, + const ffi::String& dtype) = 0; /******** Schedule: Blockize & Tensorize ********/ /*! * \brief Convert the subtree rooted at a specific loop into a block. @@ -686,14 +699,14 @@ class ScheduleNode : public runtime::Object { * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return the new block */ - virtual BlockRV Blockize(const Array& blocks, bool preserve_unit_iters = true) = 0; + virtual BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. * \param loop_rv The loop to be tensorized * \param intrin Name of the tensor intrinsic * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ - virtual void Tensorize(const LoopRV& loop_rv, const String& intrin, + virtual void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. @@ -701,7 +714,7 @@ class ScheduleNode : public runtime::Object { * \param intrin Name of the tensor intrinsic * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ - virtual void Tensorize(const BlockRV& block_rv, const String& intrin, + virtual void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters = true) = 0; /******** Schedule: Annotation ********/ @@ -711,26 +724,27 @@ class ScheduleNode : public runtime::Object { * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ - virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const Any& ann_val) = 0; + virtual void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) = 0; /*! * \brief Annotate a block with a key value pair * \param block_rv The block to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ - virtual void Annotate(const BlockRV& block_rv, const String& ann_key, const Any& ann_val) = 0; + virtual void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, + const Any& ann_val) = 0; /*! * \brief Unannotate a loop's annotation with key ann_key * \param loop_rv The loop to be unannotated * \param ann_key The annotation key */ - virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; + virtual void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) = 0; /*! * \brief Unannotate a block's annotation with key ann_key * \param block_rv The block to be unannotated * \param ann_key The annotation key */ - virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; + virtual void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) = 0; /******** Schedule: Layout transformation ********/ /*! @@ -766,7 +780,7 @@ class ScheduleNode : public runtime::Object { */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value = std::nullopt, + const ffi::Optional& pad_value = std::nullopt, bool assume_injective_transform = false) = 0; /*! @@ -789,7 +803,7 @@ class ScheduleNode : public runtime::Object { */ virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) = 0; + const ffi::Array& axis_separators) = 0; /******** Schedule: Padding ********/ /*! @@ -818,7 +832,7 @@ class ScheduleNode : public runtime::Object { * The size of the producer buffers are infered from the padding size of the Einsum computation. * The producer buffers are padded by the initial value of the corresponding reduction. */ - virtual void PadEinsum(const BlockRV& block_rv, const Array& padding) = 0; + virtual void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) = 0; /******** Schedule: Buffer transformation ********/ /*! @@ -858,8 +872,8 @@ class ScheduleNode : public runtime::Object { * \param buf_type The buffer type: read/write * \param buf_index_array The array of buffer indices we hide access. */ - virtual void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) = 0; + virtual void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + const ffi::Array& buf_index_array) = 0; }; /*! @@ -912,7 +926,7 @@ class Schedule : public runtime::ObjectRef { TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check = true); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Schedule, runtime::ObjectRef, ScheduleNode); }; } // namespace tir diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 99994d2bf68a..4467463912e8 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -43,7 +43,7 @@ namespace tir { */ struct BlockInfo { /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */ - BlockScope scope{nullptr}; + BlockScope scope{ffi::UnsafeInit()}; // The properties below are information about the current block realization under its parent scope /*! \brief Property of a block, indicating the block realization binding is quasi-affine */ bool affine_binding{false}; @@ -147,7 +147,7 @@ class ScheduleStateNode : public Object { * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars. */ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, - const Map& block_sref_reuse); + const ffi::Map& block_sref_reuse); /*! * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. @@ -156,8 +156,8 @@ class ScheduleStateNode : public Object { */ TVM_DLL void DebugVerify() const; - static constexpr const char* _type_key = "tir.ScheduleState"; - TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleStateNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.ScheduleState", ScheduleStateNode, Object); /******** Property of blocks ********/ /*! \brief Returns the BlockInfo correpsonding to the block sref */ @@ -218,10 +218,7 @@ class ScheduleState : public ObjectRef { */ TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0, bool enable_check = true); - /*! \return The mutable pointer to the ScheduleStateNode */ - ScheduleStateNode* get() const { return static_cast(data_.get()); } - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleState, ObjectRef, ScheduleStateNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScheduleState, ObjectRef, ScheduleStateNode); }; } // namespace tir diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index 6e3dd29551ef..f5aa7cb5ffd6 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -37,8 +37,8 @@ class Trace; * \return A new decision */ using FTraceDecisionProvider = - ffi::TypedFunction& inputs, - const Array& attrs, const Any& decision)>; + ffi::TypedFunction& inputs, + const ffi::Array& attrs, const Any& decision)>; /*! * \brief An execution trace of a scheduling program @@ -58,9 +58,9 @@ using FTraceDecisionProvider = class TraceNode : public runtime::Object { public: /*! \brief The instructions invoked so far in the program execution */ - Array insts; + ffi::Array insts; /*! \brief The random decisions made upon those instructions */ - Map decisions; + ffi::Map decisions; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -69,8 +69,8 @@ class TraceNode : public runtime::Object { .def_ro("decisions", &TraceNode::decisions); } - static constexpr const char* _type_key = "tir.Trace"; - TVM_DECLARE_FINAL_OBJECT_INFO(TraceNode, runtime::Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Trace", TraceNode, runtime::Object); public: /*! @@ -89,14 +89,14 @@ class TraceNode : public runtime::Object { * \param inst The new instruction to be appended * \param decision The random decision made on this instruction * The type of `decision` depends on the instruction, e.g. - * the decision of `SamplePerfectTile` has type `Array` + * the decision of `SamplePerfectTile` has type `ffi::Array` */ void Append(Instruction inst, Any decision); /*! * \brief Remove the last instruction, along with the decision made on that instruction, if any * \return The instruction removed; std::nullopt if the trace is empty */ - Optional Pop(); + ffi::Optional Pop(); /*! * \brief Apply the trace to a TensorIR schedule * \param sch The schedule to be applied onto @@ -118,7 +118,7 @@ class TraceNode : public runtime::Object { * \param remove_postproc If postprocessing instructions are removed * \return A sequence of python statements */ - Array AsPython(bool remove_postproc) const; + ffi::Array AsPython(bool remove_postproc) const; /*! * \brief Create a new trace with an instruction whose decision is changed, * assuming this instruction exists in the resulting trace @@ -149,7 +149,7 @@ class Trace : public runtime::ObjectRef { * \param insts The instructions used * \param decisions The decisions made in sampling */ - explicit Trace(Array insts, Map decisions); + explicit Trace(ffi::Array insts, ffi::Map decisions); /*! * \brief Apply a JSON-serialized trace to a TensorIR schedule * \param json The JSON-serialized trace @@ -157,7 +157,7 @@ class Trace : public runtime::ObjectRef { */ static void ApplyJSONToSchedule(ObjectRef json, Schedule sch); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, runtime::ObjectRef, TraceNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Trace, runtime::ObjectRef, TraceNode); }; } // namespace tir diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index bbdb7c272ed8..75ba37b43fb8 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -53,17 +53,16 @@ class StmtNode : public Object { TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr const char* _type_key = "tir.Stmt"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const uint32_t _type_child_slots = 15; - TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tir.Stmt", StmtNode, Object); }; /*! \brief Container of all statements */ class Stmt : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Stmt, ObjectRef, StmtNode); }; /*! @@ -85,9 +84,7 @@ class LetStmtNode : public StmtNode { .def_ro("value", &LetStmtNode::value) .def_ro("body", &LetStmtNode::body); } - - static constexpr const char* _type_key = "tir.LetStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.LetStmt", LetStmtNode, StmtNode); }; /*! @@ -98,7 +95,7 @@ class LetStmt : public Stmt { public: TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LetStmt, Stmt, LetStmtNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode); }; @@ -117,7 +114,7 @@ class AttrStmtNode : public StmtNode { /*! \brief this is attribute about certain node */ ffi::Any node; /*! \brief the type key of the attribute */ - String attr_key; + ffi::String attr_key; /*! \brief The attribute value, value is well defined at current scope. */ PrimExpr value; /*! \brief The body statement to be executed */ @@ -131,9 +128,7 @@ class AttrStmtNode : public StmtNode { .def_ro("value", &AttrStmtNode::value) .def_ro("body", &AttrStmtNode::body); } - - static constexpr const char* _type_key = "tir.AttrStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AttrStmt", AttrStmtNode, StmtNode); }; /*! @@ -142,9 +137,10 @@ class AttrStmtNode : public StmtNode { */ class AttrStmt : public Stmt { public: - TVM_DLL AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); + TVM_DLL AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, + Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrStmt, Stmt, AttrStmtNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode); }; @@ -170,9 +166,7 @@ class AssertStmtNode : public StmtNode { .def_ro("message", &AssertStmtNode::message) .def_ro("body", &AssertStmtNode::body); } - - static constexpr const char* _type_key = "tir.AssertStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AssertStmt", AssertStmtNode, StmtNode); }; /*! @@ -183,7 +177,7 @@ class AssertStmt : public Stmt { public: TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AssertStmt, Stmt, AssertStmtNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode); }; @@ -204,9 +198,9 @@ class BufferStoreNode : public StmtNode { /*! \brief The value to be stored. */ PrimExpr value; /*! \brief The indices location to be stored. */ - Array indices; + ffi::Array indices; /*! \brief The predicate mask for storing values. */ - Optional predicate; + ffi::Optional predicate; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -216,9 +210,7 @@ class BufferStoreNode : public StmtNode { .def_ro("indices", &BufferStoreNode::indices) .def_ro("predicate", &BufferStoreNode::predicate); } - - static constexpr const char* _type_key = "tir.BufferStore"; - TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferStore", BufferStoreNode, StmtNode); }; /*! @@ -227,10 +219,11 @@ class BufferStoreNode : public StmtNode { */ class BufferStore : public Stmt { public: - TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate = std::nullopt, Span span = Span()); + TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate = std::nullopt, + Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); }; @@ -250,7 +243,7 @@ class BufferRealizeNode : public StmtNode { /*! \brief The buffer variable. */ Buffer buffer; /*! \brief Bounds to be realized */ - Array bounds; + ffi::Array bounds; /*! \brief Only realize if condition holds. */ PrimExpr condition; /*! \brief The body of realization. */ @@ -266,12 +259,10 @@ class BufferRealizeNode : public StmtNode { } BufferRealizeNode() = default; - BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, + BufferRealizeNode(Buffer buffer, ffi::Array bounds, PrimExpr condition, Stmt body, Span span = Span()) : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {} - - static constexpr const char* _type_key = "tir.BufferRealize"; - TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferRealize", BufferRealizeNode, StmtNode); }; /*! @@ -280,10 +271,10 @@ class BufferRealizeNode : public StmtNode { */ class BufferRealize : public Stmt { public: - TVM_DLL explicit BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, - Span span = Span()); + TVM_DLL explicit BufferRealize(Buffer buffer, ffi::Array bounds, PrimExpr condition, + Stmt body, Span span = Span()); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BufferRealize, Stmt, BufferRealizeNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); }; @@ -297,7 +288,7 @@ class AllocateNode : public StmtNode { /*! \brief The type of the buffer. */ DataType dtype; /*! \brief The extents of the buffer. */ - Array extents; + ffi::Array extents; /*! \brief Only allocate buffer when condition is satisfied. */ PrimExpr condition; /*! \brief The body to be executed. */ @@ -308,7 +299,7 @@ class AllocateNode : public StmtNode { * These annotations can be used as auxiliary hint * to future transformations. */ - Map annotations; + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -333,11 +324,8 @@ class AllocateNode : public StmtNode { * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); - - static constexpr const char* _type_key = "tir.Allocate"; - - TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); + TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array& extents); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Allocate", AllocateNode, StmtNode); }; /*! @@ -346,11 +334,12 @@ class AllocateNode : public StmtNode { */ class Allocate : public Stmt { public: - TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations = Map(), + TVM_DLL Allocate(Var buffer_var, DataType dtype, ffi::Array extents, PrimExpr condition, + Stmt body, + ffi::Map annotations = ffi::Map(), Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Allocate, Stmt, AllocateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode); }; @@ -363,16 +352,16 @@ class AllocateConstNode : public StmtNode { Var buffer_var; /*! \brief The optional data associated to the constant. */ - Optional data; + ffi::Optional data; /*! * \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index - * to indicate the index within "constants" attribute, that is a Array of IRModule. + * to indicate the index within "constants" attribute, that is a ffi::Array of IRModule. */ - Optional irmod_storage_idx; + ffi::Optional irmod_storage_idx; /*! \brief The type of the buffer. */ DataType dtype; /*! \brief The extents of the buffer. */ - Array extents; + ffi::Array extents; /*! \brief The body to be executed. */ Stmt body; /*! @@ -381,7 +370,7 @@ class AllocateConstNode : public StmtNode { * These annotations can be used as auxiliary hint * to future transformations. */ - Map annotations; + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -407,10 +396,8 @@ class AllocateConstNode : public StmtNode { * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); - - static constexpr const char* _type_key = "tir.AllocateConst"; - TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode); + TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array& extents); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AllocateConst", AllocateConstNode, StmtNode); }; /*! @@ -423,11 +410,11 @@ class AllocateConst : public Stmt { * depending on the type of ObjectRef, it will either * create AllocateConstNode with irmod_storage_idx or data */ - TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, - Map annotations = Map(), - Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); + TVM_DLL AllocateConst( + Var buffer_var, DataType dtype, ffi::Array extents, ObjectRef data_or_idx, + Stmt body, ffi::Map annotations = ffi::Map(), + Span span = Span()); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllocateConst, Stmt, AllocateConstNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode); }; @@ -445,16 +432,14 @@ class DeclBufferNode : public StmtNode { .def_ro("buffer", &DeclBufferNode::buffer) .def_ro("body", &DeclBufferNode::body); } - - static constexpr const char* _type_key = "tir.DeclBuffer"; - TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.DeclBuffer", DeclBufferNode, StmtNode); }; /*! \brief Managed reference to DeclBufferNode */ class DeclBuffer : public Stmt { public: TVM_DLL DeclBuffer(Buffer buffer, Stmt body, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(DeclBuffer, Stmt, DeclBufferNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DeclBuffer, Stmt, DeclBufferNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DeclBufferNode); }; @@ -465,7 +450,7 @@ class DeclBuffer : public Stmt { class SeqStmtNode : public StmtNode { public: /*! \brief internal sequence content. */ - Array seq; + ffi::Array seq; /*! \return get the size of the sequence */ size_t size() const { return seq.size(); } @@ -478,9 +463,7 @@ class SeqStmtNode : public StmtNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("seq", &SeqStmtNode::seq); } - - static constexpr const char* _type_key = "tir.SeqStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SeqStmt", SeqStmtNode, StmtNode); }; /*! @@ -498,9 +481,7 @@ class EvaluateNode : public StmtNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &EvaluateNode::value); } - - static constexpr const char* _type_key = "tir.Evaluate"; - TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Evaluate", EvaluateNode, StmtNode); }; /*! @@ -513,7 +494,7 @@ class Evaluate : public Stmt { explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {} - TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Evaluate, Stmt, EvaluateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode); }; @@ -525,7 +506,7 @@ class SeqStmt : public Stmt { * \param seq The sequence. * \param span The location of this object in the source code. */ - TVM_DLL explicit SeqStmt(Array seq, Span span = Span()); + TVM_DLL explicit SeqStmt(ffi::Array seq, Span span = Span()); /*! \return get the size of the sequence */ size_t size() const { return operator->()->size(); } @@ -555,7 +536,7 @@ class SeqStmt : public Stmt { */ template static Stmt Flatten(Args&&... seq_args) { - Array seq; + ffi::Array seq; ffi::details::for_each(Flattener(&seq), std::forward(seq_args)...); @@ -593,10 +574,10 @@ class SeqStmt : public Stmt { /*! \brief Helper class to flatten sequence of arguments into Array. */ class Flattener { public: - explicit Flattener(Array* seq) : seq_(seq) {} + explicit Flattener(ffi::Array* seq) : seq_(seq) {} template - static Optional AsSeqStmt(const T& t) { + static ffi::Optional AsSeqStmt(const T& t) { if constexpr (std::is_same_v) { return t; } @@ -605,7 +586,7 @@ class SeqStmt : public Stmt { } if constexpr (std::is_base_of_v) { if (const SeqStmtNode* ptr = t.template as()) { - return GetRef(ptr); + return ffi::GetRef(ptr); } else { return std::nullopt; } @@ -661,10 +642,10 @@ class SeqStmt : public Stmt { } private: - Array* seq_; + ffi::Array* seq_; }; - TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SeqStmt, Stmt, SeqStmtNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode); }; @@ -678,7 +659,7 @@ class IfThenElseNode : public StmtNode { /*! \brief The branch to be executed when condition is true. */ Stmt then_case; /*! \brief The branch to be executed when condition is false, can be null. */ - Optional else_case; + ffi::Optional else_case; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -687,9 +668,7 @@ class IfThenElseNode : public StmtNode { .def_ro("then_case", &IfThenElseNode::then_case) .def_ro("else_case", &IfThenElseNode::else_case); } - - static constexpr const char* _type_key = "tir.IfThenElse"; - TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IfThenElse", IfThenElseNode, StmtNode); }; /*! @@ -698,10 +677,10 @@ class IfThenElseNode : public StmtNode { */ class IfThenElse : public Stmt { public: - TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case = std::nullopt, - Span span = Span()); + TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, + ffi::Optional else_case = std::nullopt, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IfThenElse, Stmt, IfThenElseNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode); }; @@ -738,7 +717,7 @@ enum class ForKind : int { * * \code * - * for (loop_var = min; loop_var < min + extent; ++loop_var) { + * for (loop_var = min; loop_var < min + extent; loop_var += step) { * // body * } * \endcode @@ -759,7 +738,7 @@ class ForNode : public StmtNode { * \brief Only valid when kind == ForKind::kThreadBinding * The context thread that this loop variable bounds to. */ - Optional thread_binding; + ffi::Optional thread_binding; /*! * \brief Additional annotations about the loop. * @@ -768,7 +747,11 @@ class ForNode : public StmtNode { * not change the control flow semantics of the loop * and can be ignored in most passes. */ - Map annotations; + ffi::Map annotations; + /*! + * \brief The loop step. It is one if not specified. + */ + ffi::Optional step; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -779,11 +762,14 @@ class ForNode : public StmtNode { .def_ro("kind", &ForNode::kind) .def_ro("body", &ForNode::body) .def_ro("thread_binding", &ForNode::thread_binding) - .def_ro("annotations", &ForNode::annotations); + .def_ro("annotations", &ForNode::annotations) + .def_ro("step", &ForNode::step); } - static constexpr const char* _type_key = "tir.For"; - TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); + /*! \brief Check it is a loop without nontrivial loop step. */ + bool HasTrivialStep() const; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode); }; /*! @@ -793,10 +779,11 @@ class ForNode : public StmtNode { class For : public Stmt { public: TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding = std::nullopt, - Map annotations = Map(), Span span = Span()); + ffi::Optional thread_binding = std::nullopt, + ffi::Map annotations = {}, + ffi::Optional step = std::nullopt, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); }; @@ -823,9 +810,7 @@ class WhileNode : public StmtNode { .def_ro("condition", &WhileNode::condition) .def_ro("body", &WhileNode::body); } - - static constexpr const char* _type_key = "tir.While"; - TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.While", WhileNode, StmtNode); }; /*! @@ -836,7 +821,7 @@ class While : public Stmt { public: TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(While, Stmt, WhileNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode); }; @@ -848,7 +833,7 @@ class BufferRegionNode : public PrimExprConvertibleNode { /*! \brief The buffer of the buffer region. */ Buffer buffer; /*! \brief The region array of the buffer region. */ - Array region; + ffi::Array region; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -859,9 +844,8 @@ class BufferRegionNode : public PrimExprConvertibleNode { TVM_DLL PrimExpr ToPrimExpr() const final; - static constexpr const char* _type_key = "tir.BufferRegion"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, PrimExprConvertibleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferRegion", BufferRegionNode, PrimExprConvertibleNode); }; /*! @@ -870,7 +854,7 @@ class BufferRegionNode : public PrimExprConvertibleNode { */ class BufferRegion : public PrimExprConvertible { public: - TVM_DLL explicit BufferRegion(Buffer buffer, Array region); + TVM_DLL explicit BufferRegion(Buffer buffer, ffi::Array region); /*! * \brief Create a BufferRegion which is full region of the given buffer. @@ -885,9 +869,9 @@ class BufferRegion : public PrimExprConvertible { * \param indices The access point indices of the buffer * \return The BufferRegion which is the single point of the given buffer. */ - TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array indices); + TVM_DLL static BufferRegion FromPoint(Buffer buffer, ffi::Array indices); - TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, PrimExprConvertible, BufferRegionNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferRegion, PrimExprConvertible, BufferRegionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode); }; @@ -914,9 +898,8 @@ class MatchBufferRegionNode : public Object { .def_ro("source", &MatchBufferRegionNode::source); } - static constexpr const char* _type_key = "tir.MatchBufferRegion"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.MatchBufferRegion", MatchBufferRegionNode, Object); }; /*! @@ -927,7 +910,7 @@ class MatchBufferRegion : public ObjectRef { public: TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source); - TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchBufferRegion, ObjectRef, MatchBufferRegionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode); }; @@ -955,19 +938,19 @@ class MatchBufferRegion : public ObjectRef { class BlockNode : public StmtNode { public: /*! \brief The variables of the block. */ - Array iter_vars; + ffi::Array iter_vars; /*! \brief The read buffer regions of the block. */ - Array reads; + ffi::Array reads; /*! \brief The write buffer regions of the block. */ - Array writes; + ffi::Array writes; /*! \brief The name_hint of the block. */ - String name_hint; + ffi::String name_hint; /*! \brief The buffer allocated in the block. */ - Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The match buffer regions. */ - Array match_buffers; + ffi::Array match_buffers; /*! \brief The annotation of the block. */ - Map annotations; + ffi::Map annotations; /*! * \brief The init statement is executed during the first iteration of reduction loops in a * reduction block. The optional init field allows us to represent initialization and @@ -975,7 +958,7 @@ class BlockNode : public StmtNode { * We also provide primitives to decompose the init into a separate block during scheduling. * Init field is `std::nullopt` if there is no reduction iter_vars */ - Optional init; + ffi::Optional init; /*! \brief The body of the block. */ Stmt body; @@ -992,9 +975,7 @@ class BlockNode : public StmtNode { .def_ro("init", &BlockNode::init) .def_ro("body", &BlockNode::body); } - - static constexpr const char* _type_key = "tir.Block"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Block", BlockNode, StmtNode); }; /*! @@ -1003,15 +984,16 @@ class BlockNode : public StmtNode { */ class Block : public Stmt { public: - TVM_DLL explicit Block(Array iter_vars, Array reads, - Array writes, String name_hint, Stmt body, - Optional init = std::nullopt, - Array alloc_buffers = Array(), - Array match_buffers = Array(), - Map annotations = Map(), - Span span = Span()); - - TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode); + TVM_DLL explicit Block( + ffi::Array iter_vars, ffi::Array reads, + ffi::Array writes, ffi::String name_hint, Stmt body, + ffi::Optional init = std::nullopt, + ffi::Array alloc_buffers = ffi::Array(), + ffi::Array match_buffers = ffi::Array(), + ffi::Map annotations = ffi::Map(), + Span span = Span()); + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Block, Stmt, BlockNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode); }; @@ -1021,7 +1003,7 @@ class Block : public Stmt { class BlockRealizeNode : public StmtNode { public: /*! \brief The corresponding values of the iter vars. */ - Array iter_values; + ffi::Array iter_values; /*! * \brief The predicate of the block realization, the block will only be executed when the * predicate is true. @@ -1037,9 +1019,7 @@ class BlockRealizeNode : public StmtNode { .def_ro("predicate", &BlockRealizeNode::predicate) .def_ro("block", &BlockRealizeNode::block); } - - static constexpr const char* _type_key = "tir.BlockRealize"; - TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BlockRealize", BlockRealizeNode, StmtNode); }; /*! @@ -1048,10 +1028,10 @@ class BlockRealizeNode : public StmtNode { */ class BlockRealize : public Stmt { public: - TVM_DLL explicit BlockRealize(Array iter_values, PrimExpr predicate, Block block, + TVM_DLL explicit BlockRealize(ffi::Array iter_values, PrimExpr predicate, Block block, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BlockRealize, Stmt, BlockRealizeNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); }; @@ -1146,7 +1126,7 @@ constexpr const char* buffer_dim_align = "buffer_dim_align"; constexpr const char* buffer_bound = "buffer_bound"; /*! * \brief Bind the buffer specification to the region of the op - * When this scope occurs, the stmt.node is a Array = [buffer, tensor] + * When this scope occurs, the stmt.node is a ffi::Array = [buffer, tensor] * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). * The scope represents that we need to bind the storage region of tensor to buffer. * This will affect replacement of some variables inside the scope that @@ -1339,6 +1319,11 @@ constexpr const char* explicit_read_region = "explicit_read_region"; */ constexpr const char* explicit_write_region = "explicit_write_region"; +constexpr const char* tilelang_assume = "tl.assume"; + +/*! \brief ,ark a ForNode represent an irregular loop of non-structural control flow edges. */ +constexpr const char* irregular_loop_mark = "irregular_loop_mark"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 23747a7e936c..b3c43bdc1459 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -325,7 +325,7 @@ class StmtExprMutator : public StmtMutator, public ExprMutator { * when the IRNode's type key is in the list. */ TVM_DLL Stmt IRTransform(Stmt stmt, const ffi::Function& preorder, const ffi::Function& postorder, - Optional> only_enable = std::nullopt); + ffi::Optional> only_enable = std::nullopt); /*! * \brief Recursively visit the ir in post DFS order node, apply fvisit @@ -341,7 +341,7 @@ TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function(const Var& var)> vmap); +TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& var)> vmap); /*! * \brief Substitute the var specified by vmap. @@ -349,7 +349,8 @@ TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& v * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. * \return The result. */ -TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(const Var& var)> vmap); +TVM_DLL PrimExpr Substitute(PrimExpr expr, + std::function(const Var& var)> vmap); /*! * \brief Substitute the var specified by vmap. @@ -358,7 +359,8 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(cons * \return The result. */ template -Array Substitute(const Array& arr, std::function(const Var& var)> vmap) { +ffi::Array Substitute(const ffi::Array& arr, + std::function(const Var& var)> vmap) { return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); }); } @@ -369,7 +371,7 @@ Array Substitute(const Array& arr, std::function(const * \return The modified Range. */ inline Range Substitute(const Range& range, - std::function(const Var& var)> vmap) { + std::function(const Var& var)> vmap) { return Range::FromMinExtent(Substitute(range->min, vmap), Substitute(range->extent, vmap)); } @@ -385,8 +387,8 @@ inline Range Substitute(const Range& range, * \return The modified object. */ template -auto Substitute(Obj&& obj, const Map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { return vmap.Get(var); }; +auto Substitute(Obj&& obj, const ffi::Map& vmap) { + auto func = [&vmap](const Var& var) -> ffi::Optional { return vmap.Get(var); }; return Substitute(std::forward(obj), func); } @@ -401,8 +403,8 @@ auto Substitute(Obj&& obj, const Map& vmap) { */ template >> -auto Substitute(Obj&& obj, const Map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { +auto Substitute(Obj&& obj, const ffi::Map& vmap) { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto opt = vmap.Get(var)) { return opt.value(); } else { @@ -424,7 +426,7 @@ auto Substitute(Obj&& obj, const Map& vmap) { template >> auto Substitute(Obj&& obj, const std::unordered_map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto it = vmap.find(var.get()); it != vmap.end()) { return it->second; } else { @@ -446,7 +448,7 @@ auto Substitute(Obj&& obj, const std::unordered_map& vmap) template >> auto Substitute(Obj&& obj, const std::unordered_map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto it = vmap.find(var); it != vmap.end()) { return it->second; } else { @@ -473,7 +475,7 @@ auto Substitute(Obj&& obj, const std::unordered_map& iter_vmap) { vmap[iter_var->var.get()] = expr; } - auto func = [&vmap](const Var& var) -> Optional { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto it = vmap.find(var.get()); it != vmap.end()) { return it->second; } else { @@ -493,8 +495,8 @@ auto Substitute(Obj&& obj, const std::unordered_map& iter_vmap) { * \sa Substitute * \return The result. */ -TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt, - std::function(const Var&)> vmap); +TVM_DLL Stmt SubstituteWithDataTypeLegalization( + Stmt stmt, std::function(const Var&)> vmap); /*! * \brief Substitute the var specified by vmap and legalize data types after substitution. @@ -507,7 +509,7 @@ TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt, * \return The result. */ TVM_DLL PrimExpr SubstituteWithDataTypeLegalization( - PrimExpr expr, std::function(const Var&)> vmap); + PrimExpr expr, std::function(const Var&)> vmap); /*! * \brief Recursively visit the IR in pre DFS order node, apply fvisit. diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index eb64d87f9518..bf100dc49c4c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,8 +56,8 @@ using tvm::transform::Sequential; * \return The created function pass. */ TVM_DLL Pass CreatePrimFuncPass(std::function pass_func, - int opt_level, String name, tvm::Array required, - bool traceable = false); + int opt_level, ffi::String name, + tvm::ffi::Array required, bool traceable = false); /*! * \brief partition loops in the stmt. @@ -197,7 +197,7 @@ TVM_DLL Pass MakeUnpackedAPI(); * * \return The pass. */ -TVM_DLL Pass RemapThreadAxis(Map axis_map); +TVM_DLL Pass RemapThreadAxis(ffi::Map axis_map); /*! * \brief Lower custom datatypes. @@ -273,7 +273,7 @@ TVM_DLL Pass SkipAssert(); * \param storage_scope The storage scope considered. * \return The pass. */ -TVM_DLL Pass ThreadSync(String storage_scope); +TVM_DLL Pass ThreadSync(ffi::String storage_scope); /*! * \brief Lower cross thread alleduce. @@ -357,11 +357,11 @@ TVM_DLL Pass BF16ComputeLegalize(); /*! * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32 * before Ops, then add a cast back to fp8. - * \param promote_dtype_str The data type used for type promotion, defaults to float16 + * \param promote_dtype The data type used for type promotion, defaults to float16 * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16"); +TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16"); /*! * \brief Legalize bf16 storage types to u16. @@ -676,7 +676,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); */ TVM_DLL Pass InjectSoftwarePipeline(); -TVM_DLL Pass BindParams(const Array& constants); +TVM_DLL Pass BindParams(const ffi::Array& constants); /*! * \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute. @@ -729,17 +729,17 @@ TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true); /*! * \brief Remove the weight layout rewrite block - * \param skip_ndarray_rewrite If True, exact rewrite of NDArray, according to the given index map, - * will be skipped. Only the shape of the NDArray is transformed correctly, and the content of + * \param skip_tensor_rewrite If True, exact rewrite of Tensor, according to the given index map, + * will be skipped. Only the shape of the Tensor is transformed correctly, and the content of * the destination array will be filled with random values. * - * When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, - * before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's - * MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary. + * When this pass is called many times during MetaSchedule tuning, the raw data of Tensor, + * before and after rewrite, does not matter. Since Tensor layout rewrite, using IndexMap's + * MapTensor, is currently slow, skipping the exact rewrite is sometimes necessary. * * \return The pass. */ -TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false); +TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite = false); /*! * \brief Add the explicit local stage for the shared memory access on GPU. diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 7bf29265ceea..521b03a4728b 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -51,7 +51,7 @@ class VarNode : public PrimExprNode { * \brief The hint to the variable name. * \note Each variable is uniquely identified by its address. */ - String name_hint; + ffi::String name_hint; /*! * \brief type annotation of the variable. * @@ -69,22 +69,22 @@ class VarNode : public PrimExprNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; - static constexpr const char* _type_key = "tir.Var"; static constexpr const uint32_t _type_child_slots = 1; - TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("tir.Var", VarNode, PrimExprNode); }; /*! \brief a named variable in TIR */ class Var : public PrimExpr { public: - explicit Var(ObjectPtr n) : PrimExpr(n) {} + explicit Var(ffi::UnsafeInit tag) : PrimExpr(tag) {} + explicit Var(ObjectPtr n) : PrimExpr(n) {} /*! * \brief Constructor * \param name_hint variable name * \param dtype data type * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32), + TVM_DLL explicit Var(ffi::String name_hint = "v", DataType dtype = DataType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. @@ -92,19 +92,19 @@ class Var : public PrimExpr { * \param type_annotation The type annotation. * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span()); + TVM_DLL explicit Var(ffi::String name_hint, Type type_annotation, Span span = Span()); /*! * \brief Make a new copy of var with same type, but a different nam * \param name The new name to be used. * \return the new Var copy */ - TVM_DLL Var copy_with_name(const String& name) const; + TVM_DLL Var copy_with_name(const ffi::String& name) const; /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ - TVM_DLL Var copy_with_suffix(const String& suffix) const; + TVM_DLL Var copy_with_suffix(const ffi::String& suffix) const; /*! * \brief Make a new copy of the variable with specified dtype * \param dtype The specified dtype @@ -136,21 +136,21 @@ class SizeVarNode : public VarNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - static constexpr const char* _type_key = "tir.SizeVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SizeVar", SizeVarNode, VarNode); }; /*! \brief a named variable represents a tensor index size */ class SizeVar : public Var { public: - explicit SizeVar(ObjectPtr n) : Var(n) {} + explicit SizeVar(ObjectPtr n) : Var(n) {} + explicit SizeVar(ffi::UnsafeInit tag) : Var(tag) {} /*! * \brief constructor * \param name_hint variable name * \param t data type * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32), + TVM_DLL explicit SizeVar(ffi::String name_hint = "s", DataType t = DataType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. @@ -158,7 +158,7 @@ class SizeVar : public Var { * \param type_annotation The type annotation. * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span()); + TVM_DLL explicit SizeVar(ffi::String name_hint, Type type_annotation, Span span = Span()); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. @@ -173,7 +173,7 @@ class SizeVar : public Var { using ContainerType = SizeVarNode; }; -using Region = Array; +using Region = ffi::Array; /*! * \brief Type of iteration variable. @@ -266,7 +266,7 @@ class IterVarNode : public PrimExprConvertibleNode { * \brief additional tag on the iteration variable, * set this if this is bound already to a known thread tag. */ - String thread_tag; + ffi::String thread_tag; /*! * \brief Span that points to the original source code. * Reserved debug information. @@ -284,9 +284,8 @@ class IterVarNode : public PrimExprConvertibleNode { .def_ro("thread_tag", &IterVarNode::thread_tag); } - static constexpr const char* _type_key = "tir.IterVar"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, PrimExprConvertibleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IterVar", IterVarNode, PrimExprConvertibleNode); }; /*! @@ -297,14 +296,14 @@ class IterVarNode : public PrimExprConvertibleNode { */ class IterVar : public PrimExprConvertible { public: - TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "", + TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, ffi::String thread_tag = "", Span span = Span()); /*! * \return the corresponding var in the IterVar. */ inline operator PrimExpr() const; - TVM_DEFINE_OBJECT_REF_METHODS(IterVar, PrimExprConvertible, IterVarNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterVar, PrimExprConvertible, IterVarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode); }; diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 9be7256b446e..2aedef4c58b6 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -46,7 +46,7 @@ namespace topi { * \return A Tensor whose op member is a broadcast operation */ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, - const tvm::Array& output_shape, + const tvm::ffi::Array& output_shape, std::string name = "T_broadcast_to", std::string tag = kBroadcast) { ICHECK_GE(output_shape.size(), t->shape.size()) @@ -54,7 +54,7 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); ICHECK_EQ(output_shape.size(), bh.common_shape.size()); - Array oshape; + ffi::Array oshape; for (size_t i = 0; i < output_shape.size(); ++i) { if (output_shape[i].as() == nullptr) { oshape.push_back(output_shape[i]); @@ -63,30 +63,32 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, oshape.push_back(bh.common_shape[i]); } } - auto l = [&](tvm::Array ovars) { + auto l = [&](tvm::ffi::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; return tvm::te::compute(oshape, l, name, tag); } -#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ - inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, std::string tag = kBroadcast) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return detail::WithBroadcast(l, A, B, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ - std::string name = "T_" #Name, std::string tag = kElementWise) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute( \ - A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, std::string tag = kElementWise) { \ - auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute( \ - B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \ +#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ + inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kBroadcast) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return detail::WithBroadcast(l, A, B, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + A->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, \ + tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + B->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, \ + tag); \ } #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ diff --git a/include/tvm/topi/contrib/cublas.h b/include/tvm/topi/contrib/cublas.h index 3032643ed700..3590b7a54458 100644 --- a/include/tvm/topi/contrib/cublas.h +++ b/include/tvm/topi/contrib/cublas.h @@ -49,7 +49,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, b return make_extern( {{n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, @@ -74,7 +74,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool tra return make_extern( {{b, n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, diff --git a/include/tvm/topi/contrib/rocblas.h b/include/tvm/topi/contrib/rocblas.h index 4f0b887fb178..e29b135b7d2c 100644 --- a/include/tvm/topi/contrib/rocblas.h +++ b/include/tvm/topi/contrib/rocblas.h @@ -48,7 +48,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, return make_extern( {{n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, @@ -71,7 +71,7 @@ inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool tr return make_extern( {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, diff --git a/include/tvm/topi/detail/array_utils.h b/include/tvm/topi/detail/array_utils.h index 89c985695865..f10eff6f61cb 100644 --- a/include/tvm/topi/detail/array_utils.h +++ b/include/tvm/topi/detail/array_utils.h @@ -41,7 +41,7 @@ using namespace tvm::te; * \return True iff the given array contains the given item. */ template -inline bool contains(Array array, T item) { +inline bool contains(ffi::Array array, T item) { for (auto& i : array) { if (i == item) { return true; diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index c861fbb71b2a..aab6fea22d2c 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -48,8 +48,8 @@ static inline DataType CommonType(DataType type1, DataType type2) { return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1); } -inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, - const tvm::Array& shape2) { +inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shape1, + const tvm::ffi::Array& shape2) { BroadcastHelper bh; int s1_size = shape1.size(); int s2_size = shape2.size(); @@ -94,8 +94,8 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } else { ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " << shape2[s2_size - i] - << " in: " << tvm::Array(shape1.begin(), shape1.end()) << " and " - << tvm::Array(shape2.begin(), shape2.end()); + << " in: " << tvm::ffi::Array(shape1.begin(), shape1.end()) + << " and " << tvm::ffi::Array(shape2.begin(), shape2.end()); } } // Remaining dimensions whether on shape1 or shape2 can always be completed @@ -110,10 +110,10 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, return bh; } -inline tvm::Array InputIndexFromBroadcast( - const tvm::Array& ovars, const tvm::te::Tensor& T, +inline tvm::ffi::Array InputIndexFromBroadcast( + const tvm::ffi::Array& ovars, const tvm::te::Tensor& T, const std::deque& my_vars, const std::deque& all_vars) { - tvm::Array ivars; + tvm::ffi::Array ivars; ICHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. size_t expected_dims = T->shape.size(); @@ -141,12 +141,12 @@ inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A, const tvm::te::Tensor& B, const std::string& name = "tensor", const std::string& tag = "") { auto bh = BroadcastShape(A->shape, B->shape); - auto l = [&](tvm::Array ovars) { + auto l = [&](tvm::ffi::Array ovars) { return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; - return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, name, tag); + return tvm::te::compute( + tvm::ffi::Array(bh.common_shape.begin(), bh.common_shape.end()), l, name, tag); } } // namespace detail diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 95e68f5f6d61..74b4ce143cad 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -55,7 +55,7 @@ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance array) { +inline bool IsConstIntArray(ffi::Array array) { bool is_const_int = true; for (auto const& elem : array) { is_const_int &= !elem.defined() || elem->IsInstance(); @@ -88,7 +88,7 @@ inline int64_t GetConstInt(PrimExpr expr) { * * \return A vector of the integer values */ -inline std::vector GetConstIntValues(Array exprs, const std::string& var_name) { +inline std::vector GetConstIntValues(ffi::Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { @@ -107,7 +107,7 @@ inline std::vector GetConstIntValues(Array exprs, const std::stri * * \return A vector of the int64_t values */ -inline std::vector GetConstInt64Values(Array exprs, +inline std::vector GetConstInt64Values(ffi::Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index e54169ea2934..05543f74a50b 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -41,7 +41,7 @@ using namespace tvm::te; * function. The function expects two arguments: an array of Buffers holding the input * tensor values, and a pre-allocated array of Buffers to be filled with the outputs. */ -using FExtern = std::function, Array)>; +using FExtern = std::function, ffi::Array)>; /*! * \brief Create tensors representing the result of invoking an external function. @@ -60,18 +60,19 @@ using FExtern = std::function, Array)>; * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding * element of out_types. */ -inline Array make_extern(const Array>& out_shapes, - const std::vector& out_types, - const Array& inputs, FExtern fextern, std::string name, - std::string tag, ::tvm::Map attrs) { +inline ffi::Array make_extern(const ffi::Array>& out_shapes, + const std::vector& out_types, + const ffi::Array& inputs, FExtern fextern, + std::string name, std::string tag, + ::tvm::ffi::Map attrs) { ICHECK_EQ(out_shapes.size(), out_types.size()) << "make_extern: out_shapes and out_types must have equal size"; - Array input_placeholders; + ffi::Array input_placeholders; for (auto t : inputs) { input_placeholders.push_back(tvm::tir::decl_buffer(t->shape, t->dtype, t->op->name)); } - Array output_placeholders; + ffi::Array output_placeholders; for (size_t i = 0; i < out_shapes.size(); ++i) { output_placeholders.push_back(tvm::tir::decl_buffer(out_shapes[i], out_types[i], name)); } @@ -81,7 +82,7 @@ inline Array make_extern(const Array>& out_shapes, auto op = ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt); - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < output_placeholders.size(); ++i) { outputs.push_back(op.output(i)); } @@ -107,12 +108,13 @@ inline PrimExpr pack_buffer(Buffer buf) { } else { strides = 0; } - Array pack_args{buf->data, - shape, - strides, - make_const(DataType::Int(32), static_cast(buf->shape.size())), - make_const(buf->dtype, 0), - buf->elem_offset}; + ffi::Array pack_args{ + buf->data, + shape, + strides, + make_const(DataType::Int(32), static_cast(buf->shape.size())), + make_const(buf->dtype, 0), + buf->elem_offset}; return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args); } @@ -125,7 +127,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * * \return An expression representing the invocation */ -inline PrimExpr call_packed(Array args) { +inline PrimExpr call_packed(ffi::Array args) { return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args); } diff --git a/include/tvm/topi/detail/fuse.h b/include/tvm/topi/detail/fuse.h index 7305ccef9b1d..993a837ca46c 100644 --- a/include/tvm/topi/detail/fuse.h +++ b/include/tvm/topi/detail/fuse.h @@ -40,7 +40,7 @@ using namespace tvm::te; * * \return The fused iteration variable */ -inline IterVar Fuse(Stage stage, const Array& args) { +inline IterVar Fuse(Stage stage, const ffi::Array& args) { IterVar res; stage.fuse(args, &res); return res; diff --git a/include/tvm/topi/detail/pad_utils.h b/include/tvm/topi/detail/pad_utils.h index 96eb49a505e4..dfb9542e7655 100644 --- a/include/tvm/topi/detail/pad_utils.h +++ b/include/tvm/topi/detail/pad_utils.h @@ -45,7 +45,7 @@ using namespace tvm::te; * \return An array of 4 elements, representing padding sizes for * each individual side. The array is in the order { top, left, bottom, right } */ -inline Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { +inline ffi::Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { pad_h *= 2; pad_w *= 2; diff --git a/include/tvm/topi/detail/ravel_unravel.h b/include/tvm/topi/detail/ravel_unravel.h index e91d6afb666a..27d2f9180251 100644 --- a/include/tvm/topi/detail/ravel_unravel.h +++ b/include/tvm/topi/detail/ravel_unravel.h @@ -42,7 +42,7 @@ using namespace tvm::te; * * \return The index after flattening */ -inline PrimExpr RavelIndex(Array indices, Array shape) { +inline PrimExpr RavelIndex(ffi::Array indices, ffi::Array shape) { ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; if (indices.size() == 0U) { return 0; @@ -66,7 +66,7 @@ inline PrimExpr RavelIndex(Array indices, Array shape) { * * \return The coordinate corresponding to the 1D index */ -inline Array UnravelIndex(PrimExpr idx, Array shape) { +inline ffi::Array UnravelIndex(PrimExpr idx, ffi::Array shape) { std::vector indices; for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index f2e021ed98bc..e75aeed8b97d 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -50,8 +50,8 @@ inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) } inline std::tuple, std::vector, std::vector> ConvertToVec( - const Array& begin, const Array& end, const Array& strides, - std::string slice_mode) { + const ffi::Array& begin, const ffi::Array& end, + const ffi::Array& strides, std::string slice_mode) { std::vector stride_vec(strides.size(), 1); if (slice_mode == "end") { for (size_t i = 0; i < strides.size(); ++i) { @@ -88,12 +88,13 @@ inline std::tuple, std::vector, std::vector StridedSliceCanonicalizeBegin(const Array& ishape, - const std::vector& begin, - const std::vector& strides, - const Array& axes, DataType dtype, - std::string slice_mode = "end") { - Array begin_expr; +inline ffi::Array StridedSliceCanonicalizeBegin(const ffi::Array& ishape, + const std::vector& begin, + const std::vector& strides, + const ffi::Array& axes, + DataType dtype, + std::string slice_mode = "end") { + ffi::Array begin_expr; for (size_t i = 0; i < axes.size(); ++i) { if (ishape[axes[i].IntValue()]->IsInstance()) { int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]); @@ -115,16 +116,14 @@ inline Array StridedSliceCanonicalizeBegin(const Array& isha return begin_expr; } -inline Array StridedSliceOutputShape(const Array& ishape, - const std::vector& begin, - const std::vector& end, - const std::vector& strides, - const Array& axes, std::string slice_mode, - const Array& begin_canonicalized, - bool use_any = false) { +inline ffi::Array StridedSliceOutputShape( + const ffi::Array& ishape, const std::vector& begin, + const std::vector& end, const std::vector& strides, + const ffi::Array& axes, std::string slice_mode, + const ffi::Array& begin_canonicalized, bool use_any = false) { ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any"; const size_t src_tensor_dim = ishape.size(); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < src_tensor_dim; ++i) { out_shape.push_back(ishape[i]); } diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h index 397c70c9451e..d67ad6359434 100644 --- a/include/tvm/topi/detail/tensor_utils.h +++ b/include/tvm/topi/detail/tensor_utils.h @@ -40,7 +40,7 @@ using namespace tvm::te; * * \return True if the input shape is empty. */ -inline bool is_empty_shape(const Array& x) { +inline bool is_empty_shape(const ffi::Array& x) { bool is_empty = false; for (const auto& dim : x) { if (auto int_dim = dim.as()) { @@ -63,7 +63,7 @@ inline bool is_empty_shape(const Array& x) { * * \return The interpolated value in the given index. */ -inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& indices, +inline PrimExpr bilinear_sample_nchw(const Tensor& input, const ffi::Array& indices, const PrimExpr max_y, const PrimExpr max_x) { auto batch_id = indices[0]; auto channel_id = indices[1]; @@ -107,7 +107,7 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& * * \return The interpolated value in the given index. */ -inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array& indices, +inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const ffi::Array& indices, const PrimExpr max_y, const PrimExpr max_x) { auto batch_id = indices[0]; auto channel_id = indices[3]; diff --git a/include/tvm/topi/einsum.h b/include/tvm/topi/einsum.h index 5e7813f8431b..44f01b0a967c 100644 --- a/include/tvm/topi/einsum.h +++ b/include/tvm/topi/einsum.h @@ -56,8 +56,8 @@ using namespace topi::detail; * * \return the shape of the output. */ -Array InferEinsumShape(const std::string& subscripts, - const std::vector>& operands); +ffi::Array InferEinsumShape(const std::string& subscripts, + const std::vector>& operands); /*! * \brief Evaluates the Einstein summation convention on the operands. @@ -70,7 +70,7 @@ Array InferEinsumShape(const std::string& subscripts, * * \return The calculation based on the Einstein summation convention. */ -Tensor einsum(const std::string& subscripts_str, const Array inputs, +Tensor einsum(const std::string& subscripts_str, const ffi::Array inputs, std::string name = "T_einsum", std::string tag = kEinsum); struct EinsumEquation { diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 806ddcb662f9..0ed082b0c140 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -40,11 +40,11 @@ namespace topi { using namespace tvm::te; // Unary intrinsic operators -#define TOPI_DECLARE_UNARY_OP(OpName) \ - inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \ - std::string tag = kElementWise) { \ - return compute( \ - x->shape, [&](const Array& i) { return ::tvm::OpName(x(i)); }, name, tag); \ +#define TOPI_DECLARE_UNARY_OP(OpName) \ + inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \ + std::string tag = kElementWise) { \ + return compute( \ + x->shape, [&](const ffi::Array& i) { return ::tvm::OpName(x(i)); }, name, tag); \ } TOPI_DECLARE_UNARY_OP(exp); @@ -101,7 +101,7 @@ inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string ta return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { auto x2 = x(i) * x(i); auto p = x2 * alpha_13 + alpha_11; p = x2 * p + alpha_9; @@ -136,7 +136,7 @@ inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", } else { // fallback to default implementation return compute( - x->shape, [&](const Array& i) { return ::tvm::tanh(x(i)); }, name, tag); + x->shape, [&](const ffi::Array& i) { return ::tvm::tanh(x(i)); }, name, tag); } } @@ -152,7 +152,7 @@ inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", inline Tensor identity(const Tensor& x, std::string name = "T_identity", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return x(i); }, name, tag); } /*! @@ -167,7 +167,7 @@ inline Tensor identity(const Tensor& x, std::string name = "T_identity", inline Tensor negative(const Tensor& x, std::string name = "T_negative", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return -x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return -x(i); }, name, tag); } /*! @@ -182,7 +182,7 @@ inline Tensor negative(const Tensor& x, std::string name = "T_negative", inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return !x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return !x(i); }, name, tag); } /*! @@ -197,7 +197,7 @@ inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not", inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return ~x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return ~x(i); }, name, tag); } /*! @@ -212,7 +212,7 @@ inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not", inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { PrimExpr zero = make_zero(x->dtype); PrimExpr one = make_const(x->dtype, 1); PrimExpr minus_one = make_const(x->dtype, -1); @@ -235,7 +235,7 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { PrimExpr one = make_const(x->dtype, 1); return one / tvm::sqrt(x(i)); }, @@ -258,7 +258,7 @@ inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max std::string name = "T_clip", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { auto min_val = tvm::cast(x->dtype, a_min); auto max_val = tvm::cast(x->dtype, a_max); return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) @@ -282,7 +282,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) -> PrimExpr { + [&](const ffi::Array& i) -> PrimExpr { auto expr = x(i); if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { if (expr.dtype().lanes() == type.lanes()) { @@ -310,7 +310,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return reinterpret(type, x(i)); }, name, tag); + x->shape, [&](const ffi::Array& i) { return reinterpret(type, x(i)); }, name, tag); } /*! @@ -322,12 +322,12 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te * * \return A Tensor whose op member is the sum operation */ -inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwise_sum", +inline Tensor elemwise_sum(const ffi::Array& xs, std::string name = "T_elemwise_sum", std::string tag = kElementWise) { ICHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor."; return compute( xs[0]->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { auto sum_expr = xs[0](i); for (size_t j = 1; j < xs.size(); j++) { sum_expr = sum_expr + xs[j](i); @@ -348,14 +348,14 @@ inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwi * * \return A Tensor whose op member is the full operation */ -inline Tensor full(const Array& shape, DataType dtype, const PrimExpr fill_value, +inline Tensor full(const ffi::Array& shape, DataType dtype, const PrimExpr fill_value, std::string name = "T_full", std::string tag = kElementWise) { PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { LOG(ERROR) << "Can't cast fill_value to " << dtype; } return compute( - shape, [&](const Array& i) { return ev; }, name, tag); + shape, [&](const ffi::Array& i) { return ev; }, name, tag); } /*! @@ -373,7 +373,7 @@ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, std::string name = "T_full_like", std::string tag = kElementWise) { PrimExpr ev = cast(x->dtype, fill_value); return compute( - x->shape, [&](const Array& i) { return ev; }, name, tag); + x->shape, [&](const ffi::Array& i) { return ev; }, name, tag); } /*! @@ -414,7 +414,7 @@ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string t return compute( _x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { // clamp x auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); // integer part @@ -448,7 +448,7 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", return ret; } else { return compute( - x->shape, [&](const Array& i) { return ::tvm::exp(x(i)); }, name, tag); + x->shape, [&](const ffi::Array& i) { return ::tvm::exp(x(i)); }, name, tag); } } @@ -457,7 +457,7 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", */ inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) { return compute( - data->shape, [&](const Array& i) { return fast_erf_float_expr(data(i), 32); }, name, + data->shape, [&](const ffi::Array& i) { return fast_erf_float_expr(data(i), 32); }, name, tag); } @@ -466,7 +466,7 @@ inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string */ inline Tensor fast_erf_float16(const Tensor& data, std::string name, std::string tag) { return compute( - data->shape, [&](const Array& i) { return fast_erf_float_expr(data(i), 16); }, name, + data->shape, [&](const ffi::Array& i) { return fast_erf_float_expr(data(i), 16); }, name, tag); } diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 6bef5d0f1c2a..36ce8594b3db 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -56,7 +56,7 @@ inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast< std::string name = "T_relu", std::string tag = kElementWise) { return tvm::te::compute( t->shape, - [&](const tvm::Array& i) { + [&](const tvm::ffi::Array& i) { auto threshold_const = tvm::tir::make_const(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, @@ -78,7 +78,7 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, std::string tag = kElementWise) { return tvm::te::compute( t->shape, - [&](const tvm::Array& i) { + [&](const tvm::ffi::Array& i) { auto value = t(i); auto calpha = tvm::tir::make_const(value.dtype(), alpha); return tvm::tir::Select(value > 0, value, value * calpha); @@ -106,7 +106,7 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl return tvm::te::compute( x->shape, - [&](const tvm::Array& indices) { + [&](const tvm::ffi::Array& indices) { auto xval = x(indices); return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis])); }, @@ -152,11 +152,11 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl * * */ -inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array& pad_before, - tvm::Array pad_after = tvm::Array(), - PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", - std::string tag = kElementWise, std::string pad_mode = "constant", - const Array* dyn_output_shape = nullptr) { +inline tvm::te::Tensor pad( + const tvm::te::Tensor& t, const tvm::ffi::Array& pad_before, + tvm::ffi::Array pad_after = tvm::ffi::Array(), + PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", std::string tag = kElementWise, + std::string pad_mode = "constant", const ffi::Array* dyn_output_shape = nullptr) { if (pad_after.size() < pad_before.size()) { for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { pad_after.push_back(pad_before[i]); @@ -166,8 +166,8 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array pad_before_int32; - tvm::Array pad_after_int32; + tvm::ffi::Array pad_before_int32; + tvm::ffi::Array pad_after_int32; for (const auto& ele : pad_before) { pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); @@ -176,7 +176,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array output_shape; + tvm::ffi::Array output_shape; if (dyn_output_shape == nullptr) { for (size_t i = 0; i < t->shape.size(); ++i) { if (i >= pad_before.size()) { @@ -196,10 +196,10 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Arraydtype, 0); } - auto l = [&](tvm::Array ovars) { - tvm::Array indices; - tvm::Array sel; - tvm::Array pad_idx; + auto l = [&](tvm::ffi::Array ovars) { + tvm::ffi::Array indices; + tvm::ffi::Array sel; + tvm::ffi::Array pad_idx; for (size_t i = 0; i < t->shape.size(); ++i) { if (i >= pad_before_int32.size()) { indices.push_back(ovars[i]); @@ -273,7 +273,7 @@ inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tens ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B W->shape[0], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H @@ -317,7 +317,7 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tens ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W I->shape[2], // B @@ -363,7 +363,7 @@ inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm auto pH = I->shape[2]; auto pW = I->shape[3]; auto pCM = W->shape[1]; // channel_multiplier - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B W->shape[1], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H @@ -392,7 +392,7 @@ inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm auto pH = I->shape[1]; auto pW = I->shape[2]; auto pCM = W->shape[1]; // channel_multiplier - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W @@ -440,7 +440,7 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t ICHECK_EQ(5, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B I->shape[1], // G W->shape[2], // O @@ -454,7 +454,7 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::Array args) { + auto l = [&](tvm::ffi::Array args) { tvm::tir::Var b = args[0]; tvm::tir::Var g = args[1]; tvm::tir::Var o = args[2]; @@ -480,9 +480,9 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t * \return A Tensor whose op member is the space_to_batch_nd operation */ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, - const tvm::Array& block_shape, - const tvm::Array& pad_before, - const tvm::Array& pad_after, + const tvm::ffi::Array& block_shape, + const tvm::ffi::Array& pad_before, + const tvm::ffi::Array& pad_after, PrimExpr pad_value = PrimExpr(), std::string name = "space_to_batch_nd", std::string tag = kInjective) { @@ -490,8 +490,8 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, CHECK_EQ(pad_before.size(), pad_after.size()); CHECK_EQ(block_shape.size(), pad_before.size()) << "Paddings must be provided for each spatial dimension"; - tvm::Array pad_before_int32; - tvm::Array pad_after_int32; + tvm::ffi::Array pad_before_int32; + tvm::ffi::Array pad_after_int32; // pad size for batch dimension is 0 pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); @@ -514,9 +514,9 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, auto padded_shape = padded_t->shape; // infer shapes - tvm::Array r_shape; - tvm::Array axis; - tvm::Array o_shape; + tvm::ffi::Array r_shape; + tvm::ffi::Array axis; + tvm::ffi::Array o_shape; size_t num_block_dims = block_shape.size(); int batch = static_cast(GetConstInt(input_shape[0])); @@ -576,15 +576,15 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, * \return A Tensor whose op member is the batch_to_space_nd operation */ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, - const tvm::Array& block_shape, - const tvm::Array& crop_begin_list, - const tvm::Array& crop_end_list, + const tvm::ffi::Array& block_shape, + const tvm::ffi::Array& crop_begin_list, + const tvm::ffi::Array& crop_end_list, std::string name = "batch_to_space_nd", std::string tag = kInjective) { // Construct shapes for reshape and transpose operation - Array in_shape = data->shape; - Array r_shape; - Array axis; + ffi::Array in_shape = data->shape; + ffi::Array r_shape; + ffi::Array axis; size_t num_block_dims = block_shape.size(); size_t num_input_dims = in_shape.size(); tvm::PrimExpr block_shape_prod(1); @@ -605,7 +605,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, r_shape.push_back(in_shape[i]); } - Array r_p_shape; + ffi::Array r_p_shape; r_p_shape.push_back(batch / block_shape_prod); for (size_t i = 1; i <= num_block_dims; i++) { r_p_shape.push_back(in_shape[i] * block_shape[i - 1]); @@ -620,7 +620,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - Array begin_idx, end_idx, strides; + ffi::Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { strides.push_back(Integer(1)); if (i > 0 && i <= num_block_dims) { @@ -665,7 +665,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T // prediction->shape = (C,), targets->shape = (), weights->shape = (C,) auto T = tvm::te::compute( {}, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tir::Select(c != ignore_index, -predictions(c) * weights(c), tvm::tir::make_const(predictions->dtype, 0)); @@ -674,7 +674,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T if (reduction == "mean") { auto W = tvm::te::compute( {}, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tir::Select(c != ignore_index, weights(c), tvm::tir::make_const(predictions->dtype, 0)); @@ -687,9 +687,9 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T } auto T = tvm::te::compute( targets->shape, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); - tvm::Array pred_indices; + tvm::ffi::Array pred_indices; pred_indices.push_back(target_indices[0]); // batch index pred_indices.push_back(c); // class index for (size_t i = 1; i < target_indices.size(); i++) { @@ -703,16 +703,16 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T if (reduction == "mean") { auto W = tvm::te::compute( targets->shape, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); return tvm::tir::Select(c != ignore_index, weights(c), tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); - return topi::divide(topi::sum(T, tvm::Array(nullptr)), - topi::sum(W, tvm::Array(nullptr))); + return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), + topi::sum(W, tvm::ffi::Array(nullptr))); } else if (reduction == "sum") { - return topi::sum(T, tvm::Array(nullptr)); + return topi::sum(T, tvm::ffi::Array(nullptr)); } else { // reduction == "none" return T; } diff --git a/include/tvm/topi/nn/bnn.h b/include/tvm/topi/nn/bnn.h index 815b8a23c998..2cc494eaa9d4 100644 --- a/include/tvm/topi/nn/bnn.h +++ b/include/tvm/topi/nn/bnn.h @@ -57,7 +57,7 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, arith::Analyzer analyzer; auto n = ishape.size(); - Array oshape; + ffi::Array oshape; for (size_t i = 0; i < n; ++i) { oshape.push_back(i == static_cast(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32)) : ishape[i]); @@ -65,15 +65,15 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, return tvm::te::compute( oshape, - [&](const Array& indices) { - Array start_idx; + [&](const ffi::Array& indices) { + ffi::Array start_idx; for (size_t i = 0; i < n; ++i) { start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 : static_cast(indices[i])); } auto packed = make_const(DataType::UInt(32), 0); for (size_t j = 0; j < 32; ++j) { - Array idx; + ffi::Array idx; for (size_t i = 0; i < n; ++i) { idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) : start_idx[i]); diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index 74c46e2694b3..816d489c400e 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -44,7 +44,7 @@ using namespace tvm::te; * * \return The logical conjunction expression */ -PrimExpr all(Array args) { +PrimExpr all(ffi::Array args) { ICHECK_GT(args.size(), 0) << "all requires at least one argument"; PrimExpr ret = args[0]; @@ -67,13 +67,13 @@ PrimExpr all(Array args) { * * \return The output tensor. */ -inline Tensor dilate(const Tensor& x, Array strides, double dilation_value, +inline Tensor dilate(const Tensor& x, ffi::Array strides, double dilation_value, std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); ICHECK_EQ(n, strides.size()) << "strides size (" << strides.size() << ") must match dimension of x (" << n << ")"; - Array out_shape; + ffi::Array out_shape; arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { out_shape.push_back(analyzer.Simplify((x->shape[i] - 1) * (strides[i] + 1))); @@ -81,9 +81,9 @@ inline Tensor dilate(const Tensor& x, Array strides, double dilation_v return tvm::te::compute( out_shape, - [&](const Array& indices) { - Array not_zero; - Array index_tuple; + [&](const ffi::Array& indices) { + ffi::Array not_zero; + ffi::Array index_tuple; for (size_t i = 0; i < n; ++i) { if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { index_tuple.push_back(indices[i]); diff --git a/include/tvm/topi/nn/flatten.h b/include/tvm/topi/nn/flatten.h index cd96d303b920..e60ae1e1d641 100644 --- a/include/tvm/topi/nn/flatten.h +++ b/include/tvm/topi/nn/flatten.h @@ -54,7 +54,7 @@ inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string dim = dim * ishape[i]; } - Array oshape({ishape[0], dim}); + ffi::Array oshape({ishape[0], dim}); std::vector extra_shape; for (size_t i = 1; i < ishape.size(); ++i) { diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index 9dcc1dda9e43..9c03b682407d 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -37,7 +37,7 @@ namespace nn { using namespace tvm::te; inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - int num_groups, int channel_axis, const Array& axes, + int num_groups, int channel_axis, const ffi::Array& axes, double epsilon, std::string name = "T_group_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; @@ -50,11 +50,11 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& bool is_float16 = data_type == DataType::Float(16); // reshape data C -> G, C/G int ndim = data->shape.size(); - channel_axis = GetRealAxis(static_cast(ndim), Array({channel_axis}))[0]; + channel_axis = GetRealAxis(static_cast(ndim), ffi::Array({channel_axis}))[0]; auto shape = data->shape; auto group_size = floordiv(shape[channel_axis], num_groups); - auto new_shape = Array(); + auto new_shape = ffi::Array(); for (int i = 0; i < ndim; ++i) { if (i == channel_axis) { new_shape.push_back(num_groups); @@ -82,7 +82,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& // get the new axes to normalize after reshape std::vector new_axes{channel_axis + 1}; for (auto axis : axes) { - int new_axis = GetRealAxis(static_cast(ndim), Array({axis}))[0]; + int new_axis = GetRealAxis(static_cast(ndim), ffi::Array({axis}))[0]; if (new_axis < channel_axis) { new_axes.push_back(new_axis); } else if (new_axis > channel_axis) { @@ -100,8 +100,9 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& MakeReduceTargetShape(new_axes, data_reshaped, /*keepdims=*/false, /*atleast1d=*/true); auto func = MakeTupleSumReducer(); - auto compute = [ndim, &new_axes, &reduce_axes, &func, &data_reshaped](const Array& indices) { - Array eval_range; + auto compute = [ndim, &new_axes, &reduce_axes, &func, + &data_reshaped](const ffi::Array& indices) { + ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -129,8 +130,8 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& for (auto axis : new_axes) { reduce_extent *= data_reshaped->shape[axis]; } - auto group_norm_func = [&](const Array& indices) { - Array reduce_indices, non_reduce_indices, gamma_indices; + auto group_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices, gamma_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) { reduce_indices.push_back(indices[i]); diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index d400721215ec..c6a10ec89f0a 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -51,7 +51,7 @@ using namespace tvm::te; * \return The normalized tensor, with the same shape as data. */ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - int channel_axis, const Array& axis, double epsilon, + int channel_axis, const ffi::Array& axis, double epsilon, std::string name = "T_instance_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; @@ -71,8 +71,8 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso auto func = MakeTupleSumReducer(); auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, - &data](const Array& indices) { - Array eval_range; + &data](const ffi::Array& indices) { + ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -110,8 +110,8 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso for (int i : real_axis) { reduce_extent *= data->shape[i]; } - auto instance_norm_func = [&](const Array& indices) { - Array reduce_indices, non_reduce_indices; + auto instance_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index f1b0e4ac9eaa..7caa30b0a23b 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -49,7 +49,7 @@ using namespace tvm::te; * \return The normalized tensor, with the same shape as data. */ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - const Array& axis, double epsilon, + const ffi::Array& axis, double epsilon, std::string name = "T_layer_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; @@ -69,8 +69,8 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto func = MakeTupleSumReducer(); auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, - &data](const Array& indices) { - Array eval_range; + &data](const ffi::Array& indices) { + ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -108,8 +108,8 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& for (int i : real_axis) { reduce_extent *= data->shape[i]; } - auto layer_norm_func = [&](const Array& indices) { - Array reduce_indices, non_reduce_indices; + auto layer_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { reduce_indices.push_back(indices[i]); diff --git a/include/tvm/topi/nn/local_response_norm.h b/include/tvm/topi/nn/local_response_norm.h index a9d72250bbb0..119ab0c19eb0 100644 --- a/include/tvm/topi/nn/local_response_norm.h +++ b/include/tvm/topi/nn/local_response_norm.h @@ -57,8 +57,8 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; ICHECK(data->dtype.is_float()) << "datatype should be float"; auto input_shape = data->shape; - Array pad_before{0, 0, 0, 0}; - Array pad_after{0, 0, 0, 0}; + ffi::Array pad_before{0, 0, 0, 0}; + ffi::Array pad_after{0, 0, 0, 0}; pad_before.Set(axis, static_cast(size / 2)); pad_after.Set(axis, static_cast(size / 2)); auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 8e13ae49afdf..b977a54a5920 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -47,8 +47,9 @@ enum PoolType : int { }; inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, - const Array& kernel_size, const Array& stride_size, - const Array& padding_size, PoolType pool_type, + const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad) { ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; @@ -77,11 +78,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, pad_right += stride_width - 1; } - Array pad_before(std::vector(x->shape.size(), 0)); + ffi::Array pad_before(std::vector(x->shape.size(), 0)); pad_before.Set(height_axis, pad_top); pad_before.Set(width_axis, pad_left); - Array pad_after(std::vector(x->shape.size(), 0)); + ffi::Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); arith::Analyzer analyzer; @@ -93,8 +94,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh"); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw"); - Array data_shape = x->shape; - Array out_shape = data_shape; + ffi::Array data_shape = x->shape; + ffi::Array out_shape = data_shape; out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); @@ -106,7 +107,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { - Array ravel_shape{data_shape.begin(), data_shape.end()}; + ffi::Array ravel_shape{data_shape.begin(), data_shape.end()}; ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); @@ -120,8 +121,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto mp_argmax = tvm::te::compute( out_shape, - [&](const Array& inds) { - Array window_inds{inds.begin(), inds.end()}; + [&](const ffi::Array& inds) { + ffi::Array window_inds{inds.begin(), inds.end()}; window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); auto idx = detail::RavelIndex(window_inds, ravel_shape); @@ -133,13 +134,13 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, return tvm::te::compute( data_shape, - [&](const Array& inds) { - Array pad_inds{inds.begin(), inds.end()}; + [&](const ffi::Array& inds) { + ffi::Array pad_inds{inds.begin(), inds.end()}; pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left); auto idx = detail::RavelIndex(pad_inds, ravel_shape); - Array out_idx{inds.begin(), inds.end()}; + ffi::Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); @@ -165,12 +166,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww"); return tvm::te::compute( data_shape, - [&](const Array& inds) { + [&](const ffi::Array& inds) { PrimExpr pad_h_idx = inds[height_axis] + pad_top; PrimExpr pad_w_idx = inds[width_axis] + pad_left; // output indices whose pooling windows cover current input element (can be out-of-bound) - Array out_idx{inds.begin(), inds.end()}; + ffi::Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); @@ -290,9 +291,11 @@ inline bool find_width(const std::string& layout, int* width_axis) { * * \return The output tensor in the same layout */ -inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, - PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", +inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, + const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& padding_size, PoolType pool_type, + bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; @@ -319,24 +322,24 @@ inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const Prim * * \return The output tensor in same layout order */ -inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_size, +inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::vector& axes) { const auto n_dim = output_size.size(); ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; - Array data_shape = x->shape; - Array out_shape = data_shape; - Array in_size, out_size; + ffi::Array data_shape = x->shape; + ffi::Array out_shape = data_shape; + ffi::Array in_size, out_size; for (size_t i = 0; i < n_dim; ++i) { in_size.push_back(data_shape[axes[i]]); out_size.push_back(output_size[i]); out_shape.Set(axes[i], out_size[i]); } - auto get_iter_vars = [=](const Array& output, bool reduce_indices) { - Array indices; + auto get_iter_vars = [=](const ffi::Array& output, bool reduce_indices) { + ffi::Array indices; for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]); - Array reduce_axes; + ffi::Array reduce_axes; for (size_t i = 0; i < n_dim; ++i) { auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]); auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]); @@ -350,25 +353,25 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ return std::make_tuple(indices, reduce_axes); }; - Map attrs; + ffi::Map attrs; if (pool_type == kMaxPool) { - attrs.Set("schedule_rule", tvm::String("meta_schedule.adaptive_pool_max")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.adaptive_pool_max")); return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; - Array reduce_axes; + [&](const ffi::Array& output) { + ffi::Array indices; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, true); return tvm::max(x(indices), reduce_axes); // NOLINT(*) }, "adaptive_pool_max", "adaptive_pool_max", attrs); } else if (pool_type == kAvgPool) { - attrs.Set("schedule_rule", tvm::String("meta_schedule.adaptive_pool_avg")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.adaptive_pool_avg")); auto pool_sum = tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; - Array reduce_axes; + [&](const ffi::Array& output) { + ffi::Array indices; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, true); return tvm::sum(x(indices), reduce_axes); }, @@ -376,9 +379,9 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; - Array reduce_axes; + [&](const ffi::Array& output) { + ffi::Array indices; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, false); PrimExpr divide_factor = tvm::cast(x->dtype, 1); @@ -421,8 +424,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ * * \return The output tensor in same layout order */ -inline Tensor adaptive_pool(const Tensor& x, const Array& output_size, PoolType pool_type, - const std::string& layout = "NCHW") { +inline Tensor adaptive_pool(const Tensor& x, const ffi::Array& output_size, + PoolType pool_type, const std::string& layout = "NCHW") { int height_axis = -1, width_axis = -1; ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis}); @@ -436,7 +439,7 @@ inline Tensor adaptive_pool(const Tensor& x, const Array& output_size, * \param pool_type The type of pooling operator * \param layout The input layout. The default is "NCDHW". */ -inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_size, +inline Tensor adaptive_pool3d(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::string& layout = "NCDHW") { int depth_axis = -1, height_axis = -1, width_axis = -1; ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) @@ -452,7 +455,7 @@ inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_siz * \param pool_type The type of pooling operator * \param layout The input layout. The default is "NCW". */ -inline Tensor adaptive_pool1d(const Tensor& x, const Array& output_size, +inline Tensor adaptive_pool1d(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::string& layout = "NCW") { int width_axis = -1; ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; @@ -485,7 +488,7 @@ inline Tensor adaptive_pool1d(const Tensor& x, const Array& output_siz * e.g., for NCHW, the output shape will be [batch, channel, 1, 1] */ inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") { - return adaptive_pool(x, Array{1, 1}, pool_type, layout); + return adaptive_pool(x, ffi::Array{1, 1}, pool_type, layout); } /*! @@ -504,10 +507,11 @@ inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string * * \return The output tensor in same layout order */ -inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, - const std::vector& axis, bool count_include_pad) { +inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, + bool ceil_mode, const std::vector& axis, bool count_include_pad) { int k_size = kernel_size.size(); int x_size = x->shape.size(); ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; @@ -515,17 +519,17 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, " kernel"; ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; - Array daxis; + ffi::Array daxis; std::vector kernel(k_size); std::vector stride(k_size); std::vector dilation(k_size); std::vector pad_head(k_size); std::vector pad_tail(k_size); std::vector offset(k_size, 0); - Array pad_before(std::vector(x_size, 0)); - Array pad_after(std::vector(x_size, 0)); - Array data_shape = x->shape; - Array out_shape = data_shape; + ffi::Array pad_before(std::vector(x_size, 0)); + ffi::Array pad_after(std::vector(x_size, 0)); + ffi::Array data_shape = x->shape; + ffi::Array out_shape = data_shape; bool do_pad = false; for (int i = 0; i < k_size; i++) { @@ -563,14 +567,14 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, out_shape.Set(ii, out_dim); } - Map attrs; + ffi::Map attrs; if (pool_type == kMaxPool) { auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - attrs.Set("schedule_rule", tvm::String("meta_schedule.pool_max")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_max")); return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; + [&](const ffi::Array& output) { + ffi::Array indices; for (const Var& var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { @@ -581,15 +585,15 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, }, "pool_max", "pool_max", attrs); } else if (pool_type == kAvgPool) { - attrs.Set("schedule_rule", tvm::String("meta_schedule.pool_avg")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_avg")); // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. auto pool_sum = tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; + [&](const ffi::Array& output) { + ffi::Array indices; for (const Var& var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { @@ -603,8 +607,8 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, // TVM compute for dividing the reduced window sum by kernel size. return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; + [&](const ffi::Array& output) { + ffi::Array indices; for (const Var& var : output) indices.push_back(var); if (count_include_pad) { std::vector start(k_size); @@ -687,9 +691,10 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, * * \return The output tensor in the same layout */ -inline Tensor pool1d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, +inline Tensor pool1d(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", bool count_include_pad = true) { int width_axis = -1; ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; @@ -728,9 +733,10 @@ inline Tensor pool1d(const Tensor& x, const Array& kernel_size, * * \return The output tensor in the same layout */ -inline Tensor pool2d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, +inline Tensor pool2d(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; @@ -770,9 +776,10 @@ inline Tensor pool2d(const Tensor& x, const Array& kernel_size, * * \return The output tensor in the same layout */ -inline Tensor pool3d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, +inline Tensor pool3d(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", bool count_include_pad = true) { int depth_axis = -1, height_axis = -1, width_axis = -1; ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index 7e95000f1ee2..66a2ae62dfec 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -47,7 +47,7 @@ using namespace tvm::te; * \param tag The tag to mark the operation. * \return The normalized tensor, with the same shape as data. */ -inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array& axis, +inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Array& axis, double epsilon, std::string name = "T_rms_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; @@ -67,8 +67,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Arrayshape[i]; } - auto rsqrt_func = [&](const Array& indices) { - Array non_reduce_indices; + auto rsqrt_func = [&](const ffi::Array& indices) { + ffi::Array non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { non_reduce_indices.push_back(indices[i]); @@ -78,7 +78,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array(); + auto rsqrt_shape = ffi::Array(); for (int i = 0, n = static_cast(data_fp32->shape.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { rsqrt_shape.push_back(data_fp32->shape[i]); @@ -86,8 +86,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array& indices) { - Array reduce_indices, non_reduce_indices; + auto rms_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { reduce_indices.push_back(indices[i]); diff --git a/include/tvm/topi/nn/softmax.h b/include/tvm/topi/nn/softmax.h index 6679b84c8d03..f58d66ece139 100644 --- a/include/tvm/topi/nn/softmax.h +++ b/include/tvm/topi/nn/softmax.h @@ -60,11 +60,12 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2"); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); - tvm::Map attrs; + tvm::ffi::Map attrs; attrs.Set("axis", Integer(axis)); - auto insert_reduce_index = [axis, ndim](const Array& indices, const IterVar& reduce_index) { - Array eval_range; + auto insert_reduce_index = [axis, ndim](const ffi::Array& indices, + const IterVar& reduce_index) { + ffi::Array eval_range; int arg_counter = 0; for (size_t i = 0; i < ndim; ++i) { if (static_cast(i) == axis) { @@ -76,41 +77,41 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor return eval_range; }; - auto get_non_reduce_indices = [axis, ndim](const Array& indices) { - Array non_reduce_indices; + auto get_non_reduce_indices = [axis, ndim](const ffi::Array& indices) { + ffi::Array non_reduce_indices; for (size_t i = 0; i < ndim; ++i) { if (static_cast(i) != axis) non_reduce_indices.push_back(indices[i]); } return non_reduce_indices; }; - auto _compute_max = [&](const Array& indices) { + auto _compute_max = [&](const ffi::Array& indices) { auto eval_range = insert_reduce_index(indices, k1); return topi::MaxOp(x(eval_range), {k1}); }; - auto _compute_exp = [&](const Tensor& max_elem, const Array& indices) { + auto _compute_exp = [&](const Tensor& max_elem, const ffi::Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return tvm::exp(x(indices) - max_elem(non_reduce_indices)); }; - auto _compute_expsum = [&](const Tensor& exp, const Array& indices) { + auto _compute_expsum = [&](const Tensor& exp, const ffi::Array& indices) { auto eval_range = insert_reduce_index(indices, k2); return tvm::sum(exp(eval_range), {k2}); }; - auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array& indices) { + auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const ffi::Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return exp(indices) / expsum(non_reduce_indices); }; auto max_elem = tvm::te::compute(reduced_shape, _compute_max); auto exp = tvm::te::compute( - input_shape, [&](const Array& indices) { return _compute_exp(max_elem, indices); }); + input_shape, [&](const ffi::Array& indices) { return _compute_exp(max_elem, indices); }); auto expsum = tvm::te::compute( - reduced_shape, [&](const Array& indices) { return _compute_expsum(exp, indices); }); + reduced_shape, [&](const ffi::Array& indices) { return _compute_expsum(exp, indices); }); return tvm::te::compute( - input_shape, [&](const Array& indices) { return _normalize(exp, expsum, indices); }, + input_shape, [&](const ffi::Array& indices) { return _normalize(exp, expsum, indices); }, name, tag, attrs); } @@ -132,7 +133,7 @@ inline Tensor log_softmax(const Tensor& x, std::string name = "tensor", auto k = tvm::te::reduce_axis(Range(0, n), "k"); auto max_elem = - tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array{k}); }); + tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), ffi::Array{k}); }); k = tvm::te::reduce_axis(Range(0, n), "k"); auto expsum = diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 277de68e972e..fda754061bbe 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -43,12 +43,12 @@ namespace topi { using namespace tvm::te; /*! \brief The operation to use for CommReduce */ -using FReduce = std::function& axis, - Array init, Span span)>; +using FReduce = std::function& axis, + ffi::Array init, Span span)>; /*! \brief The operation to use for CommReduceIdx */ -using FCommReduce = std::function(Array exprs, const Array& axis, - PrimExpr* condition)>; +using FCommReduce = std::function( + ffi::Array exprs, const ffi::Array& axis, PrimExpr* condition)>; /*! * \brief Convert a reduction axis which could be empty or have negative @@ -62,7 +62,7 @@ using FCommReduce = std::function(Array exprs, const A * If any input element is negative, it will be treated as an offset from the * last dimension (same as python indexing rules). */ -inline std::vector GetRealAxis(int ndim, const Optional>& axis) { +inline std::vector GetRealAxis(int ndim, const ffi::Optional>& axis) { std::vector real_axis; if (!axis.has_value()) { for (int i = 0; i < ndim; ++i) { @@ -86,8 +86,8 @@ inline std::vector GetRealAxis(int ndim, const Optional>& ax } /*! \brief Enumerate the axes for a reduce op */ -inline Array MakeReduceAxes(const std::vector& real_axis, const Tensor& data) { - Array reduce_axes; +inline ffi::Array MakeReduceAxes(const std::vector& real_axis, const Tensor& data) { + ffi::Array reduce_axes; for (auto i : real_axis) { std::string name = "k" + std::to_string(i); reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name)); @@ -96,10 +96,11 @@ inline Array MakeReduceAxes(const std::vector& real_axis, const Te } /*! \brief Calculate the target shape for a reduce op */ -inline Array MakeReduceTargetShape(const std::vector& real_axis, const Tensor& data, - bool keepdims, bool atleast1d) { +inline ffi::Array MakeReduceTargetShape(const std::vector& real_axis, + const Tensor& data, bool keepdims, + bool atleast1d) { auto ndim = data->shape.size(); - Array target_shape; + ffi::Array target_shape; if (keepdims) { for (size_t i = 0; i < ndim; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { @@ -136,13 +137,14 @@ inline Array MakeReduceTargetShape(const std::vector& real_axis, * * \return The result tensor. */ -inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array& target_shape, +inline Tensor DoCommReduce(const Tensor& data, FReduce func, + const ffi::Array& target_shape, const std::vector& reduce_axes, const std::vector& squeeze_axes, Span span = Span()) { auto r_axes = MakeReduceAxes(reduce_axes, data); - auto compute = [&](const Array& indices) { - Array eval_range; - Array eval_indices; + auto compute = [&](const ffi::Array& indices) { + ffi::Array eval_range; + ffi::Array eval_indices; int arg_counter = 0; int red_counter = 0; @@ -179,8 +181,8 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array>& axis, FReduce func, - bool keepdims, bool atleast1d) { +inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, + FReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); @@ -202,7 +204,7 @@ inline Tensor CommReduce(const Tensor& data, const Optional>& axi * * \return The result tensor. */ -inline Tensor CommReduceIdx(const Tensor& data, const Optional>& axis, +inline Tensor CommReduceIdx(const Tensor& data, const ffi::Optional>& axis, FCommReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; @@ -211,9 +213,9 @@ inline Tensor CommReduceIdx(const Tensor& data, const Optional>& auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, - &data](const Array& indices) { - Array eval_range; - Array eval_indices; + &data](const ffi::Array& indices) { + ffi::Array eval_range; + ffi::Array eval_indices; int arg_counter = 0; int red_counter = 0; @@ -233,7 +235,7 @@ inline Tensor CommReduceIdx(const Tensor& data, const Optional>& } } - Array ravel_shape; + ffi::Array ravel_shape; for (auto i : real_axis) { ravel_shape.push_back(data->shape[i]); } @@ -246,15 +248,15 @@ inline Tensor CommReduceIdx(const Tensor& data, const Optional>& auto temp_idx = temp_idx_val[0]; auto temp_val = temp_idx_val[1]; return tvm::te::compute( - target_shape, [&temp_idx](const Array& indices) { return temp_idx(indices); }, + target_shape, [&temp_idx](const ffi::Array& indices) { return temp_idx(indices); }, data->op->name + "_red", kCommReduceIdx); } /*! \brief A combiner function for a reduction */ -using FCombine = std::function(Array lhs, Array rhs)>; +using FCombine = std::function(ffi::Array lhs, ffi::Array rhs)>; /*! \brief An initializer function for a reduction */ -using FIdentity = std::function(std::vector types)>; +using FIdentity = std::function(std::vector types)>; /*! * \brief Create a commutative reducer for a reduction @@ -267,9 +269,9 @@ using FIdentity = std::function(std::vector types)>; */ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name = "reduce") { - return [fcombine, fidentity, name](Array exprs, const Array& axis, + return [fcombine, fidentity, name](ffi::Array exprs, const ffi::Array& axis, PrimExpr* condition) { - Array lhs, rhs; + ffi::Array lhs, rhs; std::vector dtypes; for (size_t i = 0; i < exprs.size(); ++i) { @@ -284,7 +286,7 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, auto cond = condition != nullptr ? *condition : tir::const_true(); auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem); - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast(i), {})); } @@ -293,19 +295,19 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, } /*! \brief Wrap tvm::min to ensure we get the correct overload */ -inline PrimExpr MinOp(PrimExpr source, Array axis, Array init = {}, +inline PrimExpr MinOp(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()) { return tvm::min(source, axis, init, span); } /*! \brief Wrap tvm::max to ensure we get the correct overload */ -inline PrimExpr MaxOp(PrimExpr source, Array axis, Array init = {}, +inline PrimExpr MaxOp(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()) { return tvm::max(source, axis, init, span); // NOLINT(*) } /*! \brief Wrap tvm::prod to ensure we get the correct overload */ -inline PrimExpr ProdOp(PrimExpr source, Array axis, Array init = {}, +inline PrimExpr ProdOp(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()) { return tvm::prod(source, axis, init, span); // NOLINT(*) } @@ -323,8 +325,8 @@ inline PrimExpr ProdOp(PrimExpr source, Array axis, Array ini * * \return A Tensor whose op member is the sum operation */ -inline Tensor sum(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor sum(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { if (data->dtype.is_bool()) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } else { @@ -332,7 +334,7 @@ inline Tensor sum(const Tensor& data, const Optional>& axis, bool } } -inline Tensor collapse_sum(const Tensor& data, Array target_shape) { +inline Tensor collapse_sum(const Tensor& data, ffi::Array target_shape) { const auto& ishape = data->shape; const auto& oshape = target_shape; int isize = data->shape.size(); @@ -380,8 +382,8 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { * * \return A Tensor whose op member is the all operation */ -inline Tensor all(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor all(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::all, keepdims, atleast1d); } @@ -399,8 +401,8 @@ inline Tensor all(const Tensor& data, const Optional>& axis, bool * * \return A Tensor whose op member is the all operation */ -inline Tensor any(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor any(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } @@ -418,8 +420,8 @@ inline Tensor any(const Tensor& data, const Optional>& axis, bool * * \return A Tensor whose op member is the min operation */ -inline Tensor min(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor min(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MinOp, keepdims, atleast1d); } @@ -437,15 +439,15 @@ inline Tensor min(const Tensor& data, const Optional>& axis, bool * * \return A Tensor whose op member is the max operation */ -inline Tensor max(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor max(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [=](Array lhs, Array rhs) { - Array result; + auto fcombine = [=](ffi::Array lhs, ffi::Array rhs) { + ffi::Array result; // Casting to avoid operator ambiguity PrimExpr lhs_idx = static_cast(lhs[0]); @@ -473,7 +475,7 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { return result; }; auto fidentity = [&](std::vector types) { - Array result; + ffi::Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::max_value(types[1])); // val return result; @@ -497,7 +499,7 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { * * \return A Tensor whose op member is the argmin operation */ -inline Tensor argmin(const Tensor& data, const Optional>& axis, +inline Tensor argmin(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false, bool select_last_index = false) { auto reducer = MakeArgminReducer(select_last_index); @@ -506,8 +508,8 @@ inline Tensor argmin(const Tensor& data, const Optional>& axis, inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [=](Array lhs, Array rhs) { - Array result; + auto fcombine = [=](ffi::Array lhs, ffi::Array rhs) { + ffi::Array result; // Casting to avoid operator ambiguity PrimExpr lhs_idx = static_cast(lhs[0]); @@ -535,7 +537,7 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { return result; }; auto fidentity = [&](std::vector types) { - Array result; + ffi::Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val return result; @@ -558,7 +560,7 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ -inline Tensor argmax(const Tensor& data, const Optional>& axis, +inline Tensor argmax(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false, bool select_last_index = false) { auto reducer = MakeArgmaxReducer(select_last_index); @@ -578,8 +580,8 @@ inline Tensor argmax(const Tensor& data, const Optional>& axis, * * \return A Tensor whose op member is the prod operation */ -inline Tensor prod(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor prod(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, ProdOp, keepdims, atleast1d); } @@ -587,8 +589,8 @@ inline Tensor prod(const Tensor& data, const Optional>& axis, boo * \brief Create communitive reducer summing over tuples */ inline FCommReduce MakeTupleSumReducer() { - auto fcombine = [](Array lhs, Array rhs) { - Array result; + auto fcombine = [](ffi::Array lhs, ffi::Array rhs) { + ffi::Array result; ICHECK_EQ(lhs.size(), rhs.size()); result.reserve(lhs.size()); for (size_t i = 0; i < lhs.size(); ++i) { @@ -597,7 +599,7 @@ inline FCommReduce MakeTupleSumReducer() { return result; }; auto fidentity = [](std::vector types) { - Array result; + ffi::Array result; for (size_t i = 0; i < types.size(); ++i) { result.push_back(tvm::tir::make_const(types[i], 0)); } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index df637f6f5862..4d0678099582 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -73,8 +73,8 @@ using namespace topi::detail; * * \return A Tensor whose op member is the sliding_window operation */ -inline Tensor sliding_window(const Tensor& x, int axis, Array window_shape, - Array strides, std::string name = "T_sliding_window", +inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array window_shape, + ffi::Array strides, std::string name = "T_sliding_window", std::string tag = "") { CHECK_GE(axis, 0); auto _axis = size_t(axis); @@ -85,7 +85,7 @@ inline Tensor sliding_window(const Tensor& x, int axis, Array window_sh CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length."; // Compute the new shape. - Array new_shape; + ffi::Array new_shape; // Dimensions up until `axis` remain the same. for (size_t i = 0; i < _axis; ++i) { new_shape.push_back(x->shape[i]); @@ -113,9 +113,9 @@ inline Tensor sliding_window(const Tensor& x, int axis, Array window_sh return compute( new_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { // The index at which to index the old tensor x. - Array idx; + ffi::Array idx; // Dimensions up until `axis` remain the same. for (size_t i = 0; i < _axis; ++i) { @@ -164,7 +164,7 @@ inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, // Calculate offset from last dimension axis = ndim + axis + 1; } - Array new_shape; + ffi::Array new_shape; for (size_t i = 0; i < static_cast(axis); ++i) { new_shape.push_back(x->shape[i]); } @@ -177,8 +177,8 @@ inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -201,16 +201,16 @@ inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, * * \return A Tensor whose op member is the transpose operation */ -inline Tensor transpose(const Tensor& x, Optional> opt_axes, +inline Tensor transpose(const Tensor& x, ffi::Optional> opt_axes, std::string name = "T_transpose", std::string tag = kInjective) { - Array axes = opt_axes.value_or({}); + ffi::Array axes = opt_axes.value_or({}); if (axes.size() == 0) { for (int i = static_cast(x->shape.size()) - 1; i >= 0; --i) { axes.push_back(i); } } - Array new_shape; + ffi::Array new_shape; for (size_t i = 0; i < axes.size(); ++i) { int axis = static_cast(axes[i]->value); int new_axis = axis; @@ -232,7 +232,7 @@ inline Tensor transpose(const Tensor& x, Optional> opt_axes, return compute( new_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { std::vector idx; for (size_t i = 0; i < axes.size(); ++i) { idx.push_back(1); @@ -292,8 +292,8 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast(x->shape.size()) << "-dimensional input tensor"; - auto func = [&](const Array& indices) { - Array real_indices; + auto func = [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < src_tensor_dim; ++i) { if (i == static_cast(seq_axis)) { if (seq_lengths.defined()) { @@ -325,10 +325,10 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s * * \return A Tensor whose op member is the reshape operation */ -inline Tensor reshape(const Tensor& x, Array newshape, std::string name = "T_reshape", - std::string tag = kInjective) { +inline Tensor reshape(const Tensor& x, ffi::Array newshape, + std::string name = "T_reshape", std::string tag = kInjective) { auto x_shape = x->shape; - Array target_shape; + ffi::Array target_shape; for (const auto& ele : newshape) { target_shape.push_back(ele); @@ -337,13 +337,15 @@ inline Tensor reshape(const Tensor& x, Array newshape, std::string nam // If either the input shape or the target shape contains a zero, return an empty tensor. if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) { return compute( - target_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); + target_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + tag); } else { return compute( target_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { return x(UnravelIndex( - RavelIndex(Array{indices.begin(), indices.end()}, target_shape), x_shape)); + RavelIndex(ffi::Array{indices.begin(), indices.end()}, target_shape), + x_shape)); }, name, tag); } @@ -365,13 +367,13 @@ inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string na auto x_shape = x->shape; auto shape_shape = shape->shape; - Array oshape; + ffi::Array oshape; oshape.push_back(shape_shape[0]); if (x_shape.size() != 0) { oshape.push_back(x_shape[0]); } - auto func = [&](const Array& indices) { + auto func = [&](const ffi::Array& indices) { auto i = indices[0]; std::vector indices_divs; PrimExpr ret = 0; @@ -408,8 +410,9 @@ inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string na * * \return A Tensor whose op member is the squeeze operation */ -inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool atleast1d = false, - std::string name = "T_squeeze", std::string tag = kInjective) { +inline Tensor squeeze(const Tensor& x, ffi::Optional> opt_axes, + bool atleast1d = false, std::string name = "T_squeeze", + std::string tag = kInjective) { auto ndim = x->shape.size(); std::vector axis_val; if (!opt_axes.has_value()) { @@ -419,22 +422,23 @@ inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool a } } } else { - Array axis = *std::move(opt_axes); + ffi::Array axis = *std::move(opt_axes); for (size_t i = 0; i < axis.size(); ++i) { int64_t val = axis[i]->value; if (val < 0) { val += static_cast(x->shape.size()); } - if (IsConstInt(x->shape[val])) { - ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1"; + // If a dimension is not 1, silently skip it (no-op). + bool is_const = IsConstInt(x->shape[val]); + if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) { + axis_val.push_back(val); } - axis_val.push_back(val); } } std::unordered_set axis_set(axis_val.begin(), axis_val.end()); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < ndim; ++i) { if (axis_set.count(static_cast(i)) == 0) { out_shape.push_back(x->shape[i]); @@ -446,8 +450,8 @@ inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool a return compute( out_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; int flag = 0; for (size_t i = 0; i < ndim; ++i) { if (axis_set.count(static_cast(i)) == 0) { @@ -472,8 +476,8 @@ inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool a * * \return A Tensor whose op member is the concatenate operation */ -inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name = "T_concat", - std::string tag = kInjective) { +inline Tensor concatenate(const ffi::Array& inputs, int axis = 0, + std::string name = "T_concat", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" << ", but got axis = " << axis << ", and ndim = " << ndim; @@ -482,7 +486,7 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string } ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; - Array axis_sizes; + ffi::Array axis_sizes; for (auto t : inputs) { axis_sizes.push_back(t->shape[axis]); } @@ -492,20 +496,20 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string join_size += axis_sizes[i]; } join_size = analyzer.Simplify(join_size); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); } return compute( out_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { auto ret = inputs[0](indices); auto ind = indices[axis]; for (size_t i = 0; i < inputs.size() - 1; ++i) { ind -= axis_sizes[i]; - Array idx; + ffi::Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -531,7 +535,7 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string * * \return A Tensor whose op member is the stack operation */ -inline Tensor stack(const Array& inputs, int axis = 0, std::string name = "T_stack", +inline Tensor stack(const ffi::Array& inputs, int axis = 0, std::string name = "T_stack", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); ICHECK(-ndim - 1 <= axis && axis <= ndim) @@ -543,7 +547,7 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds"; const int stack_size = static_cast(inputs.size()); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < static_cast(axis); ++i) out_shape.push_back(inputs[0]->shape[i]); out_shape.push_back(stack_size); for (size_t i = static_cast(axis); i < static_cast(ndim); ++i) @@ -551,8 +555,8 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name return compute( out_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; for (size_t i = 0; i < indices.size(); ++i) if (i != static_cast(axis)) idx.push_back(indices[i]); auto ind = indices[axis]; @@ -577,9 +581,9 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name * * \return A Tensor whose op member is the split operation */ -inline Array split_indices_array(const Tensor& x, Array split_indices, int axis, - std::string name = "T_split", - std::string tag = kInjective) { +inline ffi::Array split_indices_array(const Tensor& x, ffi::Array split_indices, + int axis, std::string name = "T_split", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -598,7 +602,7 @@ inline Array split_indices_array(const Tensor& x, Array split_ begin_ids.push_back(idx); } - Array> out_shapes; + ffi::Array> out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { PrimExpr out_axis_size; if (i == begin_ids.size() - 1) { @@ -607,7 +611,7 @@ inline Array split_indices_array(const Tensor& x, Array split_ out_axis_size = begin_ids[i + 1] - begin_ids[i]; } - Array shape; + ffi::Array shape; for (size_t i = 0; i < static_cast(axis); ++i) { shape.push_back(x->shape[i]); } @@ -619,13 +623,13 @@ inline Array split_indices_array(const Tensor& x, Array split_ out_shapes.push_back(shape); } - Array result; + ffi::Array result; for (size_t i = 0; i < begin_ids.size(); ++i) { result.push_back(compute( out_shapes[i], - [&](const Array& indices) { + [&](const ffi::Array& indices) { auto begin = begin_ids[i]; - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(indices[j]); } @@ -706,10 +710,11 @@ inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExp * * \return A Tensor whose op member is the dynamic_strided_slice operation */ -inline Tensor dynamic_strided_slice_with_axes( - const Tensor& x, const Array& begin, const Array& end, - const Array& strides, const Array& axes, bool assume_inbound = true, - std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = kInjective) { +inline te::Tensor dynamic_strided_slice_with_axes( + const te::Tensor& x, const ffi::Array& begin, const ffi::Array& end, + const ffi::Array& strides, const ffi::Array& axes, + bool assume_inbound = true, std::string name = "T_dynamic_strided_slice_with_axes", + std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); ICHECK_EQ(begin.size(), end.size()); ICHECK_EQ(begin.size(), strides.size()); @@ -723,7 +728,7 @@ inline Tensor dynamic_strided_slice_with_axes( arith::Analyzer analyzer; - Array out_shape = x->shape; + ffi::Array out_shape = x->shape; for (size_t i = 0; i < begin.size(); i++) { int axis = axes[i]->value; PrimExpr new_shape = @@ -733,8 +738,9 @@ inline Tensor dynamic_strided_slice_with_axes( return te::compute( out_shape, - [&](const Array& indices) { - Array real_indices = indices.Map([](const auto& var) -> PrimExpr { return var; }); + [&](const ffi::Array& indices) { + ffi::Array real_indices = + indices.Map([](const auto& var) -> PrimExpr { return var; }); for (size_t i = 0; i < begin.size(); i++) { int axis = axes[i]->value; @@ -761,9 +767,9 @@ inline Tensor dynamic_strided_slice_with_axes( * * \return A Tensor whose op member is the dynamic_strided_slice operation */ -inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - bool assume_inbound = true, +inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, bool assume_inbound = true, std::string name = "T_dynamic_strided_slice", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); @@ -774,7 +780,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi ICHECK_EQ(begin.size(), strides.size()); const size_t num_slice_axes = begin.size(); - Array out_shape; + ffi::Array out_shape; arith::Analyzer analyzer; for (size_t i = 0; i < num_slice_axes; ++i) { @@ -794,8 +800,8 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi return te::compute( out_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < num_slice_axes; ++i) { real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); } @@ -832,7 +838,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); - Array begin_expr, end_expr, strides_expr; + ffi::Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { auto ind = make_const(index_dtype, i); begin_expr.push_back(begin(ind)); @@ -856,9 +862,12 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b * * \return The output shape of strided_slice using the arguments above */ -inline Array StridedSliceOutputShape( - const Array& ishape, const Array& begin, const Array& end, - const Array& strides, const Array& axes, const std::string& slice_mode) { +inline ffi::Array StridedSliceOutputShape(const ffi::Array& ishape, + const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, + const ffi::Array& axes, + const std::string& slice_mode) { ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); @@ -884,9 +893,11 @@ inline Array StridedSliceOutputShape( * * \return A Tensor whose op member is the sstrided_slice operation */ -inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - const Array& axes, std::string slice_mode = "end", +inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, + const ffi::Array& axes, + std::string slice_mode = "end", std::string name = "T_strided_slice_with_axes", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); @@ -903,8 +914,8 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg return te::compute( out_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < axes.size(); ++i) { auto stride = make_const(strides[i].dtype(), strides_vec[i]); @@ -930,15 +941,16 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg * * \return A Tensor whose op member is the strided_slice operation */ -inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, - const Array& strides, std::string slice_mode = "end", - std::string name = "T_strided_slice", std::string tag = kInjective) { +inline Tensor strided_slice(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, const ffi::Array& strides, + std::string slice_mode = "end", std::string name = "T_strided_slice", + std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); - Array axes; + ffi::Array axes; for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); - Array begin_full(begin); - Array end_full(end); - Array strides_full(strides); + ffi::Array begin_full(begin); + ffi::Array end_full(end); + ffi::Array strides_full(strides); DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64); const IntImm one = IntImm(index_dtype, 1); @@ -971,9 +983,9 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const * * \return A Tensor whose op member is the split operation */ -inline Array split_n_sections(const Tensor& x, int num_sections, int axis, - std::string name = "T_split_sections", - std::string tag = kInjective) { +inline ffi::Array split_n_sections(const Tensor& x, int num_sections, int axis, + std::string name = "T_split_sections", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -983,7 +995,7 @@ inline Array split_n_sections(const Tensor& x, int num_sections, int axi ICHECK_GT(num_sections, 0) << "Slice count must be > 0"; - Array split_indices; + ffi::Array split_indices; auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections); for (int i = 0; i < num_sections; ++i) { // region at index 0 is added by split() @@ -1010,8 +1022,8 @@ inline Array split_n_sections(const Tensor& x, int num_sections, int axi inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, std::string mode = "fast", std::string name = "T_take", std::string tag = kInjective) { - Array a_shape = a->shape; - Array out_shape = indices->shape; + ffi::Array a_shape = a->shape; + ffi::Array out_shape = indices->shape; PrimExpr a_size = 1; for (size_t i = 0; i < a_shape.size(); ++i) { a_size = a_size * a_shape[i]; @@ -1020,7 +1032,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, if (mode == "clip") { return compute( out_shape, - [&](const Array& out_index) { + [&](const ffi::Array& out_index) { auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); }, @@ -1030,12 +1042,14 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, "Make sure input indices are in bound"; return compute( out_shape, - [&](const Array& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); }, + [&](const ffi::Array& out_index) { + return a(UnravelIndex(indices(out_index), a_shape)); + }, name, tag); } else if (mode == "nan") { return compute( out_shape, - [&](const Array& out_index) { + [&](const ffi::Array& out_index) { auto idx = tvm::if_then_else( indices(out_index) < 0 || indices(out_index) >= a_size, tvm::FloatImm(a->dtype, std::numeric_limits::quiet_NaN()), indices(out_index)); @@ -1045,7 +1059,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, } else { // mode == "wrap" return compute( out_shape, - [&](const Array& out_index) { + [&](const ffi::Array& out_index) { auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size); return a(UnravelIndex(idx, a_shape)); }, @@ -1072,11 +1086,11 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)."; auto length_dim = data->shape[axis]; auto batch_dim = data->shape[1 - axis]; - Array out_shape = data->shape; + ffi::Array out_shape = data->shape; Tensor out = compute( out_shape, - [&](const Array& out_index) { - Array len_index; + [&](const ffi::Array& out_index) { + ffi::Array len_index; auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); @@ -1103,8 +1117,8 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub * * \return A Tensor whose op member is the take operation */ -inline Tensor take(const Tensor& a, Variant indices, int batch_dims, int axis, - std::string mode = "fast", std::string name = "T_take", +inline Tensor take(const Tensor& a, ffi::Variant indices, int batch_dims, + int axis, std::string mode = "fast", std::string name = "T_take", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); @@ -1112,7 +1126,7 @@ inline Tensor take(const Tensor& a, Variant indices, int batch ICHECK_GE(axis, 0) << "axis out of bounds"; ICHECK_LT(axis, a->shape.size()) << "axis out of bounds"; auto axis_dim = a->shape[axis]; - auto indices_shape = [&]() -> Array { + auto indices_shape = [&]() -> ffi::Array { if (auto tensor = indices.as()) { return tensor->shape; } else { @@ -1145,7 +1159,7 @@ inline Tensor take(const Tensor& a, Variant indices, int batch // The result shape is a.shape[:axis] + indices.shape[batch_dims:] + // a.shape[axis + 1:]. - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < batch_dims_; ++i) { out_shape.push_back(a->shape[i]); } @@ -1159,7 +1173,7 @@ inline Tensor take(const Tensor& a, Variant indices, int batch out_shape.push_back(a->shape[i]); } - auto get_index = [&](const Array& indices_position) -> PrimExpr { + auto get_index = [&](const ffi::Array& indices_position) -> PrimExpr { if (auto tensor = indices.as()) { return tensor.value()(indices_position); } else if (auto prim = indices.as()) { @@ -1174,12 +1188,12 @@ inline Tensor take(const Tensor& a, Variant indices, int batch if (batch_dims_ == 0) { return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1194,15 +1208,15 @@ inline Tensor take(const Tensor& a, Variant indices, int batch } else { return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = 0; j < static_cast(batch_dims_); ++j) { indices_position.push_back(out_index[j]); } for (size_t j = axis; j < static_cast(axis + indices_len - batch_dims_); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1220,12 +1234,12 @@ inline Tensor take(const Tensor& a, Variant indices, int batch "Make sure input indices are in bound"; return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1239,12 +1253,12 @@ inline Tensor take(const Tensor& a, Variant indices, int batch } else if (mode == "nan") { return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1262,21 +1276,24 @@ inline Tensor take(const Tensor& a, Variant indices, int batch } else { // mode == "wrap" return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim); + PrimExpr idx = get_index(indices_position); real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } - return a(real_indices); + PrimExpr in_bounds = idx >= 0 && idx < axis_dim; + return tvm::if_then_else( + in_bounds, a(real_indices), + tvm::tir::make_const(a->dtype, std::numeric_limits::quiet_NaN())); }, name, tag); } @@ -1299,9 +1316,9 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, << y->dtype; auto get_out_shape = [&]() { auto bh1 = detail::BroadcastShape(x->shape, y->shape); - Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); + ffi::Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); auto bh2 = detail::BroadcastShape(condition->shape, common_shape1); - Array common_shape2(bh2.common_shape.begin(), bh2.common_shape.end()); + ffi::Array common_shape2(bh2.common_shape.begin(), bh2.common_shape.end()); return common_shape2; }; @@ -1311,7 +1328,7 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, auto x_bh = detail::BroadcastShape(x->shape, oshape); auto y_bh = detail::BroadcastShape(y->shape, oshape); - auto select = [&](tvm::Array ovars) { + auto select = [&](tvm::ffi::Array ovars) { auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars)); auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars)); auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars)); @@ -1345,7 +1362,7 @@ inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = // Calculate offset from last dimension axis += ndim; } - Array new_shape; + ffi::Array new_shape; for (size_t i = 0; i < static_cast(axis); ++i) { new_shape.push_back(x->shape[i]); } @@ -1356,8 +1373,8 @@ inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -1380,14 +1397,14 @@ inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = * * \return A Tensor whose op member is the tile operation */ -inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_tile", +inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); size_t rdim = reps.size(); size_t tdim = (ndim > rdim) ? ndim : rdim; - Array data_shape; - Array reps_shape; - Array new_shape; + ffi::Array data_shape; + ffi::Array reps_shape; + ffi::Array new_shape; if (ndim == rdim) { for (size_t i = 0; i < ndim; ++i) { data_shape.push_back(x->shape[i]); @@ -1406,12 +1423,13 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_t if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + tag); } else { return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; if (ndim >= rdim) { for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i])); } else { @@ -1435,17 +1453,18 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_t * * \return A Tensor whose op member is the tile operation */ -inline Tensor dyn_tile(const Tensor& x, Array new_shape, size_t rdim, +inline Tensor dyn_tile(const Tensor& x, ffi::Array new_shape, size_t rdim, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + tag); } else { return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; if (ndim >= rdim) { for (size_t i = 0; i < ndim; ++i) { idx.push_back(indexmod(indices[i], x->shape[i])); @@ -1489,19 +1508,19 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, } ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t i = 0; i < ndim_i; ++i) { indices_position.push_back(out_index[i]); } - Array real_indices; + ffi::Array real_indices; for (size_t i = 0; i < ndim_i; ++i) { if (i == static_cast(axis)) { real_indices.push_back(indices(indices_position)); @@ -1533,7 +1552,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim size_t indices_dim0 = static_cast(GetConstInt(indices->shape[0])); ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more " << "than dimensions of data tensor"; - Array out_shape; + ffi::Array out_shape; for (size_t i = 1; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } @@ -1542,13 +1561,13 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim } return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; indices_position.push_back(0); for (size_t i = 0; i < ndim_i - 1; ++i) { indices_position.push_back(out_index[i]); } - Array real_indices; + ffi::Array real_indices; for (size_t i = 0; i < static_cast(batch_dims); ++i) { real_indices.push_back(out_index[i]); } @@ -1589,7 +1608,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B, bool trans_a = false, bool trans_b = false, std::string name = "T_matmul", std::string tag = kMatMul) { - tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; + tvm::ffi::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); @@ -1613,19 +1632,19 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, ICHECK_GE(A->shape.size(), axes); ICHECK_GE(B->shape.size(), axes); - Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); + ffi::Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); - Array iter_vars; + ffi::Array iter_vars; for (int i = 0; i < axes; ++i) iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i))); - auto func = [&A, &B, &iter_vars, axes](const Array& input_indices) { - Array A_indices(input_indices.begin(), - input_indices.begin() + (A->shape.size() - axes)); + auto func = [&A, &B, &iter_vars, axes](const ffi::Array& input_indices) { + ffi::Array A_indices(input_indices.begin(), + input_indices.begin() + (A->shape.size() - axes)); for (auto& v : iter_vars) A_indices.push_back(v); - Array B_indices; + ffi::Array B_indices; for (auto& v : iter_vars) B_indices.push_back(v); auto it = input_indices.begin() + (A->shape.size() - axes); @@ -1654,15 +1673,15 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array A_axes, - Array B_axes, std::string name = "T_tensordot", +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, ffi::Array A_axes, + ffi::Array B_axes, std::string name = "T_tensordot", std::string tag = kMatMul) { ICHECK_EQ(A_axes.size(), B_axes.size()); auto A_axes_val = GetConstIntValues(A_axes, "A_axes"); auto B_axes_val = GetConstIntValues(B_axes, "B_axes"); - Array output_shape; + ffi::Array output_shape; for (unsigned i = 0; i < A->shape.size(); ++i) if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end()) output_shape.push_back(A->shape[i]); @@ -1670,13 +1689,13 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Arrayshape[i]); - Array iter_vars; + ffi::Array iter_vars; for (unsigned i = 0; i < B_axes_val.size(); ++i) iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); - auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array& input_indices) { + auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const ffi::Array& input_indices) { int idx_input = 0; - Array A_indices; + ffi::Array A_indices; for (unsigned i = 0; i < A->shape.size(); ++i) { auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); if (axes_pos == A_axes_val.end()) { @@ -1686,7 +1705,7 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array B_indices; + ffi::Array B_indices; for (unsigned i = 0; i < B->shape.size(); ++i) { auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); if (axes_pos == B_axes_val.end()) { @@ -1720,8 +1739,8 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr return compute( {num_elem}, - [&](const Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, - tag); + [&](const ffi::Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, + name, tag); } /*! @@ -1734,22 +1753,22 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr * * \return A Tensor whose op member is the meshgrid operation */ -inline Array meshgrid(const Array& inputs, const std::string& indexing, - std::string name = "T_meshgrid", std::string tag = kInjective) { +inline ffi::Array meshgrid(const ffi::Array& inputs, const std::string& indexing, + std::string name = "T_meshgrid", std::string tag = kInjective) { const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2; - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < inputs.size(); ++i) { const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]); } - Array result; + ffi::Array result; for (size_t i = 0; i < inputs.size(); ++i) { result.push_back(compute( out_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; auto ndim = inputs[i]->GetShape().size(); - Array real_indices = {}; + ffi::Array real_indices = {}; if (ndim > 0) { real_indices = {indices[src_index]}; } @@ -1789,19 +1808,19 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, ICHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; - Array dst_shape = layout_converter.ForwardShape(src->shape); + ffi::Array dst_shape = layout_converter.ForwardShape(src->shape); - Map attrs = {{"schedule_rule", String(schedule_rule)}, - // Information about layouts needed for the schedule rule - {"src_layout", String(src_layout)}, - {"dst_layout", String(dst_layout)}, - {"input_shape", src->shape}}; + ffi::Map attrs = {{"schedule_rule", ffi::String(schedule_rule)}, + // Information about layouts needed for the schedule rule + {"src_layout", ffi::String(src_layout)}, + {"dst_layout", ffi::String(dst_layout)}, + {"input_shape", src->shape}}; return compute( dst_shape, - [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); + [&](const ffi::Array& dst_indices) { + ffi::Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + ffi::Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true for (size_t i = 0; i < src.ndim(); ++i) { in_range = in_range && (src_indices[i] < src->shape[i]); @@ -1812,7 +1831,7 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, } /*! \brief Utility function for auto_scheduler_layout_transform */ -inline void parse_auto_scheduler_layout(const String& layout, Array* shape, +inline void parse_auto_scheduler_layout(const ffi::String& layout, ffi::Array* shape, std::vector* axes) { int32_t factor = 0; std::string axis = ""; @@ -1848,22 +1867,21 @@ inline void parse_auto_scheduler_layout(const String& layout, Array* s * \param tag output tensor tag. * \return A tensor with shape in \p dst_layout */ -inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout, - const String& dst_layout, - const String name = "T_auto_scheduler_layout_trans", - const String tag = kInjective) { - Array src_shape; +inline Tensor auto_scheduler_layout_transform( + const Tensor& src, const ffi::String& src_layout, const ffi::String& dst_layout, + const ffi::String name = "T_auto_scheduler_layout_trans", const ffi::String tag = kInjective) { + ffi::Array src_shape; std::vector src_axes; - Array dst_shape; + ffi::Array dst_shape; std::vector dst_axes; parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes); parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes); return compute( dst_shape, - [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices; + [&](const ffi::Array& dst_indices) { + ffi::Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + ffi::Array src_indices; for (const std::string& src_axis : src_axes) { PrimExpr src_index = 0; CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); @@ -1915,21 +1933,22 @@ inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& s * In this case, the transformation pattern is: * A'[a, b, c, d] = A[a * 4 + c, b * 16 + d] */ -inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map, - const String name = "T_meta_schedule_layout_trans", - const String tag = kInjective) { +inline Tensor meta_schedule_layout_transform( + const Tensor& src, const tir::IndexMap& index_map, + const ffi::String name = "T_meta_schedule_layout_trans", const ffi::String tag = kInjective) { arith::Analyzer analyzer; - Array iter_domain; + ffi::Array iter_domain; iter_domain.reserve(src->shape.size()); for (const PrimExpr& e : src->shape) { iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e)); } - Array post_transform_shape = index_map->MapShape(src->shape, &analyzer); + ffi::Array post_transform_shape = index_map->MapShape(src->shape, &analyzer); return compute( post_transform_shape, [src, inv = index_map.Inverse(iter_domain, &analyzer), - &analyzer](const Array& indices) -> PrimExpr { - return src(inv->MapIndices(Array{indices.begin(), indices.end()}, &analyzer)); + &analyzer](const ffi::Array& indices) -> PrimExpr { + return src( + inv->MapIndices(ffi::Array{indices.begin(), indices.end()}, &analyzer)); }, name, tag); } @@ -1945,10 +1964,10 @@ inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::Index inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); - Array out_shape{ndim}; + ffi::Array out_shape{ndim}; return compute( out_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { auto idx = indices[0]; PrimExpr ret = 0; for (int i = 0; i < ndim; ++i) { @@ -1967,14 +1986,14 @@ inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, - const std::string& name = "ndarray_size", - const std::string& tag = kInjective) { +inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, + const std::string& name = "tensor_size", + const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); - Array out_ndarray_size = {}; + ffi::Array out_tensor_size = {}; return compute( - out_ndarray_size, - [&](const Array& indices) { + out_tensor_size, + [&](const ffi::Array& indices) { PrimExpr ret = 1; for (int i = 0; i < ndim; ++i) { ret *= src->shape[i]; @@ -2000,7 +2019,7 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, */ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType& dtype, - Array oshape = Array(), + ffi::Array oshape = ffi::Array(), const std::string name = "T_one_hot", const std::string tag = kInjective) { int true_axis = (axis == -1) ? indices->shape.size() : axis; if (oshape.size() == 0) { @@ -2019,8 +2038,8 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim PrimExpr off_value_cast = cast(dtype, off_value); return compute( oshape, - [&](const Array& iter_vars) { - Array indices_indices; + [&](const ffi::Array& iter_vars) { + ffi::Array indices_indices; for (size_t i = 0; i < iter_vars.size(); i++) { if (static_cast(i) == true_axis) { continue; @@ -2045,8 +2064,9 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim * \param tag output tensor tag. * \return Tensor of output_shape. */ -inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array& output_shape, - const Tensor& sparse_values, const PrimExpr& default_value, +inline Tensor sparse_to_dense(const Tensor& sparse_indices, + const ffi::Array& output_shape, const Tensor& sparse_values, + const PrimExpr& default_value, const std::string name = "T_sparse_to_dense", const std::string tag = kInjective) { ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; @@ -2055,13 +2075,13 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Arrayshape.size(), 2) << "sparse_values tensor should be 0D or 1D only"; const auto rank_sparse_indices = static_cast(sparse_indices->shape.size()); - Array oshape; + ffi::Array oshape; for (auto l : output_shape) { oshape.push_back(l); } return compute( oshape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { PrimExpr ret = default_value; if (0 == rank_sparse_indices) { ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret); @@ -2106,9 +2126,9 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k return compute( input->shape, - [&](const Array& iter_vars) { + [&](const ffi::Array& iter_vars) { auto get_diag = [&]() { - Array diagonal_indices; + ffi::Array diagonal_indices; PrimExpr k, offset = 0; for (size_t i = 0; i < ndim - 1; i++) { diagonal_indices.push_back(iter_vars[i]); @@ -2152,18 +2172,18 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k * \param tag output tensor tag. * \return Output tensor. */ -inline Tensor adv_index(const Tensor& data, const Array& indices, +inline Tensor adv_index(const Tensor& data, const ffi::Array& indices, const std::string name = "advanced_index", const std::string tag = kInjective) { ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!"; - Array oshape; - Array broadcast_shape; - Array bindices; + ffi::Array oshape; + ffi::Array broadcast_shape; + ffi::Array bindices; broadcast_shape = indices[0]->shape; for (size_t i = 1; i < indices.size(); ++i) { auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape); - broadcast_shape = Array(bh.common_shape.begin(), bh.common_shape.end()); + broadcast_shape = ffi::Array(bh.common_shape.begin(), bh.common_shape.end()); } if (indices.size() == 1) { // quick path @@ -2184,12 +2204,12 @@ inline Tensor adv_index(const Tensor& data, const Array& indices, return compute( oshape, - [&](const Array& iter_var) { - Array tensor_indices; + [&](const ffi::Array& iter_var) { + ffi::Array tensor_indices; for (size_t i = 0; i < broadcast_shape.size(); ++i) { tensor_indices.push_back(iter_var[i]); } - Array real_indices; + ffi::Array real_indices; for (size_t i = 0; i < bindices.size(); ++i) { real_indices.push_back(bindices[i](tensor_indices)); } @@ -2206,7 +2226,7 @@ namespace relax { // relax dynamic slice inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, const te::Tensor& end, const te::Tensor& strides, - Array output_shape, + ffi::Array output_shape, std::string name = "T_strided_slice_dynamic", std::string tag = kInjective) { const size_t num_dynamic_axes = x.ndim(); @@ -2225,8 +2245,8 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b return te::compute( output_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < num_dynamic_axes; ++i) { auto ind = make_const(DataType::Int(64), i); real_indices.push_back(indices[i] * strides(ind) + tvm::min(begin(ind), x->shape[i] - 1)); diff --git a/include/tvm/topi/utils.h b/include/tvm/topi/utils.h index b5f2d6c38d61..41a2cce0e4f9 100644 --- a/include/tvm/topi/utils.h +++ b/include/tvm/topi/utils.h @@ -32,17 +32,17 @@ namespace topi { using namespace tvm::runtime; -/*! \brief Canonicalize an argument that may be Array or int to Array */ -inline Optional> ArrayOrInt(AnyView arg) { +/*! \brief Canonicalize an argument that may be ffi::Array or int to ffi::Array */ +inline ffi::Optional> ArrayOrInt(AnyView arg) { if (arg == nullptr) { return std::nullopt; } if (auto opt_int = arg.try_cast()) { - Array result; + ffi::Array result; result.push_back(opt_int.value()); return result; } else { - return arg.cast>(); + return arg.cast>(); } } } // namespace topi diff --git a/include/tvm/topi/vision/reorg.h b/include/tvm/topi/vision/reorg.h index 381272bb818c..f9a089d1abdc 100644 --- a/include/tvm/topi/vision/reorg.h +++ b/include/tvm/topi/vision/reorg.h @@ -72,7 +72,7 @@ inline Tensor reorg(const Tensor& data, int stride = 1, std::string name = "tens int out_h = h_in / stride; int out_w = w_in / stride; - Array out_shape = {batch, out_c, out_h, out_w}; + ffi::Array out_shape = {batch, out_c, out_h, out_w}; return reshape(out, out_shape); } } // namespace vision diff --git a/jvm/README.md b/jvm/README.md index 71c737a4d00a..355a17a7b266 100644 --- a/jvm/README.md +++ b/jvm/README.md @@ -19,7 +19,7 @@ This folder contains the Java interface for TVM runtime. It brings TVM runtime to Java virtual machine. -- It enables you to construct NDArray from Java native array and vice versa. +- It enables you to construct Tensor from Java native array and vice versa. - You can register and convert Java native functions to TVM functions. - It enables you to load shared libraries created by Python and C++. - It provides a simple interface for RPC server and client. @@ -95,7 +95,7 @@ The following code snippet demonstrate how to load generated shared library (add ```java import org.apache.tvm.Module; -import org.apache.tvm.NDArray; +import org.apache.tvm.Tensor; import org.apache.tvm.Device; import java.io.File; @@ -109,9 +109,9 @@ public class LoadAddFunc { Device dev = Device.cpu(); long[] shape = new long[]{2}; - NDArray arr = NDArray.empty(shape, dev); + Tensor arr = Tensor.empty(shape, dev); arr.copyFrom(new float[]{3f, 4f}); - NDArray res = NDArray.empty(shape, dev); + Tensor res = Tensor.empty(shape, dev); fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke(); System.out.println(Arrays.toString(res.asFloatArray())); diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index ee6b8e8cf5c5..29e105dee9f5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -138,12 +138,12 @@ public Function pushArg(String arg) { /** * Push argument to the function. - * @param arg NDArray. + * @param arg Tensor. * @return this */ - public Function pushArg(NDArrayBase arg) { - if (arg instanceof NDArray) { - Base._LIB.tvmFFIFunctionPushArgHandle(((NDArray) arg).handle, TypeIndex.kTVMFFINDArray); + public Function pushArg(TensorBase arg) { + if (arg instanceof Tensor) { + Base._LIB.tvmFFIFunctionPushArgHandle(((Tensor) arg).handle, TypeIndex.kTVMFFITensor); } else { Base._LIB.tvmFFIFunctionPushArgHandle(arg.dltensorHandle, TypeIndex.kTVMFFIDLTensorPtr); } @@ -192,7 +192,7 @@ public Function pushArg(Device arg) { /** * Invoke function with arguments. - * @param args Can be Integer, Long, Float, Double, String, NDArray. + * @param args Can be Integer, Long, Float, Double, String, Tensor. * @return the result. */ public TVMValue call(Object... args) { @@ -203,10 +203,10 @@ public TVMValue call(Object... args) { } private static void pushArgToStack(Object arg) { - if (arg instanceof NDArrayBase) { - NDArrayBase nd = (NDArrayBase) arg; - if (nd instanceof NDArray) { - Base._LIB.tvmFFIFunctionPushArgHandle(((NDArray) nd).handle, TypeIndex.kTVMFFINDArray); + if (arg instanceof TensorBase) { + TensorBase nd = (TensorBase) arg; + if (nd instanceof Tensor) { + Base._LIB.tvmFFIFunctionPushArgHandle(((Tensor) nd).handle, TypeIndex.kTVMFFITensor); } else { Base._LIB.tvmFFIFunctionPushArgHandle(nd.dltensorHandle, TypeIndex.kTVMFFIDLTensorPtr); } diff --git a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java index f471883ca5bc..a1e15a873a60 100644 --- a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java +++ b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java @@ -52,7 +52,7 @@ class LibInfo { native int tvmFFIFunctionCreateFromCallback(Function.Callback function, Base.RefLong handle); - // NDArray + // Tensor native int tvmFFIDLTensorGetShape(long handle, List shape); native int tvmFFIDLTensorCopyFromTo(long from, long to); @@ -67,7 +67,7 @@ class LibInfo { // Device native int tvmSynchronize(int deviceType, int deviceId); - native int tvmNDArrayEmpty(long[] shape, int dtypeCode, int dtypeBits, + native int tvmTensorEmpty(long[] shape, int dtypeCode, int dtypeBits, int dtypeLanes, int deviceType, int deviceId, Base.RefLong handle); } diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 46a74346760e..174457131f05 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -46,7 +46,7 @@ private static Function getApi(String name) { } private Function entry = null; - private final String entryName = "__tvm_ffi_main__"; + private final String entryName = "main"; /** diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMType.java b/jvm/core/src/main/java/org/apache/tvm/TVMType.java index 1c2719eeca90..658fdaedc1e5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMType.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMType.java @@ -31,7 +31,7 @@ public class TVMType { /** * TVMType constructor. * @param typeStr type name, e.g., "float32", "float64", "uint8", etc. - * @param lanes NDArray lanes. + * @param lanes Tensor lanes. */ public TVMType(String typeStr, int lanes) { this.lanes = lanes; diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java index 45aef808f44c..532490a91367 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java @@ -45,7 +45,7 @@ public Function asFunction() { throw new UnsupportedOperationException(); } - public NDArrayBase asNDArray() { + public TensorBase asTensor() { throw new UnsupportedOperationException(); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArray.java b/jvm/core/src/main/java/org/apache/tvm/Tensor.java similarity index 90% rename from jvm/core/src/main/java/org/apache/tvm/NDArray.java rename to jvm/core/src/main/java/org/apache/tvm/Tensor.java index 6b151d7bf9d2..7b44049f9372 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArray.java +++ b/jvm/core/src/main/java/org/apache/tvm/Tensor.java @@ -23,13 +23,13 @@ import java.util.List; /** - * Lightweight NDArray class of TVM runtime. + * Lightweight Tensor class of TVM runtime. */ -public class NDArray extends NDArrayBase { +public class Tensor extends TensorBase { private final TVMType dtype; private final Device device; - NDArray(long handle, boolean isView, TVMType dtype, Device dev) { + Tensor(long handle, boolean isView, TVMType dtype, Device dev) { super(handle, isView); this.dtype = dtype; this.device = dev; @@ -37,7 +37,7 @@ public class NDArray extends NDArrayBase { /** * Copy from a native array. - * The NDArray type must by float64 + * The Tensor type must by float64 * @param sourceArray the source data */ public void copyFrom(double[] sourceArray) { @@ -54,7 +54,7 @@ public void copyFrom(double[] sourceArray) { /** * Copy from a native array. - * The NDArray type must by float32 + * The Tensor type must by float32 * @param sourceArray the source data */ public void copyFrom(float[] sourceArray) { @@ -71,7 +71,7 @@ public void copyFrom(float[] sourceArray) { /** * Copy from a native array. - * The NDArray type must by int64 + * The Tensor type must by int64 * @param sourceArray the source data */ public void copyFrom(long[] sourceArray) { @@ -88,7 +88,7 @@ public void copyFrom(long[] sourceArray) { /** * Copy from a native array. - * The NDArray type must by float32 + * The Tensor type must by float32 * @param sourceArray the source data */ public void copyFrom(int[] sourceArray) { @@ -105,7 +105,7 @@ public void copyFrom(int[] sourceArray) { /** * Copy from a native array. - * The NDArray type must by int16 + * The Tensor type must by int16 * @param sourceArray the source data */ public void copyFrom(short[] sourceArray) { @@ -122,7 +122,7 @@ public void copyFrom(short[] sourceArray) { /** * Copy from a native array. - * The NDArray type must by int8 + * The Tensor type must by int8 * @param sourceArray the source data */ public void copyFrom(byte[] sourceArray) { @@ -135,7 +135,7 @@ public void copyFrom(byte[] sourceArray) { /** * Copy from a native array. - * The NDArray type must by uint16 + * The Tensor type must by uint16 * @param sourceArray the source data */ public void copyFrom(char[] sourceArray) { @@ -167,8 +167,8 @@ public void copyFromRaw(byte[] sourceArray) { } /** - * Get shape of current NDArray. - * @return an array representing shape of current ndarray + * Get shape of current Tensor. + * @return an array representing shape of current tensor */ public long[] shape() { List data = new ArrayList(); @@ -181,8 +181,8 @@ public long[] shape() { } /** - * Get total size of current NDArray. - * @return size of current NDArray. + * Get total size of current Tensor. + * @return size of current Tensor. */ public long size() { long product = 1L; @@ -195,7 +195,7 @@ public long size() { /** * Return a copied flat java array of current array (row-major). - * The NDArray dtype must be float64 + * The Tensor dtype must be float64 * @return A copy of array content. */ public double[] asDoubleArray() { @@ -213,7 +213,7 @@ public double[] asDoubleArray() { /** * Return a copied flat java array of current array (row-major). - * The NDArray dtype must be float32 + * The Tensor dtype must be float32 * @return A copy of array content. */ public float[] asFloatArray() { @@ -231,7 +231,7 @@ public float[] asFloatArray() { /** * Return a copied flat java array of current array (row-major). - * The NDArray dtype must be int64 + * The Tensor dtype must be int64 * @return A copy of array content. */ public long[] asLongArray() { @@ -249,7 +249,7 @@ public long[] asLongArray() { /** * Return a copied flat java array of current array (row-major). - * The NDArray dtype must be int32 + * The Tensor dtype must be int32 * @return A copy of array content. */ public int[] asIntArray() { @@ -267,7 +267,7 @@ public int[] asIntArray() { /** * Return a copied flat java array of current array (row-major). - * The NDArray dtype must be int16 + * The Tensor dtype must be int16 * @return A copy of array content. */ public short[] asShortArray() { @@ -285,7 +285,7 @@ public short[] asShortArray() { /** * Return a copied flat java array of current array (row-major). - * The NDArray dtype must be uint16 + * The Tensor dtype must be uint16 * @return A copy of array content. */ public char[] asCharArray() { @@ -303,7 +303,7 @@ public char[] asCharArray() { /** * Return a copied flat java array of current array (row-major). - * The NDArray dtype must be int8 + * The Tensor dtype must be int8 * @return A copy of array content. */ public byte[] asByteArray() { @@ -319,7 +319,7 @@ public byte[] asByteArray() { * @return A copy of array content. */ public byte[] internal() { - NDArray tmp = NDArray.empty(shape(), dtype); + Tensor tmp = Tensor.empty(shape(), dtype); copyTo(tmp); int arrLength = dtype.numOfBytes * (int) size(); @@ -359,12 +359,12 @@ public Device device() { * @param dev The device of the array. * @return The array tvm supported. */ - public static NDArray empty(long[] shape, TVMType dtype, Device dev) { + public static Tensor empty(long[] shape, TVMType dtype, Device dev) { Base.RefLong refHandle = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmNDArrayEmpty( + Base.checkCall(Base._LIB.tvmTensorEmpty( shape, dtype.typeCode, dtype.bits, dtype.lanes, dev.deviceType, dev.deviceId, refHandle)); - return new NDArray(refHandle.value, false, dtype, dev); + return new Tensor(refHandle.value, false, dtype, dev); } /** @@ -373,7 +373,7 @@ public static NDArray empty(long[] shape, TVMType dtype, Device dev) { * @param dtype The data type of the array. * @return The array tvm supported. */ - public static NDArray empty(long[] shape, TVMType dtype) { + public static Tensor empty(long[] shape, TVMType dtype) { return empty(shape, dtype, Device.cpu(0)); } @@ -382,7 +382,7 @@ public static NDArray empty(long[] shape, TVMType dtype) { * @param shape The shape of the array. * @return The array tvm supported. */ - public static NDArray empty(long[] shape) { + public static Tensor empty(long[] shape) { return empty(shape, new TVMType("float32", 1), Device.cpu(0)); } @@ -392,7 +392,7 @@ public static NDArray empty(long[] shape) { * @param dev The device of the array. * @return The array tvm supported. */ - public static NDArray empty(long[] shape, Device dev) { + public static Tensor empty(long[] shape, Device dev) { return empty(shape, new TVMType("float32", 1), dev); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java b/jvm/core/src/main/java/org/apache/tvm/TensorBase.java similarity index 86% rename from jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java rename to jvm/core/src/main/java/org/apache/tvm/TensorBase.java index 534dcb38d4a9..b150d65807ee 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java +++ b/jvm/core/src/main/java/org/apache/tvm/TensorBase.java @@ -18,26 +18,26 @@ package org.apache.tvm; /** - * Base class of NDArray. To handle callback array. + * Base class of Tensor. To handle callback array. * Only deep-copy supported. */ -public class NDArrayBase extends TVMValue { +public class TensorBase extends TVMValue { protected long handle; public final boolean isView; protected final long dltensorHandle; - NDArrayBase(long handle, boolean isView) { + TensorBase(long handle, boolean isView) { this.dltensorHandle = isView ? handle : handle + 8 * 2; this.handle = isView ? 0 : handle; this.isView = isView; } - @Override public NDArrayBase asNDArray() { + @Override public TensorBase asTensor() { return this; } /** - * Release the NDArray. + * Release the Tensor. */ public void release() { if (this.handle != 0) { @@ -56,7 +56,7 @@ public void release() { * @param target The target array to be copied, must have same shape as this array. * @return target */ - public NDArrayBase copyTo(NDArrayBase target) { + public TensorBase copyTo(TensorBase target) { Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromTo(this.dltensorHandle, target.dltensorHandle)); return target; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java b/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java index 97169bb6c58c..e29bae51828c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java +++ b/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java @@ -36,9 +36,9 @@ public class TypeIndex { public static final int kTVMFFIBytes = 66; public static final int kTVMFFIError = 67; public static final int kTVMFFIFunction = 68; - public static final int kTVMFFIArray = 69; - public static final int kTVMFFIMap = 70; - public static final int kTVMFFIShape = 71; - public static final int kTVMFFINDArray = 72; + public static final int kTVMFFIShape = 70; + public static final int kTVMFFITensor = 71; + public static final int kTVMFFIArray = 72; + public static final int kTVMFFIMap = 73; public static final int kTVMFFIModule = 73; } diff --git a/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java b/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java index c2a1f78fa432..56e9a21a2b83 100644 --- a/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java @@ -78,14 +78,14 @@ public void test_sum_first_byte() { } @Test - public void test_sum_ndarray() { + public void test_sum_tensor() { final long[] shape = new long[]{2, 1}; Function func = Function.convertFunc(new Function.Callback() { @Override public Object invoke(TVMValue... args) { double sum = 0.0; for (TVMValue arg : args) { - NDArray arr = NDArray.empty(shape, new TVMType("float32")); - arg.asNDArray().copyTo(arr); + Tensor arr = Tensor.empty(shape, new TVMType("float32")); + arg.asTensor().copyTo(arr); float[] nativeArr = arr.asFloatArray(); for (int i = 0; i < nativeArr.length; ++i) { sum += nativeArr[i]; @@ -95,7 +95,7 @@ public void test_sum_ndarray() { return sum; } }); - NDArray arr = NDArray.empty(shape, new TVMType("float32")); + Tensor arr = Tensor.empty(shape, new TVMType("float32")); arr.copyFrom(new float[]{2f, 3f}); TVMValue res = func.pushArg(arr).pushArg(arr).invoke(); assertEquals(10.0, res.asDouble(), 1e-3); diff --git a/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java b/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java index 888cd18923be..5c692eecc3f6 100644 --- a/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java @@ -42,11 +42,11 @@ public void test_load_add_func_cpu() { Device dev = new Device("cpu", 0); long[] shape = new long[]{2}; - NDArray arr = NDArray.empty(shape, dev); + Tensor arr = Tensor.empty(shape, dev); arr.copyFrom(new float[]{3f, 4f}); - NDArray res = NDArray.empty(shape, dev); + Tensor res = Tensor.empty(shape, dev); fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke(); assertArrayEquals(new float[]{6f, 8f}, res.asFloatArray(), 1e-3f); @@ -74,7 +74,7 @@ public void test_load_add_func_cuda() { final int dim = 100; long[] shape = new long[]{dim}; - NDArray arr = NDArray.empty(shape, dev); + Tensor arr = Tensor.empty(shape, dev); float[] data = new float[dim]; float[] dataX2 = new float[dim]; @@ -84,7 +84,7 @@ public void test_load_add_func_cuda() { } arr.copyFrom(data); - NDArray res = NDArray.empty(shape, dev); + Tensor res = Tensor.empty(shape, dev); fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke(); assertArrayEquals(dataX2, res.asFloatArray(), 1e-3f); diff --git a/jvm/core/src/test/java/org/apache/tvm/NDArrayTest.java b/jvm/core/src/test/java/org/apache/tvm/NDArrayTest.java deleted file mode 100644 index c4c34360f740..000000000000 --- a/jvm/core/src/test/java/org/apache/tvm/NDArrayTest.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.tvm; - -import org.junit.Test; - -import static org.junit.Assert.*; - -public class NDArrayTest { - @Test - public void test_from_float32() { - NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("float32")); - ndarray.copyFrom(new float[]{1, 2, 3, 4}); - assertArrayEquals(new float[]{1f, 2f, 3f, 4f}, ndarray.asFloatArray(), 1e-3f); - ndarray.release(); - } - - @Test - public void test_from_float64() { - NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("float64")); - ndarray.copyFrom(new double[]{1, 2, 3, 4}); - assertArrayEquals(new double[]{1.0, 2.0, 3.0, 4.0}, ndarray.asDoubleArray(), 1e-3); - ndarray.release(); - } - - @Test - public void test_from_int8() { - NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int8")); - ndarray.copyFrom(new byte[]{1, 2, 3, 4}); - assertArrayEquals(new byte[]{1, 2, 3, 4}, ndarray.asByteArray()); - ndarray.release(); - } - - @Test - public void test_from_int16() { - NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int16")); - ndarray.copyFrom(new short[]{1, 2, 3, 4}); - assertArrayEquals(new short[]{1, 2, 3, 4}, ndarray.asShortArray()); - ndarray.release(); - } - - @Test - public void test_from_int32() { - NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int32")); - ndarray.copyFrom(new int[]{1, 2, 3, 4}); - assertArrayEquals(new int[]{1, 2, 3, 4}, ndarray.asIntArray()); - ndarray.release(); - } - - @Test - public void test_from_int64() { - NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int64")); - ndarray.copyFrom(new long[]{1, 2, 3, 4}); - assertArrayEquals(new long[]{1, 2, 3, 4}, ndarray.asLongArray()); - ndarray.release(); - } - - @Test - public void test_from_uint16() { - NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("uint16")); - ndarray.copyFrom(new char[]{65535, 2, 3, 4}); - assertArrayEquals(new char[]{65535, 2, 3, 4}, ndarray.asCharArray()); - ndarray.release(); - } -} diff --git a/jvm/core/src/test/java/org/apache/tvm/TensorTest.java b/jvm/core/src/test/java/org/apache/tvm/TensorTest.java new file mode 100644 index 000000000000..546bf661e400 --- /dev/null +++ b/jvm/core/src/test/java/org/apache/tvm/TensorTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tvm; + +import org.junit.Test; + +import static org.junit.Assert.*; + +public class TensorTest { + @Test + public void test_from_float32() { + Tensor tensor = Tensor.empty(new long[]{2, 2}, new TVMType("float32")); + tensor.copyFrom(new float[]{1, 2, 3, 4}); + assertArrayEquals(new float[]{1f, 2f, 3f, 4f}, tensor.asFloatArray(), 1e-3f); + tensor.release(); + } + + @Test + public void test_from_float64() { + Tensor tensor = Tensor.empty(new long[]{2, 2}, new TVMType("float64")); + tensor.copyFrom(new double[]{1, 2, 3, 4}); + assertArrayEquals(new double[]{1.0, 2.0, 3.0, 4.0}, tensor.asDoubleArray(), 1e-3); + tensor.release(); + } + + @Test + public void test_from_int8() { + Tensor tensor = Tensor.empty(new long[]{2, 2}, new TVMType("int8")); + tensor.copyFrom(new byte[]{1, 2, 3, 4}); + assertArrayEquals(new byte[]{1, 2, 3, 4}, tensor.asByteArray()); + tensor.release(); + } + + @Test + public void test_from_int16() { + Tensor tensor = Tensor.empty(new long[]{2, 2}, new TVMType("int16")); + tensor.copyFrom(new short[]{1, 2, 3, 4}); + assertArrayEquals(new short[]{1, 2, 3, 4}, tensor.asShortArray()); + tensor.release(); + } + + @Test + public void test_from_int32() { + Tensor tensor = Tensor.empty(new long[]{2, 2}, new TVMType("int32")); + tensor.copyFrom(new int[]{1, 2, 3, 4}); + assertArrayEquals(new int[]{1, 2, 3, 4}, tensor.asIntArray()); + tensor.release(); + } + + @Test + public void test_from_int64() { + Tensor tensor = Tensor.empty(new long[]{2, 2}, new TVMType("int64")); + tensor.copyFrom(new long[]{1, 2, 3, 4}); + assertArrayEquals(new long[]{1, 2, 3, 4}, tensor.asLongArray()); + tensor.release(); + } + + @Test + public void test_from_uint16() { + Tensor tensor = Tensor.empty(new long[]{2, 2}, new TVMType("uint16")); + tensor.copyFrom(new char[]{65535, 2, 3, 4}); + assertArrayEquals(new char[]{65535, 2, 3, 4}, tensor.asCharArray()); + tensor.release(); + } +} diff --git a/jvm/native/linux-x86_64/pom.xml b/jvm/native/linux-x86_64/pom.xml index c21a3d2ae5af..0bf5d88b76fe 100644 --- a/jvm/native/linux-x86_64/pom.xml +++ b/jvm/native/linux-x86_64/pom.xml @@ -118,7 +118,7 @@ under the License. -I../../../include - -I../../../ffi/include + -I../../../3rdparty/tvm-ffi/include -I${JAVA_HOME}/include -I${JAVA_HOME}/include/linux ${cflags} diff --git a/jvm/native/osx-x86_64/pom.xml b/jvm/native/osx-x86_64/pom.xml index e2bd0fd7ae9d..de468519b828 100644 --- a/jvm/native/osx-x86_64/pom.xml +++ b/jvm/native/osx-x86_64/pom.xml @@ -119,7 +119,7 @@ under the License. -I../../../include - -I../../../ffi/include + -I../../../3rdparty/tvm-ffi/include -I${JAVA_HOME}/include -I${JAVA_HOME}/include/darwin ${cflags} diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index 5db3e279cf3f..659c6e4f2943 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -151,8 +151,8 @@ jobject newFunction(JNIEnv* env, jlong value) { return object; } -jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) { - jclass cls = env->FindClass("org/apache/tvm/NDArrayBase"); +jobject newTensor(JNIEnv* env, jlong handle, jboolean isview) { + jclass cls = env->FindClass("org/apache/tvm/TensorBase"); jmethodID constructor = env->GetMethodID(cls, "", "(JZ)V"); jobject object = env->NewObject(cls, constructor, handle, isview); env->DeleteLocalRef(cls); @@ -218,10 +218,10 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { return newFunction(env, reinterpret_cast(value.v_obj)); } case TypeIndex::kTVMFFIDLTensorPtr: { - return newNDArray(env, reinterpret_cast(value.v_ptr), true); + return newTensor(env, reinterpret_cast(value.v_ptr), true); } - case TypeIndex::kTVMFFINDArray: { - return newNDArray(env, reinterpret_cast(value.v_obj), false); + case TypeIndex::kTVMFFITensor: { + return newTensor(env, reinterpret_cast(value.v_obj), false); } case TypeIndex::kTVMFFISmallStr: { TVMFFIByteArray arr = TVMFFISmallBytesGetContentByteArray(&value); @@ -236,7 +236,7 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { } case TypeIndex::kTVMFFIBytes: { jobject ret = newTVMValueBytes(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); - TVMFFIObjectFree(value.v_obj); + TVMFFIObjectDecRef(value.v_obj); return ret; } default: { diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 3ebe7fddfa8f..e18d1171df1f 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -26,8 +26,8 @@ #else #include #include -#include #include +#include #include #endif #include @@ -322,10 +322,10 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionSetGlobal(JNIEn // Module JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIObjectFree(JNIEnv* env, jobject obj, jlong jhandle) { - return TVMFFIObjectFree(reinterpret_cast(jhandle)); + return TVMFFIObjectDecRef(reinterpret_cast(jhandle)); } -// NDArray +// Tensor JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorGetShape(JNIEnv* env, jobject obj, jlong jhandle, @@ -356,7 +356,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromTo(JNIE jlong jfrom, jlong jto) { TVM_FFI_SAFE_CALL_BEGIN(); - static auto fcopy_from_to = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyFromTo"); + static auto fcopy_from_to = tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorCopyFromTo"); fcopy_from_to(reinterpret_cast(jfrom), reinterpret_cast(jto)); TVM_FFI_SAFE_CALL_END(); } @@ -370,7 +370,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromJArray( DLTensor* to = reinterpret_cast(jto); size_t size = tvm::ffi::GetDataSize(*to); static auto fcopy_from_bytes = - tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyFromBytes"); + tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorCopyFromBytes"); fcopy_from_bytes(to, static_cast(pdata), size); env->ReleaseByteArrayElements(jarr, pdata, 0); TVM_FFI_SAFE_CALL_END(); @@ -384,7 +384,8 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyToJArray(JN DLTensor* from = reinterpret_cast(jfrom); size_t size = tvm::ffi::GetDataSize(*from); jbyte* pdata = env->GetByteArrayElements(jarr, NULL); - static auto fcopy_to_bytes = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyToBytes"); + static auto fcopy_to_bytes = + tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorCopyToBytes"); fcopy_to_bytes(from, static_cast(pdata), size); env->ReleaseByteArrayElements(jarr, static_cast(pdata), 0); // copy back to java array automatically @@ -401,7 +402,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, j TVM_FFI_SAFE_CALL_END(); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmNDArrayEmpty( +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmTensorEmpty( JNIEnv* env, jobject obj, jlongArray jshape, jint jdtypeCode, jint jdtypeBits, jint jdtypeLanes, jint jdeviceType, jint jdeviceId, jobject jret) { TVM_FFI_SAFE_CALL_BEGIN(); @@ -414,8 +415,8 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmNDArrayEmpty( dtype.lanes = static_cast(jdtypeLanes); DLDevice device{static_cast(jdeviceType), jdeviceId}; env->ReleaseLongArrayElements(jshape, shapeArray, 0); - static auto fempty = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayAllocWithScope"); - tvm::ffi::NDArray out = fempty(shape, dtype, device, nullptr).cast(); + static auto fempty = tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorAllocWithScope"); + tvm::ffi::Tensor out = fempty(shape, dtype, device, nullptr).cast(); void* handle = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out)); setLongField(env, jret, reinterpret_cast(handle)); TVM_FFI_SAFE_CALL_END(); diff --git a/licenses/LICENSE.dlpack.txt b/licenses/LICENSE.dlpack.txt deleted file mode 100644 index 20a9c8a7b4dc..000000000000 --- a/licenses/LICENSE.dlpack.txt +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2017 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/licenses/LICENSE.libbacktrace.txt b/licenses/LICENSE.libbacktrace.txt deleted file mode 100644 index e9e256244d69..000000000000 --- a/licenses/LICENSE.libbacktrace.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (C) 2012-2016 Free Software Foundation, Inc. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: - -# (1) Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. - -# (2) Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in -# the documentation and/or other materials provided with the -# distribution. - -# (3) The name of the author may not be used to -# endorse or promote products derived from this software without -# specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR -# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING -# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. diff --git a/pyproject.toml b/pyproject.toml index 65add46b09e0..54e0cc91dc3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,182 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +[build-system] +requires = ["scikit-build-core>=0.10.0"] +build-backend = "scikit_build_core.build" + +[project] +name = "tvm" +# Note: Call version.py to update the version before building the wheel +version = "0.23.dev0" +description = "Apache TVM: An End-to-End Deep Learning Compiler Stack" +readme = "README.md" +license = { text = "Apache-2.0" } +requires-python = ">=3.9" +authors = [ + { name = "Apache TVM Community", email = "dev@tvm.apache.org" } +] +keywords = ["machine learning", "compiler", "deep learning", "inference"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] +# Core dependencies - these are the minimum required for basic TVM functionality +dependencies = [ + "apache-tvm-ffi", + "cloudpickle", + "ml_dtypes", + "numpy", + "packaging", + "psutil", + "scipy", + "tornado", + "typing_extensions" +] + +# Optional dependencies for different features +[project.optional-dependencies] +# Model importers +importer-coreml = ["coremltools"] +importer-keras = ["tensorflow", "tensorflow-estimator"] +importer-onnx = ["future", "onnx", "onnxoptimizer", "onnxruntime", "torch", "torchvision"] +importer-pytorch = ["torch", "torchvision"] +importer-tensorflow = ["tensorflow", "tensorflow-estimator"] +importer-tflite = ["tflite"] +importer-paddle = ["paddlepaddle"] + +# AutoTVM and autoscheduler +autotvm = ["xgboost"] +autoscheduler = ["xgboost"] + +# SMT support +z3 = ["z3-solver>=4.13.0"] + +# Development and testing +dev = [ + "black", + "isort", + "mypy", + "pylint", + "pytest", + "pytest-xdist", + "pytest-cov", + "pytest-mock", + "pytest-benchmark", + "pytest-timeout", + "pytest-rerunfailures", + "pytest-repeat", + "pytest-xdist", + "pytest-cov", + "pytest-mock", + "pytest-benchmark", + "pytest-timeout", + "pytest-rerunfailures", + "pytest-repeat", +] + +# All optional dependencies (excluding dev) +all = [ + "coremltools", + "tensorflow", + "tensorflow-estimator", + "future", + "onnx", + "onnxoptimizer", + "onnxruntime", + "torch", + "torchvision", + "tflite", + "paddlepaddle", + "xgboost", + "z3-solver>=4.13.0" +] + +[project.urls] +Homepage = "https://tvm.apache.org/" +Documentation = "https://tvm.apache.org/docs/" +Repository = "https://github.com/apache/tvm" +"Bug Tracker" = "https://github.com/apache/tvm/issues" + +[tool.scikit-build] +# Point to the root CMakeLists.txt +cmake.source-dir = "." +cmake.build-type = "Release" + +# Configure the wheel to be Python version-agnostic +wheel.py-api = "py3" + +# Build configuration +build-dir = "build" + +# CMake configuration - ensure proper installation paths +cmake.args = ["-DTVM_BUILD_PYTHON_MODULE=ON"] + +# Wheel configuration +wheel.packages = ["python/tvm"] +wheel.install-dir = "tvm" + +# Source distribution configuration +sdist.include = [ + # Build files + "/CMakeLists.txt", + "/pyproject.toml", + "/cmake/**/*", + "/ */*", + + # Source code + "/src/**/*.cc", + "/src/**/*.h", + "/include/**/*.h", + + # Python source + "/python/tvm/**/*.py", + "/python/tvm/**/*.pyi", + + # Documentation and metadata + "/docs/**/*", + "/LICENSE", + "/README.md", + "/NOTICE", + + # Tests + "/tests/**/*", +] + +sdist.exclude = [ + "**/.git", + "**/.github", + "**/__pycache__", + "**/*.pyc", + "build", + "dist", + "**/3rdparty/*/docs", + "**/3rdparty/*/media", + "**/3rdparty/*/examples", + "**/3rdparty/*/test", +] + +# Logging +logging.level = "INFO" + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-v --tb=short" +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] + [tool.isort] profile = "black" src_paths = ["python", "tests/python"] @@ -51,5 +227,48 @@ exclude = ''' ''' [tool.ruff] +# Enable pycodestyle (`E`), Pyflakes (`F`), and isort (`I`) codes +select = ["E", "F", "I"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["A", "B", "C", "D", "E", "F", "I", "N", "UP", "W", "ARG", "B", "C4", "DTZ", "T10", "EM", "EXE", "FA", "ICN", "Q", "T20", "TID", "TCH", "RUF"] +unfixable = [] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".darcs", + ".git", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "3rdparty", +] + line-length = 100 indent-width = 4 + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + + +[tool.ruff.mccabe] +max-complexity = 10 + +[tool.ruff.isort] +known-first-party = ["tvm"] diff --git a/python/setup.py b/python/setup.py deleted file mode 100644 index cf2eff2a3af4..000000000000 --- a/python/setup.py +++ /dev/null @@ -1,276 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, exec-used -"""Setup TVM package.""" -import os -import pathlib -import shutil -import sys - -from setuptools import find_packages -from setuptools.dist import Distribution - -# need to use distutils.core for correct placement of cython dll -if "--inplace" in sys.argv: - from distutils.core import setup - from distutils.extension import Extension -else: - from setuptools import setup - from setuptools.extension import Extension - -CURRENT_DIR = os.path.dirname(__file__) -FFI_MODE = os.environ.get("TVM_FFI", "auto") -CONDA_BUILD = os.getenv("CONDA_BUILD") is not None -INPLACE_BUILD = "--inplace" in sys.argv - - -def get_lib_path(): - """Get library path, name and version""" - # We can not import `libinfo.py` in setup.py directly since __init__.py - # Will be invoked which introduces dependencies - libinfo_py = os.path.join(CURRENT_DIR, "./tvm/libinfo.py") - libinfo = {"__file__": libinfo_py} - exec(compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"), libinfo, libinfo) - version = libinfo["__version__"] - if not CONDA_BUILD and not INPLACE_BUILD: - lib_path = libinfo["find_lib_path"]() - libs = [lib_path[0]] - if "runtime" not in libs[0]: - for name in lib_path[1:]: - if "runtime" in name: - libs.append(name) - break - - # Add byoc shared libraries, if present - for name in lib_path: - if "3rdparty" in name: - libs.append(name) - - # Add tvmc configuration json files - for name in lib_path: - candidate_path = os.path.abspath(os.path.join(os.path.dirname(name), "..", "configs")) - if os.path.isdir(candidate_path): - libs.append(candidate_path) - break - - for dir in [ - "3rdparty", - "jvm", - "web", - "rust", - "golang", - "include", - "src", - "cmake", - "CMakeLists.txt", - ]: - for name in lib_path: - candidate_path = os.path.abspath(os.path.join(os.path.dirname(name), "..", dir)) - if os.path.exists(candidate_path): - libs.append(candidate_path) - if dir == "3rdparty": - # remove large files - _remove_path(os.path.join(candidate_path, "cutlass", "docs")) - _remove_path(os.path.join(candidate_path, "cutlass", "media")) - _remove_path( - os.path.join(candidate_path, "cutlass_fpA_intB_gemm", "cutlass", "docs") - ) - _remove_path( - os.path.join( - candidate_path, "cutlass_fpA_intB_gemm", "cutlass", "media" - ) - ) - _remove_path( - os.path.join(candidate_path, "libflash_attn", "cutlass", "docs") - ) - _remove_path( - os.path.join(candidate_path, "libflash_attn", "cutlass", "media") - ) - break - else: - libs = None - - return libs, version - - -def git_describe_version(original_version): - """Get git describe version.""" - ver_py = os.path.join(CURRENT_DIR, "..", "version.py") - libver = {"__file__": ver_py} - exec(compile(open(ver_py, "rb").read(), ver_py, "exec"), libver, libver) - _, gd_version = libver["git_describe_version"]() - if gd_version != original_version and "--inplace" not in sys.argv: - print("Use git describe based version %s" % gd_version) - return gd_version - - -def _remove_path(path): - if os.path.exists(path): - if os.path.isfile(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) - - -LIB_LIST, __version__ = get_lib_path() -__version__ = git_describe_version(__version__) - -if not CONDA_BUILD and not INPLACE_BUILD: - # Wheel cleanup - for path in LIB_LIST: - libname = os.path.basename(path) - _remove_path(f"tvm/{libname}") - - -def config_cython(): - """Try to configure cython and return cython configuration""" - # Enforce cython unless FFI_MODE is explicitly set to ctypes - # we might consider fully converge to cython later - if FFI_MODE == "ctypes": - return [] - try: - from Cython.Build import cythonize - - # for python 3.12+, use limited API for future compact - limited_api_kwargs = {} - if sys.version_info >= (3, 12): - limited_api_kwargs = { - "define_macros": [ - ("Py_LIMITED_API", 0x030C0000), - ], - "py_limited_api": True, - } - - ret = [] - extra_compile_args = ["-std=c++17", "-DDMLC_USE_LOGGING_LIBRARY="] - if os.name == "nt": - library_dirs = ["tvm", "../build/Release", "../build"] - libraries = ["tvm"] - extra_compile_args = [ - "/std:c++17", - "/D DMLC_USE_LOGGING_LIBRARY=", - ] - # library is available via conda env. - if CONDA_BUILD: - library_dirs = [os.environ["LIBRARY_LIB"]] - else: - library_dirs = None - libraries = None - - # the latest ffi source - for fn in os.listdir("tvm/ffi/cython"): - if not fn.endswith(".pyx"): - continue - ret.append( - Extension( - f"tvm.ffi.{fn[:-4]}", - ["tvm/ffi/cython/%s" % fn], - include_dirs=[ - "../ffi/include/", - "../ffi/3rdparty/dlpack/include", - ], - extra_compile_args=extra_compile_args, - library_dirs=library_dirs, - libraries=libraries, - language="c++", - **limited_api_kwargs, - ) - ) - return cythonize(ret, compiler_directives={"language_level": 3}) - except ImportError as error: - raise RuntimeError("Cython is not installed, please pip install cython") - - -class BinaryDistribution(Distribution): - def has_ext_modules(self): - return True - - def is_pure(self): - return False - - -setup_kwargs = {} -if not CONDA_BUILD and not INPLACE_BUILD: - with open("MANIFEST.in", "w") as fo: - for path in LIB_LIST: - if os.path.isfile(path): - shutil.copy(path, os.path.join(CURRENT_DIR, "tvm")) - _, libname = os.path.split(path) - fo.write(f"include tvm/{libname}\n") - - if os.path.isdir(path): - _, libname = os.path.split(path) - shutil.copytree(path, os.path.join(CURRENT_DIR, "tvm", libname)) - fo.write(f"recursive-include tvm/{libname} *\n") - - setup_kwargs = {"include_package_data": True} - - -def long_description_contents(): - with open(pathlib.Path(CURRENT_DIR).resolve().parent / "README.md", encoding="utf-8") as readme: - description = readme.read() - - return description - - -# Temporarily add this directory to the path so we can import the requirements generator -# tool. -sys.path.insert(0, os.path.dirname(__file__)) -import gen_requirements - -sys.path.pop(0) - -requirements = gen_requirements.join_requirements() -extras_require = { - piece: deps for piece, (_, deps) in requirements.items() if piece not in ("all", "core") -} - -setup( - name="tvm", - version=__version__, - description="TVM: An End to End Tensor IR/DSL Stack for Deep Learning Systems", - long_description=long_description_contents(), - long_description_content_type="text/markdown", - url="https://tvm.apache.org/", - download_url="https://github.com/apache/tvm/tags", - author="Apache TVM", - license="Apache", - # See https://pypi.org/classifiers/ - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - ], - keywords="machine learning", - zip_safe=False, - install_requires=requirements["core"][1], - extras_require=extras_require, - packages=find_packages(), - package_dir={"tvm": "tvm"}, - distclass=BinaryDistribution, - ext_modules=config_cython(), - **setup_kwargs, -) - - -if not CONDA_BUILD and not INPLACE_BUILD: - # Wheel cleanup - os.remove("MANIFEST.in") - for path in LIB_LIST: - libname = os.path.basename(path) - _remove_path(f"tvm/{libname}") diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 150e5d4b1dbc..55c78e43c07b 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -20,18 +20,18 @@ import sys import os +# ffi module must load first +from tvm_ffi import register_object, register_global_func, get_global_func + # top-level alias -# tvm._ffi from .base import TVMError, __version__, _RUNTIME_ONLY -from .ffi import register_object, register_func, get_global_func - # top-level alias # tvm.runtime from .runtime.object import Object -from .runtime.ndarray import device, cpu, cuda, opencl, vulkan, metal -from .runtime.ndarray import vpi, rocm, ext_dev, hexagon -from .runtime import ndarray as nd, DataType, DataTypeCode +from .runtime._tensor import device, cpu, cuda, opencl, vulkan, metal +from .runtime._tensor import vpi, rocm, ext_dev, hexagon +from .runtime import DataType, DataTypeCode # tvm.error from . import error diff --git a/python/tvm/arith/_ffi_api.py b/python/tvm/arith/_ffi_api.py index e05405b0fcc6..519423aa4e1f 100644 --- a/python/tvm/arith/_ffi_api.py +++ b/python/tvm/arith/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.arith""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("arith", __name__) +tvm_ffi.init_ffi_api("arith", __name__) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 434e2a3e65c6..d8c7e88656b9 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -17,9 +17,9 @@ # pylint: disable=invalid-name """Arithmetic data structure and utility""" import enum -from typing import Union +from typing import Union, Dict -import tvm.ffi +import tvm_ffi from tvm import ir, tir from tvm.arith import IntSet from tvm.runtime import Object @@ -47,7 +47,7 @@ class Extension(enum.Flag): ComparisonOfProductAndSum = 1 << 3 -@tvm.ffi.register_object("arith.ModularSet") +@tvm_ffi.register_object("arith.ModularSet") class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" @@ -55,7 +55,7 @@ def __init__(self, coeff, base): self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base) -@tvm.ffi.register_object("arith.ConstIntBound") +@tvm_ffi.register_object("arith.ConstIntBound") class ConstIntBound(Object): """Represent constant integer bound @@ -108,22 +108,80 @@ class Analyzer: def __init__(self): _mod = _ffi_api.CreateAnalyzer() - self._const_int_bound = _mod("const_int_bound") - self._const_int_bound_update = _mod("const_int_bound_update") - self._const_int_bound_is_bound = _mod("const_int_bound_is_bound") - self._bind = _mod("bind") - self._modular_set = _mod("modular_set") - self._simplify = _mod("Simplify") - self._rewrite_simplify = _mod("rewrite_simplify") - self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats") - self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats") - self._canonical_simplify = _mod("canonical_simplify") - self._int_set = _mod("int_set") - self._enter_constraint_context = _mod("enter_constraint_context") - self._can_prove_equal = _mod("can_prove_equal") - self._can_prove = _mod("can_prove") - self._get_enabled_extensions = _mod("get_enabled_extensions") - self._set_enabled_extensions = _mod("set_enabled_extensions") + self._assign_functions(_mod) + + def _assign_functions(self, mod_factory): + # Save factory for later use (e.g., clone) + self._factory = mod_factory + self._const_int_bound = mod_factory("const_int_bound") + self._const_int_bound_update = mod_factory("const_int_bound_update") + self._const_int_bound_is_bound = mod_factory("const_int_bound_is_bound") + self._bind = mod_factory("bind") + self._modular_set = mod_factory("modular_set") + self._simplify = mod_factory("Simplify") + self._rewrite_simplify = mod_factory("rewrite_simplify") + self._get_rewrite_simplify_stats = mod_factory("get_rewrite_simplify_stats") + self._reset_rewrite_simplify_stats = mod_factory("reset_rewrite_simplify_stats") + self._canonical_simplify = mod_factory("canonical_simplify") + self._int_set = mod_factory("int_set") + self._enter_constraint_context = mod_factory("enter_constraint_context") + self._can_prove_equal = mod_factory("can_prove_equal") + self._can_prove = mod_factory("can_prove") + self._get_smtlib2 = mod_factory("get_smtlib2") + self._set_z3_timeout_ms = mod_factory("set_z3_timeout_ms") + self._set_z3_rlimit = mod_factory("set_z3_rlimit") + self._get_z3_stats = mod_factory("get_z3_stats") + self._get_enabled_extensions = mod_factory("get_enabled_extensions") + self._set_enabled_extensions = mod_factory("set_enabled_extensions") + # Clone factory returns another mod_factory when invoked + self._clone_factory = mod_factory("clone") + + def get_smtlib2(self, expr: tir.PrimExpr = None) -> str: + return self._get_smtlib2(expr) + + def set_z3_timeout_ms(self, timeout_ms: int) -> None: + """Set z3 timeout in milliseconds. + + Parameters + ---------- + timeout_ms : int + The timeout in milliseconds. + """ + self._set_z3_timeout_ms(timeout_ms) + + def set_z3_rlimit(self, max_step: int) -> None: + """Set z3 max step. + + Parameters + ---------- + max_step : int + The maximum number of steps. + """ + self._set_z3_rlimit(max_step) + + def get_z3_stats(self) -> str: + """Get z3 statistics. + + Returns + ------- + stats : str + The z3 statistics. + """ + return self._get_z3_stats() + + def clone(self) -> "Analyzer": + """Create a deep copy of this Analyzer, including internal state. + + Returns + ------- + Analyzer + A new Analyzer instance with the same analysis state. + """ + # _clone_factory() returns a new factory bound to the cloned C++ Analyzer + new_factory = self._clone_factory() + obj = Analyzer.__new__(Analyzer) + Analyzer._assign_functions(obj, new_factory) + return obj def const_int_bound(self, expr: tir.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. @@ -227,7 +285,7 @@ def canonical_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr: """ return self._canonical_simplify(expr) - def int_set(self, expr: tir.PrimExpr, dom_map: dict[tir.Var, IntSet]) -> IntSet: + def int_set(self, expr: tir.PrimExpr, dom_map: Dict[tir.Var, IntSet]) -> IntSet: """Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index 7a0aae5fdaea..fc6c20dec1ce 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """Integer set.""" -import tvm.ffi +import tvm_ffi from tvm.runtime import Object from . import _ffi_api -@tvm.ffi.register_object("ir.IntSet") +@tvm_ffi.register_object("ir.IntSet") class IntSet(Object): """Represent a set of integer in one dimension.""" @@ -65,7 +65,7 @@ def single_point(point): return _ffi_api.intset_single_point(point) -@tvm.ffi.register_object("arith.IntervalSet") +@tvm_ffi.register_object("arith.IntervalSet") class IntervalSet(IntSet): """Represent set of continuous interval [min_value, max_value] @@ -82,7 +82,7 @@ def __init__(self, min_value, max_value): self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value) -@tvm.ffi.register_object("arith.PresburgerSet") +@tvm_ffi.register_object("arith.PresburgerSet") class PresburgerSet(IntSet): """Represent of Presburger Set""" diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index a97cda10f8eb..72e4c46896ff 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """integer constraints data structures and solvers""" -import tvm.ffi +import tvm_ffi from tvm.runtime import Object from . import _ffi_api -@tvm.ffi.register_object("arith.IntGroupBounds") +@tvm_ffi.register_object("arith.IntGroupBounds") class IntGroupBounds(Object): """Represent integer grouped bounds which are classified into lower bounds (include), upper bounds (include) and equalities. @@ -66,7 +66,7 @@ def find_best_range(self): return _ffi_api.IntGroupBounds_FindBestRange(self) -@tvm.ffi.register_object("arith.IntConstraints") +@tvm_ffi.register_object("arith.IntConstraints") class IntConstraints(Object): """Represent a set of integer constraints including variables, their ranges and the relations between them (either equations or inequalities) @@ -85,7 +85,7 @@ def __init__(self, variables, ranges, relations): self.__init_handle_by_constructor__(_ffi_api.IntConstraints, variables, ranges, relations) -@tvm.ffi.register_object("arith.IntConstraintsTransform") +@tvm_ffi.register_object("arith.IntConstraintsTransform") class IntConstraintsTransform(Object): """We can have different set of variables to represent the same integer constraints. For example, the following two constrains are equivalent, diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 328bb052b87f..69ad3022fb4a 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -16,18 +16,18 @@ # under the License. """Iterator (quasi)affine mapping patterns.""" from enum import IntEnum -import tvm.ffi +import tvm_ffi from tvm.runtime import Object from tvm.ir import PrimExpr from . import _ffi_api -@tvm.ffi.register_object("arith.IterMapExpr") +@tvm_ffi.register_object("arith.IterMapExpr") class IterMapExpr(PrimExpr): """Base class of all IterMap expressions.""" -@tvm.ffi.register_object("arith.IterMark") +@tvm_ffi.register_object("arith.IterMark") class IterMark(Object): """Mark the source as an iterator in [0, extent). @@ -44,7 +44,7 @@ def __init__(self, source, extent): self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) -@tvm.ffi.register_object("arith.IterSplitExpr") +@tvm_ffi.register_object("arith.IterSplitExpr") class IterSplitExpr(IterMapExpr): """Split of an iterator. @@ -71,7 +71,7 @@ def __init__(self, source, lower_factor, extent, scale): ) -@tvm.ffi.register_object("arith.IterSumExpr") +@tvm_ffi.register_object("arith.IterSumExpr") class IterSumExpr(IterMapExpr): """Fuse multiple iterators by summing them with scaling. @@ -90,7 +90,7 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) -@tvm.ffi.register_object("arith.IterMapResult") +@tvm_ffi.register_object("arith.IterMapResult") class IterMapResult(Object): """Result of iter map detection.""" diff --git a/python/tvm/base.py b/python/tvm/base.py index 63e097999cf5..f5bdc215ce1e 100644 --- a/python/tvm/base.py +++ b/python/tvm/base.py @@ -26,8 +26,8 @@ # ---------------------------- # Python3 version. # ---------------------------- -if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 9): - PY3STATEMENT = "The minimal Python requirement is Python 3.9" +if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 8): + PY3STATEMENT = "The minimal Python requirement is Python 3.8" raise Exception(PY3STATEMENT) # ---------------------------- @@ -42,7 +42,7 @@ def _load_lib(): if sys.platform.startswith("win32"): for path in libinfo.get_dll_directories(): os.add_dll_directory(path) - lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) + lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL | os.RTLD_LAZY) return lib, os.path.basename(lib_path[0]) @@ -62,7 +62,7 @@ def _load_lib(): if _RUNTIME_ONLY: - from .ffi import registry as _tvm_ffi_registry + from tvm_ffi import registry as _tvm_ffi_registry _tvm_ffi_registry._SKIP_UNKNOWN_OBJECTS = True diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 04a69baee9c1..bd3583453533 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -348,12 +348,12 @@ def _linux_compile( if compile_shared or output.endswith(".so") or output.endswith(".dylib"): cmd += ["-shared"] cmd += ["-o", output] + if options: + cmd += options if isinstance(objects, str): cmd += [objects] else: cmd += objects - if options: - cmd += options env = None if ccache_env is not None: if shutil.which("ccache"): @@ -362,6 +362,7 @@ def _linux_compile( env.update(ccache_env) else: raise ValueError("ccache not found") + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env) (out, _) = proc.communicate() if proc.returncode != 0: diff --git a/python/tvm/contrib/coreml_runtime.py b/python/tvm/contrib/coreml_runtime.py index def5d3c2e06e..1d185059f0bd 100644 --- a/python/tvm/contrib/coreml_runtime.py +++ b/python/tvm/contrib/coreml_runtime.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """CoreML runtime that load and run coreml models.""" -import tvm.ffi +import tvm_ffi from ..rpc import base as rpc_base @@ -35,13 +35,13 @@ def create(symbol, compiled_model_path, device): coreml_runtime : CoreMLModule Runtime coreml module that can be used to execute the coreml model. """ - device_type = device.device_type + device_type = device.dlpack_device_type() runtime_func = "tvm.coreml_runtime.create" if device_type >= rpc_base.RPC_SESS_MASK: fcreate = device._rpc_sess.get_function(runtime_func) else: - fcreate = tvm.ffi.get_global_func(runtime_func) + fcreate = tvm_ffi.get_global_func(runtime_func) assert fcreate, "Cannot find `tvm.coreml_runtime.create` function." return CoreMLModule(fcreate(symbol, compiled_model_path)) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 1c80d4a3b9e1..b69bc4f84ee5 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -20,7 +20,7 @@ import numpy as np import tvm -import tvm.ffi +import tvm_ffi from tvm import te # algos can be read from cudnn.h @@ -123,7 +123,7 @@ def _get_np_int32_array_handle(arr): Parameters ---------- - arr: numpy.NDArray + arr: numpy.Tensor source numpy array Returns @@ -349,7 +349,7 @@ def _conv_find_algo( dims - 2, pad, stride, dilation, x_shape, w_shape ) yshape = np.array(y_shape, dtype=np.int32) - func = tvm.ffi.get_global_func(func_name) + func = tvm_ffi.get_global_func(func_name) return func( tensor_format, dims - 2, diff --git a/python/tvm/contrib/cutlass/_ffi_api.py b/python/tvm/contrib/cutlass/_ffi_api.py index be71b0d48f13..d57825835b6b 100644 --- a/python/tvm/contrib/cutlass/_ffi_api.py +++ b/python/tvm/contrib/cutlass/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI API for CUTLASS BYOC.""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("contrib.cutlass", __name__) +tvm_ffi.init_ffi_api("contrib.cutlass", __name__) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index fe29cd59459b..ff804e83460c 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -147,7 +147,7 @@ def instantiate_attention_template(attrs): } CHECK(Attention::check_supported(p)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); kernel_fn<<>>(p); @@ -185,7 +185,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${query}->data), @@ -235,7 +235,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${qkv}->data), @@ -291,7 +291,7 @@ def instantiate_flash_attention_var_len_template(attrs): int v_row_stride = v_head_stride * ${num_kv_heads}; int o_row_stride = o_head_stride * ${num_q_heads}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_var_len_forward( static_cast(${query}->data), diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 0aea5bf1416a..4b2a50a5f1d8 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -23,10 +23,10 @@ import os from functools import reduce from typing import Optional, Sequence +from tvm_ffi import register_global_func import tvm from tvm import relax, runtime -from tvm.ffi.registry import register_func from tvm.contrib.nvcc import get_cuda_version from tvm.topi.utils import get_const_tuple @@ -821,7 +821,7 @@ def visit_span(self, span): return span -@register_func("contrib.cutlass.tune_relax_function") +@register_global_func("contrib.cutlass.tune_relax_function") def profile_relax_function(functions, options): """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates.""" tmp_dir = options.get("tmp_dir", "./tmp") @@ -840,7 +840,7 @@ def profile_relax_function(functions, options): return annotated_functions -@register_func("contrib.cutlass.compile") +@register_global_func("contrib.cutlass.compile") def compile_cutlass_module(c_source_module, options): """Compile all CUTLASS kernels in the given C-source module. diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index b0afdcdd6e84..e323e2a14937 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -424,7 +424,7 @@ def instantiate_conv2d_template(attrs): TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); ${split_k_update} - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${data_arg}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${data_arg}->device.device_id)); status = conv2d_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 453839cc8130..d8940230e0e3 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -345,7 +345,7 @@ def instantiate_gemm_template(attrs): status = gemm_op.initialize(arguments, workspace.get()); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${A_arg}->device.device_id)); status = gemm_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); @@ -428,7 +428,7 @@ def emit_fp16A_intB_matmul(attrs): int k = ${B_arg}->shape[0]; cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); + TVMFFIEnvGetStream(kDLCUDA, ${A_arg}->device.device_id)); """, attrs, ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index c594b3897a6c..3a875ce220d0 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -24,7 +24,7 @@ import subprocess import tempfile -import tvm.ffi +import tvm_ffi from tvm.runtime import Object from tvm.tir import IntImm @@ -461,7 +461,7 @@ def _get_optional_int_annotation(annotations, key, default=None): return int(value) -@tvm.ffi.register_func("contrib.cutlass.instantiate_template") +@tvm_ffi.register_global_func("contrib.cutlass.instantiate_template") def instantiate_template(func_name, annotations, func_args): """Return CUTLASS host code based on a template and the provided annotations. diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py b/python/tvm/contrib/cutlass/layer_norm_operation.py index d2a031024475..b0f7dc7c14f7 100644 --- a/python/tvm/contrib/cutlass/layer_norm_operation.py +++ b/python/tvm/contrib/cutlass/layer_norm_operation.py @@ -39,7 +39,7 @@ def instantiate_layer_norm_template(attrs): cutlass::TensorRef _beta((data_type*)${beta}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${input}->device.device_id)); cutlass::layernorm(size, _output, _input, _gamma, _beta, stream); """ diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py b/python/tvm/contrib/cutlass/rms_norm_operation.py index 51c18d4ae47b..3d038ab21011 100644 --- a/python/tvm/contrib/cutlass/rms_norm_operation.py +++ b/python/tvm/contrib/cutlass/rms_norm_operation.py @@ -38,7 +38,7 @@ def instantiate_rms_norm_template(attrs): cutlass::TensorRef _weight((data_type*)${weight}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${input}->device.device_id)); cutlass::rmsnorm(size, _output, _input, _weight, stream, ${rms_eps}); """ diff --git a/python/tvm/contrib/dlpack.py b/python/tvm/contrib/dlpack.py index 75b37cef6199..e6214ed3a259 100644 --- a/python/tvm/contrib/dlpack.py +++ b/python/tvm/contrib/dlpack.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Wrapping functions to bridge frameworks with DLPack support to TVM""" -from tvm.runtime import ndarray +import tvm.runtime def convert_func(tvm_func, tensor_type, to_dlpack_func): @@ -37,7 +37,7 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): def _wrapper(*args): args = tuple( - ndarray.from_dlpack(to_dlpack_func(arg)) if isinstance(arg, tensor_type) else arg + tvm.runtime.from_dlpack(to_dlpack_func(arg)) if isinstance(arg, tensor_type) else arg for arg in args ) return tvm_func(*args) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index f4b02ff80f73..eb1a0342c75e 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -35,7 +35,7 @@ from typing import Union from tvm.contrib.hexagon.hexagon_profiler import HexagonProfiler -from ...ffi import libinfo +from tvm_ffi import libinfo from .session import Session from .tools import HEXAGON_SIMULATOR_NAME diff --git a/python/tvm/contrib/hexagon/generate_take_op.py b/python/tvm/contrib/hexagon/generate_take_op.py index b70eb451a1a5..080a7d6a1953 100644 --- a/python/tvm/contrib/hexagon/generate_take_op.py +++ b/python/tvm/contrib/hexagon/generate_take_op.py @@ -84,7 +84,7 @@ def visit_call_(self, call_node: relax.Call) -> relax.Call: take_node = relax.call_tir( take_func_gv, relax.expr.Tuple( - [call_node.args[1][0], relax.expr.Constant(tvm.nd.array(LUT))] + [call_node.args[1][0], relax.expr.Constant(tvm.runtime.tensor(LUT))] ), call_node.struct_info, ) diff --git a/python/tvm/contrib/hexagon/meta_schedule.py b/python/tvm/contrib/hexagon/meta_schedule.py index 92298c011d4a..7c4ccdd5b20f 100644 --- a/python/tvm/contrib/hexagon/meta_schedule.py +++ b/python/tvm/contrib/hexagon/meta_schedule.py @@ -21,7 +21,7 @@ import tvm from tvm.ir.module import IRModule -from tvm.runtime import Module, NDArray +from tvm.runtime import Module, Tensor from tvm.target import Target from tvm.driver import build as tvm_build from tvm.tir.transform import RemoveWeightLayoutRewriteBlock @@ -140,10 +140,10 @@ def export_func(mod): return str(binary_path) def default_build_with_context( - mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]] + mod: IRModule, target: Target, _params: Optional[Dict[str, Tensor]] ) -> Module: with pass_context: - mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod) + mod = RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite=True)(mod) return tvm_build(mod, target=target) if pass_context is not None: diff --git a/python/tvm/contrib/hexagon/tools.py b/python/tvm/contrib/hexagon/tools.py index f7f22db721ce..f010461df082 100644 --- a/python/tvm/contrib/hexagon/tools.py +++ b/python/tvm/contrib/hexagon/tools.py @@ -29,7 +29,7 @@ import tvm import tvm.contrib.cc as cc -from ...ffi.registry import register_func +from tvm_ffi import register_global_func # Linking Hexagon shared libraries. @@ -67,10 +67,10 @@ def register_linker(f): """Register a function that will return the path to the Hexagon linker.""" - return register_func("tvm.contrib.hexagon.hexagon_link", f, True) + return register_global_func("tvm.contrib.hexagon.hexagon_link", f, True) -@register_func("tvm.contrib.hexagon.hexagon_link") +@register_global_func("tvm.contrib.hexagon.hexagon_link") def hexagon_link() -> str: """Return path to the Hexagon linker.""" return str(HEXAGON_LINK_MAIN) @@ -112,7 +112,7 @@ def toolchain_version(toolchain=None) -> List[int]: raise RuntimeError("Cannot establish toolchain version") -@register_func("tvm.contrib.hexagon.link_shared") +@register_global_func("tvm.contrib.hexagon.link_shared") def link_shared(so_name, objs, extra_args=None): """Link shared library on Hexagon using the registered Hexagon linker. @@ -248,10 +248,10 @@ def __create_shared_mac(so_name, objs, **kwargs): return link_shared_macos(so_name, objs, kwargs) create_shared = __create_shared_mac - register_func("tvm.contrib.hexagon.link_shared", f=link_shared_macos, override=True) + register_global_func("tvm.contrib.hexagon.link_shared", f=link_shared_macos, override=True) else: # Linux and Win32 create_shared = cc.create_shared - register_func("tvm.contrib.hexagon.link_shared", f=link_shared, override=True) + register_global_func("tvm.contrib.hexagon.link_shared", f=link_shared, override=True) def create_aot_shared(so_name: Union[str, pathlib.Path], files, hexagon_arch: str, options=None): @@ -336,7 +336,7 @@ def pack_imports( """ path_bin = os.path.join(workspace_dir, "imports.bin") - pack_to_bin_f_name = "runtime.ModulePackImportsToNDArray" + pack_to_bin_f_name = "runtime.ModulePackImportsToTensor" fpack_to_bin = tvm.get_global_func(pack_to_bin_f_name) assert fpack_to_bin, f"Expecting {pack_to_bin_f_name} in registry" @@ -438,7 +438,7 @@ def allocate_hexagon_array( for dim_i, dim_f in zip(boundaries[:-1], boundaries[1:]) ] - arr = tvm.nd.empty(physical_shape, dtype=dtype, device=dev, mem_scope=mem_scope) + arr = tvm.runtime.empty(physical_shape, dtype=dtype, device=dev, mem_scope=mem_scope) if data is not None: arr.copyfrom(data.reshape(physical_shape)) diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index 22b08f38ca76..6ec2cd78e4d3 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -19,7 +19,7 @@ import ctypes import numpy as np import tvm -import tvm.ffi +import tvm_ffi from tvm import te @@ -29,7 +29,7 @@ def _get_np_int32_array_handle(arr): Parameters ---------- - arr: numpy.NDArray + arr: numpy.Tensor source numpy array Returns @@ -94,7 +94,7 @@ def conv2d_forward( oshape = np.zeros((len(x.shape)), dtype=np.int32) xshape = x.shape wshape = w.shape - setup_func = tvm.ffi.get_global_func("tvm.contrib.miopen.conv2d.setup") + setup_func = tvm_ffi.get_global_func("tvm.contrib.miopen.conv2d.setup") algo = setup_func( conv_mode, data_type, diff --git a/python/tvm/contrib/mrvl.py b/python/tvm/contrib/mrvl.py index 36c932cd1a1d..996f6f881882 100644 --- a/python/tvm/contrib/mrvl.py +++ b/python/tvm/contrib/mrvl.py @@ -23,11 +23,10 @@ import tempfile import base64 import numpy as np -import tvm -import tvm.ffi +import tvm_ffi -@tvm.ffi.register_func("tvm.mrvl.find_value_in_KV_pair") +@tvm_ffi.register_global_func("tvm.mrvl.find_value_in_KV_pair") def find_value_in_KV_pair(json_input: str, key_to_find: str) -> str: """This function takes the graph_json string and key to be searched in the json string, using json parser routine it loads the json string @@ -54,7 +53,7 @@ def find_value_in_KV_pair(json_input: str, key_to_find: str) -> str: return value -@tvm.ffi.register_func("tvm.mrvl.GetNodesJSONString") +@tvm_ffi.register_global_func("tvm.mrvl.GetNodesJSONString") def get_nodes_json_string(graph_json): """This takes the graph_json string from MrvlJSONSerializer and adds / modifies the json string to a form suitable for the Marvell Backend. @@ -206,7 +205,7 @@ def get_nodes_json_string(graph_json): return nodes_json_string -@tvm.ffi.register_func("tvm.mrvl.ModifyConstNames") +@tvm_ffi.register_global_func("tvm.mrvl.ModifyConstNames") def modify_const_names(nodes_json_str, consts_json_str): """This takes the graph module returned by build an generates nodes and constant meta data suitable for compilation by the back end. @@ -329,7 +328,7 @@ def get_working_dir(): return os.getcwd() -@tvm.ffi.register_func("tvm.mrvl.WriteJsonFile") +@tvm_ffi.register_global_func("tvm.mrvl.WriteJsonFile") def write_json_file(json_string, json_filename): """Generate json file under working directory""" working_dir = get_working_dir() @@ -351,7 +350,7 @@ def delete_temp_files(symbol_name): shutil.rmtree(bin_folder) -@tvm.ffi.register_func("tvm.mrvl.CompileModel") +@tvm_ffi.register_global_func("tvm.mrvl.CompileModel") def compile_model( symbol_name, nodes_json_string, @@ -414,7 +413,7 @@ def compile_model( raise RuntimeError(error_msg) -@tvm.ffi.register_func("tvm.mrvl.CleanUpSim") +@tvm_ffi.register_global_func("tvm.mrvl.CleanUpSim") def clean_up_sim(bin_file, input_json, input_bin, out_bin_prefix, num_outputs): os.remove(bin_file) os.remove(input_json) @@ -424,7 +423,7 @@ def clean_up_sim(bin_file, input_json, input_bin, out_bin_prefix, num_outputs): os.remove(out_bin) -@tvm.ffi.register_func("tvm.mrvl.SearchPath") +@tvm_ffi.register_global_func("tvm.mrvl.SearchPath") def search_path(file_name): path = shutil.which(file_name) if path is None: @@ -432,7 +431,7 @@ def search_path(file_name): return os.path.dirname(path) -@tvm.ffi.register_func("tvm.mrvl.JsonToBin") +@tvm_ffi.register_global_func("tvm.mrvl.JsonToBin") def convert_json_to_bin(json_file, input_bin_file): with open(json_file) as input_json: data = json.load(input_json) @@ -442,7 +441,7 @@ def convert_json_to_bin(json_file, input_bin_file): f.write(data_b) -@tvm.ffi.register_func("tvm.mrvl.RunSim") +@tvm_ffi.register_global_func("tvm.mrvl.RunSim") def run_simulation(run_command, sim_directory): cwd_path = get_working_dir() os.mkdir(sim_directory) @@ -452,6 +451,6 @@ def run_simulation(run_command, sim_directory): shutil.rmtree(sim_directory) -@tvm.ffi.register_func("tvm.mrvl.TempDir") +@tvm_ffi.register_global_func("tvm.mrvl.TempDir") def get_temp_dir(): return tempfile.gettempdir() diff --git a/python/tvm/contrib/msc/core/_ffi_api.py b/python/tvm/contrib/msc/core/_ffi_api.py index f7c975aff98a..ff027a0dec8e 100644 --- a/python/tvm/contrib/msc/core/_ffi_api.py +++ b/python/tvm/contrib/msc/core/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.core._ffi_api""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("msc.core", __name__) +tvm_ffi.init_ffi_api("msc.core", __name__) diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index 96c9c23dfd9d..b2b97fc8b593 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -129,7 +129,7 @@ def load( def to_relax( graph: MSCGraph, - weights: Optional[Dict[str, tvm.nd.array]] = None, + weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index 687d770c93a6..24825c99d485 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -67,13 +67,13 @@ def _normalize(info): def normalize_weights( - t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph -) -> Dict[str, tvm.nd.array]: + t_weights: Dict[MSCTensor, tvm.runtime.Tensor], graph: MSCGraph +) -> Dict[str, tvm.runtime.Tensor]: """Normalize the weghts. Parameters ---------- - t_weights: dict of + t_weights: dict of The weights extracted from IRModule. graph: tvm.contrib.msc.core.ir.MSCGraph The translated graph. @@ -88,7 +88,7 @@ def _to_data(ref_t, data): weight_t = graph.find_tensor(ref_t.name) if weight_t.ndim == 1: if ref_t.ndim != weight_t.ndim: - return tvm.nd.array(data.numpy().reshape(weight_t.get_shape())) + return tvm.runtime.tensor(data.numpy().reshape(weight_t.get_shape())) return data if ref_t.layout and weight_t.layout: ref_layout, weight_layout = ref_t.layout.name, weight_t.layout.name @@ -97,7 +97,7 @@ def _to_data(ref_t, data): l in ref_layout for l in weight_layout ), "layout mismatch {} compare to {}".format(ref_t, weight_t) permute = [ref_layout.index(l) for l in weight_layout] - return tvm.nd.array(data.numpy().transpose(*permute)) + return tvm.runtime.tensor(data.numpy().transpose(*permute)) return data weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if graph.has_tensor(t.name)} @@ -111,11 +111,11 @@ def _to_data(ref_t, data): def from_relax( mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, + params: Optional[Dict[str, tvm.runtime.Tensor]] = None, trans_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, opt_config: Optional[Dict[str, str]] = None, -) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]: +) -> Tuple[MSCGraph, Dict[str, tvm.runtime.Tensor]]: """Change IRModule to MSCGraph. Parameters @@ -195,10 +195,10 @@ def visit_var_binding_(self, binding) -> None: def byoc_partition( target: str, mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, + params: Optional[Dict[str, tvm.runtime.Tensor]] = None, trans_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, -) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.nd.array]]]]: +) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.runtime.Tensor]]]]: """Partition module to target sub functions. Parameters diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 7bd88df5f6f4..6b40be4bf9de 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -18,6 +18,7 @@ from typing import Dict, Tuple, List, Optional, Union, Iterable, Any import numpy as np +import tvm_ffi import tvm from tvm.runtime import Object @@ -25,7 +26,7 @@ from tvm.contrib.msc.core import utils as msc_utils -@tvm.ffi.register_object("msc.core.MSCTensor") +@tvm_ffi.register_object("msc.core.MSCTensor") class MSCTensor(Object): """Tensor in MSCGraph @@ -194,12 +195,12 @@ def ndim(self) -> int: return len(self.shape) -@tvm.ffi.register_object("msc.core.BaseJoint") +@tvm_ffi.register_object("msc.core.BaseJoint") class BaseJoint(Object): """Base class of all MSC Nodes.""" -@tvm.ffi.register_object("msc.core.MSCJoint") +@tvm_ffi.register_object("msc.core.MSCJoint") class MSCJoint(BaseJoint): """Node in MSCGraph @@ -424,7 +425,7 @@ def equal(self, other: BaseJoint) -> bool: return msc_utils.dict_equal(self.get_attrs(), other.get_attrs()) -@tvm.ffi.register_object("msc.core.MSCPrim") +@tvm_ffi.register_object("msc.core.MSCPrim") class MSCPrim(BaseJoint): """Prim in MSCGraph @@ -448,7 +449,7 @@ def __init__( self.__init_handle_by_constructor__(_ffi_api.MSCPrim, index, name, optype, attrs, parents) -@tvm.ffi.register_object("msc.core.WeightJoint") +@tvm_ffi.register_object("msc.core.WeightJoint") class WeightJoint(BaseJoint): """Node in WeightGraph @@ -562,12 +563,12 @@ def has_attr(self, key: str) -> bool: return bool(_ffi_api.WeightJointHasAttr(self, key)) -@tvm.ffi.register_object("msc.core.BaseGraph") +@tvm_ffi.register_object("msc.core.BaseGraph") class BaseGraph(Object): """Base class of all MSC Graphs.""" -@tvm.ffi.register_object("msc.core.MSCGraph") +@tvm_ffi.register_object("msc.core.MSCGraph") class MSCGraph(BaseGraph): """The MSCGraph @@ -956,7 +957,7 @@ def visualize(self, path: Optional[str] = None) -> str: return graph_proto -@tvm.ffi.register_object("msc.core.WeightGraph") +@tvm_ffi.register_object("msc.core.WeightGraph") class WeightGraph(BaseGraph): """The WeightGraph diff --git a/python/tvm/contrib/msc/core/runtime/hook.py b/python/tvm/contrib/msc/core/runtime/hook.py index e129d9771b02..f87b2d3d06a0 100644 --- a/python/tvm/contrib/msc/core/runtime/hook.py +++ b/python/tvm/contrib/msc/core/runtime/hook.py @@ -136,9 +136,9 @@ def _apply( self, runner: object, graphs: List[MSCGraph], - weights: Dict[str, tvm.nd.array], + weights: Dict[str, tvm.runtime.Tensor], weights_path: str, - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Apply the default funcion Parameters @@ -147,7 +147,7 @@ def _apply( The runner context. graphs: list The translated graphs - weights: dict + weights: dict The translated weights. weights_path: str The weights path. @@ -156,7 +156,7 @@ def _apply( ------- graphs: list The updated graphs - weights: dict + weights: dict The updated weights. """ diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index 074c7048c5e9..bd9cc01d76f2 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -340,7 +340,9 @@ def save_cache( title = self.runner_mark("SAVE_CACHE") self._logger.debug(msc_utils.msg_block(title, {"folder": cache_dir, "info": cache_info})) - def translate(self, apply_hooks: bool = True) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def translate( + self, apply_hooks: bool = True + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Translate IRModule to MSCgraphs Parameters @@ -352,7 +354,7 @@ def translate(self, apply_hooks: bool = True) -> Tuple[List[MSCGraph], Dict[str, ------- graphs: list The translated graphs - weights: dict + weights: dict The translated weights. """ @@ -366,7 +368,7 @@ def translate(self, apply_hooks: bool = True) -> Tuple[List[MSCGraph], Dict[str, graphs, weights = self._apply_hook("after translate", hook, graphs, weights) return graphs, weights - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Translate IRModule to MSCgraphs Parameters @@ -378,7 +380,7 @@ def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.n ------- graphs: list The translated graphs - weights: dict + weights: dict The translated weights. """ @@ -387,7 +389,7 @@ def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.n def reset_tools( self, graphs: List[MSCGraph] = None, - weights: List[Dict[str, tvm.nd.array]] = None, + weights: List[Dict[str, tvm.runtime.Tensor]] = None, tools: List[BaseTool] = None, cache_dir: msc_utils.MSCDirectory = None, ): @@ -397,7 +399,7 @@ def reset_tools( ------- graphs: list The msc graphs. - weights: list> + weights: list> The weights. tools: list The tools. @@ -408,7 +410,7 @@ def reset_tools( ------- graphs: list The msc graphs. - weights: list> + weights: list> The weights. """ @@ -444,14 +446,16 @@ def generate_model(self, apply_hooks: bool = True) -> Any: model = self._apply_hook("after generate", hook, model) return model - def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: + def _generate_model( + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Any: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns @@ -763,7 +767,9 @@ def get_outputs(self) -> List[Dict[str, str]]: return self._model_info["outputs"] - def get_weights(self, framework: str = None, device: str = None) -> Iterable[tvm.nd.array]: + def get_weights( + self, framework: str = None, device: str = None + ) -> Iterable[tvm.runtime.Tensor]: """Get the weights from graphs Parameters @@ -775,7 +781,7 @@ def get_weights(self, framework: str = None, device: str = None) -> Iterable[tvm Returns ------- - weights: generator + weights: generator The generator of weight datas. """ @@ -787,23 +793,23 @@ def get_weights(self, framework: str = None, device: str = None) -> Iterable[tvm data = msc_utils.cast_array(data, framework, device) yield data - def get_runtime_params(self) -> Dict[str, tvm.nd.array]: + def get_runtime_params(self) -> Dict[str, tvm.runtime.Tensor]: """Get the runtime parameters Returns ------- - params: dict + params: dict The parameters from runtime. """ return self._get_runtime_params() - def _get_runtime_params(self) -> Dict[str, tvm.nd.array]: + def _get_runtime_params(self) -> Dict[str, tvm.runtime.Tensor]: """Get the runtime parameters Returns ------- - params: dict + params: dict The parameters from runtime. """ @@ -1146,7 +1152,7 @@ def support_device(cls, device: str) -> bool: class ModelRunner(BaseRunner): """Model runner of MSC""" - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Translate IRModule to MSCgraphs Parameters @@ -1158,7 +1164,7 @@ def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.n ------- graphs: list The translated graphs - weights: dict + weights: dict The translated weights. """ @@ -1210,14 +1216,16 @@ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: f_graph.write(self._graphs[0].to_json()) return {"main": main_info} - def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: + def _generate_model( + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Any: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns @@ -1319,7 +1327,7 @@ def visualize(self, visual_dir: msc_utils.MSCDirectory, export_graph: bool = Fal with open(visual_dir.relpath(self._byoc_graph.name + "_graph.json"), "w") as f_graph: f_graph.write(self._byoc_graph.to_json()) - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Translate IRModule to MSCgraphs Parameters @@ -1331,7 +1339,7 @@ def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.n ------- graphs: list The translated graphs - weights: dict + weights: dict The translated weights. """ @@ -1405,14 +1413,16 @@ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: "byoc_mod": "byoc_module.json", } - def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: + def _generate_model( + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Any: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py b/python/tvm/contrib/msc/core/tools/distill/distiller.py index 55b7947a6e20..7812627ebc75 100644 --- a/python/tvm/contrib/msc/core/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/core/tools/distill/distiller.py @@ -48,22 +48,22 @@ def setup(self) -> dict: return super().setup() def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. """ @@ -164,7 +164,7 @@ def _save_weights(self, weights: Dict[str, Any]): The distilled weights. """ - weights = {n: tvm.nd.array(msc_utils.cast_array(d)) for n, d in weights.items()} + weights = {n: tvm.runtime.tensor(msc_utils.cast_array(d)) for n, d in weights.items()} weights_path = self._weights_folder.relpath("distill_{}.bin".format(self._current_iter)) with open(weights_path, "wb") as f_params: f_params.write(tvm.runtime.save_param_dict(weights)) diff --git a/python/tvm/contrib/msc/core/tools/execute.py b/python/tvm/contrib/msc/core/tools/execute.py index 2a47d755619e..dce9b1f1316f 100644 --- a/python/tvm/contrib/msc/core/tools/execute.py +++ b/python/tvm/contrib/msc/core/tools/execute.py @@ -214,7 +214,7 @@ def process_tensor(tensor: Any, name: str, consumer: str, scope: str, tag: str = return tensor -@tvm.register_func("msc_tool.codegen_tensor") +@tvm.register_global_func("msc_tool.codegen_tensor") def codegen_tensor( tensor_ctx: Dict[str, str], name: str, consumer: str, scope: str, tag: str = "main" ) -> List[str]: @@ -356,7 +356,7 @@ def _execute_step_with_context( return step_ctx -@tvm.register_func("msc_tool.codegen_step") +@tvm.register_global_func("msc_tool.codegen_step") def codegen_step( step_ctx: Dict[str, str], step: str, graph_name: str, tag: str = "main" ) -> List[str]: @@ -384,7 +384,7 @@ def codegen_step( return step_ctx["processed"] -@tvm.register_func("msc_tool.callback_step") +@tvm.register_global_func("msc_tool.callback_step") def callback_step(step_ctx: Dict[str, Any], step: str, graph_name: str = "main", tag: str = "main"): """Execute tools for a step diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 38f855d0ebce..95024e1abb41 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -104,22 +104,22 @@ def _update_stages(strategy): return super()._parse_strategys([_update_stages(s) for s in strategy_list]) def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. """ @@ -315,22 +315,22 @@ def _prunable(w_node: WeightJoint) -> bool: self._plan[w_node.name]["out_indices"] = [] def prune_graphs( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. """ @@ -375,7 +375,7 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): if w_config["out_indices"]: data = PruneMethod.prune_axis(data, out_axis, w_config["out_indices"]) pruned_tensors[w_name] = _prune_by_shape(weight, data.shape) - pruned_weights[w_name] = tvm.nd.array(data) + pruned_weights[w_name] = tvm.runtime.tensor(data) w_node.set_attr( "pruned_shape", ",".join([str(i) for i in pruned_tensors[w_name].get_shape()]), diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 06a16f2bbe49..cb860729f792 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -372,16 +372,16 @@ def setup(self) -> dict: def reset( self, graphs: List[MSCGraph], - weights: Dict[str, tvm.nd.array], + weights: Dict[str, tvm.runtime.Tensor], cache_dir: msc_utils.MSCDirectory = None, - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Reset the tool with graphs and weights Parameters ---------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. cache_dir: MSCDirectory cache path for save/load info. @@ -390,7 +390,7 @@ def reset( ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. """ @@ -411,22 +411,22 @@ def reset( return self._graphs, self._weights def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. """ @@ -1440,22 +1440,22 @@ def setup(self) -> dict: return super().setup() def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. """ diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py index 47ea21266eb0..19f5b5a03236 100644 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -122,7 +122,7 @@ def SetBYOCAttrs(target, entry_name: str = "main") -> tvm.ir.transform.Pass: def BindNamedParams( func_name: str, - params: Dict[str, tvm.runtime.NDArray], + params: Dict[str, tvm.runtime.Tensor], ) -> tvm.ir.transform.Pass: """Bind params of function of the module to constant tensors with span names. @@ -130,7 +130,7 @@ def BindNamedParams( ---------- func_name: str The function name to be bound - params: dict + params: dict The map from parameter or parameter name to constant tensors. diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 189dd3ebbb37..03eed9b7fdd0 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -21,6 +21,7 @@ import numpy as np import tvm +import tvm.testing from tvm.contrib.msc.core import _ffi_api from .namespace import MSCFramework @@ -46,10 +47,10 @@ def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: return MSCFramework.MSC, "list", "cpu" if isinstance(data, np.ndarray): return MSCFramework.MSC, "tensor", "cpu" - if isinstance(data, tvm.runtime.NDArray): - device = tvm.runtime.Device.DEVICE_TYPE_TO_NAME[data.device.device_type] - if data.device.device_id: - device += ":{}".format(data.device.device_id) + if isinstance(data, tvm.runtime.Tensor): + device = tvm.runtime.Device._DEVICE_TYPE_TO_NAME[data.device.dlpack_device_type()] + if data.device.index: + device += ":{}".format(data.device.index) return MSCFramework.TVM, "tensor", device if isinstance(data, tvm.relax.Var): return MSCFramework.TVM, "var", "cpu" @@ -71,7 +72,7 @@ def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: def abstract(self) -> str: """Get abstract describe of the data""" - data = self._to_ndarray() + data = self._to_tensor() prefix = "[{},{}]".format(";".join([str(s) for s in data.shape]), data.dtype.name) if data.size < 10: return "{} {}".format(prefix, ",".join([str(i) for i in data.flatten()])) @@ -79,7 +80,7 @@ def abstract(self) -> str: prefix, data.max(), data.min(), data.sum() / data.size ) - def _to_ndarray(self) -> np.ndarray: + def _to_tensor(self) -> np.ndarray: """Cast array like object to np.ndarray Returns @@ -120,7 +121,7 @@ def _to_device(self, device: str) -> Any: if self._framework == MSCFramework.TORCH: return self._meta_data.to(self.get_device(device)) if self._framework == MSCFramework.TVM: - return tvm.nd.array(self._cast_data(), device=self.get_device(device)) + return tvm.runtime.tensor(self._cast_data(), device=self.get_device(device)) return self._meta_data def cast(self, framework: str, device: str = "cpu") -> Any: @@ -144,13 +145,13 @@ def cast(self, framework: str, device: str = "cpu") -> Any: return self._meta_data if framework == self._framework: return self._to_device(device) - data = self._to_ndarray() + data = self._to_tensor() if framework == MSCFramework.TORCH: import torch # pylint: disable=import-outside-toplevel return torch.from_numpy(data).to(self.get_device(device, framework)) if framework == MSCFramework.TVM: - return tvm.nd.array(data, device=self.get_device(device, framework)) + return tvm.runtime.tensor(data, device=self.get_device(device, framework)) return data def get_device(self, device: str, framework: str = None) -> Any: @@ -198,7 +199,7 @@ def is_array(cls, data: Any) -> bool: Whether the data is array like. """ - normal_types = (np.ndarray, tvm.runtime.NDArray, tvm.relax.Var) + normal_types = (np.ndarray, tvm.runtime.Tensor, tvm.relax.Var) if isinstance(data, normal_types): return True if isinstance(data, (list, tuple)) and all(isinstance(d, (int, float)) for d in data): @@ -365,11 +366,11 @@ def _add_report(name: str, gol: Any, data: Any, passed: bool): ) continue if gol.dtype.name in ("int32", "int64"): - passed = np.abs(gol - data), max() == 0 + passed = np.abs(gol - data).max() == 0 _add_report(name, gol, data, passed) continue try: - np.testing.assert_allclose(gol, data, rtol=rtol, atol=atol, verbose=False) + tvm.testing.assert_allclose(gol, data, rtol=rtol, atol=atol, verbose=False) _add_report(name, gol, data, True) except: # pylint: disable=bare-except _add_report(name, gol, data, False) diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py index be82e1d0907a..4f7dcc3688ef 100644 --- a/python/tvm/contrib/msc/core/utils/register.py +++ b/python/tvm/contrib/msc/core/utils/register.py @@ -58,7 +58,7 @@ def reset(cls): cls.REGISTERY = {} -def register_func(name: str, func: callable, framework: str = MSCFramework.MSC): +def register_global_func(name: str, func: callable, framework: str = MSCFramework.MSC): """Register a func for framework. Parameters diff --git a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py index 5b85e16a53ba..f7cd2ea43e3e 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.tensorflow._ffi_api""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("msc.framework.tensorflow", __name__) +tvm_ffi.init_ffi_api("msc.framework.tensorflow", __name__) diff --git a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py index f24150efcd6c..b9728b8f63cc 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py @@ -28,7 +28,7 @@ def to_tensorflow( graph: MSCGraph, - weights: Optional[Dict[str, tvm.nd.array]] = None, + weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py index 1accaba8595a..36e4e75491fa 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py @@ -34,7 +34,7 @@ def from_tensorflow( build_config: Optional[Dict[str, str]] = None, opt_config: Optional[Dict[str, str]] = None, as_msc: bool = True, -) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]: +) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.runtime.Tensor]]: """Change tensorflow GraphDef to MSCGraph. Parameters diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py index 2297b3e82523..49e231b7a524 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py @@ -88,7 +88,7 @@ def destory(self): super().destory() def _generate_model( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] ) -> tf_v1.Graph: """Codegen the model according to framework @@ -96,7 +96,7 @@ def _generate_model( ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns @@ -195,7 +195,7 @@ def load_native(cls, model: Any, config: dict) -> Tuple[tf_v1.GraphDef, str, boo "Load native model {} with type {} is not supported".format(model, type(model)) ) device_protos = device_lib.list_local_devices() - if any(dev.device_type == "GPU" for dev in device_protos): + if any(dev.dlpack_device_type() == "GPU" for dev in device_protos): device = "cuda" else: device = "cpu" @@ -301,5 +301,5 @@ def support_device(cls, device: str) -> bool: return True if device.startswith("cuda"): device_protos = device_lib.list_local_devices() - return any(dev.device_type == "GPU" for dev in device_protos) + return any(dev.dlpack_device_type() == "GPU" for dev in device_protos) return False diff --git a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py index 4db71f3a19de..a09ab875fbed 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.tensorrt._ffi_api""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("msc.framework.tensorrt", __name__) +tvm_ffi.init_ffi_api("msc.framework.tensorrt", __name__) diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py index 4643d49c1e83..a3cd7224953c 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py @@ -33,7 +33,7 @@ def to_sub_tensorrt( graph: MSCGraph, - weights: Dict[str, tvm.nd.array], + weights: Dict[str, tvm.runtime.Tensor], codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, @@ -145,7 +145,7 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: def to_tensorrt( mod: tvm.IRModule, graphs: List[MSCGraph], - weights: Dict[str, tvm.nd.array], + weights: Dict[str, tvm.runtime.Tensor], codegen_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, print_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, extra_options: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py index 4a02b02728de..59095aff4563 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py @@ -60,10 +60,10 @@ def transform_for_tensorrt( def partition_for_tensorrt( mod: tvm.IRModule, - params: Optional[Dict[str, tvm.nd.array]] = None, + params: Optional[Dict[str, tvm.runtime.Tensor]] = None, trans_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, -) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.nd.array]]]]: +) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.runtime.Tensor]]]]: """Partition module to tensorrt sub functions. Parameters diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py index 3dd392c7d8ac..43b9d096bd9e 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py @@ -79,14 +79,16 @@ def make_plan(self, tool_type: str, data_loader: Any = None) -> dict: assert quantizer.calibrated, "Failed to calibrate the tenosrrt quantizer" return super().make_plan(tool_type, data_loader) - def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: + def _generate_model( + self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] + ) -> Any: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: dict + weights: dict The weights. Returns diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py index 88cc55a65e1f..259085454f18 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py @@ -67,22 +67,22 @@ def setup(self) -> dict: return super().setup() def _reset( - self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.runtime.Tensor]] + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.runtime.Tensor]]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: list> + weights: list> The weights Returns ------- graphs: list The msc graphs. - weights: list> + weights: list> The weights """ diff --git a/python/tvm/contrib/msc/framework/torch/_ffi_api.py b/python/tvm/contrib/msc/framework/torch/_ffi_api.py index d12fcf2e2f87..d1f27a53bdcf 100644 --- a/python/tvm/contrib/msc/framework/torch/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/torch/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.torch._ffi_api""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("msc.framework.torch", __name__) +tvm_ffi.init_ffi_api("msc.framework.torch", __name__) diff --git a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py index 5ca5de400634..cac575f9e2c7 100644 --- a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py @@ -28,7 +28,7 @@ def to_torch( graph: MSCGraph, - weights: Optional[Dict[str, tvm.nd.array]] = None, + weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py b/python/tvm/contrib/msc/framework/torch/frontend/translate.py index b11051376014..eb6e8b5e56b0 100644 --- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py @@ -66,7 +66,7 @@ def from_torch( build_config: Optional[Dict[str, str]] = None, as_msc: bool = True, custom_convert_map: dict = None, -) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]: +) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.runtime.Tensor]]: """Change torch nn.Module to MSCGraph. Parameters diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py b/python/tvm/contrib/msc/framework/torch/runtime/runner.py index a4d37d08f521..de1356f08d06 100644 --- a/python/tvm/contrib/msc/framework/torch/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py @@ -37,7 +37,7 @@ class TorchRunner(ModelRunner): """Runner of Torch""" - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: """Translate IRModule to MSCgraphs Parameters @@ -49,7 +49,7 @@ def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.n ------- graph_list: list The translated graphs - weights_list: list> + weights_list: list> The translated weights """ graphs, weights = super()._translate(mod) @@ -107,12 +107,12 @@ def _call_runnable( ] return runnable(*torch_inputs) - def _get_runtime_params(self) -> Dict[str, tvm.nd.array]: + def _get_runtime_params(self) -> Dict[str, tvm.runtime.Tensor]: """Get the runtime parameters Returns ------- - params: dict + params: dict The parameters from runtime. """ diff --git a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py index a3683181b0e4..c9f63e21eaef 100644 --- a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.tvm._ffi_api""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("msc.framework.tvm", __name__) +tvm_ffi.init_ffi_api("msc.framework.tvm", __name__) diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py index 3c964464043a..31c2cc619ea8 100644 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py @@ -26,7 +26,7 @@ def to_relax( graph: MSCGraph, - weights: Optional[Dict[str, tvm.nd.array]] = None, + weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py index c6ae512a64e6..a27200d7b6a5 100644 --- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py @@ -49,7 +49,7 @@ def __init__(self, runnable: tvm.relax.VirtualMachine, entry: str = "main"): self._runnable = runnable self._entry = entry - def __call__(self, *inputs) -> List[tvm.nd.array]: + def __call__(self, *inputs) -> List[tvm.runtime.Tensor]: execute_step("before_forward", *inputs) output = self._runnable[self._entry](*inputs) return execute_step("after_forward", output) @@ -250,13 +250,13 @@ def run_native( with tvm.transform.PassContext(opt_level=3): relax_exec = tvm.compile(model, target) runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda()) - tvm_inputs = [tvm.nd.array(inputs[i], device=tvm.cuda()) for i in input_names] + tvm_inputs = [tvm.runtime.tensor(inputs[i], device=tvm.cuda()) for i in input_names] else: target = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): relax_exec = tvm.compile(model, target) runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu()) - tvm_inputs = [tvm.nd.array(inputs[i]) for i in input_names] + tvm_inputs = [tvm.runtime.tensor(inputs[i]) for i in input_names] def _run_once(): return runnable["main"](*tvm_inputs) @@ -271,7 +271,7 @@ def _run_once(): else: outputs = _run_once() avg_time = -1 - if isinstance(outputs, tvm.runtime.NDArray): + if isinstance(outputs, tvm.runtime.Tensor): outputs = [outputs] assert len(output_names) == len(outputs), "Outputs mismatch, {} with {}".format( output_names, len(outputs) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py index d56193d9f7c1..cc9e7e818355 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py @@ -81,9 +81,9 @@ def get_quantize_cache( scale_tensor = scale_tensor.astype(quantizer.find_tensor(name).dtype_name) zero_point = np.zeros_like(scale_tensor).astype("int8") scale_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + "_scale") - scale_tensor = tvm.relax.Constant(tvm.nd.array(scale_tensor), span=scale_span) + scale_tensor = tvm.relax.Constant(tvm.runtime.tensor(scale_tensor), span=scale_span) zp_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + "_zero_point") - zero_point = tvm.relax.Constant(tvm.nd.array(zero_point), span=zp_span) + zero_point = tvm.relax.Constant(tvm.runtime.tensor(zero_point), span=zp_span) quantizer._save_tensor_cache(name, consumer, "scale_tensor", scale_tensor) quantizer._save_tensor_cache(name, consumer, "zero_point", zero_point) return scale_tensor, zero_point diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py index 173dc7c3d9e8..58fbd96c3741 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py @@ -85,8 +85,8 @@ def _execute_after_build( return super()._execute_after_build(output + gather_tensors) def _execute_after_forward( - self, outputs: List[tvm.runtime.NDArray] - ) -> Union[tvm.runtime.NDArray, List[tvm.runtime.NDArray]]: + self, outputs: List[tvm.runtime.Tensor] + ) -> Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]]: """Execute after model forward Parameters diff --git a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py b/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py index 2bb0de02be22..39b8e4034b56 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py @@ -83,8 +83,8 @@ def _execute_after_build( return super()._execute_after_build(output + track_tensors) def _execute_after_forward( - self, outputs: List[tvm.runtime.NDArray] - ) -> Union[tvm.runtime.NDArray, List[tvm.runtime.NDArray]]: + self, outputs: List[tvm.runtime.Tensor] + ) -> Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]]: """Execute after model forward Parameters diff --git a/python/tvm/contrib/msc/plugin/_ffi_api.py b/python/tvm/contrib/msc/plugin/_ffi_api.py index c566d3b0d332..88f9204f3a02 100644 --- a/python/tvm/contrib/msc/plugin/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.plugin._ffi_api""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("msc.plugin", __name__) +tvm_ffi.init_ffi_api("msc.plugin", __name__) diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py b/python/tvm/contrib/msc/plugin/codegen/sources.py index a4e89ad7ecd2..a0923cd3210e 100644 --- a/python/tvm/contrib/msc/plugin/codegen/sources.py +++ b/python/tvm/contrib/msc/plugin/codegen/sources.py @@ -686,14 +686,14 @@ class TVMUtils { }; #define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF(FuncName, Body) \ - TVM_FFI_STATIC_INIT_BLOCK({ \ + TVM_FFI_STATIC_INIT_BLOCK() { \ tvm::ffi::reflection::GlobalDef().def(FuncName, Body); \ - }) + } #define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED(FuncName, Body) \ - TVM_FFI_STATIC_INIT_BLOCK({ \ + TVM_FFI_STATIC_INIT_BLOCK() { \ tvm::ffi::reflection::GlobalDef().def_packed(FuncName, Body); \ - }) + } #endif // PLUGIN_SUPPORT_TVM """ @@ -1162,4 +1162,7 @@ def get_plugin_sources() -> Dict[str, str]: The base utils sources. """ - return {"plugin_base.h": get_plugin_base_h_code(), "plugin_utils.h": get_plugin_utils_h_code()} + return { + "plugin_base.h": get_plugin_base_h_code(), + "plugin_utils.h": get_plugin_utils_h_code(), + } diff --git a/python/tvm/contrib/msc/plugin/op/_ffi_api.py b/python/tvm/contrib/msc/plugin/op/_ffi_api.py index 0d8ad3c5e457..8ca5071cdaf6 100644 --- a/python/tvm/contrib/msc/plugin/op/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/op/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.plugin.op._ffi_api""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("msc.plugin.op", __name__) +tvm_ffi.init_ffi_api("msc.plugin.op", __name__) diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py index c1441c496ae8..f3a23e55db0c 100644 --- a/python/tvm/contrib/ndk.py +++ b/python/tvm/contrib/ndk.py @@ -25,7 +25,7 @@ import tempfile from pathlib import Path -from ..ffi import register_func +from tvm_ffi import register_global_func from ..base import py_str from . import utils as _utils, tar as _tar, cc as _cc from .cc import get_target_by_dump_machine @@ -157,7 +157,7 @@ def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]: return _cc.get_global_symbol_section_map(path, nm=nm) -@register_func("meta_schedule.builder.export_ndk") +@register_global_func("meta_schedule.builder.export_ndk") def _ndk_export(mod): tmp_dir = tempfile.mkdtemp() binary_name = "tmp_binary.so" diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 1b4f51850805..a0aba75b019b 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -17,7 +17,7 @@ """External function interface to NNPACK libraries.""" import tvm from tvm import te -import tvm.ffi +import tvm_ffi def is_available(): @@ -232,4 +232,4 @@ def convolution_inference_weight_transform( ) -tvm.ffi._init_api("tvm.contrib.nnpack") +tvm_ffi.init_ffi_api("tvm.contrib.nnpack") diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index e9d8fac761c0..80aeec9740e6 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -20,16 +20,32 @@ import os import subprocess +import tempfile import warnings from typing import Tuple -import tvm.ffi +import tvm_ffi +import tvm from tvm.target import Target from ..base import py_str from . import utils +def _resolve_artifact_paths(temp, file_name, target_format, kernels_output_dir=None): + if kernels_output_dir is None: + return temp.relpath(f"{file_name}.cu"), temp.relpath(f"{file_name}.{target_format}") + + os.makedirs(kernels_output_dir, exist_ok=True) + source_fd, temp_code = tempfile.mkstemp( + prefix=f"{file_name}_", suffix=".cu", dir=kernels_output_dir + ) + os.close(source_fd) + file_stem, _ = os.path.splitext(os.path.basename(temp_code)) + temp_target = os.path.join(kernels_output_dir, f"{file_stem}.{target_format}") + return temp_code, temp_target + + def compile_cuda(code, target_format=None, arch=None, options=None, path_target=None): """Compile cuda code with NVCC from env. @@ -85,20 +101,16 @@ def compile_cuda(code, target_format=None, arch=None, options=None, path_target= target_format = "ptx" if target_format not in ["cubin", "ptx", "fatbin"]: raise ValueError("target_format must be in cubin, ptx, fatbin") - temp_code = temp.relpath(f"{file_name}.cu") - temp_target = temp.relpath(f"{file_name}.{target_format}") - pass_context = tvm.get_global_func("transform.GetCurrentPassContext")() + pass_context = tvm_ffi.get_global_func("transform.GetCurrentPassContext")() kernels_output_dir = ( pass_context.config["cuda.kernels_output_dir"] if "cuda.kernels_output_dir" in pass_context.config else None ) - if kernels_output_dir is not None: - if not os.path.isdir(kernels_output_dir): - os.makedirs(kernels_output_dir) - temp_code = os.path.join(kernels_output_dir, f"{file_name}.cu") - temp_target = os.path.join(kernels_output_dir, f"{file_name}.{target_format}") + temp_code, temp_target = _resolve_artifact_paths( + temp, file_name, target_format, kernels_output_dir=kernels_output_dir + ) with open(temp_code, "w") as out_file: out_file.write(code) @@ -250,9 +262,10 @@ def get_cuda_version(cuda_path=None): def find_nvshmem_paths() -> Tuple[str, str]: """ Searches for the NVSHMEM include and library directories. - Returns: - A tuple containing the path to the include directory and the library directory. - (include_path, lib_path) + + Returns + ------- + A tuple containing the path to the include directory and the library directory. """ candidate_roots = [] @@ -311,14 +324,14 @@ def find_nvshmem_paths() -> Tuple[str, str]: raise RuntimeError("\n".join(error_message)) -@tvm.ffi.register_func +@tvm_ffi.register_global_func def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument """use nvcc to generate fatbin code for better optimization""" ptx = compile_cuda(code, target_format="fatbin") return ptx -@tvm.ffi.register_func("tvm_callback_libdevice_path") +@tvm_ffi.register_global_func("tvm_callback_libdevice_path") def find_libdevice_path(arch): """Utility function to find libdevice @@ -383,7 +396,7 @@ def callback_libdevice_path(arch): return "" -@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version") def get_target_compute_version(target=None): """Utility function to get compute capability of compilation target. @@ -528,7 +541,7 @@ def have_cudagraph(): return False -@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16") def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not @@ -544,7 +557,7 @@ def have_bf16(compute_version): return False -@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8") def have_fp8(compute_version): """Whether fp8 support is provided in the specified compute capability or not @@ -562,7 +575,7 @@ def have_fp8(compute_version): return False -@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp4") +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp4") def have_fp4(compute_version): """Whether fp4 support is provided in the specified compute capability or not diff --git a/python/tvm/contrib/random.py b/python/tvm/contrib/random.py index 6a17693b9162..681978ff7132 100644 --- a/python/tvm/contrib/random.py +++ b/python/tvm/contrib/random.py @@ -17,7 +17,7 @@ """External function interface to random library.""" import tvm from tvm import te -import tvm.ffi +import tvm_ffi def randint(low, high, size, dtype="int32"): @@ -112,4 +112,4 @@ def normal(loc, scale, size): ) -tvm.ffi._init_api("tvm.contrib.random") +tvm_ffi.init_ffi_api("tvm.contrib.random") diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index 6e6a985c2732..38e74b660c51 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -20,7 +20,7 @@ import os from os.path import join, exists -import tvm.ffi +import tvm_ffi from tvm.base import py_str import tvm.runtime import tvm.target @@ -99,7 +99,7 @@ def rocm_link(in_file, out_file, lld=None): raise RuntimeError(msg) -@tvm.ffi.register_func("tvm_callback_rocm_link") +@tvm_ffi.register_global_func("tvm_callback_rocm_link") def callback_rocm_link(obj_bin): """Links object file generated from LLVM to HSA Code Object @@ -123,7 +123,7 @@ def callback_rocm_link(obj_bin): return cobj_bin -@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path") +@tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path") def callback_rocm_bitcode_path(rocdl_dir=None): """Utility function to find ROCm device library bitcodes @@ -227,7 +227,7 @@ def have_matrixcore(compute_version=None): return False -@tvm.ffi.register_func("tvm_callback_rocm_get_arch") +@tvm_ffi.register_global_func("tvm_callback_rocm_get_arch") def get_rocm_arch(rocm_path=None): """Utility function to get the AMD GPU architecture diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index aceeefd248f4..f3f5bf4c21fa 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """TFLite runtime that load and run tflite models.""" -import tvm.ffi +import tvm_ffi from ..rpc import base as rpc_base @@ -35,7 +35,7 @@ def create(tflite_model_bytes, device, runtime_target="cpu"): tflite_runtime : TFLiteModule Runtime tflite module that can be used to execute the tflite model. """ - device_type = device.device_type + device_type = device.dlpack_device_type() if runtime_target == "edge_tpu": runtime_func = "tvm.edgetpu_runtime.create" @@ -45,7 +45,7 @@ def create(tflite_model_bytes, device, runtime_target="cpu"): if device_type >= rpc_base.RPC_SESS_MASK: fcreate = device._rpc_sess.get_function(runtime_func) else: - fcreate = tvm.ffi.get_global_func(runtime_func) + fcreate = tvm_ffi.get_global_func(runtime_func) return TFLiteModule(fcreate(bytearray(tflite_model_bytes), device)) @@ -86,7 +86,7 @@ def set_input(self, index, value): value : the input value. The input key - params : dict of str to NDArray + params : dict of str to Tensor Additonal arguments """ self._set_input(index, value) @@ -96,7 +96,7 @@ def invoke(self): Parameters ---------- - input_dict: dict of str to NDArray + input_dict: dict of str to Tensor List of input values to be feed to """ self._invoke() diff --git a/python/tvm/contrib/thrust.py b/python/tvm/contrib/thrust.py index 9a05cfafbac3..8cf7c59fadfe 100644 --- a/python/tvm/contrib/thrust.py +++ b/python/tvm/contrib/thrust.py @@ -17,7 +17,7 @@ """Utilities for thrust""" import logging -from tvm.ffi import get_global_func +from tvm_ffi import get_global_func def maybe_warn(target, func_name): diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index e24b88a3f8c3..a40c0cfbb07e 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -71,7 +71,7 @@ def _calculate_md5(filename): return hash_md5.hexdigest() -class NDArrayCacheShardingManager: +class TensorCacheShardingManager: """Internal helper to shard ndarrays.""" def __init__( @@ -198,10 +198,10 @@ def pending_nbytes(self): return len(self.curr_data) -def dump_ndarray_cache( +def dump_tensor_cache( params: Union[ - Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], - Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]], + Mapping[str, Union[np.ndarray, tvm.runtime.Tensor]], + Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.Tensor]]], ], cache_dir: str, encode_format="f32-to-bf16", @@ -210,13 +210,13 @@ def dump_ndarray_cache( show_progress: bool = True, update_if_exists: bool = False, ): - """Dump parameters to NDArray cache. + """Dump parameters to Tensor cache. Parameters ---------- params: Union[ - Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], - Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]], + Mapping[str, Union[np.ndarray, tvm.runtime.Tensor]], + Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.Tensor]]], ] The parameter dictionary or generator @@ -257,7 +257,7 @@ def dump_ndarray_cache( print("Start storing to cache %s" % cache_dir) shard_cap_nbytes = shard_cap_mb * (1 << 20) - nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json") + nd_cache_json = os.path.join(cache_dir, "tensor-cache.json") if update_if_exists and os.path.exists(nd_cache_json): with open(nd_cache_json, "r") as infile: old_data = json.load(infile) @@ -265,7 +265,7 @@ def dump_ndarray_cache( meta_data = old_data["metadata"] records = old_data["records"] - shard_manager = NDArrayCacheShardingManager( + shard_manager = TensorCacheShardingManager( cache_dir, "params_shard", shard_cap_nbytes, initial_shard_records=records ) @@ -277,10 +277,10 @@ def dump_ndarray_cache( v = v.numpy() # prefer to preserve original dtype, especially if the format was bfloat16 - dtype = origin_v.dtype if isinstance(origin_v, tvm.nd.NDArray) else v.dtype + dtype = origin_v.dtype if isinstance(origin_v, tvm.runtime.Tensor) else v.dtype - if dtype in DataType.NUMPY_DTYPE_TO_STR: - dtype = DataType.NUMPY_DTYPE_TO_STR[dtype] + if dtype in DataType._NUMPY_DTYPE_TO_STR: + dtype = DataType._NUMPY_DTYPE_TO_STR[dtype] else: dtype = str(dtype) @@ -325,15 +325,15 @@ def dump_ndarray_cache( if item["dtype"] == "float32": item["format"] = "raw" item["dtype"] = "bfloat16" - b16_nd_cache_json = os.path.join(cache_dir, "ndarray-cache-b16.json") + b16_nd_cache_json = os.path.join(cache_dir, "tensor-cache-b16.json") # also dump a file that contains bf16 with open(b16_nd_cache_json, "w") as outfile: json.dump({"metadata": meta_data, "records": records}, outfile, indent=4) print("Also saved a bf16 record to %s" % b16_nd_cache_json) -def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): - """Load the ndarray cache from the directory or json. +def load_tensor_cache(cachepath: str, device: tvm.runtime.Device): + """Load the tensor cache from the directory or json. Parameters @@ -345,7 +345,7 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): The device we would like to load the data from. """ if not cachepath.endswith(".json"): - cachepath = os.path.join(cachepath, "ndarray-cache.json") + cachepath = os.path.join(cachepath, "tensor-cache.json") cachedir = os.path.dirname(cachepath) json_info = json.loads(open(cachepath, "r").read()) @@ -366,7 +366,7 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): offset = rec["byteOffset"] nbytes = rec["nbytes"] - arr = tvm.nd.empty(shape, dtype, device=device) + arr = tvm.runtime.empty(shape, dtype, device=device) assert offset + nbytes <= len(raw_data) buffer_source = raw_data[offset : offset + nbytes] if dtype == "float8_e4m3fn": diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index 2c2baa849b40..b70626345b2b 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -47,6 +47,7 @@ class TempDirectory(object): # In debug mode, each tempdir is named after the sequence _NUM_TEMPDIR_CREATED = 0 _NUM_TEMPDIR_CREATED_LOCK = threading.Lock() + _DEBUG_PARENT_DIR_LOCK = threading.Lock() @classmethod def _increment_num_tempdir_created(cls): @@ -61,12 +62,14 @@ def _increment_num_tempdir_created(cls): @classmethod def _get_debug_parent_dir(cls): if cls._DEBUG_PARENT_DIR is None: - all_parents = f"{tempfile.gettempdir()}/tvm-debug-mode-tempdirs" - if not os.path.isdir(all_parents): - os.makedirs(all_parents) - cls._DEBUG_PARENT_DIR = tempfile.mkdtemp( - prefix=datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S___"), dir=all_parents - ) + with cls._DEBUG_PARENT_DIR_LOCK: + if cls._DEBUG_PARENT_DIR is None: + all_parents = f"{tempfile.gettempdir()}/tvm-debug-mode-tempdirs" + os.makedirs(all_parents, exist_ok=True) + cls._DEBUG_PARENT_DIR = tempfile.mkdtemp( + prefix=datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S___"), + dir=all_parents, + ) return cls._DEBUG_PARENT_DIR TEMPDIRS = set() @@ -94,6 +97,8 @@ def set_keep_for_debug(cls, set_to=True): cls._KEEP_FOR_DEBUG = old_keep_for_debug def __init__(self, custom_path=None, keep_for_debug=None): + self.temp_dir = None + self._created_with_keep_for_debug = False if self.TEMPDIRS is None: raise DirectoryCreatedPastAtExit() @@ -118,10 +123,13 @@ def __init__(self, custom_path=None, keep_for_debug=None): def remove(self): """Remove the tmp dir""" - if self.temp_dir: - if not self._created_with_keep_for_debug: - shutil.rmtree(self.temp_dir, ignore_errors=True) - self.TEMPDIRS.remove(self.temp_dir) + temp_dir = getattr(self, "temp_dir", None) + if temp_dir: + if not getattr(self, "_created_with_keep_for_debug", False): + shutil.rmtree(temp_dir, ignore_errors=True) + temp_dirs = getattr(self, "TEMPDIRS", None) + if temp_dirs is not None: + temp_dirs.discard(temp_dir) self.temp_dir = None @property diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index bd70acf00f90..3d42d1972dcc 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -16,6 +16,7 @@ # under the License. """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu +from . import adreno from . import cpu from .analysis import ( BlockInfo, diff --git a/python/tvm/ffi/_ffi_api.py b/python/tvm/dlight/adreno/__init__.py similarity index 91% rename from python/tvm/ffi/_ffi_api.py rename to python/tvm/dlight/adreno/__init__.py index 60bd2463e9ac..ea2781455989 100644 --- a/python/tvm/ffi/_ffi_api.py +++ b/python/tvm/dlight/adreno/__init__.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI API.""" -from .registry import _init_api - - -_init_api("ffi", __name__) +""" +Adreno schedule rules. +""" +from .convolution import Conv2d diff --git a/docker/Dockerfile.demo_gpu b/python/tvm/dlight/adreno/base.py similarity index 51% rename from docker/Dockerfile.demo_gpu rename to python/tvm/dlight/adreno/base.py index 4ef6b0c29cbc..d043706c2fc5 100644 --- a/docker/Dockerfile.demo_gpu +++ b/python/tvm/dlight/adreno/base.py @@ -14,23 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Base schedule rule for Adreno operators.""" -# Minimum docker image for demo purposes -# CI docker GPU env -# tag: v0.54 -FROM tlcpack/ci-gpu:v0.55 +from tvm.target import Target -COPY utils/apt-install-and-clear.sh /usr/local/bin/apt-install-and-clear +from ..base import ScheduleRule -# Jupyter notebook. -RUN pip3 install matplotlib Image "Pillow<7" jupyter[notebook] -# Build TVM -COPY install/install_tvm_gpu.sh /install/install_tvm_gpu.sh -RUN bash /install/install_tvm_gpu.sh +class AdrenoScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to Adreno targets, + will return None if the target is not Adreno.""" -# Environment variables -ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/vta/python:${PYTHONPATH} -ENV PATH=/usr/local/nvidia/bin:${PATH} -ENV PATH=/usr/local/cuda/bin:${PATH} -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for Adreno rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "adreno" in target.keys diff --git a/python/tvm/dlight/adreno/convolution.py b/python/tvm/dlight/adreno/convolution.py new file mode 100644 index 000000000000..fc2cc449a1c6 --- /dev/null +++ b/python/tvm/dlight/adreno/convolution.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Conv2d schedule rule for Adreno GPU operators.""" +from dataclasses import dataclass +from typing import List, Optional + +from tvm import tir +from tvm.target import Target +from tvm.tir import IterVar +from tvm.tir.schedule.schedule import BlockRV + +from ..analysis import BlockInfo, IterInfo +from .base import AdrenoScheduleRule + + +def is_spatial_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + +def is_reduction_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo: + def _iter_kind(loop: tir.IterVar) -> str: + return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") + + def _is_reduction_block(block: tir.schedule.BlockRV): + for iter_var in sch.get(block).iter_vars: + if _iter_kind(iter_var) == "R": + return True + return False + + return BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter_var), + var=iter_var.var, + dom=iter_var.dom.extent, + loop_rv=loop_rv, + ) + for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars) + ], + block_rv=block, + reduction_block=_is_reduction_block(block), + ) + + +def get_reduction_blocks(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]) -> bool: + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all( + [is_reduction_block(sch, block) or is_spatial_block(sch, block) for block in blocks] + ): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction_block(sch, block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks[0] + + +def is_convolution(sch: tir.Schedule, block: tir.schedule.BlockRV): + # TODO: Use buffer access patterns to discover convolution type kernels instead of using name. + return ( + sch.get(block).name_hint.count("conv2d_NCHWc_OIHWo") + and "".join([iter_type.kind for iter_type in get_block_info(sch, block).iters]) + == "SSSSSRRR" + ) + + +class Conv2d(AdrenoScheduleRule): + """The schedule rule for convolution computation""" + + @dataclass + class Config: + block_size_x: int = 8 + block_size_y: int = 8 + vector_size: int = 1 + unroll: int = 256 # 0 means no unroll + use_shared: bool = True + storage_align: bool = False + inner_x: bool = False + + def get_configs(self, target: Target) -> Config: + """Get the schedule config for the target""" + if target.kind.name == "cuda" or target.kind.name == "rocm": + return Conv2d.Config( + block_size_x=8, + block_size_y=16, + vector_size=2, + unroll=256, + use_shared=True, + storage_align=True, + inner_x=False, + ) + elif target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + return Conv2d.Config( + block_size_x=32, + block_size_y=4, + vector_size=8, + unroll=16, + use_shared=False, + storage_align=False, + inner_x=True, + ) + else: + return Conv2d.Config() + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + + if isinstance(func, tir.PrimFunc): + sch = tir.Schedule(func) + + # config = self.get_configs(target) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_block = get_reduction_blocks(sch, blocks) + + if reduction_block is None: + return None + if not is_convolution(sch, reduction_block): + return None + + def schedule_data_pad(blk): + axes = sch.get_loops(blk) + axes, vec = axes[:-1], axes[-1] + axis = sch.fuse(*axes) + bx, ty, tx = sch.split(axis, [None, 16, 16]) + sch.bind(bx, "blockIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def schedule_conv2d(blk): + # TODO: Loop Pattern mayn't be reliable, need to perform better analysis. + n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk) + sch.reorder(n, oc, oh, ow, ic, kh, kw, ob) + main_lp = sch.fuse(n, oc, oh, ow) + bx, ty, tx = sch.split(main_lp, [None, 16, 16]) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(bx, "blockIdx.x") + + ico, icv = sch.split(ic, [None, 4]) + sch.reorder(ico, kh, kw, icv, ob) + rblk = sch.cache_read(blk, 0, "local") + sch.compute_at(rblk, kw) + sch.vectorize(sch.get_loops(rblk)[-1]) + wblk = sch.cache_write(blk, 0, "local") + sch.reverse_compute_at(wblk, tx) + sch.vectorize(sch.get_loops(wblk)[-1]) + sch.vectorize(ob) + init_blk = sch.decompose_reduction(blk, ico) + sch.vectorize(sch.get_loops(init_blk)[-1]) + + def is_data_pad(block: tir.stmt.Block): + return is_spatial_block(sch, block) and tir.analysis.has_if_then_else(sch.get(block)) + + def schedule_conv2d_blocks(): + + # Do analysis to find block type + blocks = sch.get_child_blocks(root_block) + passed_reduction = False + for blk in blocks: + if is_reduction_block(sch, blk): + schedule_conv2d(blk) + passed_reduction = True + elif is_data_pad(blk): + schedule_data_pad(blk) + elif is_spatial_block(sch, blk): + try: + if not passed_reduction: + sch.compute_inline(blk) + else: + sch.reverse_compute_inline(blk) + except: # pylint: disable=W0702 + pass + else: + raise TypeError("Can't Schedule this Block", sch.get(blk)) + + schedule_conv2d_blocks() + return sch diff --git a/python/tvm/dlight/analysis/common_analysis.py b/python/tvm/dlight/analysis/common_analysis.py index a3499274e5a8..e3357c6e78db 100644 --- a/python/tvm/dlight/analysis/common_analysis.py +++ b/python/tvm/dlight/analysis/common_analysis.py @@ -18,9 +18,9 @@ from typing import List, Optional, Set, Union from typing_extensions import Literal +from tvm_ffi import get_global_func from tvm import ir, tir -from tvm.ffi import get_global_func from tvm.target.target import Target from tvm.tir import Schedule from tvm.tir.schedule import BlockRV diff --git a/python/tvm/dlight/benchmark/bench.py b/python/tvm/dlight/benchmark/bench.py index 7ab50d412575..b600e7efb783 100644 --- a/python/tvm/dlight/benchmark/bench.py +++ b/python/tvm/dlight/benchmark/bench.py @@ -106,7 +106,7 @@ def benchmark( input_infos = populuate_input_shape(args, dym_var_sample) # generate input tensors, including scalars # scalars are appended to the end of the list due to parsing order - input_tensors: List[Union[tvm.nd.NDArray, int]] = [] + input_tensors: List[Union[tvm.runtime.Tensor, int]] = [] scalar_input_tensors: List[int] = [] for input_shape, input_dtype in input_infos: if input_dtype == "scalar": @@ -116,7 +116,7 @@ def benchmark( else: # normal case like [1, n, 128], generate random tensor input_tensors.append( - tvm.nd.array(generate_input_data(list(input_shape), input_dtype), device=dev) + tvm.runtime.tensor(generate_input_data(list(input_shape), input_dtype), device=dev) ) # append scalar input tensors for rotary embedding input_tensors.extend(scalar_input_tensors) @@ -143,8 +143,8 @@ def benchmark( _, profile_result = rpc_run( rt_mod, - device_type=dev.DEVICE_TYPE_TO_NAME[dev.device_type], - args=[w.numpy() if isinstance(w, tvm.nd.NDArray) else w for w in input_tensors], + device_type=dev._DEVICE_TYPE_TO_NAME[dev.dlpack_device_type()], + args=[w.numpy() if isinstance(w, tvm.runtime.Tensor) else w for w in input_tensors], rpc_config=rpc_config, evaluator_config=evaluator_config, ) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index 1ceecc9c94c6..e56426fd5182 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.driver""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("driver", __name__) +tvm_ffi.init_ffi_api("driver", __name__) diff --git a/python/tvm/error.py b/python/tvm/error.py index 671f3292388b..edabbb3a45fc 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -25,7 +25,7 @@ Please also refer to :ref:`error-handling-guide`. """ -from tvm.ffi import register_error +from tvm_ffi import register_error class TVMError(RuntimeError): diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py index 6d1d4b7f339b..7fe94c6cb0df 100644 --- a/python/tvm/exec/disco_worker.py +++ b/python/tvm/exec/disco_worker.py @@ -22,45 +22,45 @@ from typing import Callable import tvm -from tvm.ffi import get_global_func, register_func -from tvm.runtime import NDArray, ShapeTuple, String -from tvm.runtime.ndarray import array +from tvm_ffi import get_global_func, register_global_func +from tvm.runtime import Tensor, ShapeTuple, String +from tvm.runtime import tensor -@register_func("tests.disco.add_one", override=True) +@register_global_func("tests.disco.add_one", override=True) def _add_one(x: int) -> int: return x + 1 -@register_func("tests.disco.add_one_float", override=True) +@register_global_func("tests.disco.add_one_float", override=True) def _add_one_float(x: float): return x + 0.5 -@register_func("tests.disco.add_one_ndarray", override=True) -def _add_one_ndarray(x: NDArray) -> NDArray: - return array(x.numpy() + 1) +@register_global_func("tests.disco.add_one_tensor", override=True) +def _add_one_tensor(x: Tensor) -> Tensor: + return tensor(x.numpy() + 1) -@register_func("tests.disco.str", override=True) +@register_global_func("tests.disco.str", override=True) def _str_func(x: str): return x + "_suffix" -@register_func("tests.disco.str_obj", override=True) +@register_global_func("tests.disco.str_obj", override=True) def _str_obj_func(x: str): assert isinstance(x, str) return String(x + "_suffix") -@register_func("tests.disco.shape_tuple", override=True) +@register_global_func("tests.disco.shape_tuple", override=True) def _shape_tuple_func(x: ShapeTuple): assert isinstance(x, ShapeTuple) return ShapeTuple(list(x) + [4, 5]) -@register_func("tests.disco.test_callback", override=True) -def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], NDArray]: +@register_global_func("tests.disco.test_callback", override=True) +def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], Tensor]: """For use in tests/python/disco/test_callback.py This function simulates a callback to be used for lazy parameter @@ -75,7 +75,7 @@ def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], NDArray]: Returns ------- - fget_item: Callable[[str,int], NDArray] + fget_item: Callable[[str,int], Tensor] A callback function that accepts a parameter's name and index, and returns the specified parameter. @@ -83,7 +83,7 @@ def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], NDArray]: """ import numpy as np # pylint: disable=import-outside-toplevel - def fget_item(param_name: str, param_index: int) -> NDArray: + def fget_item(param_name: str, param_index: int) -> Tensor: if param_index == 0: assert param_name == "A" arr = np.arange(16).reshape([4, 4]).astype("int32") @@ -92,7 +92,7 @@ def fget_item(param_name: str, param_index: int) -> NDArray: arr = np.arange(4).reshape([2, 2]).astype("float32") else: raise ValueError(f"Unexpected index {param_index}") - return tvm.nd.array(arr, device=device) + return tvm.runtime.tensor(arr, device=device) return fget_item diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index fd3ec55ba655..f8b4507f8e2f 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -40,7 +40,7 @@ def find_example_resource(): # recursively apend things in www, up to two levels resource_bases = [ os.path.join(base_path, "web", "dist", "www"), - os.path.join(base_path, "web", ".ndarray_cache"), + os.path.join(base_path, "web", ".tensor_cache"), ] for base in resource_bases: if not os.path.isdir(base): diff --git a/ffi/cmake/Utils/CxxWarning.cmake b/python/tvm/ffi.py similarity index 63% rename from ffi/cmake/Utils/CxxWarning.cmake rename to python/tvm/ffi.py index c272bfdf7bf2..88fa903a924c 100644 --- a/ffi/cmake/Utils/CxxWarning.cmake +++ b/python/tvm/ffi.py @@ -14,17 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -function(add_cxx_warning target_name) - # GNU, Clang, or AppleClang - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") - target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic" "-Wno-unused-parameter") - return() - endif() - # MSVC - if(MSVC) - # target_compile_options(${target_name} PRIVATE "/W4" "/WX") - return() - endif() - message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") -endfunction() +# pylint: disable=wildcard-import +"""Redirects to tvm_ffi""" +from tvm_ffi import * diff --git a/python/tvm/ffi/.gitignore b/python/tvm/ffi/.gitignore deleted file mode 100644 index eeb15feab328..000000000000 --- a/python/tvm/ffi/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -core.cpp -core.cpython* diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py deleted file mode 100644 index 801a8d298906..000000000000 --- a/python/tvm/ffi/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM FFI binding module. - -This module binds the TVM FFI C API to python. -This is a standalone module that can be -""" - -from .registry import register_object, register_func, get_global_func, _init_api -from .dtype import dtype, DataTypeCode -from .core import String, Bytes -from .core import Object, ObjectGeneric, Function -from .convert import convert -from .error import register_error -from .ndarray import Device, device -from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu -from .ndarray import from_dlpack, NDArray, Shape -from .container import Array, Map -from .module import Module, ModulePropertyMask, system_lib, load_module -from . import serialization -from . import access_path -from . import testing - - -__all__ = [ - "dtype", - "DataTypeCode", - "Device", - "Object", - "register_object", - "register_func", - "get_global_func", - "_init_api", - "Object", - "ObjectGeneric", - "Function", - "convert", - "String", - "Bytes", - "register_error", - "Device", - "device", - "cpu", - "cuda", - "rocm", - "opencl", - "metal", - "vpi", - "vulkan", - "ext_dev", - "hexagon", - "webgpu", - "from_dlpack", - "NDArray", - "Shape", - "Array", - "Map", - "testing", - "access_path", - "serialization", - "Module", - "ModulePropertyMask", - "system_lib", - "load_module", -] diff --git a/python/tvm/ffi/access_path.py b/python/tvm/ffi/access_path.py deleted file mode 100644 index fb8ab1b2edea..000000000000 --- a/python/tvm/ffi/access_path.py +++ /dev/null @@ -1,181 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Access path classes.""" - -from enum import IntEnum -from typing import List, Any -from . import core -from .registry import register_object - - -class AccessKind(IntEnum): - ATTR = 0 - ARRAY_ITEM = 1 - MAP_ITEM = 2 - ATTR_MISSING = 3 - ARRAY_ITEM_MISSING = 4 - MAP_ITEM_MISSING = 5 - - -@register_object("ffi.reflection.AccessStep") -class AccessStep(core.Object): - """Access step container""" - - -@register_object("ffi.reflection.AccessPath") -class AccessPath(core.Object): - """Access path container""" - - def __init__(self) -> None: - super().__init__() - raise ValueError( - "AccessPath can't be initialized directly. " - "Use AccessPath.root() to create a path to the root object" - ) - - @staticmethod - def root() -> "AccessPath": - """Create a root access path""" - return AccessPath._root() - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, AccessPath): - return False - return self._path_equal(other) - - def __ne__(self, other: Any) -> bool: - if not isinstance(other, AccessPath): - return True - return not self._path_equal(other) - - def is_prefix_of(self, other: "AccessPath") -> bool: - """Check if this access path is a prefix of another access path - - Parameters - ---------- - other : AccessPath - The access path to check if it is a prefix of this access path - - Returns - ------- - bool - True if this access path is a prefix of the other access path, False otherwise - """ - return self._is_prefix_of(other) - - def attr(self, attr_key: str) -> "AccessPath": - """Create an access path to the attribute of the current object - - Parameters - ---------- - attr_key : str - The key of the attribute to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._attr(attr_key) - - def attr_missing(self, attr_key: str) -> "AccessPath": - """Create an access path that indicate an attribute is missing - - Parameters - ---------- - attr_key : str - The key of the attribute to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._attr_missing(attr_key) - - def array_item(self, index: int) -> "AccessPath": - """Create an access path to the item of the current array - - Parameters - ---------- - index : int - The index of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._array_item(index) - - def array_item_missing(self, index: int) -> "AccessPath": - """Create an access path that indicate an array item is missing - - Parameters - ---------- - index : int - The index of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._array_item_missing(index) - - def map_item(self, key: Any) -> "AccessPath": - """Create an access path to the item of the current map - - Parameters - ---------- - key : Any - The key of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._map_item(key) - - def map_item_missing(self, key: Any) -> "AccessPath": - """Create an access path that indicate a map item is missing - - Parameters - ---------- - key : Any - The key of the item to access - - Returns - ------- - AccessPath - The extended access path - """ - return self._map_item_missing(key) - - def to_steps(self) -> List["AccessStep"]: - """Convert the access path to a list of access steps - - Returns - ------- - List[AccessStep] - The list of access steps - """ - return self._to_steps() - - __hash__ = core.Object.__hash__ diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py deleted file mode 100644 index 157840ba9d46..000000000000 --- a/python/tvm/ffi/container.py +++ /dev/null @@ -1,206 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Container classes.""" -import collections.abc - -from typing import Any, Mapping, Sequence -from . import core -from . import _ffi_api -from .registry import register_object - -__all__ = ["Array", "Map"] - - -def getitem_helper(obj, elem_getter, length, idx): - """Helper function to implement a pythonic getitem function. - - Parameters - ---------- - obj: object - The original object - - elem_getter : function - A simple function that takes index and return a single element. - - length : int - The size of the array - - idx : int or slice - The argument passed to getitem - - Returns - ------- - result : object - The result of getitem - """ - if isinstance(idx, slice): - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else length - step = idx.step if idx.step is not None else 1 - if start < 0: - start += length - if stop < 0: - stop += length - return [elem_getter(obj, i) for i in range(start, stop, step)] - - if idx < -length or idx >= length: - raise IndexError(f"Index out of range. size: {length}, got index {idx}") - if idx < 0: - idx += length - return elem_getter(obj, idx) - - -@register_object("ffi.Array") -class Array(core.Object, collections.abc.Sequence): - """Array container""" - - def __init__(self, input_list: Sequence[Any]): - self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) - - def __getitem__(self, idx): - return getitem_helper(self, _ffi_api.ArrayGetItem, len(self), idx) - - def __len__(self): - return _ffi_api.ArraySize(self) - - def __repr__(self): - # exception safety handling for chandle=None - if self.__chandle__() == 0: - return type(self).__name__ + "(chandle=None)" - return "[" + ", ".join([x.__repr__() for x in self]) + "]" - - -class KeysView(collections.abc.KeysView): - """Helper class to return keys view""" - - def __init__(self, backend_map): - self._backend_map = backend_map - - def __len__(self): - return len(self._backend_map) - - def __iter__(self): - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self._backend_map) - while True: - k = functor(0) - yield k - if not functor(2): - break - - def __contains__(self, k): - return self._backend_map.__contains__(k) - - -class ValuesView(collections.abc.ValuesView): - """Helper class to return values view""" - - def __init__(self, backend_map): - self._backend_map = backend_map - - def __len__(self): - return len(self._backend_map) - - def __iter__(self): - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self._backend_map) - while True: - v = functor(1) - yield v - if not functor(2): - break - - -class ItemsView(collections.abc.ItemsView): - """Helper class to return items view""" - - def __init__(self, backend_map): - self.backend_map = backend_map - - def __len__(self): - return len(self.backend_map) - - def __iter__(self): - if self.__len__() == 0: - return - functor = _ffi_api.MapForwardIterFunctor(self.backend_map) - while True: - k = functor(0) - v = functor(1) - yield (k, v) - if not functor(2): - break - - -@register_object("ffi.Map") -class Map(core.Object, collections.abc.Mapping): - """Map container.""" - - def __init__(self, input_dict: Mapping[Any, Any]): - list_kvs = [] - for k, v in input_dict.items(): - list_kvs.append(k) - list_kvs.append(v) - self.__init_handle_by_constructor__(_ffi_api.Map, *list_kvs) - - def __getitem__(self, k): - return _ffi_api.MapGetItem(self, k) - - def __contains__(self, k): - return _ffi_api.MapCount(self, k) != 0 - - def keys(self): - return KeysView(self) - - def values(self): - return ValuesView(self) - - def items(self): - """Get the items from the map""" - return ItemsView(self) - - def __len__(self): - return _ffi_api.MapSize(self) - - def __iter__(self): - return iter(self.keys()) - - def get(self, key, default=None): - """Get an element with a default value. - - Parameters - ---------- - key : object - The attribute key. - - default : object - The default object. - - Returns - ------- - value: object - The result value. - """ - return self[key] if key in self else default - - def __repr__(self): - # exception safety handling for chandle=None - if self.__chandle__() == 0: - return type(self).__name__ + "(chandle=None)" - return "{" + ", ".join([f"{k.__repr__()}: {v.__repr__()}" for k, v in self.items()]) + "}" diff --git a/python/tvm/ffi/convert.py b/python/tvm/ffi/convert.py deleted file mode 100644 index 5b25ddae259b..000000000000 --- a/python/tvm/ffi/convert.py +++ /dev/null @@ -1,68 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Conversion utilities to bring python objects into ffi values.""" -from numbers import Number -from typing import Any -from . import core -from . import container - - -def convert(value: Any) -> Any: - """Convert a python object to ffi values. - - Parameters - ---------- - value : Any - The python object to be converted. - - Returns - ------- - ffi_obj : Any - The converted TVM FFI object. - """ - if isinstance(value, core.Object): - return value - elif isinstance(value, core.PyNativeObject): - return value - elif isinstance(value, (bool, Number)): - return value - elif isinstance(value, (list, tuple)): - return container.Array(value) - elif isinstance(value, dict): - return container.Map(value) - elif isinstance(value, str): - return core.String(value) - elif isinstance(value, (bytes, bytearray)): - return core.Bytes(value) - elif isinstance(value, core.ObjectGeneric): - return value.asobject() - elif callable(value): - return core._convert_to_ffi_func(value) - elif value is None: - return None - elif hasattr(value, "__dlpack__"): - return core.from_dlpack( - value, - required_alignment=core.__dlpack_auto_import_required_alignment__, - ) - elif isinstance(value, Exception): - return core._convert_to_ffi_error(value) - else: - raise TypeError(f"don't know how to convert type {type(value)} to object") - - -core._set_func_convert_to_object(convert) diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi deleted file mode 100644 index e61eaf322db2..000000000000 --- a/python/tvm/ffi/cython/base.pxi +++ /dev/null @@ -1,287 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import ctypes -from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, int16_t -from libc.string cimport memcpy -from libcpp.vector cimport vector -from cpython.bytes cimport PyBytes_AsStringAndSize, PyBytes_FromStringAndSize, PyBytes_AsString -from cpython cimport Py_INCREF, Py_DECREF -from cpython cimport PyErr_CheckSignals, PyGILState_Ensure, PyGILState_Release, PyObject -from cpython cimport pycapsule, PyCapsule_Destructor -from cpython cimport PyErr_SetNone - - -# Cython binding for TVM FFI C API -cdef extern from "tvm/ffi/c_api.h": - cdef enum TVMFFITypeIndex: - kTVMFFIAny = -1 - kTVMFFINone = 0 - kTVMFFIInt = 1 - kTVMFFIBool = 2 - kTVMFFIFloat = 3 - kTVMFFIOpaquePtr = 4 - kTVMFFIDataType = 5 - kTVMFFIDevice = 6 - kTVMFFIDLTensorPtr = 7 - kTVMFFIRawStr = 8 - kTVMFFIByteArrayPtr = 9 - kTVMFFIObjectRValueRef = 10 - kTVMFFISmallStr = 11 - kTVMFFISmallBytes = 12 - kTVMFFIStaticObjectBegin = 64 - kTVMFFIObject = 64 - kTVMFFIStr = 65 - kTVMFFIBytes = 66 - kTVMFFIError = 67 - kTVMFFIFunction = 68 - kTVMFFIArray = 69 - kTVMFFIMap = 70 - kTVMFFIShape = 71 - kTVMFFINDArray = 72 - kTVMFFIModule = 73 - - ctypedef void* TVMFFIObjectHandle - - ctypedef struct DLDataType: - uint8_t code - uint8_t bits - int16_t lanes - - ctypedef struct DLDevice: - int device_type - int device_id - - ctypedef struct DLTensor: - void* data - DLDevice device - int ndim - DLDataType dtype - int64_t* shape - int64_t* strides - uint64_t byte_offset - - ctypedef struct DLPackVersion: - uint32_t major - uint32_t minor - - ctypedef struct DLManagedTensor: - DLTensor dl_tensor - void* manager_ctx - void (*deleter)(DLManagedTensor* self) - - ctypedef struct DLManagedTensorVersioned: - DLPackVersion version - DLManagedTensor dl_tensor - void* manager_ctx - void (*deleter)(DLManagedTensorVersioned* self) - uint64_t flags - - ctypedef struct TVMFFIObject: - int32_t type_index - int32_t ref_counter - void (*deleter)(TVMFFIObject* self) - - ctypedef struct TVMFFIAny: - int32_t type_index - int32_t zero_padding - int64_t v_int64 - double v_float64 - void* v_ptr - TVMFFIObject* v_obj - const char* v_c_str - DLDataType v_dtype - DLDevice v_device - - ctypedef struct TVMFFIByteArray: - const char* data - size_t size - - ctypedef struct TVMFFIShapeCell: - const int64_t* data - size_t size - - ctypedef struct TVMFFIErrorCell: - TVMFFIByteArray kind - TVMFFIByteArray message - TVMFFIByteArray traceback - void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback) - - ctypedef int (*TVMFFISafeCallType)( - void* handle, const TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result) noexcept - - cdef enum TVMFFIFieldFlagBitMask: - kTVMFFIFieldFlagBitMaskWritable = 1 << 0 - kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1 - kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2 - - ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept; - ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept; - ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept; - - ctypedef struct TVMFFIFieldInfo: - TVMFFIByteArray name - TVMFFIByteArray doc - TVMFFIByteArray type_schema - int64_t flags - int64_t size - int64_t alignment - int64_t offset - TVMFFIFieldGetter getter - TVMFFIFieldSetter setter - TVMFFIAny default_value - int32_t field_static_type_index - - ctypedef struct TVMFFIMethodInfo: - TVMFFIByteArray name - TVMFFIByteArray doc - TVMFFIByteArray type_schema - int64_t flags - TVMFFIAny method - - ctypedef struct TVMFFITypeMetadata: - TVMFFIByteArray doc - TVMFFIObjectCreator creator - int64_t total_size - - ctypedef struct TVMFFITypeInfo: - int32_t type_index - int32_t type_depth - TVMFFIByteArray type_key - const int32_t* type_acenstors - uint64_t type_key_hash - int32_t num_fields - int32_t num_methods - const TVMFFIFieldInfo* fields - const TVMFFIMethodInfo* methods - const TVMFFITypeMetadata* metadata - - int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil - int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil - int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result) nogil - int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void*), TVMFFIObjectHandle* out) nogil - int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out) nogil - int TVMFFIFunctionSetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle f, int override) nogil - int TVMFFIFunctionGetGlobal(TVMFFIByteArray* name, TVMFFIObjectHandle* out) nogil - void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result) nogil - void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil - TVMFFIObjectHandle TVMFFIErrorCreate(TVMFFIByteArray* kind, TVMFFIByteArray* message, - TVMFFIByteArray* traceback) nogil - - int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil - int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil - int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil - const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) nogil; - int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t require_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out) nogil - int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* src, - int32_t require_alignment, - int32_t require_contiguous, - TVMFFIObjectHandle* out) nogil - int TVMFFINDArrayToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out) nogil - int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src, - DLManagedTensorVersioned** out) nogil - const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil - TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil - TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil - TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil - TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil - DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil - DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil - -cdef extern from "tvm/ffi/extra/c_env_api.h": - ctypedef void* TVMFFIStreamHandle - - int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil - void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil - int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream) nogil - - -cdef class ByteArrayArg: - cdef TVMFFIByteArray cdata - cdef object py_data - - def __cinit__(self, py_data): - if isinstance(py_data, bytearray): - py_data = bytes(py_data) - cdef char* data - cdef Py_ssize_t size - self.py_data = py_data - PyBytes_AsStringAndSize(py_data, &data, &size) - self.cdata.data = data - self.cdata.size = size - - cdef inline TVMFFIByteArray* cptr(self): - return &self.cdata - - -cdef inline py_str(const char* x): - """Convert a c_char_p to a python string - - Parameters - ---------- - x : c_char_p - A char pointer that can be passed to C API - """ - return x.decode("utf-8") - - -cdef inline str bytearray_to_str(const TVMFFIByteArray* x): - return PyBytes_FromStringAndSize(x.data, x.size).decode("utf-8") - - -cdef inline c_str(pystr): - """Create ctypes char * from a python string - - Parameters - ---------- - string : string type - python string - - Returns - ------- - str : c_char_p - A char pointer that can be passed to C API - """ - return pystr.encode("utf-8") - - -cdef inline object ctypes_handle(void* chandle): - """Cast C handle to ctypes handle.""" - return ctypes.cast(chandle, ctypes.c_void_p) - - -cdef inline void* c_handle(object handle): - """Cast C types handle to c handle.""" - cdef unsigned long long v_ptr - v_ptr = handle.value - return (v_ptr) - - -cdef _init_env_api(): - # Initialize env api for signal handling - # Also registers the gil state release and ensure as PyErr_CheckSignals - # function is called with gil released and we need to regrab the gil - CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyErr_CheckSignals"), PyErr_CheckSignals)) - CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Ensure"), PyGILState_Ensure)) - CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Release"), PyGILState_Release)) - -_init_env_api() diff --git a/python/tvm/ffi/cython/device.pxi b/python/tvm/ffi/cython/device.pxi deleted file mode 100644 index 90d641c44ffa..000000000000 --- a/python/tvm/ffi/cython/device.pxi +++ /dev/null @@ -1,167 +0,0 @@ - - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -_CLASS_DEVICE = None - -def _set_class_device(cls): - global _CLASS_DEVICE - _CLASS_DEVICE = cls - - -def _create_device_from_tuple(cls, device_type, device_id): - cdef DLDevice cdevice = TVMFFIDLDeviceFromIntPair(device_type, device_id) - ret = cls.__new__(cls) - (ret).cdevice = cdevice - return ret - - -cdef class Device: - """Device is a wrapper around DLDevice. - - Parameters - ---------- - device_type_or_name : Union[str, int] - The string representation of the device type - - device_id : int - The device id - """ - cdef DLDevice cdevice - - kDLCPU = 1 - kDLCUDA = 2 - kDLCUDAHost = 3 - kDLOpenCL = 4 - kDLVulkan = 7 - kDLMetal = 8 - kDLVPI = 9 - kDLROCM = 10 - kDLROCMHost = 11 - kDLExtDev = 12 - kDLCUDAManaged = 13 - kDLOneAPI = 14 - kDLWebGPU = 15 - kDLHexagon = 16 - - DEVICE_TYPE_TO_NAME = { - kDLCPU: "cpu", - kDLCUDA: "cuda", - kDLCUDAHost: "cuda_host", - kDLCUDAManaged: "cuda_managed", - kDLOpenCL: "opencl", - kDLVulkan: "vulkan", - kDLMetal: "metal", - kDLVPI: "vpi", - kDLROCM: "rocm", - kDLROCMHost: "rocm_host", - kDLExtDev: "ext_dev", - kDLOneAPI: "oneapi", - kDLWebGPU: "webgpu", - kDLHexagon: "hexagon", - } - - DEVICE_NAME_TO_TYPE = { - "llvm": kDLCPU, - "cpu": kDLCPU, - "c": kDLCPU, - "test": kDLCPU, - "hybrid": kDLCPU, - "composite": kDLCPU, - "cuda": kDLCUDA, - "nvptx": kDLCUDA, - "cl": kDLOpenCL, - "opencl": kDLOpenCL, - "vulkan": kDLVulkan, - "metal": kDLMetal, - "vpi": kDLVPI, - "rocm": kDLROCM, - "ext_dev": kDLExtDev, - "hexagon": kDLHexagon, - "webgpu": kDLWebGPU, - } - - def __init__(self, device_type_or_name, device_id = None): - if isinstance(device_type_or_name, str): - parts = device_type_or_name.split(":") - if len(parts) < 1 or len(parts) > 2: - raise ValueError(f"Invalid device: {device_type_or_name}") - if parts[0] not in self.DEVICE_NAME_TO_TYPE: - raise ValueError(f"Unknown device: {parts[0]}") - device_type = self.DEVICE_NAME_TO_TYPE[parts[0]] - if len(parts) == 2: - try: - device_id = int(parts[1]) - except ValueError: - raise ValueError(f"Invalid device id: {parts[1]}") - else: - device_type = device_type_or_name - device_id = device_id if device_id is not None else 0 - if not isinstance(device_id, int): - raise TypeError(f"Invalid device id: {device_id}") - self.cdevice = TVMFFIDLDeviceFromIntPair(device_type, device_id) - - def __reduce__(self): - cls = type(self) - return (_create_device_from_tuple, (cls, self.cdevice.device_type, self.cdevice.device_id)) - - def __eq__(self, other): - if not isinstance(other, Device): - return False - return ( - self.cdevice.device_type == (other).cdevice.device_type - and self.cdevice.device_id == (other).cdevice.device_id - ) - - def __ne__(self, other): - return not self.__eq__(other) - - def __device_type_name__(self): - return self.DEVICE_TYPE_TO_NAME[self.cdevice.device_type] - - def __str__(self): - cdef int dev_type = self.cdevice.device_type - name = self.__device_type_name__() - index = self.cdevice.device_id - return f"{name}:{index}" - - def __repr__(self): - cdef int dev_type = self.cdevice.device_type - name = self.__device_type_name__() - index = self.cdevice.device_id - return f"device(type='{name}', index={index})" - - def __hash__(self): - return hash((self.cdevice.device_type, self.cdevice.device_id)) - - @property - def device_type(self): - return self.cdevice.device_type - - @property - def device_id(self): - return self.cdevice.device_id - - -cdef inline object make_ret_device(TVMFFIAny result): - ret = _CLASS_DEVICE.__new__(_CLASS_DEVICE) - (ret).cdevice = result.v_device - return ret - - -_set_class_device(Device) diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi deleted file mode 100644 index 279b17f8c83c..000000000000 --- a/python/tvm/ffi/cython/dtype.pxi +++ /dev/null @@ -1,116 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -_CLASS_DTYPE = None - -def _set_class_dtype(cls): - global _CLASS_DTYPE - _CLASS_DTYPE = cls - - -def _create_dtype_from_tuple(cls, code, bits, lanes): - cdef DLDataType cdtype - cdtype.code = code - cdtype.bits = bits - cdtype.lanes = lanes - ret = cls.__new__(cls, str(cdtype)) - (ret).cdtype = cdtype - return ret - - -cdef class DataType: - """DataType is a wrapper around DLDataType. - - Parameters - ---------- - dtype_str : str - The string representation of the data type - """ - cdef DLDataType cdtype - - def __init__(self, dtype_str): - cdef ByteArrayArg dtype_str_arg = ByteArrayArg(c_str(dtype_str)) - CHECK_CALL(TVMFFIDataTypeFromString(dtype_str_arg.cptr(), &(self.cdtype))) - - def __reduce__(self): - cls = type(self) - return (_create_dtype_from_tuple, - (cls, self.cdtype.code, self.cdtype.bits, self.cdtype.lanes)) - - def __eq__(self, other): - if not isinstance(other, DataType): - return False - return ( - self.cdtype.code == other.cdtype.code - and self.cdtype.bits == other.cdtype.bits - and self.cdtype.lanes == other.cdtype.lanes - ) - - def __ne__(self, other): - return not self.__eq__(other) - - @property - def type_code(self): - return self.cdtype.code - - @property - def bits(self): - return self.cdtype.bits - - @property - def lanes(self): - return self.cdtype.lanes - - @property - def itemsize(self): - """Get the number of bytes of a single element of this data type. When the number of lanes - is greater than 1, the itemsize is the size of the vector type. - - Returns - ------- - itemsize : int - The number of bytes of a single element of this data type - """ - lanes_as_int = self.cdtype.lanes - if lanes_as_int < 0: - raise ValueError("Cannot determine itemsize for scalable vector types") - return (self.cdtype.bits * self.cdtype.lanes + 7) // 8 - - def __str__(self): - cdef TVMFFIAny temp_any - cdef TVMFFIByteArray* bytes_ptr - cdef TVMFFIByteArray bytes - - CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &temp_any)) - if temp_any.type_index == kTVMFFISmallStr: - bytes = TVMFFISmallBytesGetContentByteArray(&temp_any) - res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) - return res - - bytes_ptr = TVMFFIBytesGetByteArrayPtr(temp_any.v_obj) - res = py_str(PyBytes_FromStringAndSize(bytes_ptr.data, bytes_ptr.size)) - CHECK_CALL(TVMFFIObjectFree(temp_any.v_obj)) - return res - - -cdef inline object make_ret_dtype(TVMFFIAny result): - cdtype = DataType.__new__(DataType) - (cdtype).cdtype = result.v_dtype - val = str.__new__(_CLASS_DTYPE, cdtype.__str__()) - val.__tvm_ffi_dtype__ = cdtype - return val diff --git a/python/tvm/ffi/cython/error.pxi b/python/tvm/ffi/cython/error.pxi deleted file mode 100644 index 968860390a3c..000000000000 --- a/python/tvm/ffi/cython/error.pxi +++ /dev/null @@ -1,134 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# error handling for FFI - -import types -import re - -ERROR_NAME_TO_TYPE = {} -ERROR_TYPE_TO_NAME = {} - -_WITH_APPEND_TRACEBACK = None -_TRACEBACK_TO_STR = None - - -cdef class Error(Object): - """Base class for all FFI errors, usually they are attached to errors - - Note - ---- - Do not directly raise this object, instead use the `py_error` method - to convert it to a python error then raise it. - """ - - def __init__(self, kind, message, traceback): - cdef ByteArrayArg kind_arg = ByteArrayArg(c_str(kind)) - cdef ByteArrayArg message_arg = ByteArrayArg(c_str(message)) - cdef ByteArrayArg traceback_arg = ByteArrayArg(c_str(traceback)) - (self).chandle = TVMFFIErrorCreate( - kind_arg.cptr(), message_arg.cptr(), traceback_arg.cptr() - ) - - def update_traceback(self, traceback): - """Update the traceback of the error - - Parameters - ---------- - traceback : str - The traceback to update. - """ - cdef ByteArrayArg traceback_arg = ByteArrayArg(c_str(traceback)) - TVMFFIErrorGetCellPtr(self.chandle).update_traceback(self.chandle, traceback_arg.cptr()) - - def py_error(self): - """ - Convert the FFI error to the python error - """ - error_cls = ERROR_NAME_TO_TYPE.get(self.kind, RuntimeError) - py_error = error_cls(self.message) - py_error = _WITH_APPEND_TRACEBACK(py_error, self.traceback) - py_error.__tvm_ffi_error__ = self - return py_error - - @property - def kind(self): - return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).kind)) - - @property - def message(self): - return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).message)) - - @property - def traceback(self): - return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).traceback)) - - -_register_object_by_index(kTVMFFIError, Error) - - -cdef inline Error move_from_last_error(): - # raise last error - error = Error.__new__(Error) - TVMFFIErrorMoveFromRaised(&(error).chandle) - return error - - -cdef inline int raise_existing_error() except -2: - return -2 - - -cdef inline int set_last_ffi_error(error) except -1: - """Set the last FFI error""" - cdef Error ffi_error - - kind = ERROR_TYPE_TO_NAME.get(type(error), "RuntimeError") - message = error.__str__() - py_traceback = _TRACEBACK_TO_STR(error.__traceback__) - c_traceback = bytearray_to_str(TVMFFITraceback("", 0, "")) - - # error comes from an exception thrown from C++ side - if hasattr(error, "__tvm_ffi_error__"): - # already have stack trace - ffi_error = error.__tvm_ffi_error__ - # attach the python traceback together with the C++ traceback to get full trace - ffi_error.update_traceback(c_traceback + py_traceback) - TVMFFIErrorSetRaised(ffi_error.chandle) - else: - ffi_error = Error(kind, message, c_traceback + py_traceback) - TVMFFIErrorSetRaised(ffi_error.chandle) - - -def _convert_to_ffi_error(error): - """Convert the python error to the FFI error""" - py_traceback = _TRACEBACK_TO_STR(error.__traceback__) - if hasattr(error, "__tvm_ffi_error__"): - error.__tvm_ffi_error__.update_traceback(py_traceback) - return error.__tvm_ffi_error__ - else: - kind = ERROR_TYPE_TO_NAME.get(type(error), "RuntimeError") - message = error.__str__() - return Error(kind, message, py_traceback) - - -cdef inline int CHECK_CALL(int ret) except -2: - """Check the return code of the C API function call""" - if ret == 0: - return 0 - # -2 brings exception - if ret == -2: - raise raise_existing_error() - raise move_from_last_error().py_error() diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi deleted file mode 100644 index 4148cc6c88e1..000000000000 --- a/python/tvm/ffi/cython/function.pxi +++ /dev/null @@ -1,516 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import ctypes -from numbers import Real, Integral - -try: - # optionally import torch and setup torch related utils - import torch -except ImportError: - torch = None - - -def load_torch_get_current_cuda_stream(): - """Create a faster get_current_cuda_stream for torch through cpp extension. - """ - from torch.utils import cpp_extension - - source = """ - #include - - int64_t get_current_cuda_stream(int device_id) { - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id); - // fast invariant, default stream is always 0 - if (stream.id() == 0) return 0; - // convert to cudaStream_t - return reinterpret_cast(static_cast(stream)); - } - """ - def fallback_get_current_cuda_stream(device_id): - """Fallback with python api""" - return torch.cuda.current_stream(device_id).cuda_stream - try: - result = cpp_extension.load_inline( - name="get_current_cuda_stream", - cpp_sources=[source], - cuda_sources=[], - extra_cflags=["-O3"], - extra_include_paths=cpp_extension.include_paths("cuda"), - functions=["get_current_cuda_stream"], - ) - return result.get_current_cuda_stream - except Exception: - return fallback_get_current_cuda_stream - -if torch is not None: - # when torch is available, jit compile the get_current_cuda_stream function - # the torch caches the extension so second loading is faster - torch_get_current_cuda_stream = load_torch_get_current_cuda_stream() - - -cdef inline object make_ret_small_str(TVMFFIAny result): - """convert small string to return value.""" - cdef TVMFFIByteArray bytes - bytes = TVMFFISmallBytesGetContentByteArray(&result) - return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) - - -cdef inline object make_ret_small_bytes(TVMFFIAny result): - """convert small bytes to return value.""" - cdef TVMFFIByteArray bytes - bytes = TVMFFISmallBytesGetContentByteArray(&result) - return PyBytes_FromStringAndSize(bytes.data, bytes.size) - - -cdef inline object make_ret(TVMFFIAny result): - """convert result to return value.""" - # TODO: Implement - cdef int32_t type_index - type_index = result.type_index - if type_index == kTVMFFINDArray: - # specially handle NDArray as it needs a special dltensor field - return make_ndarray_from_any(result) - elif type_index >= kTVMFFIStaticObjectBegin: - return make_ret_object(result) - elif type_index == kTVMFFINone: - return None - elif type_index == kTVMFFIBool: - return bool(result.v_int64) - elif type_index == kTVMFFIInt: - return result.v_int64 - elif type_index == kTVMFFIFloat: - return result.v_float64 - elif type_index == kTVMFFISmallStr: - return make_ret_small_str(result) - elif type_index == kTVMFFISmallBytes: - return make_ret_small_bytes(result) - elif type_index == kTVMFFIOpaquePtr: - return ctypes_handle(result.v_ptr) - elif type_index == kTVMFFIDataType: - return make_ret_dtype(result) - elif type_index == kTVMFFIDevice: - return make_ret_device(result) - elif type_index == kTVMFFIDLTensorPtr: - return make_ret_dltensor(result) - elif type_index == kTVMFFIObjectRValueRef: - raise ValueError("Return value cannot be ObjectRValueRef") - elif type_index == kTVMFFIByteArrayPtr: - raise ValueError("Return value cannot be ByteArrayPtr") - elif type_index == kTVMFFIRawStr: - raise ValueError("Return value cannot be RawStr") - raise ValueError("Unhandled type index %d" % type_index) - - -cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, - int* ctx_dev_type, int* ctx_dev_id, TVMFFIStreamHandle* ctx_stream) except -1: - """Pack arguments into c args tvm call accept""" - cdef unsigned long long temp_ptr - cdef DLTensor* temp_dltensor - cdef int is_cuda = 0 - - for i, arg in enumerate(py_args): - # clear the value to ensure zero padding on 32bit platforms - if sizeof(void*) != 8: - out[i].v_int64 = 0 - out[i].zero_padding = 0 - - if isinstance(arg, NDArray): - if (arg).chandle != NULL: - out[i].type_index = kTVMFFINDArray - out[i].v_ptr = (arg).chandle - else: - out[i].type_index = kTVMFFIDLTensorPtr - out[i].v_ptr = (arg).cdltensor - elif isinstance(arg, Object): - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - elif torch is not None and isinstance(arg, torch.Tensor): - is_cuda = arg.is_cuda - arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg), - required_alignment=__dlpack_auto_import_required_alignment__) - out[i].type_index = kTVMFFINDArray - out[i].v_ptr = (arg).chandle - temp_dltensor = TVMFFINDArrayGetDLTensorPtr((arg).chandle) - # record the stream and device for torch context - if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1: - ctx_dev_type[0] = temp_dltensor.device.device_type - ctx_dev_id[0] = temp_dltensor.device.device_id - temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id) - ctx_stream[0] = temp_ptr - temp_args.append(arg) - elif hasattr(arg, "__dlpack__"): - arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__) - out[i].type_index = kTVMFFINDArray - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: - arg = arg.__tvm_ffi_object__ - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - elif isinstance(arg, bool): - # A python `bool` is a subclass of `int`, so this check - # must occur before `Integral`. - out[i].type_index = kTVMFFIBool - out[i].v_int64 = arg - elif isinstance(arg, Integral): - out[i].type_index = kTVMFFIInt - out[i].v_int64 = arg - elif isinstance(arg, float): - out[i].type_index = kTVMFFIFloat - out[i].v_float64 = arg - elif isinstance(arg, _CLASS_DTYPE): - # dtype is a subclass of str, so this check occur before str - arg = arg.__tvm_ffi_dtype__ - out[i].type_index = kTVMFFIDataType - out[i].v_dtype = (arg).cdtype - elif isinstance(arg, _CLASS_DEVICE): - out[i].type_index = kTVMFFIDevice - out[i].v_device = (arg).cdevice - elif isinstance(arg, str): - tstr = c_str(arg) - out[i].type_index = kTVMFFIRawStr - out[i].v_c_str = tstr - temp_args.append(tstr) - elif arg is None: - out[i].type_index = kTVMFFINone - elif isinstance(arg, Real): - out[i].type_index = kTVMFFIFloat - out[i].v_float64 = arg - elif isinstance(arg, (bytes, bytearray)): - arg = ByteArrayArg(arg) - out[i].type_index = kTVMFFIByteArrayPtr - out[i].v_int64 = 0 - out[i].v_ptr = (arg).cptr() - temp_args.append(arg) - elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): - arg = _FUNC_CONVERT_TO_OBJECT(arg) - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - elif isinstance(arg, ctypes.c_void_p): - out[i].type_index = kTVMFFIOpaquePtr - out[i].v_ptr = c_handle(arg) - elif isinstance(arg, Exception): - arg = _convert_to_ffi_error(arg) - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - elif isinstance(arg, ObjectRValueRef): - out[i].type_index = kTVMFFIObjectRValueRef - out[i].v_ptr = &(((arg.obj)).chandle) - elif callable(arg): - arg = _convert_to_ffi_func(arg) - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - else: - raise TypeError("Unsupported argument type: %s" % type(arg)) - - -cdef inline int FuncCall3(void* chandle, - tuple args, - TVMFFIAny* result, - int* c_api_ret_code) except -1: - # fast path with stack alloca for less than 3 args - cdef TVMFFIAny[3] packed_args - cdef int nargs = len(args) - cdef int ctx_dev_type = -1 - cdef int ctx_dev_id = 0 - cdef TVMFFIStreamHandle ctx_stream = NULL - cdef TVMFFIStreamHandle prev_stream = NULL - temp_args = [] - make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) - with nogil: - if ctx_dev_type != -1: - # set the stream based on ctx stream - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) - if c_api_ret_code[0] != 0: - return 0 - c_api_ret_code[0] = TVMFFIFunctionCall( - chandle, &packed_args[0], nargs, result - ) - # restore the original stream if it is not the same as the context stream - if ctx_dev_type != -1 and prev_stream != ctx_stream: - # restore the original stream - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) - if c_api_ret_code[0] != 0: - return 0 - return 0 - - -cdef inline int FuncCall(void* chandle, - tuple args, - TVMFFIAny* result, - int* c_api_ret_code) except -1: - cdef int nargs = len(args) - cdef int ctx_dev_type = -1 - cdef int ctx_dev_id = 0 - cdef TVMFFIStreamHandle ctx_stream = NULL - cdef TVMFFIStreamHandle prev_stream = NULL - - if nargs <= 3: - FuncCall3(chandle, args, result, c_api_ret_code) - return 0 - - cdef vector[TVMFFIAny] packed_args - packed_args.resize(nargs) - - temp_args = [] - make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) - - with nogil: - if ctx_dev_type != -1: - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) - if c_api_ret_code[0] != 0: - return 0 - c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result) - # restore the original stream if it is not the same as the context stream - if ctx_dev_type != -1 and prev_stream != ctx_stream: - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) - if c_api_ret_code[0] != 0: - return 0 - - return 0 - - -cdef inline int ConstructorCall(void* constructor_handle, - tuple args, - void** handle) except -1: - """Call contructor of a handle function""" - cdef TVMFFIAny result - cdef int c_api_ret_code - # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone - result.type_index = kTVMFFINone - result.v_int64 = 0 - FuncCall(constructor_handle, args, &result, &c_api_ret_code) - CHECK_CALL(c_api_ret_code) - handle[0] = result.v_ptr - return 0 - - -class Function(Object): - """The Function object used in TVM FFI. - - See Also - -------- - tvm.ffi.register_func: How to register global function. - tvm.ffi.get_global_func: How to get global function. - """ - def __call__(self, *args): - cdef TVMFFIAny result - cdef int c_api_ret_code - # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone - result.type_index = kTVMFFINone - result.v_int64 = 0 - FuncCall((self).chandle, args, &result, &c_api_ret_code) - # NOTE: logic is same as check_call - # directly inline here to simplify traceback - if c_api_ret_code == 0: - return make_ret(result) - elif c_api_ret_code == -2: - raise_existing_error() - raise move_from_last_error().py_error() - -_register_object_by_index(kTVMFFIFunction, Function) - - -cdef class FieldGetter: - cdef TVMFFIFieldGetter getter - cdef int64_t offset - - def __call__(self, Object obj): - cdef TVMFFIAny result - cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset - result.type_index = kTVMFFINone - result.v_int64 = 0 - c_api_ret_code = self.getter(field_ptr, &result) - CHECK_CALL(c_api_ret_code) - return make_ret(result) - - -cdef class FieldSetter: - cdef TVMFFIFieldSetter setter - cdef int64_t offset - - def __call__(self, Object obj, value): - cdef TVMFFIAny[1] packed_args - cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset - cdef int nargs = 1 - temp_args = [] - make_args((value,), &packed_args[0], temp_args, NULL, NULL, NULL) - c_api_ret_code = self.setter(field_ptr, &packed_args[0]) - # NOTE: logic is same as check_call - # directly inline here to simplify traceback - if c_api_ret_code == 0: - return - elif c_api_ret_code == -2: - raise_existing_error() - raise move_from_last_error().py_error() - - -cdef _get_method_from_method_info(const TVMFFIMethodInfo* method): - cdef TVMFFIAny result - CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result)) - return make_ret(result) - - -def _member_method_wrapper(method_func): - def wrapper(self, *args): - return method_func(self, *args) - return wrapper - - -def _add_class_attrs_by_reflection(int type_index, object cls): - """Decorate the class attrs by reflection""" - cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index) - cdef const TVMFFIFieldInfo* field - cdef const TVMFFIMethodInfo* method - cdef int num_fields = info.num_fields - cdef int num_methods = info.num_methods - - for i in range(num_fields): - # attach fields to the class - field = &(info.fields[i]) - getter = FieldGetter.__new__(FieldGetter) - (getter).getter = field.getter - (getter).offset = field.offset - setter = FieldSetter.__new__(FieldSetter) - (setter).setter = field.setter - (setter).offset = field.offset - if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0: - setter = None - doc = ( - py_str(PyBytes_FromStringAndSize(field.doc.data, field.doc.size)) - if field.doc.size != 0 - else None - ) - name = py_str(PyBytes_FromStringAndSize(field.name.data, field.name.size)) - if hasattr(cls, name): - # skip already defined attributes - continue - setattr(cls, name, property(getter, setter, doc=doc)) - - for i in range(num_methods): - # attach methods to the class - method = &(info.methods[i]) - name = py_str(PyBytes_FromStringAndSize(method.name.data, method.name.size)) - doc = ( - py_str(PyBytes_FromStringAndSize(method.doc.data, method.doc.size)) - if method.doc.size != 0 - else None - ) - method_func = _get_method_from_method_info(method) - - if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod: - method_pyfunc = staticmethod(method_func) - else: - # must call into another method instead of direct capture - # to avoid the same method_func variable being used - # across multiple loop iterations - method_pyfunc = _member_method_wrapper(method_func) - - if doc is not None: - method_pyfunc.__doc__ = doc - method_pyfunc.__name__ = name - - if hasattr(cls, name): - # skip already defined attributes - continue - setattr(cls, name, method_pyfunc) - - return cls - - -def _register_global_func(name, pyfunc, override): - cdef TVMFFIObjectHandle chandle - cdef int c_api_ret_code - cdef int ioverride = override - cdef ByteArrayArg name_arg = ByteArrayArg(c_str(name)) - - if not isinstance(pyfunc, Function): - pyfunc = _convert_to_ffi_func(pyfunc) - - CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(), (pyfunc).chandle, ioverride)) - return pyfunc - - -def _get_global_func(name, allow_missing): - cdef TVMFFIObjectHandle chandle - cdef ByteArrayArg name_arg = ByteArrayArg(c_str(name)) - - CHECK_CALL(TVMFFIFunctionGetGlobal(name_arg.cptr(), &chandle)) - if chandle != NULL: - ret = Function.__new__(Function) - (ret).chandle = chandle - return ret - - if allow_missing: - return None - - raise ValueError("Cannot find global function %s" % name) - - -# handle callbacks -cdef void tvm_ffi_callback_deleter(void* fhandle) noexcept with gil: - local_pyfunc = (fhandle) - Py_DECREF(local_pyfunc) - - -cdef int tvm_ffi_callback(void* context, - const TVMFFIAny* packed_args, - int32_t num_args, - TVMFFIAny* result) noexcept with gil: - cdef list pyargs - cdef TVMFFIAny temp_result - local_pyfunc = (context) - pyargs = [] - for i in range(num_args): - CHECK_CALL(TVMFFIAnyViewToOwnedAny(&packed_args[i], &temp_result)) - pyargs.append(make_ret(temp_result)) - - try: - rv = local_pyfunc(*pyargs) - except Exception as err: - set_last_ffi_error(err) - return -1 - - temp_args = [] - make_args((rv,), &temp_result, temp_args, NULL, NULL, NULL) - CHECK_CALL(TVMFFIAnyViewToOwnedAny(&temp_result, result)) - - return 0 - - -def _convert_to_ffi_func(object pyfunc): - """Convert a python function to TVM FFI function""" - cdef TVMFFIObjectHandle chandle - Py_INCREF(pyfunc) - CHECK_CALL(TVMFFIFunctionCreate( - (pyfunc), - tvm_ffi_callback, - tvm_ffi_callback_deleter, - &chandle)) - ret = Function.__new__(Function) - (ret).chandle = chandle - return ret - -_STR_CONSTRUCTOR = _get_global_func("ffi.String", False) -_BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False) -_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) -_OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True) diff --git a/python/tvm/ffi/cython/ndarray.pxi b/python/tvm/ffi/cython/ndarray.pxi deleted file mode 100644 index 9dfe1222dc7e..000000000000 --- a/python/tvm/ffi/cython/ndarray.pxi +++ /dev/null @@ -1,292 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -__dlpack_version__ = (1, 1) -__dlpack_auto_import_required_alignment__ = 8 -_CLASS_NDARRAY = None - - -def _set_class_ndarray(cls): - global _CLASS_NDARRAY - _CLASS_NDARRAY = cls - - -cdef const char* _c_str_dltensor = "dltensor" -cdef const char* _c_str_used_dltensor = "used_dltensor" -cdef const char* _c_str_dltensor_versioned = "dltensor_versioned" -cdef const char* _c_str_used_dltensor_versioned = "used_dltensor_versioned" - -cdef void _c_dlpack_deleter(object pycaps): - cdef DLManagedTensor* dltensor - if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor): - dltensor = pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor) - dltensor.deleter(dltensor) - -cdef void _c_dlpack_versioned_deleter(object pycaps): - cdef DLManagedTensorVersioned* dltensor - if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor_versioned): - dltensor = pycapsule.PyCapsule_GetPointer( - pycaps, _c_str_dltensor_versioned) - dltensor.deleter(dltensor) - - -cdef inline int _from_dlpack( - object dltensor, int required_alignment, - int required_contiguous, TVMFFIObjectHandle* out -) except -1: - cdef DLManagedTensor* ptr - cdef int c_api_ret_code - cdef int c_req_alignment = required_alignment - cdef int c_req_contiguous = required_contiguous - if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): - ptr = pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) - with nogil: - c_api_ret_code = TVMFFINDArrayFromDLPack( - ptr, c_req_alignment, c_req_contiguous, out) - CHECK_CALL(c_api_ret_code) - # set name and destructor to be empty - pycapsule.PyCapsule_SetDestructor(dltensor, NULL) - pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor) - return 0 - raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once") - - -cdef inline int _from_dlpack_versioned( - object dltensor, int required_alignment, - int required_contiguous, TVMFFIObjectHandle* out -) except -1: - cdef DLManagedTensorVersioned* ptr - cdef int c_api_ret_code - cdef int c_req_alignment = required_alignment - cdef int c_req_contiguous = required_contiguous - if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor_versioned): - ptr = pycapsule.PyCapsule_GetPointer( - dltensor, _c_str_dltensor_versioned) - with nogil: - c_api_ret_code = TVMFFINDArrayFromDLPackVersioned( - ptr, c_req_alignment, c_req_contiguous, out) - CHECK_CALL(c_api_ret_code) - # set name and destructor to be empty - pycapsule.PyCapsule_SetDestructor(dltensor, NULL) - pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor_versioned) - return 0 - raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be consumed once") - - -def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True): - """ - Convert an external tensor to an NDArray. - - Parameters - ---------- - ext_tensor : object - The external tensor to convert. - - required_alignment : int - The minimum required alignment to check for the tensor. - - required_contiguous : bool - Whether to check for contiguous memory. - """ - cdef TVMFFIObjectHandle chandle - # as of most frameworks do not yet support v1.1 - # move to false as most frameworks get upgraded. - cdef int favor_legacy_dlpack = True - - if hasattr(ext_tensor, '__dlpack__'): - if favor_legacy_dlpack: - _from_dlpack( - ext_tensor.__dlpack__(), - required_alignment, - required_contiguous, - &chandle - ) - else: - try: - _from_dlpack_versioned( - ext_tensor.__dlpack__(max_version=__dlpack_version__), - required_alignment, - required_contiguous, - &chandle - ) - except TypeError: - _from_dlpack( - ext_tensor.__dlpack__(), - required_alignment, - required_contiguous, - &chandle - ) - else: - if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned): - _from_dlpack_versioned( - ext_tensor, - required_alignment, - required_contiguous, - &chandle - ) - elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor): - _from_dlpack( - ext_tensor, - required_alignment, - required_contiguous, - &chandle - ) - else: - raise TypeError("Expect from_dlpack to take either a compatible tensor or PyCapsule") - return make_ndarray_from_chandle(chandle) - - -# helper class for shape handling -def _shape_obj_get_py_tuple(obj): - cdef TVMFFIShapeCell* shape = TVMFFIShapeGetCellPtr((obj).chandle) - return tuple(shape.data[i] for i in range(shape.size)) - - -cdef class NDArray(Object): - """N-dimensional array that is compatible with DLPack. - """ - cdef DLTensor* cdltensor - - @property - def is_view(self): - return self.cdltensor != NULL and self.chandle == NULL - - @property - def shape(self): - """Shape of this array""" - return tuple(self.cdltensor.shape[i] for i in range(self.cdltensor.ndim)) - - @property - def dtype(self): - """Data type of this array""" - cdef TVMFFIAny dtype_any - dtype_any.v_dtype = self.cdltensor.dtype - return make_ret_dtype(dtype_any) - - @property - def device(self): - """Device of this array""" - cdef TVMFFIAny device_any - device_any.v_device = self.cdltensor.device - return make_ret_device(device_any) - - def to_dlpack(self): - """Produce an array from a DLPack Tensor without copying memory - - Returns - ------- - dlpack : DLPack tensor view of the array data - - Note - ---- - This is an old style legacy API, consider use new dlpack api instead. - """ - cdef DLManagedTensor* dltensor - cdef int c_api_ret_code - - with nogil: - c_api_ret_code = TVMFFINDArrayToDLPack(self.chandle, &dltensor) - CHECK_CALL(c_api_ret_code) - return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) - - def _to_dlpack_versioned(self): - cdef DLManagedTensorVersioned* dltensor - cdef int c_api_ret_code - - with nogil: - c_api_ret_code = TVMFFINDArrayToDLPackVersioned(self.chandle, &dltensor) - CHECK_CALL(c_api_ret_code) - return pycapsule.PyCapsule_New( - dltensor, _c_str_dltensor_versioned, _c_dlpack_versioned_deleter) - - def __dlpack_device__(self): - cdef int device_type = self.cdltensor.device.device_type - cdef int device_id = self.cdltensor.device.device_id - return (device_type, device_id) - - def __dlpack__(self, *, stream=None, max_version=None, dl_device=None, copy=None): - """Produce a DLPack tensor from this array - - Parameters - ---------- - stream : Optional[int] - The stream to use for the DLPack tensor - - max_version : int, optional - The maximum version of the DLPack tensor to produce - - dl_device : Optional[Tuple[int, int]] - The device to use for the DLPack tensor - - copy : Optional[bool] - Whether to copy the data to the new device - - Returns - ------- - dlpack : DLPack tensor - - Raises - ------ - BufferError - Export failed - """ - if max_version is None: - # Keep and use the DLPack 0.X implementation - # Note: from March 2025 onwards (but ideally as late as - # possible), it's okay to raise BufferError here - return self.to_dlpack() - else: - # We get to produce `DLManagedTensorVersioned` now. Note that - # our_own_dlpack_version is the max version that the *producer* - # supports and fills in the `DLManagedTensorVersioned::version` - # field - if max_version[0] >= __dlpack_version__[0]: - if dl_device is not None and dl_device != self.__dlpack_device__(): - raise BufferError("dl_device of different type not supported") - if copy is not None and copy: - raise BufferError("copy not yet supported") - return self._to_dlpack_versioned() - elif max_version[0] < 1: - return self.to_dlpack() - else: - raise BufferError(f"Unsupported max_version {max_version}") - - -_set_class_ndarray(NDArray) -_register_object_by_index(kTVMFFINDArray, NDArray) - - -cdef inline object make_ret_dltensor(TVMFFIAny result): - cdef DLTensor* dltensor - dltensor = result.v_ptr - ndarray = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) - (ndarray).chandle = NULL - (ndarray).cdltensor = dltensor - return ndarray - - -cdef inline object make_ndarray_from_chandle(TVMFFIObjectHandle chandle): - # TODO: Implement - cdef NDArray ndarray - ndarray = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) - (ndarray).chandle = chandle - (ndarray).cdltensor = TVMFFINDArrayGetDLTensorPtr(chandle) - return ndarray - - -cdef inline object make_ndarray_from_any(TVMFFIAny any): - return make_ndarray_from_chandle(any.v_ptr) diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi deleted file mode 100644 index dad6bee51b34..000000000000 --- a/python/tvm/ffi/cython/object.pxi +++ /dev/null @@ -1,281 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import warnings - -_CLASS_OBJECT = None -_FUNC_CONVERT_TO_OBJECT = None - - -def _set_class_object(cls): - global _CLASS_OBJECT - _CLASS_OBJECT = cls - -def _set_func_convert_to_object(func): - global _FUNC_CONVERT_TO_OBJECT - _FUNC_CONVERT_TO_OBJECT = func - - -def __object_repr__(obj): - """Object repr function that can be overridden by assigning to it""" - return type(obj).__name__ + "(" + str(obj.__ctypes_handle__().value) + ")" - - -def _new_object(cls): - """Helper function for pickle""" - return cls.__new__(cls) - - -_OBJECT_FROM_JSON_GRAPH_STR = None -_OBJECT_TO_JSON_GRAPH_STR = None - - -class ObjectGeneric: - """Base class for all classes that can be converted to object.""" - - def asobject(self): - """Convert value to object""" - raise NotImplementedError() - - -class ObjectRValueRef: - """Represent an RValue ref to an object that can be moved. - - Parameters - ---------- - obj : tvm.runtime.Object - The object that this value refers to - """ - - __slots__ = ["obj"] - - def __init__(self, obj): - self.obj = obj - - -cdef class Object: - """Base class of all TVM FFI objects. - """ - cdef void* chandle - - def __cinit__(self): - # initialize chandle to NULL to avoid leak in - # case of error before chandle is set - self.chandle = NULL - - def __dealloc__(self): - if self.chandle != NULL: - CHECK_CALL(TVMFFIObjectFree(self.chandle)) - self.chandle = NULL - - def __ctypes_handle__(self): - return ctypes_handle(self.chandle) - - def __chandle__(self): - cdef uint64_t chandle = self.chandle - return chandle - - def __reduce__(self): - cls = type(self) - return (_new_object, (cls,), self.__getstate__()) - - def __getstate__(self): - if _OBJECT_TO_JSON_GRAPH_STR is None: - raise RuntimeError("ffi.ToJSONGraphString is not registered, make sure build project with extra API") - if not self.__chandle__() == 0: - # need to explicit convert to str in case String - # returned and triggered another infinite recursion in get state - return {"handle": str(_OBJECT_TO_JSON_GRAPH_STR(self, None))} - return {"handle": None} - - def __setstate__(self, state): - # pylint: disable=assigning-non-slot, assignment-from-no-return - if _OBJECT_FROM_JSON_GRAPH_STR is None: - raise RuntimeError("ffi.FromJSONGraphString is not registered, make sure build project with extra API") - handle = state["handle"] - if handle is not None: - self.__init_handle_by_constructor__(_OBJECT_FROM_JSON_GRAPH_STR, handle) - else: - self.chandle = NULL - - def __repr__(self): - # exception safety handling for chandle=None - if self.chandle == NULL: - return type(self).__name__ + "(chandle=None)" - return str(__object_repr__(self)) - - def __eq__(self, other): - return self.same_as(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # avoid error raised during construction. - self.chandle = NULL - cdef void* chandle - ConstructorCall( - (fconstructor).chandle, args, &chandle) - self.chandle = chandle - - def same_as(self, other): - """Check object identity. - - Parameters - ---------- - other : object - The other object to compare against. - - Returns - ------- - result : bool - The comparison result. - """ - if not isinstance(other, Object): - return False - return self.chandle == (other).chandle - - def __hash__(self): - cdef uint64_t hash_value = self.chandle - return hash_value - - def _move(self): - """Create an RValue reference to the object and mark the object as moved. - - This is a advanced developer API that can be useful when passing an - unique reference to an Object that you no longer needed to a function. - - A unique reference can trigger copy on write optimization that avoids - copy when we transform an object. - - Note - ---- - All the reference of the object becomes invalid after it is moved. - Be very careful when using this feature. - - Returns - ------- - rvalue : The rvalue reference. - """ - return ObjectRValueRef(self) - - def __move_handle_from__(self, other): - """Move the handle from other to self""" - self.chandle = (other).chandle - (other).chandle = NULL - - -class PyNativeObject: - """Base class of all TVM objects that also subclass python's builtin types.""" - __slots__ = [] - - def __init_tvm_ffi_object_by_constructor__(self, fconstructor, *args): - """Initialize the internal tvm_ffi_object by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return object is directly set into the object - """ - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - obj.__init_handle_by_constructor__(fconstructor, *args) - self.__tvm_ffi_object__ = obj - - -"""Maps object type index to its constructor""" -cdef list OBJECT_TYPE = [] -"""Maps object type to its type index""" -cdef dict OBJECT_INDEX = {} - - -def _register_object_by_index(int index, object cls): - """register object class""" - global OBJECT_TYPE - while len(OBJECT_TYPE) <= index: - OBJECT_TYPE.append(None) - OBJECT_TYPE[index] = cls - OBJECT_INDEX[cls] = index - - -def _object_type_key_to_index(str type_key): - """get the type index of object class""" - cdef int32_t tidx - type_key_arg = ByteArrayArg(c_str(type_key)) - if TVMFFITypeKeyToIndex(type_key_arg.cptr(), &tidx) == 0: - return tidx - return None - -cdef inline str _type_index_to_key(int32_t tindex): - """get the type key of object class""" - cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex) - cdef const TVMFFIByteArray* type_key - if info == NULL: - return "" - type_key = &(info.type_key) - return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size)) - - -cdef inline object make_ret_object(TVMFFIAny result): - global OBJECT_TYPE - cdef int32_t tindex - cdef object cls - tindex = result.type_index - - if tindex < len(OBJECT_TYPE): - cls = OBJECT_TYPE[tindex] - if cls is not None: - if issubclass(cls, PyNativeObject): - obj = Object.__new__(Object) - (obj).chandle = result.v_obj - return cls.__from_tvm_ffi_object__(cls, obj) - obj = cls.__new__(cls) - (obj).chandle = result.v_obj - return obj - - # object is not found in registered entry - # in this case we need to report an warning - type_key = _type_index_to_key(tindex) - warnings.warn(f"Returning type `{type_key}` which is not registered via register_object, fallback to Object") - obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) - (obj).chandle = result.v_obj - return obj - - -_set_class_object(Object) diff --git a/python/tvm/ffi/cython/string.pxi b/python/tvm/ffi/cython/string.pxi deleted file mode 100644 index 4ab5c48ce07b..000000000000 --- a/python/tvm/ffi/cython/string.pxi +++ /dev/null @@ -1,85 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# helper class for string/bytes handling - -cdef inline str _string_obj_get_py_str(obj): - cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) - return py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) - - -cdef inline bytes _bytes_obj_get_py_bytes(obj): - cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) - return PyBytes_FromStringAndSize(bytes.data, bytes.size) - - - -class String(str, PyNativeObject): - __slots__ = ["__tvm_ffi_object__"] - """String object that is possibly returned by FFI call. - - Note - ---- - This class subclasses str so it can be directly treated as str. - There is no need to construct this object explicitly. - """ - def __new__(cls, value): - val = str.__new__(cls, value) - val.__tvm_ffi_object__ = None - return val - - # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj): - """Construct from a given tvm object.""" - content = _string_obj_get_py_str(obj) - val = str.__new__(cls, content) - val.__tvm_ffi_object__ = obj - return val - - -_register_object_by_index(kTVMFFIStr, String) - - -class Bytes(bytes, PyNativeObject): - """Bytes object that is possibly returned by FFI call. - - Note - ---- - This class subclasses bytes so it can be directly treated as bytes. - There is no need to construct this object explicitly. - """ - def __new__(cls, value): - val = bytes.__new__(cls, value) - val.__tvm_ffi_object__ = None - return val - - # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj): - """Construct from a given tvm object.""" - content = _bytes_obj_get_py_bytes(obj) - val = bytes.__new__(cls, content) - val.__tvm_ffi_object__ = obj - return val - - -_register_object_by_index(kTVMFFIBytes, Bytes) - -# We special handle str/bytes constructor in cython to avoid extra cyclic deps -# as the str/bytes construction must be done in the inner loop of function call -_STR_CONSTRUCTOR = None -_BYTES_CONSTRUCTOR = None diff --git a/python/tvm/ffi/dtype.py b/python/tvm/ffi/dtype.py deleted file mode 100644 index 32986a4eb0bf..000000000000 --- a/python/tvm/ffi/dtype.py +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""dtype class.""" -# pylint: disable=invalid-name -from enum import IntEnum -import numpy as np - -from . import core - - -class DataTypeCode(IntEnum): - """DataType code in DLTensor.""" - - INT = 0 - UINT = 1 - FLOAT = 2 - HANDLE = 3 - BFLOAT = 4 - Float8E3M4 = 7 - Float8E4M3 = 8 - Float8E4M3B11FNUZ = 9 - Float8E4M3FN = 10 - Float8E4M3FNUZ = 11 - Float8E5M2 = 12 - Float8E5M2FNUZ = 13 - Float8E8M0FNU = 14 - Float6E2M3FN = 15 - Float6E3M2FN = 16 - Float4E2M1FN = 17 - - -class dtype(str): - """TVM FFI dtype class. - - Parameters - ---------- - dtype_str : str - - Note - ---- - This class subclasses str so it can be directly passed - into other array api's dtype arguments. - """ - - __slots__ = ["__tvm_ffi_dtype__"] - - NUMPY_DTYPE_TO_STR = { - np.dtype(np.bool_): "bool", - np.dtype(np.int8): "int8", - np.dtype(np.int16): "int16", - np.dtype(np.int32): "int32", - np.dtype(np.int64): "int64", - np.dtype(np.uint8): "uint8", - np.dtype(np.uint16): "uint16", - np.dtype(np.uint32): "uint32", - np.dtype(np.uint64): "uint64", - np.dtype(np.float16): "float16", - np.dtype(np.float32): "float32", - np.dtype(np.float64): "float64", - } - if hasattr(np, "float_"): - NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64" - - def __new__(cls, content): - content = str(content) - val = str.__new__(cls, content) - val.__tvm_ffi_dtype__ = core.DataType(content) - return val - - def __repr__(self): - return f"dtype('{self}')" - - def with_lanes(self, lanes): - """ - Create a new dtype with the given number of lanes. - - Parameters - ---------- - lanes : int - The number of lanes. - - Returns - ------- - dtype - The new dtype with the given number of lanes. - """ - cdtype = core._create_dtype_from_tuple( - core.DataType, self.__tvm_ffi_dtype__.type_code, self.__tvm_ffi_dtype__.bits, lanes - ) - val = str.__new__(dtype, str(cdtype)) - val.__tvm_ffi_dtype__ = cdtype - return val - - @property - def itemsize(self): - return self.__tvm_ffi_dtype__.itemsize - - @property - def type_code(self): - return self.__tvm_ffi_dtype__.type_code - - @property - def bits(self): - return self.__tvm_ffi_dtype__.bits - - @property - def lanes(self): - return self.__tvm_ffi_dtype__.lanes - - -try: - import ml_dtypes - - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2" - dtype.NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" -except ImportError: - pass - -core._set_class_dtype(dtype) diff --git a/python/tvm/ffi/error.py b/python/tvm/ffi/error.py deleted file mode 100644 index a7714cb58ffd..000000000000 --- a/python/tvm/ffi/error.py +++ /dev/null @@ -1,193 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Error handling.""" -import re -import types -import sys -import ast -from . import core - - -def _parse_traceback(traceback): - """Parse the traceback string into a list of (filename, lineno, func) - - Parameters - ---------- - traceback : str - The traceback string. - - Returns - ------- - result : List[Tuple[str, int, str]] - The list of (filename, lineno, func) - """ - pattern = r'File "(.+?)", line (\d+), in (.+)' - result = [] - for line in traceback.split("\n"): - match = re.match(pattern, line.strip()) - if match: - try: - filename = match.group(1) - lineno = int(match.group(2)) - func = match.group(3) - result.append((filename, lineno, func)) - except ValueError: - pass - return result - - -class TracebackManager: - """ - Helper to manage traceback generation - """ - - def __init__(self): - self._code_cache = {} - - def _get_cached_code_object(self, filename, lineno, func): - # Hack to create a code object that points to the correct - # line number and function name - key = (filename, lineno, func) - # cache the code object to avoid re-creating it - if key in self._code_cache: - return self._code_cache[key] - # Parse to AST and zero out column info - # since column info are not accurate in original trace - tree = ast.parse("_getframe()", filename=filename, mode="eval") - for node in ast.walk(tree): - if hasattr(node, "col_offset"): - node.col_offset = 0 - if hasattr(node, "end_col_offset"): - node.end_col_offset = 0 - # call into get frame, bt changes the context - code_object = compile(tree, filename, "eval") - # replace the function name and line number - code_object = code_object.replace(co_name=func, co_firstlineno=lineno) - self._code_cache[key] = code_object - return code_object - - def _create_frame(self, filename, lineno, func): - """Create a frame object from the filename, lineno, and func""" - code_object = self._get_cached_code_object(filename, lineno, func) - # call into get frame, but changes the context so the code - # points to the correct frame - context = {"_getframe": sys._getframe} - # pylint: disable=eval-used - return eval(code_object, context, context) - - def append_traceback(self, tb, filename, lineno, func): - """Append a traceback to the given traceback - - Parameters - ---------- - tb : types.TracebackType - The traceback to append to. - filename : str - The filename of the traceback - lineno : int - The line number of the traceback - func : str - The function name of the traceback - - Returns - ------- - new_tb : types.TracebackType - The new traceback with the appended frame. - """ - frame = self._create_frame(filename, lineno, func) - return types.TracebackType(tb, frame, frame.f_lasti, lineno) - - -_TRACEBACK_MANAGER = TracebackManager() - - -def _with_append_traceback(py_error, traceback): - """Append the traceback to the py_error and return it""" - tb = py_error.__traceback__ - for filename, lineno, func in reversed(_parse_traceback(traceback)): - tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func) - return py_error.with_traceback(tb) - - -def _traceback_to_str(tb): - """Convert the traceback to a string""" - lines = [] - while tb is not None: - frame = tb.tb_frame - lineno = tb.tb_lineno - filename = frame.f_code.co_filename - funcname = frame.f_code.co_name - lines.append(f' File "{filename}", line {lineno}, in {funcname}\n') - tb = tb.tb_next - return "".join(lines) - - -core._WITH_APPEND_TRACEBACK = _with_append_traceback -core._TRACEBACK_TO_STR = _traceback_to_str - - -def register_error(name_or_cls=None, cls=None): - """Register an error class so it can be recognized by the ffi error handler. - - Parameters - ---------- - name_or_cls : str or class - The name of the error class. - - cls : class - The class to register. - - Returns - ------- - fregister : function - Register function if f is not specified. - - Examples - -------- - .. code-block:: python - - @tvm.error.register_error - class MyError(RuntimeError): - pass - - err_inst = tvm.error.create_ffi_error("MyError: xyz") - assert isinstance(err_inst, MyError) - """ - if callable(name_or_cls): - cls = name_or_cls - name_or_cls = cls.__name__ - - def register(mycls): - """internal register function""" - err_name = name_or_cls if isinstance(name_or_cls, str) else mycls.__name__ - core.ERROR_NAME_TO_TYPE[err_name] = mycls - core.ERROR_TYPE_TO_NAME[mycls] = err_name - return mycls - - if cls is None: - return register - return register(cls) - - -register_error("RuntimeError", RuntimeError) -register_error("ValueError", ValueError) -register_error("TypeError", TypeError) -register_error("AttributeError", AttributeError) -register_error("KeyError", KeyError) -register_error("IndexError", IndexError) -register_error("AssertionError", AssertionError) diff --git a/python/tvm/ffi/module.py b/python/tvm/ffi/module.py deleted file mode 100644 index 0895b317c1d4..000000000000 --- a/python/tvm/ffi/module.py +++ /dev/null @@ -1,258 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Module related objects and functions.""" -# pylint: disable=invalid-name - -from enum import IntEnum -from . import _ffi_api - -from . import core -from .registry import register_object - -__all__ = ["Module", "ModulePropertyMask", "system_lib", "load_module"] - - -class ModulePropertyMask(IntEnum): - """Runtime Module Property Mask.""" - - BINARY_SERIALIZABLE = 0b001 - RUNNABLE = 0b010 - COMPILATION_EXPORTABLE = 0b100 - - -@register_object("ffi.Module") -class Module(core.Object): - """Runtime Module.""" - - def __new__(cls): - instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter - instance.entry_name = "__tvm_ffi_main__" - instance._entry = None - return instance - - @property - def entry_func(self): - """Get the entry function - - Returns - ------- - f : tvm.ffi.Function - The entry function if exist - """ - if self._entry: - return self._entry - self._entry = self.get_function("__tvm_ffi_main__") - return self._entry - - @property - def kind(self): - """Get type key of the module.""" - return _ffi_api.ModuleGetKind(self) - - @property - def imports(self): - """Get imported modules - - Returns - ---------- - modules : list of Module - The module - """ - return self.imports_ - - def implements_function(self, name, query_imports=False): - """Returns True if the module has a definition for the global function with name. Note - that has_function(name) does not imply get_function(name) is non-null since the module - may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function - without further compilation. However, get_function(name) non null should always imply - has_function(name). - - Parameters - ---------- - name : str - The name of the function - - query_imports : bool - Whether to also query modules imported by this module. - - Returns - ------- - b : Bool - True if module (or one of its imports) has a definition for name. - """ - return _ffi_api.ModuleImplementsFunction(self, name, query_imports) - - def get_function(self, name, query_imports=False): - """Get function from the module. - - Parameters - ---------- - name : str - The name of the function - - query_imports : bool - Whether also query modules imported by this module. - - Returns - ------- - f : tvm.ffi.Function - The result function. - """ - func = _ffi_api.ModuleGetFunction(self, name, query_imports) - if func is None: - raise AttributeError(f"Module has no function '{name}'") - return func - - def import_module(self, module): - """Add module to the import list of current one. - - Parameters - ---------- - module : tvm.runtime.Module - The other module. - """ - _ffi_api.ModuleImportModule(self, module) - - def __getitem__(self, name): - if not isinstance(name, str): - raise ValueError("Can only take string as function name") - return self.get_function(name) - - def __call__(self, *args): - if self._entry: - return self._entry(*args) - # pylint: disable=not-callable - return self.entry_func(*args) - - def inspect_source(self, fmt=""): - """Get source code from module, if available. - - Parameters - ---------- - fmt : str, optional - The specified format. - - Returns - ------- - source : str - The result source code. - """ - return _ffi_api.ModuleInspectSource(self, fmt) - - def get_write_formats(self): - """Get the format of the module.""" - return _ffi_api.ModuleGetWriteFormats(self) - - def get_property_mask(self): - """Get the runtime module property mask. The mapping is stated in ModulePropertyMask. - - Returns - ------- - mask : int - Bitmask of runtime module property - """ - return _ffi_api.ModuleGetPropertyMask(self) - - def is_binary_serializable(self): - """Module 'binary serializable', save_to_bytes is supported. - - Returns - ------- - b : Bool - True if the module is binary serializable. - """ - return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 - - def is_runnable(self): - """Module 'runnable', get_function is supported. - - Returns - ------- - b : Bool - True if the module is runnable. - """ - return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 - - def is_compilation_exportable(self): - """Module 'compilation exportable', write_to_file is supported for object or source. - - Returns - ------- - b : Bool - True if the module is compilation exportable. - """ - return (self.get_property_mask() & ModulePropertyMask.COMPILATION_EXPORTABLE) != 0 - - def clear_imports(self): - """Remove all imports of the module.""" - _ffi_api.ModuleClearImports(self) - - def write_to_file(self, file_name, fmt=""): - """Write the current module to file. - - Parameters - ---------- - file_name : str - The name of the file. - fmt : str - The format of the file. - - See Also - -------- - runtime.Module.export_library : export the module to shared library. - """ - _ffi_api.ModuleWriteToFile(self, file_name, fmt) - - -def system_lib(symbol_prefix=""): - """Get system-wide library module singleton. - - System lib is a global module that contains self register functions in startup. - Unlike normal dso modules which need to be loaded explicitly. - It is useful in environments where dynamic loading api like dlopen is banned. - - The system lib is intended to be linked and loaded during the entire life-cyle of the program. - If you want dynamic loading features, use dso modules instead. - - Parameters - ---------- - symbol_prefix: Optional[str] - Optional symbol prefix that can be used for search. When we lookup a symbol - symbol_prefix + name will first be searched, then the name without symbol_prefix. - - Returns - ------- - module : runtime.Module - The system-wide library module. - """ - return _ffi_api.SystemLib(symbol_prefix) - - -def load_module(path): - """Load module from file. - - Parameters - ---------- - path : str - The path to the module file. - - Returns - ------- - module : ffi.Module - The loaded module - """ - return _ffi_api.ModuleLoadFromFile(path) diff --git a/python/tvm/ffi/ndarray.py b/python/tvm/ffi/ndarray.py deleted file mode 100644 index 05856bdae7a2..000000000000 --- a/python/tvm/ffi/ndarray.py +++ /dev/null @@ -1,255 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""NDArray related objects and functions.""" - -from numbers import Integral -from . import core -from .core import Device, NDArray, from_dlpack -from . import registry -from . import _ffi_api - - -@registry.register_object("ffi.Shape") -class Shape(tuple, core.PyNativeObject): - """Shape object that is possibly returned by FFI call.""" - - def __new__(cls, content): - if any(not isinstance(x, Integral) for x in content): - raise ValueError("Shape must be a tuple of integers") - val = tuple.__new__(cls, content) - val.__init_tvm_ffi_object_by_constructor__(_ffi_api.Shape, *content) - return val - - # pylint: disable=no-self-argument - def __from_tvm_ffi_object__(cls, obj): - """Construct from a given tvm object.""" - content = core._shape_obj_get_py_tuple(obj) - val = tuple.__new__(cls, content) - val.__tvm_ffi_object__ = obj - return val - - -def device(dev_type, dev_id=0): - """Construct a TVM FFIdevice with given device type and id. - - Parameters - ---------- - dev_type: int or str - The device type mask or name of the device. - - dev_id : int, optional - The integer device id - - Returns - ------- - dev: tvm.ffi.Device - - Examples - -------- - Device can be used to create reflection of device by - string representation of the device type. - - .. code-block:: python - - assert tvm.ffi.device("cuda:0") == tvm.ffi.cuda(1) - assert tvm.ffi.device("cpu", 0) == tvm.ffi.cpu(0) - """ - if isinstance(dev_type, str): - dev_type = dev_type.split(" ")[0] - return core._CLASS_DEVICE(dev_type, dev_id) - - -def cpu(dev_id=0): - """Construct a CPU device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLCPU, dev_id) - - -def cuda(dev_id=0): - """Construct a CUDA GPU device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLCUDA, dev_id) - - -def rocm(dev_id=0): - """Construct a ROCM device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLROCM, dev_id) - - -def opencl(dev_id=0): - """Construct a OpenCL device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLOpenCL, dev_id) - - -def metal(dev_id=0): - """Construct a metal device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLMetal, dev_id) - - -def vpi(dev_id=0): - """Construct a VPI simulated device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLVPI, dev_id) - - -def vulkan(dev_id=0): - """Construct a Vulkan device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLVulkan, dev_id) - - -def ext_dev(dev_id=0): - """Construct a extension device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - - Note - ---- - This API is reserved for quick testing of new - device by plugin device API as ext_dev. - """ - return device(Device.kDLExtDev, dev_id) - - -def hexagon(dev_id=0): - """Construct a Hexagon device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLHexagon, dev_id) - - -def webgpu(dev_id=0): - """Construct a webgpu device. - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - dev : Device - The created device - """ - return device(Device.kDLWebGPU, dev_id) - - -__all__ = [ - "from_dlpack", - "NDArray", - "device", - "cpu", - "cuda", - "rocm", - "opencl", - "metal", - "vpi", - "vulkan", - "ext_dev", - "hexagon", - "webgpu", -] diff --git a/python/tvm/ffi/registry.py b/python/tvm/ffi/registry.py deleted file mode 100644 index 9302b251733b..000000000000 --- a/python/tvm/ffi/registry.py +++ /dev/null @@ -1,179 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""FFI registry to register function and objects.""" -import sys -from . import core - -# whether we simplify skip unknown objects regtistration -_SKIP_UNKNOWN_OBJECTS = False - - -def register_object(type_key=None): - """register object type. - - Parameters - ---------- - type_key : str or cls - The type key of the node - - Examples - -------- - The following code registers MyObject - using type key "test.MyObject" - - .. code-block:: python - - @tvm.ffi.register_object("test.MyObject") - class MyObject(Object): - pass - """ - object_name = type_key if isinstance(type_key, str) else type_key.__name__ - - def register(cls): - """internal register function""" - type_index = core._object_type_key_to_index(object_name) - if type_index is None: - if _SKIP_UNKNOWN_OBJECTS: - return cls - raise ValueError("Cannot find object type index for %s" % object_name) - core._add_class_attrs_by_reflection(type_index, cls) - core._register_object_by_index(type_index, cls) - return cls - - if isinstance(type_key, str): - return register - - return register(type_key) - - -def register_func(func_name, f=None, override=False): - """Register global function - - Parameters - ---------- - func_name : str or function - The function name - - f : function, optional - The function to be registered. - - override: boolean optional - Whether override existing entry. - - Returns - ------- - fregister : function - Register function if f is not specified. - """ - if callable(func_name): - f = func_name - func_name = f.__name__ - - if not isinstance(func_name, str): - raise ValueError("expect string function name") - - def register(myf): - """internal register function""" - return core._register_global_func(func_name, myf, override) - - if f: - return register(f) - return register - - -def get_global_func(name, allow_missing=False): - """Get a global function by name - - Parameters - ---------- - name : str - The name of the global function - - allow_missing : bool - Whether allow missing function or raise an error. - - Returns - ------- - func : Function - The function to be returned, None if function is missing. - """ - return core._get_global_func(name, allow_missing) - - -def list_global_func_names(): - """Get list of global functions registered. - - Returns - ------- - names : list - List of global functions names. - """ - name_functor = get_global_func("ffi.FunctionListGlobalNamesFunctor")() - num_names = name_functor(-1) - return [name_functor(i) for i in range(num_names)] - - -def remove_global_func(name): - """Remove a global function by name - - Parameters - ---------- - name : str - The name of the global function - """ - get_global_func("ffi.FunctionRemoveGlobal")(name) - - -def _init_api(namespace, target_module_name=None): - """Initialize api for a given module name - - namespace : str - The namespace of the source registry - - target_module_name : str - The target module name if different from namespace - """ - target_module_name = target_module_name if target_module_name else namespace - - if namespace.startswith("tvm."): - prefix = namespace[4:] - else: - prefix = namespace - - target_module = sys.modules[target_module_name] - - for name in list_global_func_names(): - if not name.startswith(prefix): - continue - - fname = name[len(prefix) + 1 :] - if fname.find(".") != -1: - continue - - f = get_global_func(name) - f.__name__ = fname - setattr(target_module, f.__name__, f) - - -__all__ = [ - "register_object", - "register_func", - "get_global_func", - "list_global_func_names", - "remove_global_func", - "_init_api", -] diff --git a/python/tvm/ffi/serialization.py b/python/tvm/ffi/serialization.py deleted file mode 100644 index 25d9bcefb828..000000000000 --- a/python/tvm/ffi/serialization.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Serialization related utilities to enable some object can be pickled""" - -from typing import Optional, Any -from . import _ffi_api - - -def to_json_graph_str(obj: Any, metadata: Optional[dict] = None): - """ - Dump an object to a JSON graph string. - - The JSON graph string is a string representation of of the object - graph includes the reference information of same objects, which can - be used for serialization and debugging. - - Parameters - ---------- - obj : Any - The object to save. - - metadata : Optional[dict], optional - Extra metadata to save into the json graph string. - - Returns - ------- - json_str : str - The JSON graph string. - """ - return _ffi_api.ToJSONGraphString(obj, metadata) - - -def from_json_graph_str(json_str: str): - """ - Load an object from a JSON graph string. - - The JSON graph string is a string representation of of the object - graph that also includes the reference information. - - Parameters - ---------- - json_str : str - The JSON graph string to load. - - Returns - ------- - obj : Any - The loaded object. - """ - return _ffi_api.FromJSONGraphString(json_str) - - -__all__ = ["from_json_graph_str", "to_json_graph_str"] diff --git a/python/tvm/ffi/testing.py b/python/tvm/ffi/testing.py deleted file mode 100644 index 843a10c896a8..000000000000 --- a/python/tvm/ffi/testing.py +++ /dev/null @@ -1,63 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Testing utilities.""" - -from . import _ffi_api -from .core import Object -from .registry import register_object - - -@register_object("testing.TestObjectBase") -class TestObjectBase(Object): - """ - Test object base class. - """ - - -@register_object("testing.TestObjectDerived") -class TestObjectDerived(TestObjectBase): - """ - Test object derived class. - """ - - -def create_object(type_key: str, **kwargs) -> Object: - """ - Make an object by reflection. - - Parameters - ---------- - type_key : str - The type key of the object. - kwargs : dict - The keyword arguments to the object. - - Returns - ------- - obj : object - The created object. - - Note - ---- - This function is only used for testing purposes and should - not be used in other cases. - """ - args = [type_key] - for k, v in kwargs.items(): - args.append(k) - args.append(v) - return _ffi_api.MakeObjectFromPackedArgs(*args) diff --git a/python/tvm/ir/_ffi_analysis_api.py b/python/tvm/ir/_ffi_analysis_api.py index ca38c2309f41..9d7c12332c18 100644 --- a/python/tvm/ir/_ffi_analysis_api.py +++ b/python/tvm/ir/_ffi_analysis_api.py @@ -16,7 +16,7 @@ # under the License. """FFI APIs for tvm.ir.analysis""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("ir.analysis", __name__) +tvm_ffi.init_ffi_api("ir.analysis", __name__) diff --git a/python/tvm/ir/_ffi_api.py b/python/tvm/ir/_ffi_api.py index 6434a3925e98..798e69fca507 100644 --- a/python/tvm/ir/_ffi_api.py +++ b/python/tvm/ir/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.ir""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("ir", __name__) +tvm_ffi.init_ffi_api("ir", __name__) diff --git a/python/tvm/ir/_ffi_instrument_api.py b/python/tvm/ir/_ffi_instrument_api.py index d88faf7fddd0..18aea5cf8a2f 100644 --- a/python/tvm/ir/_ffi_instrument_api.py +++ b/python/tvm/ir/_ffi_instrument_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.instrument""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("instrument", __name__) +tvm_ffi.init_ffi_api("instrument", __name__) diff --git a/python/tvm/ir/_ffi_transform_api.py b/python/tvm/ir/_ffi_transform_api.py index 1a27fc58776c..8a2f517e2145 100644 --- a/python/tvm/ir/_ffi_transform_api.py +++ b/python/tvm/ir/_ffi_transform_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.transform""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("transform", __name__) +tvm_ffi.init_ffi_api("transform", __name__) diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index cab982f4e783..fb408cdb8c70 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. """TVM Attribute module, which is mainly used for defining attributes of operators.""" -import tvm.ffi +import tvm_ffi from tvm.runtime import Object import tvm.runtime._ffi_node_api from . import _ffi_api -@tvm.ffi.register_object("ir.Attrs") +@tvm_ffi.register_object("ir.Attrs") class Attrs(Object): """Attribute node, which is mainly use for defining attributes of operators. @@ -73,7 +73,7 @@ def __getitem__(self, item): return getattr(self, item) -@tvm.ffi.register_object("ir.DictAttrs") +@tvm_ffi.register_object("ir.DictAttrs") class DictAttrs(Attrs): """Dictionary attributes.""" diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 088ca6b96506..651ab392039c 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -15,9 +15,8 @@ # specific language governing permissions and limitations # under the License. """Common base structures.""" -import tvm.ffi import tvm.error -from tvm.ffi import get_global_func, register_object +from tvm_ffi import get_global_func, register_object from tvm.runtime import Object, _ffi_node_api from . import _ffi_api, json_compact @@ -196,7 +195,7 @@ def structural_equal(lhs, rhs, map_free_vars=False): return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) # type: ignore # pylint: disable=no-member -def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_ndarray_content=False): +def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_tensor_content=False): """Like structural_equal(), but returns the AccessPath pair of the first detected mismatch. Parameters @@ -211,7 +210,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_ndarray_co Whether free variables (i.e. variables without a definition site) should be mapped as equal to each other. - skip_ndarray_content : bool + skip_tensor_content : bool Whether to skip the content of ndarray. Returns @@ -222,7 +221,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_ndarray_co """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - return _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars, skip_ndarray_content) # type: ignore # pylint: disable=no-member + return _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars, skip_tensor_content) # type: ignore # pylint: disable=no-member def assert_structural_equal(lhs, rhs, map_free_vars=False): diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py index 4bc6fcae21ca..eecc78cba6d1 100644 --- a/python/tvm/ir/container.py +++ b/python/tvm/ir/container.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Additional container data structures used across IR variants.""" -from tvm.ffi import Array, Map +from tvm_ffi import Array, Map __all__ = ["Array", "Map"] diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index ac4adc3306e6..4a521dfa587e 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -22,9 +22,9 @@ and the DiagnosticRenderer. """ import enum -import tvm.ffi +import tvm_ffi from . import _ffi_api -from ... import get_global_func, register_func, Object +from ... import get_global_func, register_global_func, Object def get_renderer(): @@ -38,7 +38,7 @@ def get_renderer(): return _ffi_api.GetRenderer() -@tvm.register_func("diagnostics.override_renderer") +@tvm_ffi.register_global_func("diagnostics.override_renderer") def override_renderer(render_func): """ Sets a custom renderer for diagnostics. @@ -54,7 +54,7 @@ def override_renderer(render_func): def _render_factory(): return DiagnosticRenderer(render_func) - register_func("diagnostics.OverrideRenderer", _render_factory, override=True) + register_global_func("diagnostics.OverrideRenderer", _render_factory, override=True) else: _ffi_api.ClearRenderer() @@ -69,7 +69,7 @@ class DiagnosticLevel(enum.IntEnum): HELP = 50 -@tvm.ffi.register_object("Diagnostic") +@tvm_ffi.register_object("Diagnostic") class Diagnostic(Object): """A single diagnostic object from TVM.""" @@ -77,7 +77,7 @@ def __init__(self, level, span, message): self.__init_handle_by_constructor__(_ffi_api.Diagnostic, level, span, message) -@tvm.ffi.register_object("DiagnosticRenderer") +@tvm_ffi.register_object("DiagnosticRenderer") class DiagnosticRenderer(Object): """ A diagnostic renderer, which given a diagnostic context produces a "rendered" @@ -100,7 +100,7 @@ def render(self, ctx): # Register the diagnostic context. -@tvm.ffi.register_object("DiagnosticContext") +@tvm_ffi.register_object("DiagnosticContext") class DiagnosticContext(Object): """ A diagnostic context which records active errors diff --git a/python/tvm/ir/diagnostics/_ffi_api.py b/python/tvm/ir/diagnostics/_ffi_api.py index fb157c977510..65fb2cc896f3 100644 --- a/python/tvm/ir/diagnostics/_ffi_api.py +++ b/python/tvm/ir/diagnostics/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI for TVM diagnostics.""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("diagnostics", __name__) +tvm_ffi.init_ffi_api("diagnostics", __name__) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 008924c227b5..19abb6bd1eae 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -18,21 +18,22 @@ from numbers import Number from typing import Optional -import tvm.ffi +import tvm +import tvm_ffi from ..runtime import Object, Scriptable from . import _ffi_api from .base import Node, Span -@tvm.ffi.register_object("ir.BaseExpr") +@tvm_ffi.register_object("ir.BaseExpr") class BaseExpr(Node): """Base class of all the expressions.""" span: Optional[Span] -@tvm.ffi.register_object("ir.PrimExpr") +@tvm_ffi.register_object("ir.PrimExpr") class PrimExpr(BaseExpr): """Base class of all primitive expressions. @@ -43,7 +44,7 @@ class PrimExpr(BaseExpr): dtype: str -@tvm.ffi.register_object("ir.RelaxExpr") +@tvm_ffi.register_object("ir.RelaxExpr") class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" @@ -59,7 +60,7 @@ def struct_info(self) -> Optional["tvm.relax.StructInfo"]: return _ffi_api.ExprStructInfo(self) -@tvm.ffi.register_object("ir.GlobalVar") +@tvm_ffi.register_object("ir.GlobalVar") class GlobalVar(RelaxExpr): """A global variable in the IR. @@ -105,7 +106,7 @@ def __call__(self, *args: RelaxExpr) -> BaseExpr: raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}") -@tvm.ffi.register_object("ir.Range") +@tvm_ffi.register_object("ir.Range") class Range(Node, Scriptable): """Represent a range in TVM. diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index f6fc42ccbc07..75718503aae1 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -19,6 +19,8 @@ from typing import Union, Dict from enum import IntEnum +import tvm_ffi + import tvm.runtime from tvm.runtime.object import Object from .expr import RelaxExpr @@ -34,7 +36,7 @@ class CallingConv(IntEnum): DEVICE_KERNEL_LAUNCH = 2 -@tvm.ffi.register_object("ir.BaseFunc") +@tvm_ffi.register_object("ir.BaseFunc") class BaseFunc(RelaxExpr): """Base class of all functions.""" diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py index d4b4fdca1654..185e10b88cce 100644 --- a/python/tvm/ir/global_info.py +++ b/python/tvm/ir/global_info.py @@ -16,11 +16,12 @@ # under the License. """Global Info.""" import tvm +import tvm_ffi from tvm.runtime.object import Object from . import _ffi_api -@tvm.ffi.register_object("ir.GlobalInfo") +@tvm_ffi.register_object("ir.GlobalInfo") class GlobalInfo(Object): """Base node for all global info that can appear in the IR""" @@ -36,7 +37,7 @@ def same_as(self, other): return super().__eq__(other) -@tvm.ffi.register_object("ir.DummyGlobalInfo") +@tvm_ffi.register_object("ir.DummyGlobalInfo") class DummyGlobalInfo(GlobalInfo): def __init__(self) -> None: self.__init_handle_by_constructor__( @@ -44,7 +45,7 @@ def __init__(self) -> None: ) -@tvm.ffi.register_object("ir.VDevice") +@tvm_ffi.register_object("ir.VDevice") class VDevice(GlobalInfo): def __init__( self, diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 1e1505858f50..0f1bcf3adfda 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -16,16 +16,21 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass instrumentation across IR variants.""" -import inspect import functools +import inspect +import re +import shutil +from pathlib import Path +from typing import Union + +import tvm_ffi -import tvm.ffi import tvm.runtime from . import _ffi_instrument_api -@tvm.ffi.register_object("instrument.PassInstrument") +@tvm_ffi.register_object("instrument.PassInstrument") class PassInstrument(tvm.runtime.Object): """A pass instrument implementation. @@ -225,7 +230,7 @@ def create_pass_instrument(pi_cls): return create_pass_instrument -@tvm.ffi.register_object("instrument.PassInstrument") +@tvm_ffi.register_object("instrument.PassInstrument") class PassTimingInstrument(tvm.runtime.Object): """A wrapper to create a passes time instrument that implemented in C++""" @@ -288,3 +293,49 @@ class PrintBeforeAll: def run_before_pass(self, mod, info): print(f"Before Running Pass: {info}") print(mod) + + +@pass_instrument +class DumpIR: + """Dump the IR after the pass runs.""" + + def __init__(self, dump_dir: Union[Path, str], refresh: bool = False): + if isinstance(dump_dir, Path): + self.dump_dir = dump_dir + else: + self.dump_dir = Path(dump_dir) + self.counter = 0 + if refresh and self.dump_dir.is_dir(): + self._safe_remove_dump_dir() + + def _safe_remove_dump_dir(self): + """Remove dump directory only if it contains only dumped IR files.""" + # Pattern for dumped files: {counter:03d}_{pass_name}.py + dump_pattern = re.compile(r"^\d{3}_.*\.py$") + + # Check all files in the directory + for item in self.dump_dir.iterdir(): + # If there's a subdirectory or a file that doesn't match the pattern, abort + if item.is_dir() or not dump_pattern.match(item.name): + print( + f"WARNING: Skipping removal of {self.dump_dir} as it contains " + f"non-dumped files or directories. Please clean it manually." + ) + return + + # Safe to remove - only contains dumped files + try: + shutil.rmtree(self.dump_dir) + except OSError as e: + print(f"WARNING: Failed to remove directory {self.dump_dir}: {e}") + + def run_after_pass(self, mod, info): + self.dump_dir.mkdir(parents=True, exist_ok=True) + try: + sanitized_pass_name = re.sub(r'[<>:"/\\|?*]', "_", info.name) + with open(self.dump_dir / f"{self.counter:03d}_{sanitized_pass_name}.py", "w") as f: + f.write(mod.script()) + except Exception: # pylint: disable=broad-exception-caught + print(f"WARNING: Failed to dump IR for pass {info.name}") + finally: + self.counter += 1 diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3b99db85986e..21c86c05ec4c 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -20,7 +20,8 @@ from typing import Dict, Union -import tvm.ffi +import tvm +import tvm_ffi from tvm.runtime import Scriptable from tvm.runtime.object import Object @@ -30,7 +31,7 @@ from .base import Node -@tvm.ffi.register_object("ir.IRModule") +@tvm_ffi.register_object("ir.IRModule") class IRModule(Node, Scriptable): """IRModule that holds functions and type definitions. @@ -66,6 +67,7 @@ def __init__(self, functions=None, attrs=None, global_infos=None): attrs, global_infos, ) + self.pyfuncs = {} def clone(self) -> "IRModule": return _ffi_api.Module_Clone(self) diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index e5111ccc8220..5b62d3fe8df7 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -16,13 +16,13 @@ # under the License. # pylint: disable=invalid-name """Primitive operators in the TVM IR.""" -import tvm.ffi +import tvm_ffi from . import _ffi_api from .expr import RelaxExpr -@tvm.ffi.register_object("ir.Op") +@tvm_ffi.register_object("ir.Op") class Op(RelaxExpr): """Primitive operator in the IR.""" diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index 2038df4b3104..bc38089b2254 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -16,11 +16,12 @@ # under the License. """Suppliers that are used to guarantee uniqueness of names and GlobalVars.""" import tvm +import tvm_ffi from tvm import Object, IRModule from . import _ffi_api -@tvm.ffi.register_object("ir.NameSupply") +@tvm_ffi.register_object("ir.NameSupply") class NameSupply(Object): """NameSupply that can be used to generate unique names. @@ -77,7 +78,7 @@ def contains_name(self, name, add_prefix=True): return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) -@tvm.ffi.register_object("ir.GlobalVarSupply") +@tvm_ffi.register_object("ir.GlobalVarSupply") class GlobalVarSupply(Object): """GlobalVarSupply that holds a mapping between names and GlobalVars. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index b8f4c36c30c7..fd9a2ac3b212 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -19,13 +19,13 @@ import inspect import functools -import tvm.ffi +import tvm_ffi import tvm.runtime from . import _ffi_transform_api -@tvm.ffi.register_object("transform.PassInfo") +@tvm_ffi.register_object("transform.PassInfo") class PassInfo(tvm.runtime.Object): """The class contains the meta data required by a pass. It is the container of information needed by running an optimization or analysis. @@ -50,7 +50,7 @@ def __init__(self, opt_level, name, required=None, traceable=False): ) -@tvm.ffi.register_object("transform.PassContext") +@tvm_ffi.register_object("transform.PassContext") class PassContext(tvm.runtime.Object): """The basis where a TVM optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used @@ -138,7 +138,7 @@ def list_configs(): return _ffi_transform_api.ListConfigs() -@tvm.ffi.register_object("transform.Pass") +@tvm_ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): """The base class of all passes. All methods here are just simple wrappers that are implemented in the backend. They are defined for users to @@ -167,7 +167,7 @@ def __call__(self, mod): return _ffi_transform_api.RunPass(self, mod) -@tvm.ffi.register_object("transform.ModulePass") +@tvm_ffi.register_object("transform.ModulePass") class ModulePass(Pass): """A pass that works on tvm.IRModule. Users don't need to interact with this class directly. Instead, a module pass should be created through @@ -178,7 +178,7 @@ class ModulePass(Pass): """ -@tvm.ffi.register_object("transform.Sequential") +@tvm_ffi.register_object("transform.Sequential") class Sequential(Pass): """A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class. @@ -348,7 +348,7 @@ def create_module_pass(pass_arg): return create_module_pass -def PrintIR(header="", show_meta_data=False): +def PrintIR(header=""): """A special trace pass that prints the header and IR. Parameters @@ -356,14 +356,11 @@ def PrintIR(header="", show_meta_data=False): header : str The header to be displayed along with the dump. - show_meta_data : bool - A boolean flag to indicate if meta data should be printed. - Returns -------- The pass """ - return _ffi_transform_api.PrintIR(header, show_meta_data) + return _ffi_transform_api.PrintIR(header) def ApplyPassToFunction( diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 0f287be96146..68bed8fb69f0 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -16,14 +16,14 @@ # under the License. """Unified type system in the project.""" import tvm -import tvm.ffi +import tvm_ffi from tvm.runtime import Scriptable from . import _ffi_api from .base import Node -@tvm.ffi.register_object("ir.Type") +@tvm_ffi.register_object("ir.Type") class Type(Node, Scriptable): """The base class of all types.""" @@ -39,7 +39,7 @@ def same_as(self, other): return super().__eq__(other) -@tvm.ffi.register_object("ir.PrimType") +@tvm_ffi.register_object("ir.PrimType") class PrimType(Type): """Primitive data type in the low level IR @@ -53,7 +53,7 @@ def __init__(self, dtype): self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype) -@tvm.ffi.register_object("ir.PointerType") +@tvm_ffi.register_object("ir.PointerType") class PointerType(Type): """PointerType used in the low-level TIR. @@ -70,7 +70,7 @@ def __init__(self, element_type, storage_scope=""): self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type, storage_scope) -@tvm.ffi.register_object("ir.TupleType") +@tvm_ffi.register_object("ir.TupleType") class TupleType(Type): """The type of tuple values. @@ -84,7 +84,7 @@ def __init__(self, fields): self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) -@tvm.ffi.register_object("ir.FuncType") +@tvm_ffi.register_object("ir.FuncType") class FuncType(Type): """Function type. @@ -110,7 +110,7 @@ def __init__(self, arg_types, ret_type): ) -@tvm.ffi.register_object("ir.TensorMapType") +@tvm_ffi.register_object("ir.TensorMapType") class TensorMapType(Type): """TensorMapType used in the low-level TIR. diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py index d0175fda5706..70950958024d 100644 --- a/python/tvm/ir/type_relation.py +++ b/python/tvm/ir/type_relation.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """Type relation and function for type checking.""" -import tvm.ffi +import tvm_ffi from .type import Type, TypeConstraint from . import _ffi_api -@tvm.ffi.register_object("TypeCall") +@tvm_ffi.register_object("TypeCall") class TypeCall(Type): """Type function application. @@ -43,7 +43,7 @@ def __init__(self, func, args): self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) -@tvm.ffi.register_object("TypeRelation") +@tvm_ffi.register_object("TypeRelation") class TypeRelation(TypeConstraint): """User defined type relation, it is an input-output relation on types. diff --git a/python/tvm/libinfo.py b/python/tvm/libinfo.py index 2b6d11e0b21a..5062e80997e2 100644 --- a/python/tvm/libinfo.py +++ b/python/tvm/libinfo.py @@ -53,7 +53,7 @@ def get_dll_directories(): dll_path = [] if os.environ.get("TVM_LIBRARY_PATH", None): - dll_path.append(os.environ["TVM_LIBRARY_PATH"]) + dll_path.extend(os.environ["TVM_LIBRARY_PATH"].split(":")) if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":")) @@ -66,6 +66,7 @@ def get_dll_directories(): # Pip lib directory dll_path.append(ffi_dir) + dll_path.append(os.path.join(ffi_dir, "lib")) # Default cmake build directory dll_path.append(os.path.join(source_dir, "build")) dll_path.append(os.path.join(source_dir, "build", "Release")) @@ -130,10 +131,7 @@ def find_lib_path(name=None, search_path=None, optional=False): elif sys.platform.startswith("darwin"): lib_dll_names = ["libtvm.dylib"] runtime_dll_names = ["libtvm_runtime.dylib"] - ext_lib_dll_names = [ - "3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/libfpA_intB_gemm.dylib", - "3rdparty/libflash_attn/src/libflash_attn.dylib", - ] + ext_lib_dll_names = [] else: lib_dll_names = ["libtvm.so"] runtime_dll_names = ["libtvm_runtime.so"] @@ -195,7 +193,9 @@ def find_include_path(name=None, search_path=None, optional=False): include_path : list(string) List of all found paths to header files. """ - if os.environ.get("TVM_HOME", None): + if os.environ.get("TVM_SOURCE_DIR", None): + source_dir = os.environ["TVM_SOURCE_DIR"] + elif os.environ.get("TVM_HOME", None): source_dir = os.environ["TVM_HOME"] else: ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -204,7 +204,7 @@ def find_include_path(name=None, search_path=None, optional=False): if os.path.isdir(os.path.join(source_dir, "include")): break else: - raise AssertionError("Cannot find the source directory given ffi_dir: {ffi_dir}") + raise AssertionError(f"Cannot find the source directory given ffi_dir: {ffi_dir}") third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] @@ -232,9 +232,19 @@ def find_include_path(name=None, search_path=None, optional=False): dmlc_include_path = [] else: tvm_include_path = [os.path.join(p, "include") for p in header_path] - tvm_ffi_include_path = [os.path.join(p, "ffi/include") for p in header_path] - dlpack_include_path = [os.path.join(p, "dlpack/include") for p in header_path] - dmlc_include_path = [os.path.join(p, "dmlc-core/include") for p in header_path] + + # Augment with system-installed tvm_ffi includes if available + from tvm_ffi import libinfo as _tvm_ffi_libinfo # type: ignore + tvm_ffi_include_path = [] + tvm_ffi_include_path.append(_tvm_ffi_libinfo.find_include_path()) + + dlpack_include_path = [ + os.path.join(p, "3rdparty", "tvm-ffi", "3rdparty", "dlpack", "include") + for p in header_path + ] + dmlc_include_path = [ + os.path.join(p, "3rdparty", "dmlc-core", "include") for p in header_path + ] # try to find include path include_found = [p for p in tvm_include_path if os.path.exists(p) and os.path.isdir(p)] @@ -259,4 +269,4 @@ def find_include_path(name=None, search_path=None, optional=False): # We use the version of the incoming release for code # that is under development. # The following line is set by tvm/python/update_version.py -__version__ = "0.22.dev0" +__version__ = "0.23.dev0" diff --git a/python/tvm/meta_schedule/_ffi_api.py b/python/tvm/meta_schedule/_ffi_api.py index 89b8df086001..1a06aef5a482 100644 --- a/python/tvm/meta_schedule/_ffi_api.py +++ b/python/tvm/meta_schedule/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.meta_schedule""" -from ..ffi import _init_api +import tvm_ffi -_init_api("meta_schedule", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("meta_schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py index 69c8d6d4c5dc..3f8d721ed1f0 100644 --- a/python/tvm/meta_schedule/arg_info.py +++ b/python/tvm/meta_schedule/arg_info.py @@ -17,7 +17,7 @@ """The argument information""" from typing import Any, List, Union -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.ir import IRModule from tvm.runtime import DataType, Object, ShapeTuple from tvm.tir import PrimFunc diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index f323e15bd532..39493781404a 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -21,9 +21,9 @@ from typing_extensions import Literal # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.ir import IRModule -from tvm.runtime import NDArray, Object +from tvm.runtime import Tensor, Object from tvm.target import Target from .. import _ffi_api @@ -39,19 +39,19 @@ class BuilderInput(Object): The IRModule to be built. target : Target The target to be built for. - params: Optional[Dict[str, NDArray]] + params: Optional[Dict[str, Tensor]] The parameters for Relax build module """ mod: IRModule target: Target - params: Optional[Dict[str, NDArray]] + params: Optional[Dict[str, Tensor]] def __init__( self, mod: IRModule, target: Target, - params: Optional[Dict[str, NDArray]] = None, + params: Optional[Dict[str, Tensor]] = None, ) -> None: """Constructor. @@ -61,7 +61,7 @@ def __init__( The IRModule to be built. target : Target The target to be built for. - params: Optional[Dict[str, NDArray]] + params: Optional[Dict[str, Tensor]] The parameters for Relax build module """ self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index 0f68ef7afb1f..c5d8b21d89ba 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -19,9 +19,9 @@ import tempfile from typing import Callable, Dict, List, Optional, Union -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.ir import IRModule -from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict +from tvm.runtime import Module, Tensor, load_param_dict, save_param_dict from tvm.target import Target from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind @@ -33,18 +33,18 @@ T_BUILD = Callable[ # pylint: disable=invalid-name - [IRModule, Target, Optional[Dict[str, NDArray]]], Module + [IRModule, Target, Optional[Dict[str, Tensor]]], Module ] T_EXPORT = Callable[[Module], str] # pylint: disable=invalid-name -def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]: +def _serialize_params(params: Optional[Dict[str, Tensor]]) -> Optional[bytearray]: if params is None: return None return save_param_dict(params) -def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArray]]: +def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, Tensor]]: if params is None: return None return load_param_dict(params) @@ -81,7 +81,7 @@ class LocalBuilder(PyBuilder): def default_build( mod: IRModule, target: Target, - params: Optional[Dict[str, NDArray]] + params: Optional[Dict[str, Tensor]] ) -> Module: ... @@ -234,8 +234,8 @@ def _worker_func( return artifact_path -@register_func("meta_schedule.builder.default_build") -def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]) -> Module: +@register_global_func("meta_schedule.builder.default_build") +def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, Tensor]]) -> Module: """Default build function. Parameters @@ -244,7 +244,7 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA The IRModule to be built. target : Target The target to be built. - _params : Optional[Dict[str, NDArray]] + _params : Optional[Dict[str, Tensor]] The parameters to be used for the build. Must be None. Returns @@ -254,14 +254,15 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA """ # pylint: disable=import-outside-toplevel from tvm.driver import build as tvm_build + import tvm.tir.tensor_intrin # pylint: disable=unused-import from tvm.tir.transform import RemoveWeightLayoutRewriteBlock # pylint: enable=import-outside-toplevel - mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod) + mod = RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite=True)(mod) return tvm_build(mod, target=target) -@register_func("meta_schedule.builder.default_export") +@register_global_func("meta_schedule.builder.default_export") def default_export(mod: Module) -> str: """Default export function. @@ -282,7 +283,7 @@ def default_export(mod: Module) -> str: return artifact_path -@register_func("meta_schedule.builder.get_local_builder") +@register_global_func("meta_schedule.builder.get_local_builder") def get_local_builder() -> LocalBuilder: """Get the local builder. diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index 9abd50b94c75..f51d2f2ac89b 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -24,7 +24,7 @@ # isort: on import numpy as np # type: ignore -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py b/python/tvm/meta_schedule/cost_model/mlp_model.py index 9191eee6a68f..ef846a6c7c5f 100644 --- a/python/tvm/meta_schedule/cost_model/mlp_model.py +++ b/python/tvm/meta_schedule/cost_model/mlp_model.py @@ -32,7 +32,7 @@ import tvm from ...contrib.tar import tar, untar -from ...runtime import NDArray +from ...runtime import Tensor from ...target import Target from ..cost_model import PyCostModel from ..database import JSONDatabase @@ -441,7 +441,7 @@ def extract_features( """ extractor = extractor or PerStoreFeature(extract_workload=True) - def _feature(feature: NDArray) -> np.ndarray: + def _feature(feature: Tensor) -> np.ndarray: return feature.numpy().astype("float32") def _mean_cost(res: RunnerResult) -> float: diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 5806454cdddb..a14dceef379f 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -26,7 +26,7 @@ import numpy as np # type: ignore from ...contrib.tar import tar, untar -from ...runtime import NDArray +from ...runtime import Tensor from ..cost_model import PyCostModel from ..feature_extractor import FeatureExtractor from ..logging import get_logger @@ -484,7 +484,7 @@ def update( group = self.data.get(new_group_hash, None) # Step 2. Extract features - def _feature(x: NDArray) -> np.ndarray: + def _feature(x: Tensor) -> np.ndarray: return x.numpy().astype("float32") def _mean_cost(x: RunnerResult) -> float: diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 7abaead68018..08bcbd33c7ad 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -22,7 +22,7 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.ir.module import IRModule from tvm.runtime import Object from tvm.target import Target diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index f3b188493767..7c6f7459cacc 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -18,7 +18,7 @@ import os.path as osp from typing import Optional -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .database import Database @@ -38,10 +38,10 @@ class JSONDatabase(Database): A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash - - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + - "ignore-tensor": Same as "structural", but ignore tensor raw data during equality testing and hashing. - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - given module. The "ignore-ndarray" varint is used for the extracted + given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. """ diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 53755333839c..1d6d4121231c 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A database that stores TuningRecords in memory""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .database import Database @@ -31,10 +31,10 @@ class MemoryDatabase(Database): A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash - - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + - "ignore-tensor": Same as "structural", but ignore tensor raw data during equality testing and hashing. - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - given module. The "ignore-ndarray" varint is used for the extracted + given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. """ diff --git a/python/tvm/meta_schedule/database/ordered_union_database.py b/python/tvm/meta_schedule/database/ordered_union_database.py index a451d8ee2fd1..717d2f3001c9 100644 --- a/python/tvm/meta_schedule/database/ordered_union_database.py +++ b/python/tvm/meta_schedule/database/ordered_union_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .database import Database diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py index 3b7dfa79f6bf..74b2a6eb60da 100644 --- a/python/tvm/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -17,7 +17,7 @@ """A database for injecting handcrafted schedule functions.""" from typing import Callable -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.tir import Schedule from .. import _ffi_api @@ -37,10 +37,10 @@ class ScheduleFnDatabase(Database): A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash - - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + - "ignore-tensor": Same as "structural", but ignore tensor raw data during equality testing and hashing. - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - given module. The "ignore-ndarray" varint is used for the extracted + given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. """ diff --git a/python/tvm/meta_schedule/database/union_database.py b/python/tvm/meta_schedule/database/union_database.py index 7f896c1da61f..3a1afbe32adf 100644 --- a/python/tvm/meta_schedule/database/union_database.py +++ b/python/tvm/meta_schedule/database/union_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .database import Database diff --git a/python/tvm/meta_schedule/extracted_task.py b/python/tvm/meta_schedule/extracted_task.py index 0cdede120b6f..df66e774e595 100644 --- a/python/tvm/meta_schedule/extracted_task.py +++ b/python/tvm/meta_schedule/extracted_task.py @@ -17,7 +17,7 @@ """Extracted tasks from high-level IR.""" from typing import List -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object from tvm.target import Target diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index bd37214db997..b50a22142943 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -22,9 +22,9 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object -from tvm.runtime.ndarray import NDArray +from tvm.runtime._tensor import Tensor from .. import _ffi_api from ..search_strategy import MeasureCandidate @@ -40,7 +40,7 @@ class FeatureExtractor(Object): def extract_from( self, context: TuneContext, candidates: List[MeasureCandidate] - ) -> List[NDArray]: + ) -> List[Tensor]: """Extract features from the given measure candidate. Parameters @@ -52,7 +52,7 @@ def extract_from( Returns ------- - features : List[NDArray] + features : List[Tensor] The feature tvm ndarray extracted. """ result = _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member @@ -108,7 +108,7 @@ class PyFeatureExtractor: def extract_from( self, context: TuneContext, candidates: List[MeasureCandidate] - ) -> List[NDArray]: + ) -> List[Tensor]: """Extract features from the given measure candidate. Parameters @@ -120,7 +120,7 @@ def extract_from( Returns ------- - features : List[NDArray] + features : List[Tensor] The feature tvm ndarray extracted. """ raise NotImplementedError diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py index b1098bd4ea7c..673a722955d2 100644 --- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -18,7 +18,7 @@ """We extract one feature vector per BufferStoreNode statement in a TIR Stmt, so we call this feature as "per-store" feature. """ -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .feature_extractor import FeatureExtractor diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py index 18b84c364ad4..908dde400ec8 100644 --- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -18,7 +18,7 @@ from typing import List, Tuple, Union import numpy as np # type: ignore -from tvm.runtime.ndarray import NDArray, array +import tvm.runtime from ..feature_extractor import PyFeatureExtractor from ..search_strategy import MeasureCandidate @@ -54,11 +54,11 @@ def __init__(self, *, feature_size: int = 30, max_block_num: int = 5, seed=0): def extract_from( self, context: TuneContext, candidates: List[MeasureCandidate] - ) -> List[NDArray]: + ) -> List[tvm.runtime.Tensor]: np.random.set_state(self.random_state) result = [ np.random.rand(np.random.randint(1, self.max_block_num + 1), self.feature_size) for candidate in candidates ] self.random_state = np.random.get_state() - return [array(x) for x in result] + return [tvm.runtime.tensor(x) for x in result] diff --git a/python/tvm/meta_schedule/measure_callback/add_to_database.py b/python/tvm/meta_schedule/measure_callback/add_to_database.py index f40dffeaad44..e0a6f5a273fc 100644 --- a/python/tvm/meta_schedule/measure_callback/add_to_database.py +++ b/python/tvm/meta_schedule/measure_callback/add_to_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A callback that adds the measurement results into the database""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .measure_callback import MeasureCallback diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py index 17a7f45460e9..885f70e88de8 100644 --- a/python/tvm/meta_schedule/measure_callback/measure_callback.py +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -23,7 +23,7 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py index 82c18f8f9065..23808b7e99d7 100644 --- a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py +++ b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A callback that removes the build artifacts from the disk""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .measure_callback import MeasureCallback diff --git a/python/tvm/meta_schedule/measure_callback/update_cost_model.py b/python/tvm/meta_schedule/measure_callback/update_cost_model.py index 5b8b0306d421..7cf60c095b97 100644 --- a/python/tvm/meta_schedule/measure_callback/update_cost_model.py +++ b/python/tvm/meta_schedule/measure_callback/update_cost_model.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A measure callback that updates the cost model""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .measure_callback import MeasureCallback diff --git a/python/tvm/meta_schedule/mutator/mutate_compute_location.py b/python/tvm/meta_schedule/mutator/mutate_compute_location.py index 5ebe04a6b13a..620e0062cbff 100644 --- a/python/tvm/meta_schedule/mutator/mutate_compute_location.py +++ b/python/tvm/meta_schedule/mutator/mutate_compute_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A mutator that mutates the compute-at location decision of SampleComputeLocation""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py index c7736fdcf71d..fc077cd0d4aa 100644 --- a/python/tvm/meta_schedule/mutator/mutate_parallel.py +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the parallel extent""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py index 2225ca76c77d..4c9fa44c50a0 100644 --- a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py +++ b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the thread binding extent""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_tile_size.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py index 90cccdc3f5db..f40894f5ba0f 100644 --- a/python/tvm/meta_schedule/mutator/mutate_tile_size.py +++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the decision of instruction Sample-Perfect-Tile""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py index 9575c3fc22d9..97999c2888f8 100644 --- a/python/tvm/meta_schedule/mutator/mutate_unroll.py +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates auto unroll step""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 6991c72bec41..211e2c2b5015 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -22,7 +22,7 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Trace diff --git a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py index 5c50b2064426..5c18475ea0ca 100644 --- a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py +++ b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that checks if the IRModule has any strided memory copies""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py index 34c13aded935..da604e42cc81 100644 --- a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that checks if the IRModule has any loop with non-constant extent""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 33daabc3951c..8e89ad2fe138 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -22,7 +22,7 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Schedule diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py index 20c354ce601d..d20c22d0f6d8 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -17,7 +17,7 @@ """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized cooperative fetching in loop bindings.""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_layout.py b/python/tvm/meta_schedule/postproc/rewrite_layout.py index 13556f1909d2..73b6dde9f76a 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_layout.py +++ b/python/tvm/meta_schedule/postproc/rewrite_layout.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that rewrites the layout of input tensor""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py index 0be7cdbe118f..30235517f9c6 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -17,7 +17,7 @@ """A postprocessor that applies parallelization, vectorization and auto unrolling according to the annotation of each block""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py index 30c8cf9b0699..5bbe2b88381e 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that rewrites reduction block by moving the init block out.""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py index e04ddcbdf223..8f0edb869586 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py +++ b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that tensorize related components.""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py index ca4c9cdcd624..b274c2f55c11 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that adds thread binding to unbound blocks""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py index 1a74eadaa906..48fbe8f4b14c 100644 --- a/python/tvm/meta_schedule/postproc/verify_gpu_code.py +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that verifies if the GPU code is correct""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py index 51a38624d28e..96ece2270bbc 100644 --- a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py +++ b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that verifies the VTCM usage of a given schedule.""" -from tvm.ffi.registry import register_object +from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/profiler.py b/python/tvm/meta_schedule/profiler.py index 65c1079d65b0..1a41f589de4c 100644 --- a/python/tvm/meta_schedule/profiler.py +++ b/python/tvm/meta_schedule/profiler.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from typing import Dict, Optional -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from . import _ffi_api diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index 613405c8ad3b..dc78d2400a74 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -23,10 +23,10 @@ # isort: on -from tvm.ffi import get_global_func, register_func +from tvm_ffi import get_global_func, register_global_func from tvm.ir import IRModule from tvm.ir.transform import PassContext -from tvm.runtime import NDArray +from tvm.runtime import Tensor from tvm.target import Target from tvm.tir.expr import IntImm @@ -56,7 +56,7 @@ def extract_tasks( mod: Union[IRModule, "relax.Function"], target: Target, - params: Optional[Dict[str, NDArray]] = None, + params: Optional[Dict[str, Tensor]] = None, module_equality: str = "structural", ) -> List[ExtractedTask]: """Extract tuning tasks from a relax program. @@ -67,16 +67,16 @@ def extract_tasks( The module or function to tune target : tvm.target.Target The compilation target - params : Optional[Dict[str, tvm.runtime.NDArray]] + params : Optional[Dict[str, tvm.runtime.Tensor]] The associated parameters of the program module_equality : Optional[str] A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash - - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + - "ignore-tensor": Same as "structural", but ignore tensor raw data during equality testing and hashing. - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - given module. The "ignore-ndarray" varint is used for the extracted + given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. @@ -159,7 +159,7 @@ def extracted_tasks_to_tune_contexts( def tune_relax( mod: Union[IRModule, "relax.Function"], - params: Dict[str, NDArray], + params: Dict[str, Tensor], target: Union[str, Target], work_dir: str, max_trials_global: int, @@ -184,7 +184,7 @@ def tune_relax( ---------- mod : Union[IRModule, relax.Function] The module or function to tune - params : Optional[Dict[str, tvm.runtime.NDArray]] + params : Optional[Dict[str, tvm.runtime.Tensor]] The associated parameters of the program target : Union[Target, str] The compilation target @@ -221,10 +221,10 @@ def tune_relax( A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash - - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + - "ignore-tensor": Same as "structural", but ignore tensor raw data during equality testing and hashing. - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - given module. The "ignore-ndarray" variant is used for the extracted + given module. The "ignore-tensor" variant is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. @@ -269,10 +269,10 @@ def tune_relax( ) -@register_func("tvm.meta_schedule.tune_relax") +@register_global_func("tvm.meta_schedule.tune_relax") def _tune_relax( mod: Union[IRModule, "relax.Function"], - params: Dict[str, NDArray], + params: Dict[str, Tensor], target: Union[str, Target], work_dir: str, max_trials_global: int, @@ -297,7 +297,7 @@ def _tune_relax( ---------- mod : Union[IRModule, relax.Function] The module or function to tune - params : Optional[Dict[str, tvm.runtime.NDArray]] + params : Optional[Dict[str, tvm.runtime.Tensor]] The associated parameters of the program target : Union[Target, str] The compilation target @@ -334,10 +334,10 @@ def _tune_relax( A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash - - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + - "ignore-tensor": Same as "structural", but ignore tensor raw data during equality testing and hashing. - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - given module. The "ignore-ndarray" varint is used for the extracted + given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. @@ -380,7 +380,7 @@ def compile_relax( database: Database, mod: IRModule, target: Union[Target, str], - params: Optional[Dict[str, NDArray]], + params: Optional[Dict[str, Tensor]], enable_warning: bool = False, ) -> "relax.VMExecutable": """Compile a relax program with a MetaSchedule database. @@ -393,7 +393,7 @@ def compile_relax( The Relax program to be compiled target : tvm.target.Target The compilation target - params : Optional[Dict[str, tvm.runtime.NDArray]] + params : Optional[Dict[str, tvm.runtime.Tensor]] The associated parameters of the program enable_warning : bool A boolean value indicating if to print warnings for TIR functions not diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index 7ff1065a191f..b35e47c94dda 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -148,7 +148,7 @@ def resource_handler(): rt_mod = tvm.runtime.load_module(artifact_path) # Step 2: Allocate input arguments with Profiler.timeit("LocalRunner/alloc_argument"): - device = tvm.runtime.device(dev_type=device_type, dev_id=0) + device = tvm.runtime.device(device_type, 0) repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( device, args_info, @@ -392,7 +392,7 @@ def default_cleanup() -> None: pass # pylint: disable=unnecessary-pass -@tvm.register_func("meta_schedule.runner.get_local_runner") +@tvm.register_global_func("meta_schedule.runner.get_local_runner") def get_local_builder() -> LocalRunner: """Get the local Runner. diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index b249be7ded74..9d61a7b0b4d6 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -384,7 +384,7 @@ def resource_handler(): # Step 1. Create session with Profiler.timeit("RPCRunner/create_session"): session = f_create_session(rpc_config) - device = session.device(dev_type=device_type, dev_id=0) + device = session.device(device_type, 0) # Step 2. Upload the module with Profiler.timeit("RPCRunner/upload_module"): _, remote_path = osp.split(artifact_path) diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 0c2609469a19..0d7cd32bd7a5 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -22,7 +22,7 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/runner/utils.py b/python/tvm/meta_schedule/runner/utils.py index ef0d4b5f98f7..d4af6726cee0 100644 --- a/python/tvm/meta_schedule/runner/utils.py +++ b/python/tvm/meta_schedule/runner/utils.py @@ -17,8 +17,9 @@ """Runner utility functions""" import itertools from typing import Any, Callable, Dict, List +import tvm.runtime -from ...runtime import Device, Module, ndarray +from ...runtime import Device, Module from .config import EvaluatorConfig T_ARG_INFO_JSON_OBJ = List[Any] # pylint: disable=invalid-name @@ -52,8 +53,8 @@ def alloc_argument_common( The allocation args """ - def alloc_tensor(_, dtype, shape) -> ndarray.NDArray: - arg = ndarray.empty(shape=shape, dtype=dtype, device=device) + def alloc_tensor(_, dtype, shape) -> tvm.runtime.Tensor: + arg = tvm.runtime.empty(shape=shape, dtype=dtype, device=device) f_random_fill(arg) return arg diff --git a/python/tvm/meta_schedule/schedule/cuda/layout_transform.py b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py index 949ef915c9ff..58540839397d 100644 --- a/python/tvm/meta_schedule/schedule/cuda/layout_transform.py +++ b/python/tvm/meta_schedule/schedule/cuda/layout_transform.py @@ -501,7 +501,7 @@ def get_max_tile_size() -> int: return max_tile_size -@tvm.register_func("meta_schedule.cuda.layout_transform") +@tvm.register_global_func("meta_schedule.cuda.layout_transform") def cuda_layout_transform_schedule_rule( sch: tvm.tir.Schedule, block: BlockRV, testing_tile_sizes: Optional[List[int]] = None ) -> List[tvm.tir.Schedule]: diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py index ceb18a6c3aa6..2bef40fffe74 100644 --- a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -17,7 +17,7 @@ """Add-rfactor Rule that add-rfactor to some blocks if needed""" from typing import Optional -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py index 26f61aa8ceb6..2e383c75eb91 100644 --- a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py @@ -16,7 +16,7 @@ # under the License. """Create a rule that applies customized rules registered using block attribute `schedule_rule`. The rule will be dispatched according to target keys.""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/auto_bind.py b/python/tvm/meta_schedule/schedule_rule/auto_bind.py index ef34e45061f7..0704b03f740f 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_bind.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_bind.py @@ -17,7 +17,7 @@ """Auto-bind Rule that binds blocks to threads if needed""" from typing import List, Optional -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py index 8cd122ec93d3..b789dd750707 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -17,7 +17,7 @@ """Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" from typing import List, Optional -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py index d2c780b72854..0c79d4f08bac 100644 --- a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py +++ b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py @@ -17,7 +17,7 @@ """Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed""" from typing import List -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 2f389190d662..41c97a7862b4 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Callable from tvm.tir.schedule import Schedule, BlockRV -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py index e9626c40e39c..259620b3f715 100644 --- a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py @@ -18,7 +18,7 @@ each block in a follow-up post processor""" from typing import List, Optional -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py index 81de07afbbed..8f1c96f6eb0a 100644 --- a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Rule that randomly select a compute-at location for a free block""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 5684e68c715f..98c81e5b8f30 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -25,7 +25,7 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import BlockRV, Schedule diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 1833ef23bda1..04f9310e6e0d 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Evolutionary Search Strategy""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .search_strategy import SearchStrategy diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py index 09e5c58d077a..682c9638c513 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_func.py +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .search_strategy import SearchStrategy diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index a25596524451..e04a440da68a 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .search_strategy import SearchStrategy diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index ab4a6fb7b636..cfb45dafdeb2 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -24,7 +24,7 @@ from typing_extensions import Literal # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Schedule @@ -87,6 +87,21 @@ class SearchStrategy(Object): ], ] + def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + """Prevent direct instantiation of abstract SearchStrategy class. + + SearchStrategy is an abstract class and cannot be directly instantiated. + Use SearchStrategy.create() or a concrete subclass instead. + """ + if cls is SearchStrategy: + raise TypeError( + "Cannot instantiate abstract class SearchStrategy. " + "Use SearchStrategy.create() with a valid strategy type " + "(e.g., 'evolutionary', 'replay-trace', 'replay-func') " + "or use a concrete subclass instead." + ) + return super().__new__(cls) # pylint: disable=no-value-for-parameter + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the search strategy with tuning context. diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index eee9ea0d0e5d..45b81bdf3e59 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Post Order Apply Space Generator.""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .space_generator import ( diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index 2cb1538a5abc..d01cd7fdcbd1 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Union of meta Schedule design space generators.""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .space_generator import ( diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 8c9effa6e656..35f9e2236764 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -24,7 +24,7 @@ from typing_extensions import Literal # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object from tvm.tir.schedule import Schedule diff --git a/python/tvm/meta_schedule/space_generator/space_generator_union.py b/python/tvm/meta_schedule/space_generator/space_generator_union.py index f512f6535550..0b8ceb453116 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator_union.py +++ b/python/tvm/meta_schedule/space_generator/space_generator_union.py @@ -17,7 +17,7 @@ """Union of meta Schedule design space generators.""" from typing import List -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from .space_generator import ( diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index 7bac23bb3fad..18d7e2be614a 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Gradient Based Task Scheduler""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from ..logging import get_logger, get_logging_func diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 6475b4102a1d..78504608f9ab 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Round Robin Task Scheduler""" -from tvm.ffi import register_object +from tvm_ffi import register_object from .. import _ffi_api from ..logging import get_logger, get_logging_func diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 9d6fec88b63b..4513f6081560 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -22,7 +22,7 @@ # isort: on -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 2da672b40561..490929402dc7 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -48,6 +48,6 @@ def run_module_via_rpc( session.upload(filename) _, filename = os.path.split(filename) rt_mod = session.load_module(filename) - dev = session.device(dev_type=dev_type, dev_id=0) + dev = session.device(dev_type, 0) nd_args = {k: ndarray.array(v, dev) for k, v in args.items()} return continuation(rt_mod, dev, nd_args) diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py index 08618a289d52..4b1155b2a235 100644 --- a/python/tvm/meta_schedule/testing/tune_utils.py +++ b/python/tvm/meta_schedule/testing/tune_utils.py @@ -19,7 +19,7 @@ import numpy as np # type: ignore import tvm -from tvm.runtime import NDArray +from tvm.runtime import Tensor def generate_input_data( @@ -81,8 +81,8 @@ def create_calculator(backend: str) -> Callable: def f_calculator( rt_mod: tvm.runtime.Module, dev: tvm.runtime.Device, # pylint: disable=unused-argument - input_data: Dict[str, NDArray], - ) -> List[NDArray]: + input_data: Dict[str, Tensor], + ) -> List[Tensor]: """Fetch the result of running the given runtime module. Parameters diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 14ff32c0178a..8b5a87f61932 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -22,10 +22,11 @@ from statistics import mean from typing import Callable, Tuple, Union, List, Any import numpy as np # type: ignore +from tvm_ffi import get_global_func, register_global_func + import tvm from tvm import meta_schedule as ms -from tvm.ffi import get_global_func, register_func from tvm.ir import IRModule from tvm.support import describe from tvm.target import Target @@ -202,24 +203,24 @@ def __hash__(self) -> int: def initializer() -> None: """Initializer function to register the functions on PopenWorker.""" - @register_func("tvm.meta_schedule.testing.default_check_metric") + @register_global_func("tvm.meta_schedule.testing.default_check_metric") def default_check_metric( # pylint: disable=unused-variable,unreachable-code - lhs: List[tvm.nd.NDArray], rhs: List[tvm.nd.NDArray] + lhs: List[tvm.runtime.Tensor], rhs: List[tvm.runtime.Tensor] ) -> bool: """Check if the outputs are equal Parameters ---------- - lhs : List[tvm.nd.NDArray] - The first list of NDArrays to compare. + lhs : List[tvm.runtime.Tensor] + The first list of Tensors to compare. - rhs : List[tvm.nd.NDArray] - The second list of NDArrays to compare. + rhs : List[tvm.runtime.Tensor] + The second list of Tensors to compare. Returns ------- is_equal : bool - Whether the two lists of NDArrays are equal. + Whether the two lists of Tensors are equal. """ assert len(lhs) == len(rhs), "Different number of outputs from two modules" for i in range(len(lhs)): # pylint: disable=consider-using-enumerate @@ -228,10 +229,10 @@ def default_check_metric( # pylint: disable=unused-variable,unreachable-code return True -@register_func("tvm.meta_schedule.testing.default_input_generator") +@register_global_func("tvm.meta_schedule.testing.default_input_generator") def default_input_generator( # pylint: disable=unused-variable mod: IRModule, -) -> List[tvm.nd.NDArray]: +) -> List[tvm.runtime.Tensor]: """Default input generator function Parameters @@ -241,25 +242,27 @@ def default_input_generator( # pylint: disable=unused-variable Returns ------- - inputs : List[tvm.nd.NDArray] + inputs : List[tvm.runtime.Tensor] The generated input data. """ args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) inputs = [ - tvm.nd.array(generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype)) + tvm.runtime.tensor( + generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype) + ) for arg_info in args_info ] return inputs -def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: - """Convert a list of TVM NDArray to a list of numpy array +def to_numpy(a: List[tvm.runtime.Tensor]) -> List[np.ndarray]: + """Convert a list of TVM Tensor to a list of numpy array Parameters ---------- - a : List[tvm.nd.NDArray] - The list of TVM NDArray to be converted + a : List[tvm.runtime.Tensor] + The list of TVM Tensor to be converted Returns ------- @@ -270,8 +273,8 @@ def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: return [x.numpy() for x in a] -def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: - """Convert a list of numpy array to a list of TVM NDArray +def to_tvm_tensor(a: List[np.ndarray]) -> List[tvm.runtime.Tensor]: + """Convert a list of numpy array to a list of TVM Tensor Parameters ---------- @@ -280,11 +283,11 @@ def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: Returns ------- - b : List[tvm.nd.NDArray] - The list of TVM NDArray. + b : List[tvm.runtime.Tensor] + The list of TVM Tensor. """ - assert a is not None, "Empty result cannot be converted to TVM NDArray" - return [tvm.nd.array(x) for x in a] + assert a is not None, "Empty result cannot be converted to TVM Tensor" + return [tvm.runtime.tensor(x) for x in a] def is_failed_record(record: ms.database.TuningRecord) -> bool: @@ -435,7 +438,9 @@ def f_with_args_alloc_argument_common( args_list : List[T_ARGUMENT_LIST] The list of argument lists. """ - return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)] + return [ + [tvm.runtime.tensor(arg, device=device) for arg in inputs] for _ in range(alloc_repeat) + ] def f_with_args_run_evaluator_common( rt_mod: tvm.runtime.Module, @@ -486,8 +491,8 @@ def f_with_args_run_evaluator_common( # fetch comparison function passed = check_and_run( ARGS.check_metric_func, - to_tvm_ndarray(original_res), - to_tvm_ndarray(scheduled_res), + to_tvm_tensor(original_res), + to_tvm_tensor(scheduled_res), ) print_result( @@ -555,7 +560,7 @@ def local_build_and_run( """ # potential memory leak https://github.com/apache/tvm/issues/11096 lib = tvm.compile(mod, target=target) - tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs] + tvm_inputs = [tvm.runtime.tensor(inp, device=device) for inp in inputs] device.sync() func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat) benchmark_res = func(*tvm_inputs) diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index b171c9711802..69a71ba3d6d9 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -19,10 +19,10 @@ # isort: off from typing_extensions import Literal +from tvm_ffi import register_global_func # isort: on from tvm import ir, tir -from tvm.ffi import register_func from tvm.target import Target from tvm.tir.expr import IntImm @@ -161,7 +161,7 @@ def tune_tir( # pylint: disable=too-many-locals ) -@register_func("tvm.meta_schedule.tune_tir") +@register_global_func("tvm.meta_schedule.tune_tir") def _tune_tir( mod: Union[ir.IRModule, tir.PrimFunc], target: Union[str, Target], diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 78c05fed533e..2cda77ba0978 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -77,10 +77,10 @@ def tune_tasks( It must be one of the followings: - "structural": Use StructuralEqual/Hash - - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality + - "ignore-tensor": Same as "structural", but ignore tensor raw data during equality testing and hashing. - "anchor-block": Apply equality testing and hashing on the anchor block extracted from - a given module. The "ignore-ndarray" varint is used for the extracted blocks or in + a given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. post_optimization : Optional[Bool] diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 5512b7a2682b..35a8d468a75c 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -21,10 +21,11 @@ # isort: off from typing_extensions import Literal +from tvm_ffi import register_object, register_global_func + # isort: on from tvm import IRModule -from tvm.ffi import register_object, register_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule @@ -41,7 +42,7 @@ from .space_generator import SpaceGenerator -@register_func("tvm.meta_schedule.normalize_mod") +@register_global_func("tvm.meta_schedule.normalize_mod") def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): @@ -52,9 +53,8 @@ def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") func_names = mod.get_global_vars() - (func_name,) = func_names - if len(func_names) == 1 and func_name.name_hint != "main": - mod = IRModule({"main": mod[func_name]}) + if len(func_names) == 1 and func_names[0].name_hint != "main": + mod = IRModule({"main": mod[func_names[0]]}) return mod @@ -123,6 +123,15 @@ def __init__( if search_strategy is not None: if not isinstance(search_strategy, SearchStrategy): search_strategy = SearchStrategy.create(search_strategy) + # Additional check: ensure it's not the abstract SearchStrategy class itself + # Use type() for exact type check (not isinstance which would match subclasses) + elif type(search_strategy) is SearchStrategy: # pylint: disable=unidiomatic-typecheck + raise TypeError( + "Cannot use abstract SearchStrategy class directly. " + "Use SearchStrategy.create() with a valid strategy type " + "(e.g., 'evolutionary', 'replay-trace', 'replay-func') " + "or use a concrete subclass instead." + ) if logger is None: logger = get_logger(__name__) if not isinstance(num_threads, int): diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 2f18f54a816f..ba0b4846a3cc 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -22,7 +22,7 @@ import numpy as np # type: ignore import psutil # type: ignore -from tvm.ffi import get_global_func, register_func +from tvm_ffi import get_global_func, register_global_func from tvm.error import TVMError from tvm.ir import Array, IRModule, Map from tvm.rpc import RPCSession @@ -104,10 +104,17 @@ def method(*args, **kwargs): metadata = getattr(base, "_tvm_metadata") fields = metadata.get("fields", []) methods = metadata.get("methods", []) + base_cls = metadata["cls"] + derived_slots = ( + ("_inst",) + if hasattr(base_cls, "__weakref__") or getattr(base_cls, "__weakrefoffset__", 0) + else ("_inst", "__weakref__") + ) - class TVMDerivedObject(metadata["cls"]): # type: ignore + class TVMDerivedObject(base_cls): # type: ignore """The derived object to avoid cyclic dependency.""" + __slots__ = derived_slots _cls = cls _type = "TVMDerivedObject" @@ -163,7 +170,7 @@ def __setattr__(self, name, value): return TVMDerivedObject -@register_func("meta_schedule.cpu_count") +@register_global_func("meta_schedule.cpu_count") def _cpu_count_impl(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system @@ -219,7 +226,7 @@ def cpu_count(logical: bool = True) -> int: return _cpu_count_impl(logical) -@register_func("meta_schedule.using_ipython") +@register_global_func("meta_schedule.using_ipython") def _using_ipython() -> bool: """Return whether the current process is running in an IPython shell. @@ -234,7 +241,7 @@ def _using_ipython() -> bool: return False -@register_func("meta_schedule.print_interactive_table") +@register_global_func("meta_schedule.print_interactive_table") def print_interactive_table(data: str) -> None: """Print the dataframe interactive table in notebook. @@ -327,7 +334,7 @@ def get_global_func_on_rpc_session( return result -@register_func("meta_schedule.remove_build_dir") +@register_global_func("meta_schedule.remove_build_dir") def remove_build_dir(artifact_path: str) -> None: """Clean up the build directory""" shutil.rmtree(os.path.dirname(artifact_path)) diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index b88000119897..a96063c543e0 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -98,6 +98,9 @@ # utils from .utils import convert_to_expr +# BasePyModule +from .base_py_module import BasePyModule + # Import submodules in the last to avoid dependency from . import exec_builder from . import expr diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py index db1ca055865a..c5e98a22eaaf 100644 --- a/python/tvm/relax/_ffi_api.py +++ b/python/tvm/relax/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI API for Relax.""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax", __name__) +tvm_ffi.init_ffi_api("relax", __name__) diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py index fb44606f1122..0a230fbd8bb6 100644 --- a/python/tvm/relax/analysis/_ffi_api.py +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.analysis", __name__) +tvm_ffi.init_ffi_api("relax.analysis", __name__) diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py index 17d7a18a338d..97a999788b93 100644 --- a/python/tvm/relax/backend/_ffi_api.py +++ b/python/tvm/relax/backend/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """FFI API for Relax backend.""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.backend", __name__) +tvm_ffi.init_ffi_api("relax.backend", __name__) diff --git a/python/tvm/relax/backend/adreno/__init__.py b/python/tvm/relax/backend/adreno/__init__.py index b3364f2f4b4a..b97ea399ab19 100644 --- a/python/tvm/relax/backend/adreno/__init__.py +++ b/python/tvm/relax/backend/adreno/__init__.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. """The Relax Adreno backend compilation pipeline and other passes.""" + +from . import transform + from .pipeline import ( finalize_passes, get_default_pipeline, diff --git a/python/tvm/ffi/cython/core.pyx b/python/tvm/relax/backend/adreno/transform/__init__.py similarity index 81% rename from python/tvm/ffi/cython/core.pyx rename to python/tvm/relax/backend/adreno/transform/__init__.py index 010341187ce6..abeb56ac488c 100644 --- a/python/tvm/ffi/cython/core.pyx +++ b/python/tvm/relax/backend/adreno/transform/__init__.py @@ -14,13 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Adreno Relax transformations. """ - -include "./base.pxi" -include "./dtype.pxi" -include "./device.pxi" -include "./object.pxi" -include "./error.pxi" -include "./string.pxi" -include "./ndarray.pxi" -include "./function.pxi" +from .transform import ( + AnnotateCustomMemoryScope, + FoldVDeviceScopeChange, +) diff --git a/python/tvm/relax/backend/adreno/transform/_ffi_api.py b/python/tvm/relax/backend/adreno/transform/_ffi_api.py new file mode 100644 index 000000000000..d665ba02a70e --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for Adreno transform""" +import tvm.ffi + +tvm.ffi.init_ffi_api("relax.backend.adreno.transform", __name__) diff --git a/python/tvm/relax/backend/adreno/transform/transform.py b/python/tvm/relax/backend/adreno/transform/transform.py new file mode 100644 index 000000000000..9a01d7be97dd --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/transform.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Adreno Relax transformation passes.""" +from typing import Optional + +import tvm.ir +from tvm.target import Target + +from . import _ffi_api + + +def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transform.Pass: + """Allocate the memory scope information. This is Adreno specific pass to annotate + The memory scope information and realize the same with RealizeVDevice pass followed by + updating the Prim Function var_buffer mapping using SpecializePrimFuncBasedOnCallSite. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.AnnotateCustomMemoryScope(target) # type: ignore + + +def FoldVDeviceScopeChange() -> tvm.ir.transform.Pass: + """This pass is a texture specific pass that can optimize unnecessary to_device copies. + Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + store into global scope avoiding unnecessary device copy. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.FoldVDeviceScopeChange() # type: ignore diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 47a4946ca97d..6b5b1293ff21 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -16,11 +16,7 @@ # under the License. """FlashInfer JIT compilation module for CUDA backend""" -import hashlib -import json -import os -import subprocess -from concurrent.futures import ThreadPoolExecutor +import re from pathlib import Path from typing import List @@ -28,150 +24,28 @@ from tvm.target import Target -def _compile_flashinfer_kernels( - name: str, source_paths: List[Path], target: Target, num_threads: int -) -> List[Path]: - from flashinfer.jit.env import ( # pylint: disable=import-outside-toplevel - CUTLASS_INCLUDE_DIRS, - FLASHINFER_CSRC_DIR, - FLASHINFER_INCLUDE_DIR, - FLASHINFER_JIT_DIR, - FLASHINFER_TVM_BINDING_DIR, - ) - - # ------------------------------------------------------------------------ - # Caching Flow: create build_directory and compute cache hash. - # ------------------------------------------------------------------------ - build_directory = FLASHINFER_JIT_DIR / name - build_directory.mkdir(parents=True, exist_ok=True) - - def get_object_file_path(src: Path) -> Path: - obj_name = src.stem + ".o" - obj_path = build_directory / obj_name - return obj_path - - # Compute latest modification time among all source files - latest_src_mtime = max(src.stat().st_mtime for src in source_paths) - - # Get modification time for the current file (the one that contains this function) - current_file_mtime = Path(__file__).stat().st_mtime - - # Build the hash key from metadata - hash_key = { - "name": name, - "target": str(target), - "latest_src_mtime": latest_src_mtime, - "current_file_mtime": current_file_mtime, - } - - hash_value = hashlib.md5( - json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8") - ).hexdigest() - - # Check if a valid hash exists in the build directory - hash_file = build_directory / "hash.md5" - if hash_file.exists(): - with open(hash_file, "r") as f: - cached_hash = f.read().strip() - if cached_hash == hash_value: - # Check that all object files exist - object_files = [] - all_exist = True - for src in source_paths: - obj_path = get_object_file_path(src) - if not obj_path.exists(): - all_exist = False - break - object_files.append(obj_path) - if all_exist: - return object_files - - # If we are here, cache is missing or outdated. Write the new hash and compile the paths - with open(hash_file, "w") as f: - f.write(hash_value) - - # ------------------------------------------------------------------------ - # 1) Common CUDA compile flags - # ------------------------------------------------------------------------ - cuda_cflags = [ - "-O3", - "-std=c++17", - "--threads", - str(num_threads), - "-g", - "-use_fast_math", - "--expt-relaxed-constexpr", - # DMLC default - "-DDMLC_USE_FOPEN64=0", - "-DDMLC_USE_LOGGING_LIBRARY=", - # Enable `-fPIC` for the host compiler - "-Xcompiler=-fPIC", - "-DFLASHINFER_ENABLE_F16", - "-DFLASHINFER_ENABLE_BF16", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", - ] +def _rename_exported_func_names(source_paths: List[Path], prefix: str): + """Rename the ffi-exported function names in the source files to the given prefix.""" + pattern = re.compile(r"^(\s*TVM_FFI_DLL_EXPORT_TYPED_FUNC\()([A-Za-z0-9_]+)(,.*)$") + for source_path in source_paths: + if not source_path.name.endswith("_binding.cu"): + continue - # Determine compute version - compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split(".")) - if compute_version in ["90"]: - compute_version += "a" - cuda_cflags += [ - "-gencode", - f"arch=compute_{compute_version},code=sm_{compute_version}", - ] - - # ------------------------------------------------------------------------ - # 2) Include paths - # ------------------------------------------------------------------------ - tvm_home = os.environ["TVM_SOURCE_DIR"] - include_paths = [ - FLASHINFER_INCLUDE_DIR, - FLASHINFER_CSRC_DIR, - FLASHINFER_TVM_BINDING_DIR, - Path(tvm_home).resolve() / "include", - Path(tvm_home).resolve() / "ffi" / "include", - Path(tvm_home).resolve() / "3rdparty" / "dlpack" / "include", - Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", - ] + CUTLASS_INCLUDE_DIRS - - # ------------------------------------------------------------------------ - # 3) Function to compile a single source file - # ------------------------------------------------------------------------ - def compile_single_source(src: Path) -> Path: - # Derive the .o filename from the source filename - obj_path = get_object_file_path(src) - - # Construct the command - cmd = ( - ["nvcc"] - + cuda_cflags - + [f"-I{inc_path}" for inc_path in include_paths] - + ["-c", "-o", str(obj_path), str(src)] - ) + original_text = source_path.read_text(encoding="utf-8") + lines = original_text.splitlines(keepends=True) + updated = False + for idx, line in enumerate(lines): + line_body = line.rstrip("\r\n") + line_ending = line[len(line_body) :] + match = pattern.match(line_body) + if not match: + continue + new_body = f"{match.group(1)}{prefix}_{match.group(2)}{match.group(3)}" + lines[idx] = new_body + line_ending + updated = True - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - out, err = proc.communicate() - if proc.returncode != 0: - raise RuntimeError( - f"FlashInfer JIT compilation failed for {src}\n" - f"Command: {' '.join(cmd)}\n" - f"stdout:\n{out.decode('utf-8')}\n" - f"stderr:\n{err.decode('utf-8')}" - ) - return obj_path - - # ------------------------------------------------------------------------ - # 4) Compile each source in parallel using ThreadPoolExecutor - # ------------------------------------------------------------------------ - object_files = [] - with ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(compile_single_source, src) for src in source_paths] - for f in futures: - object_files.append(f.result()) # Will raise if there's a compilation error - - # Return list of generated object files for any further linking steps - return object_files + if updated: + source_path.write_text("".join(lines), encoding="utf-8") def _load_flashinfer_modules(object_files: List[Path]) -> List[tvm.runtime.Module]: @@ -187,9 +61,8 @@ def gen_flashinfer_prefill_module( dtype_o: str, qk_head_dim: int, v_head_dim: int, - target: Target, - enable_inline_rope: bool = True, - num_threads: int = 8, + enable_inline_rope: bool, + return_static_libs: bool = False, ) -> List[tvm.runtime.Module]: """Generate a FlashInfer module for prefill. @@ -205,12 +78,12 @@ def gen_flashinfer_prefill_module( The head dimension of the query and key tensors. v_head_dim : int The head dimension of the value tensor. - target : Target - The target device to compile for. enable_inline_rope : bool Whether to enable inline rotary positional embedding. - num_threads : int - The number of threads to use for compilation. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- @@ -218,7 +91,7 @@ def gen_flashinfer_prefill_module( """ try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_customize_batch_prefill_tvm_binding, + gen_customize_batch_prefill_module, ) except ImportError: raise ImportError( @@ -248,32 +121,33 @@ def gen_flashinfer_prefill_module( if backend == "fa2" else "#include " ) - jit_args = { - "backend": backend, - "uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_" + jit_spec = gen_customize_batch_prefill_module( + backend=backend, + uri=f"batch_prefill_tvm_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_o}_" + f"qk_head_dim_{qk_head_dim}_" + f"v_head_dim_{v_head_dim}_" + f"enable_inline_rope_{enable_inline_rope}", - "dtype_q": torch_dtype_q, - "dtype_kv": torch_dtype_kv, - "dtype_o": torch_dtype_o, - "idtype": torch.int32, - "head_dim_qk": qk_head_dim, - "head_dim_vo": v_head_dim, - "additional_tensor_names": [], - "additional_tensor_dtypes": [], - "additional_scalar_names": ["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], - "additional_scalar_dtypes": ["double", "double", "double"], - "variant_name": variant_name, - "variant_decl": variant_decl, - "enable_inline_rope": enable_inline_rope, - } - uri, source_paths = gen_customize_batch_prefill_tvm_binding(**jit_args) - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + dtype_q=torch_dtype_q, + dtype_kv=torch_dtype_kv, + dtype_o=torch_dtype_o, + idtype=torch.int32, + head_dim_qk=qk_head_dim, + head_dim_vo=v_head_dim, + pos_encoding_mode=int(enable_inline_rope), + additional_tensor_names=[], + additional_tensor_dtypes=[], + additional_scalar_names=["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + additional_scalar_dtypes=["double", "double", "double"], + variant_name=variant_name, + variant_decl=variant_decl, + ) + _rename_exported_func_names(jit_spec.sources, "batch_prefill") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] def gen_flashinfer_decode_module( @@ -282,8 +156,8 @@ def gen_flashinfer_decode_module( dtype_o: str, qk_head_dim: int, v_head_dim: int, - target: Target, - num_threads: int = 8, + enable_inline_rope: bool, + return_static_libs: bool = False, ) -> List[tvm.runtime.Module]: """Generate a FlashInfer module for decode. @@ -299,10 +173,12 @@ def gen_flashinfer_decode_module( The head dimension of the query and key tensors. v_head_dim : int The head dimension of the value tensor. - target : Target - The target device to compile for. - num_threads : int - The number of threads to use for compilation. + enable_inline_rope : bool + Whether to enable inline rotary positional embedding. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- @@ -310,7 +186,7 @@ def gen_flashinfer_decode_module( """ try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_customize_batch_decode_tvm_binding, + gen_customize_batch_decode_module, ) except ImportError: raise ImportError( @@ -325,29 +201,32 @@ def gen_flashinfer_decode_module( torch_dtype_q = getattr(torch, dtype_q) torch_dtype_kv = getattr(torch, dtype_kv) torch_dtype_o = getattr(torch, dtype_o) - jit_args = { - "uri": f"batch_decode_tvm_dtype_q_{dtype_q}_" + jit_spec = gen_customize_batch_decode_module( + uri=f"batch_decode_tvm_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_o}_" + f"qk_head_dim_{qk_head_dim}_" - + f"v_head_dim_{v_head_dim}", - "dtype_q": torch_dtype_q, - "dtype_kv": torch_dtype_kv, - "dtype_o": torch_dtype_o, - "idtype": torch.int32, - "head_dim_qk": qk_head_dim, - "head_dim_vo": v_head_dim, - "additional_tensor_names": [], - "additional_tensor_dtypes": [], - "additional_scalar_names": ["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], - "additional_scalar_dtypes": ["double", "double", "double"], - "variant_name": "DefaultAttention", - "variant_decl": "#include ", - } - uri, source_paths = gen_customize_batch_decode_tvm_binding(**jit_args) - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + + f"v_head_dim_{v_head_dim}_" + + f"enable_inline_rope_{enable_inline_rope}", + dtype_q=torch_dtype_q, + dtype_kv=torch_dtype_kv, + dtype_o=torch_dtype_o, + idtype=torch.int32, + head_dim_qk=qk_head_dim, + head_dim_vo=v_head_dim, + pos_encoding_mode=int(enable_inline_rope), + additional_tensor_names=[], + additional_tensor_dtypes=[], + additional_scalar_names=["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + additional_scalar_dtypes=["double", "double", "double"], + variant_name="DefaultAttention", + variant_decl="#include ", + ) + _rename_exported_func_names(jit_spec.sources, "batch_decode") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] def gen_flashinfer_mla_module( @@ -356,8 +235,7 @@ def gen_flashinfer_mla_module( dtype_o: str, head_dim_ckv: int, head_dim_kpe: int, - target: Target, - num_threads: int = 8, + return_static_libs: bool = False, ) -> List[tvm.runtime.Module]: """Generate a FlashInfer module for MLA. @@ -377,6 +255,10 @@ def gen_flashinfer_mla_module( The target device to compile for. num_threads : int The number of threads to use for compilation. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- @@ -384,7 +266,7 @@ def gen_flashinfer_mla_module( """ try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_batch_mla_tvm_binding, + gen_batch_mla_module, ) except ImportError: raise ImportError( @@ -399,51 +281,65 @@ def gen_flashinfer_mla_module( torch_dtype_q = getattr(torch, dtype_q) torch_dtype_kv = getattr(torch, dtype_kv) torch_dtype_o = getattr(torch, dtype_o) - jit_args = { - "uri": f"batch_mla_tvm_dtype_q_{dtype_q}_" - + f"dtype_kv_{dtype_kv}_" - + f"dtype_o_{dtype_o}_" - + f"head_dim_ckv_{head_dim_ckv}_" - + f"head_dim_kpe_{head_dim_kpe}", - "dtype_q": torch_dtype_q, - "dtype_kv": torch_dtype_kv, - "dtype_o": torch_dtype_o, - "dtype_idx": torch.int32, - "head_dim_ckv": head_dim_ckv, - "head_dim_kpe": head_dim_kpe, - } - uri, source_paths = gen_batch_mla_tvm_binding(**jit_args) - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + jit_spec = gen_batch_mla_module( + backend="fa2", + dtype_q=torch_dtype_q, + dtype_kv=torch_dtype_kv, + dtype_o=torch_dtype_o, + dtype_idx=torch.int32, + head_dim_ckv=head_dim_ckv, + head_dim_kpe=head_dim_kpe, + use_profiler=False, + ) + _rename_exported_func_names(jit_spec.sources, "batch_mla") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] -def gen_sampling_module(target: Target, num_threads: int = 8): - """ - Generate a FlashInfer module for sampling kernels. +def gen_grouped_gemm_module( + target: Target, return_static_libs: bool = False +) -> List[tvm.runtime.Module]: + """Generate a FlashInfer module for FP8 grouped GEMM. Parameters ---------- target : Target - The target device for which the module will be compiled. - num_threads : int, optional - The number of threads to use during compilation (default is 8). + The target device to compile for. + return_static_libs : bool + Whether to return static library modules instead of compiled modules. + When it is False, it returns the loaded shared library that links all the object files. + When it is True, it returns the static libraries of each compiled object files. Returns ------- List[tvm.runtime.Module] - A list of compiled static library modules for the FlashInfer sampling kernels. + A list of compiled static library modules for FlashInfer FP8 grouped GEMM kernels. + + Note + _____ + when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k), m_indptr: (batch_size, ) + requires all m in m_indptr to be multiple of 4 """ + # NOTE: This function is still under development, + # and we currently only support SM100 grouped gemm try: - from flashinfer.jit import ( # pylint: disable=import-outside-toplevel - gen_sampling_tvm_binding, + from flashinfer.gemm import ( # pylint: disable=import-outside-toplevel + gen_gemm_sm100_module, ) except ImportError: raise ImportError( "FlashInfer is not installed. Please follow instructions " "in https://docs.flashinfer.ai to install FlashInfer." ) - uri, source_paths = gen_sampling_tvm_binding(uri="sampling") - object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) - modules = _load_flashinfer_modules(object_files) - return modules + + compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split(".")) + if compute_version == "100": + jit_spec = gen_gemm_sm100_module() + else: + raise ValueError(f"Unsupported compute version: {compute_version}") + if return_static_libs: + jit_spec.build(verbose=False) + return _load_flashinfer_modules(jit_spec.get_object_paths()) + return [jit_spec.build_and_load()] diff --git a/python/tvm/relax/backend/gpu_generic/sampling.py b/python/tvm/relax/backend/gpu_generic/sampling.py index 2634a0742713..9a0d01ef2331 100644 --- a/python/tvm/relax/backend/gpu_generic/sampling.py +++ b/python/tvm/relax/backend/gpu_generic/sampling.py @@ -19,6 +19,7 @@ import math from typing import Callable, Optional +import tvm from tvm.script import tir as T from tvm.tir import PrimFunc @@ -69,6 +70,9 @@ def gpu_multinomial_from_uniform( The generated function """ + target = tvm.target.Target.current() + target_dtype = "int32" if "webgpu" in str(target) else "int64" + TX = T.int64(tx_len) # threadIdx.x TY = T.int64(ty_len) # threadIdx.y @@ -282,7 +286,8 @@ def parallel_sampling_from_prob( # at least one iteration while T.tvm_thread_invariant( (step_iter[()] == 0 or aggregate[()] < u - eps) - and T.Cast("int64", step_iter[()]) < T.ceildiv(vocab_size, block_elem) + and T.Cast(target_dtype, step_iter[()]) + < T.Cast(target_dtype, T.ceildiv(vocab_size, block_elem)) ): single_batch_sampling( prob, @@ -290,7 +295,7 @@ def parallel_sampling_from_prob( vocab_size, ty, tx, - T.Cast("int64", step_iter[()]), + T.Cast(target_dtype, step_iter[()]), 0.0, aggregate, u, diff --git a/python/tvm/relax/backend/metal/coreml.py b/python/tvm/relax/backend/metal/coreml.py index 139e5cc2b997..dfc891dc1f31 100644 --- a/python/tvm/relax/backend/metal/coreml.py +++ b/python/tvm/relax/backend/metal/coreml.py @@ -19,7 +19,7 @@ import os import shutil -import tvm.ffi +import tvm_ffi from tvm.contrib import coreml_runtime from tvm.contrib.xcode import compile_coreml @@ -463,7 +463,7 @@ def compile(self, out_dir): compile_coreml(model, self.model_name, out_dir) -@tvm.ffi.register_func("relax.ext.coreml") +@tvm_ffi.register_global_func("relax.ext.coreml") def coreml_compiler(funcs, options, constant_names): """ Create a CoreML runtime from a Relax module. diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py new file mode 100644 index 000000000000..41ef44fb300b --- /dev/null +++ b/python/tvm/relax/base_py_module.py @@ -0,0 +1,627 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BasePyModule: Base class for IRModules with Python function support.""" + +import inspect +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import tvm +from tvm import relax, tir +from tvm.ir import IRModule +from tvm.runtime import Device, Tensor, PackedFunc +from tvm.target import Target + +try: + from torch.utils.dlpack import to_dlpack as to_dlpack_legacy +except ImportError: + to_dlpack_legacy = None + +try: + from tvm_ffi._optional_torch_c_dlpack import load_torch_c_dlpack_extension + + _FASTER_DLPACK_EXTENSION = load_torch_c_dlpack_extension() +except ImportError: + _FASTER_DLPACK_EXTENSION = None + + +class BasePyModule: + """Base class that allows Python functions in IRModule with DLPack conversion. + + This class provides the infrastructure for: + 1. JIT compilation of TIR and Relax functions. + 2. DLPack-based conversion between PyTorch tensors and TVM Tensors. + 3. Wrapping Relax functions for easy Python calling. + 4. Cross-function calls between Python, TIR, and Relax functions. + + Only IRModules that inherit from this class are allowed to contain Python functions. + """ + + def __del__(self): + """Clean up registered Python functions on module destruction.""" + try: + clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") + clear_func() + except (ValueError, AttributeError): + pass + + def __init__( + self, + ir_mod: IRModule, + device: Device, + target: Optional[Target] = None, + ): + """Initialize BasePyModule with JIT compilation and DLPack conversion.""" + self.device = device + self.ir_mod = ir_mod + + # Delegate IRModule operations + self.functions = ir_mod.functions + self.attrs = ir_mod.attrs + self.global_infos = ir_mod.global_infos + self.__getitem__ = ir_mod.__getitem__ + self.__setitem__ = ir_mod.__setitem__ + self.functions_items = ir_mod.functions_items + self.with_attr = ir_mod.with_attr + self.get_attr = ir_mod.get_attr + self.update_global_info = ir_mod.update_global_info + + def _getattr_python_function(name: str) -> Any: + """Support direct attribute access to funcs and IRModule methods.""" + if name in self.pyfuncs: + return self.pyfuncs[name] + if name in self.compiled_tir_funcs: + return self.compiled_tir_funcs[name] + if self.relax_vm and name in self.relax_func_names: + try: + return self.relax_vm[name] + except AttributeError: # More specific exception + return None + if hasattr(self.ir_mod, name): + return getattr(self.ir_mod, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + self.__getattr__ = _getattr_python_function + + self.compiled_tir_funcs: Dict[str, PackedFunc] = {} + self.extern_funcs: Dict[str, PackedFunc] = {} + self.tir_func_names: List[str] = [] + self.relax_func_names: List[str] = [] + self.relax_vm: Optional[relax.VirtualMachine] = None + self.pyfuncs: Dict[str, Any] = {} + + if target is None: + target = Target.from_device(device) + elif isinstance(target, str): + target = Target(target) + self.target = target + + self._collect_function_names() + self._compile_functions() + self._wrap_tir_functions() + self._wrap_relax_functions() + self._register_python_functions() + + def _collect_function_names(self): + """Collect names of TIR and Relax functions from IRModule.""" + for global_var, func in self.ir_mod.functions_items(): + if isinstance(func, tir.PrimFunc): + self.tir_func_names.append(global_var.name_hint) + elif isinstance(func, relax.Function): + self.relax_func_names.append(global_var.name_hint) + + def _compile_functions(self): + """Compile TIR and Relax functions using JIT compilation.""" + # Compile TIR functions first + tir_mod = tvm.IRModule( + { + gv: func + for gv, func in self.ir_mod.functions_items() + if isinstance(func, tir.PrimFunc) + } + ) + if tir_mod: + try: + tir_exec_mod = tvm.compile(tir_mod, target=self.target) + for func_name in self.tir_func_names: + self.compiled_tir_funcs[func_name] = tir_exec_mod[func_name] + # pylint: disable=broad-exception-caught + except Exception as error: + print(f"Warning: Failed to compile one or more TIR functions: {error}") + + relax_mod = tvm.IRModule( + { + gv: func + for gv, func in self.ir_mod.functions_items() + if isinstance(func, relax.Function) + } + ) + if relax_mod: + try: + exec_mod = tvm.compile(self.ir_mod, target=self.target) + self.relax_vm = relax.VirtualMachine(exec_mod, self.device) + # pylint: disable=broad-exception-caught + except Exception as error: + print(f"Warning: Failed to compile Relax VM: {error}") + self.relax_vm = None + + def _wrap_tir_functions(self): + """Wrap TIR functions to make them accessible as instance attributes.""" + for func_name, func in self.compiled_tir_funcs.items(): + setattr(self, func_name, func) + + def _wrap_relax_functions(self): + """Wrap Relax functions to be callable from Python with auto conversion.""" + for func_name in self.relax_func_names: + + def _create_relax_wrapper(name): + def wrapper(*args, **kwargs): + """Wrapper for Relax function with automatic tensor conversion.""" + if hasattr(self.ir_mod, "pyfuncs") and name in self.ir_mod.pyfuncs: + return self.ir_mod.pyfuncs[name](*args, **kwargs) + + if self.relax_vm is not None: + converted_args = self._convert_pytorch_to_tvm(list(args)) + converted_kwargs = { + k: self._convert_pytorch_to_tvm(v) for k, v in kwargs.items() + } + result = self.relax_vm[name](*converted_args, **converted_kwargs) + return self._convert_tvm_to_pytorch(result) + + raise RuntimeError( + f"Neither converted Python function nor Relax VM available for {name}" + ) + + wrapper.__name__ = name + wrapper.__doc__ = f"Wrapped Relax function: {name}" + return wrapper + + setattr(self, func_name, _create_relax_wrapper(func_name)) + + def _register_python_functions(self): + """Register Python functions with the VM runtime for call_py_func support.""" + if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs: + return + + try: + register_py_func = tvm.get_global_func("vm.builtin.register_py_func") + except ValueError: + return + + for func_name, py_func in self.ir_mod.pyfuncs.items(): + + def create_py_func_wrapper(name, original_func): + def wrapper(*args, **kwargs): + converted_args = [self._convert_tvm_to_pytorch(arg) for arg in args] + converted_kwargs = { + k: self._convert_tvm_to_pytorch(v) for k, v in kwargs.items() + } + + result = original_func(self, *converted_args, **converted_kwargs) + + return self._convert_pytorch_to_tvm(result) + + wrapper.__name__ = name + return wrapper + + wrapped_func = create_py_func_wrapper(func_name, py_func) + register_py_func(func_name, wrapped_func) + + def call_tir(self, tir_func, args, out_sinfo): + """Call a TIR function with PyTorch tensors.""" + # Try to get function name from different sources + if isinstance(tir_func, str): + func_name = tir_func + elif hasattr(tir_func, "name"): + func_name = tir_func.name + elif hasattr(tir_func, "__name__"): + func_name = tir_func.__name__ + else: + # Try to find by function object reference + for name, func in self.compiled_tir_funcs.items(): + if func == tir_func: + func_name = name + break + else: + func_name = None + + if not func_name or func_name not in self.compiled_tir_funcs: + available_funcs = list(self.compiled_tir_funcs.keys()) + raise ValueError( + f"Could not resolve or find compiled TIR function: {tir_func}. " + f"Available functions: {available_funcs}" + ) + func = self.compiled_tir_funcs[func_name] + + out = self._create_output_tensors(out_sinfo, args) + tvm_args = self._convert_pytorch_to_tvm(args) + tvm_out = self._convert_pytorch_to_tvm(out) + + func(*tvm_args, *tvm_out) + + result = self._convert_tvm_to_pytorch(tvm_out) + return result[0] if len(result) == 1 else result + + def call_dps_packed(self, func_name: str, args, out_sinfo): + """Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack.""" + if hasattr(self, func_name) and callable(getattr(self, func_name)): + return getattr(self, func_name)(*args) + + if func_name not in self.extern_funcs: + try: + self.extern_funcs[func_name] = tvm.get_global_func(func_name) + except ValueError as error: + raise ValueError( + f"Function '{func_name}' not found as a global function. " + f"Please implement it as a method or register it." + ) from error + func = self.extern_funcs[func_name] + + out = self._create_output_tensors(out_sinfo, args) + tvm_args = self._convert_pytorch_to_tvm(args) + tvm_out = self._convert_pytorch_to_tvm(out) + func(*tvm_args, *tvm_out) + return out[0] if len(out) == 1 else out + + def call_py_func(self, func_name: str, args): + """Call a Python function stored in the module's pyfuncs.""" + if func_name not in self.pyfuncs: + raise ValueError(f"Python function '{func_name}' not found in module pyfuncs") + py_func = self.pyfuncs[func_name] + return py_func(self, *args) + + def _create_output_tensors(self, out_sinfo, in_args=None): + # pylint: disable=import-outside-toplevel + import torch + + sinfo_list = out_sinfo if isinstance(out_sinfo, list) else [out_sinfo] + out_tensors = [] + for sinfo in sinfo_list: + if isinstance(sinfo, (tuple, list)) and all( + isinstance(x, (int, np.integer)) for x in sinfo + ): + out_tensors.append(torch.zeros(list(map(int, sinfo)), dtype=torch.float32)) + continue + + if hasattr(sinfo, "shape") and hasattr(sinfo, "dtype"): + concrete_shape = self._infer_concrete_shape_from_args(sinfo.shape, in_args) + torch_dtype = self._convert_tvm_dtype_to_torch(sinfo.dtype) + out_tensors.append(torch.zeros(concrete_shape, dtype=torch_dtype)) + continue + + out_tensors.append(torch.zeros((1,), dtype=torch.float32)) + return out_tensors + + def _infer_concrete_shape_from_args(self, shape, in_args): + + concrete = [] + symbolic_positions = [] + for idx, dim in enumerate(shape): + if isinstance(dim, (int, np.integer)): + concrete.append(int(dim)) + elif isinstance(dim, tir.IntImm): + concrete.append(int(dim.value)) + else: + concrete.append(None) + symbolic_positions.append(idx) + + if not symbolic_positions: + return concrete + + candidates = [] + if in_args is not None: + if not isinstance(in_args, (list, tuple)): + in_args = [in_args] + for obj in in_args: + if hasattr(obj, "shape") and isinstance(obj.shape, (tuple, list)): + try: + candidates.append(tuple(int(x) for x in obj.shape)) + continue + except (ValueError, TypeError): + # Skip objects with invalid shapes + pass + + target_ndim = len(shape) + for cand in candidates: + if len(cand) == target_ndim: + for pos in symbolic_positions: + concrete[pos] = cand[pos] + if all(x is not None for x in concrete): + return concrete + + raise ValueError( + "Cannot infer concrete output shape from symbolic shape and inputs. " + "Please provide a concrete `out_sinfo` (e.g., a tuple/list of ints) " + "or ensure input tensors carry shapes that determine output extents." + ) + + def _convert_tvm_dtype_to_torch(self, tvm_dtype: str) -> "torch.dtype": + """Convert TVM dtype string to PyTorch dtype.""" + # pylint: disable=import-outside-toplevel + import torch + + dtype_mapping = { + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bool": torch.bool, + } + return dtype_mapping.get(str(tvm_dtype), torch.float32) + + def _convert_pytorch_to_tvm( + self, tensors: Union[Any, List[Any], Tuple[Any, ...]] + ) -> Union[Tensor, List[Tensor]]: + """Convert PyTorch tensors to TVM Tensors using DLPack.""" + # pylint: disable=import-outside-toplevel + import torch + + if isinstance(tensors, (list, tuple)): + return [self._convert_single_pytorch_to_tvm(t) for t in tensors] + return self._convert_single_pytorch_to_tvm(tensors) + + def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor: + """Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter.""" + # pylint: disable=import-outside-toplevel + import torch + + if isinstance(tensor, Tensor): + return tensor + if isinstance(tensor, torch.Tensor): + # 1. Try faster C++ DLPack converter + if _FASTER_DLPACK_EXTENSION is not None: + try: + dlpack = torch.to_dlpack(tensor) + return tvm.runtime.from_dlpack(dlpack) + except (AttributeError, ValueError): + pass # Fall through to the next method + + # 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7) + try: + dlpack = torch.to_dlpack(tensor) + return tvm.runtime.from_dlpack(dlpack) + except (AttributeError, ValueError): + pass # Fall through to the next method + + # 3. Try legacy `torch.utils.dlpack.to_dlpack` + if to_dlpack_legacy: + try: + dlpack = to_dlpack_legacy(tensor) + return tvm.runtime.from_dlpack(dlpack) + except (AttributeError, ValueError) as error_legacy: + print( + f"Warning: Legacy DLPack conversion failed ({error_legacy}), " + f"using numpy fallback." + ) + + # 4. If all DLPack methods fail, use numpy fallback + numpy_array = tensor.detach().cpu().numpy() + return tvm.runtime.tensor(numpy_array, device=self.device) + + # For other types (like scalars, lists), convert to numpy first + try: + numpy_array = np.array(tensor, dtype=np.float32) + return tvm.runtime.tensor(numpy_array, device=self.device) + except (TypeError, ValueError) as error: + raise TypeError( + f"Unsupported type for conversion to TVM Tensor: {type(tensor)}" + ) from error + + def _convert_tvm_to_pytorch( + self, tvm_tensors: Union[Any, List[Any]] + ) -> Union["torch.Tensor", List["torch.Tensor"]]: + """Convert TVM Tensors to PyTorch tensors using DLPack.""" + if isinstance(tvm_tensors, (list, tuple)): + return [self._convert_single_tvm_to_pytorch(tensor) for tensor in tvm_tensors] + return self._convert_single_tvm_to_pytorch(tvm_tensors) + + def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) -> "torch.Tensor": + """Convert a single TVM Tensor to PyTorch tensor using faster DLPack converter.""" + # pylint: disable=import-outside-toplevel + import torch + + if isinstance(tvm_tensor, torch.Tensor): + return tvm_tensor + if not isinstance(tvm_tensor, Tensor): + return torch.tensor(tvm_tensor) + + # 1. Try faster C++ DLPack converter + if _FASTER_DLPACK_EXTENSION is not None: + try: + return torch.from_dlpack(tvm_tensor) + except (AttributeError, ValueError): + pass # Fall through to the next method + + # 2. Try standard DLPack conversion + try: + return torch.from_dlpack(tvm_tensor) + # pylint: disable=broad-exception-caught + except Exception as error: + print(f"Warning: DLPack conversion from TVM failed ({error}), using numpy fallback") + numpy_array = tvm_tensor.numpy() + return torch.from_numpy(numpy_array) + + def get_function(self, name: str) -> Optional[PackedFunc]: + """Get a compiled function by name.""" + if name in self.compiled_tir_funcs: + return self.compiled_tir_funcs[name] + if name in self.extern_funcs: + return self.extern_funcs[name] + if self.relax_vm and name in self.relax_func_names: + try: + if hasattr(self, name): + return getattr(self, name) + return self.relax_vm[name] + except AttributeError as error: + print(f"Warning: Failed to get Relax function '{name}': {error}") + return None + + def list_functions(self) -> Dict[str, List[str]]: + """List all available functions.""" + return { + "tir": self.tir_func_names, + "relax": self.relax_func_names, + "extern": list(self.extern_funcs.keys()), + } + + def add_python_function(self, name: str, func: callable): + """Add a Python function to the module.""" + self.pyfuncs[name] = func + + # Create a wrapper that handles both instance methods and static functions + # pylint: disable=import-outside-toplevel + import functools + + @functools.wraps(func) + def wrapper(*args, **kwargs): + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + + if params and params[0] == "self": + return func(self, *args, **kwargs) + else: + return func(*args, **kwargs) + + # Set the wrapper as an instance attribute + setattr(self, name, wrapper) + + def script( + self, + *, + name: Optional[str] = None, + show_meta: bool = False, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + module_alias: str = "cls", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + syntax_sugar: bool = True, + show_object_address: bool = False, + show_all_struct_info: bool = True, + ) -> str: + """Print TVM IR into TVMScript text format with Python function support. + + This method extends the standard IRModule script() method to handle + Python functions stored in the IRModule's pyfuncs attribute. + """ + # First get the standard IRModule script + base_script = self.ir_mod.script( + name=name, + show_meta=show_meta, + ir_prefix=ir_prefix, + tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + module_alias=module_alias, + buffer_dtype=buffer_dtype, + int_dtype=int_dtype, + float_dtype=float_dtype, + verbose_expr=verbose_expr, + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + syntax_sugar=syntax_sugar, + show_object_address=show_object_address, + show_all_struct_info=show_all_struct_info, + ) + + # If there are no Python functions, return the base script + if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs: + return base_script + + # Insert Python functions into the script + return self._insert_python_functions(base_script, indent_spaces) + + def _insert_python_functions(self, base_script: str, indent_spaces: int) -> str: + """Insert Python functions into the TVMScript output.""" + lines = base_script.split("\n") + result_lines = [] + + # Find the class definition line and insert Python functions after it + class_found = False + class_indent = 0 + + for line in lines: + result_lines.append(line) + + # Look for class definition + if not class_found and line.strip().startswith("class "): + class_found = True + class_indent = len(line) - len(line.lstrip()) + + # Insert Python functions after the class definition + if hasattr(self.ir_mod, "pyfuncs") and self.ir_mod.pyfuncs: + for func_name, func in self.ir_mod.pyfuncs.items(): + # Get the function source code + func_source = self._get_function_source(func) + if func_source: + # Format the function with proper indentation + formatted_func = self._format_python_function( + func_name, func_source, class_indent + indent_spaces + ) + result_lines.append(formatted_func) + result_lines.append("") # Add empty line for separation + + return "\n".join(result_lines) + + def _get_function_source(self, func: callable) -> Optional[str]: + """Get the source code of a Python function.""" + try: + source = inspect.getsource(func) + return source + except (OSError, TypeError): + # If we can't get the source, return None + return None + + def _format_python_function(self, _func_name: str, func_source: str, indent: int) -> str: + """Format a Python function with proper indentation for TVMScript.""" + lines = func_source.split("\n") + formatted_lines = [] + + for line in lines: + # Skip the function definition line if it's already properly indented + if line.strip().startswith("def ") or line.strip().startswith("@"): + # Keep decorators and function definition as is + formatted_lines.append(" " * indent + line.strip()) + else: + # Add proper indentation for the function body + formatted_lines.append(" " * indent + line.strip()) + + return "\n".join(formatted_lines) + + def show( + self, style: Optional[str] = None, black_format: Optional[bool] = None, **kwargs + ) -> None: + """A sugar for print highlighted TVM script with Python function support. + + This method extends the standard IRModule show() method to handle + Python functions stored in the IRModule's pyfuncs attribute. + """ + from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + + if black_format is None: + env = os.environ.get("TVM_BLACK_FORMAT") + black_format = env and int(env) + + script_content = self.script(**kwargs) + cprint(script_content, style=style, black_format=black_format) diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py index 22215206ac4b..077f8feebb90 100644 --- a/python/tvm/relax/binding_rewrite.py +++ b/python/tvm/relax/binding_rewrite.py @@ -20,13 +20,13 @@ from typing import Optional import tvm -import tvm.ffi +import tvm_ffi from tvm.runtime import Object from . import Binding, DataflowBlock, Expr, Function, Var from . import _ffi_api -@tvm.ffi.register_object("relax.DataflowBlockRewrite") +@tvm_ffi.register_object("relax.DataflowBlockRewrite") class DataflowBlockRewrite(Object): """ A binding/statement-level dataflow block rewriter. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index e09a9fab263a..8c777eb53756 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union import tvm +import tvm_ffi from tvm import relax as rx from tvm import tir from tvm.ir.module import IRModule @@ -100,7 +101,7 @@ def __exit__(self, ptype, value, trace): self._bb.end_scope() -@tvm.ffi.register_object("relax.BlockBuilder") +@tvm_ffi.register_object("relax.BlockBuilder") class BlockBuilder(Object): """A builder to build Relax IR for testing and dev. @@ -298,6 +299,10 @@ def _normalize_python_tuple(self, expr: Union[Expr, Sequence[Expr]]): """ if isinstance(expr, (list, tuple)): return Tuple([self._normalize_python_tuple(element) for element in expr]) + elif expr is None: + from . import op + + return op.null_value() else: return expr diff --git a/python/tvm/relax/distributed/_ffi_api.py b/python/tvm/relax/distributed/_ffi_api.py index 6544a8d35572..71185a1276da 100644 --- a/python/tvm/relax/distributed/_ffi_api.py +++ b/python/tvm/relax/distributed/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.distributed""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.distributed", __name__) +tvm_ffi.init_ffi_api("relax.distributed", __name__) diff --git a/python/tvm/relax/distributed/global_info.py b/python/tvm/relax/distributed/global_info.py index 3f549ecfa37e..34d3f2da4720 100644 --- a/python/tvm/relax/distributed/global_info.py +++ b/python/tvm/relax/distributed/global_info.py @@ -18,7 +18,7 @@ """Global Info Data structures for distributed tensor.""" from typing import List, Union, Tuple -import tvm +import tvm_ffi from tvm.ir import Range from tvm.ir.global_info import GlobalInfo from tvm.runtime import ShapeTuple @@ -26,7 +26,7 @@ from . import _ffi_api as ffi -@tvm.ffi.register_object("relax.distributed.DeviceMesh") +@tvm_ffi.register_object("relax.distributed.DeviceMesh") class DeviceMesh(GlobalInfo): """Device mesh express a view of topology of devices, represented by an n-d matrix of device ids. diff --git a/python/tvm/relax/distributed/struct_info.py b/python/tvm/relax/distributed/struct_info.py index 50087b98841a..554c83e47490 100644 --- a/python/tvm/relax/distributed/struct_info.py +++ b/python/tvm/relax/distributed/struct_info.py @@ -18,7 +18,7 @@ """Struct Info for distributed tensor.""" import enum from typing import List -import tvm +import tvm_ffi from tvm.relax.struct_info import StructInfo, TensorStructInfo from tvm.ir import Span from tvm.runtime.object import Object @@ -33,7 +33,7 @@ class PlacementSpecKind(enum.IntEnum): kReplica = 1 -@tvm.ffi.register_object("relax.distributed.PlacementSpec") +@tvm_ffi.register_object("relax.distributed.PlacementSpec") class PlacementSpec(Object): """Describes how data is distributed in one dimension of the device mesh @@ -80,7 +80,7 @@ def replica() -> "PlacementSpec": return _ffi_api.Replica() -@tvm.ffi.register_object("relax.distributed.Placement") +@tvm_ffi.register_object("relax.distributed.Placement") class Placement(Object): """Describes how data is distributed in each dimension of the device mesh @@ -110,7 +110,7 @@ def from_text(text: str) -> "Placement": return _ffi_api.PlacementFromText(text) -@tvm.ffi.register_object("relax.DTensorStructInfo") +@tvm_ffi.register_object("relax.DTensorStructInfo") class DTensorStructInfo(StructInfo): """StructInfo of a Distributed Tensor value. diff --git a/python/tvm/relax/distributed/transform/_ffi_api.py b/python/tvm/relax/distributed/transform/_ffi_api.py index b694a67116d2..35808cc2bc93 100644 --- a/python/tvm/relax/distributed/transform/_ffi_api.py +++ b/python/tvm/relax/distributed/transform/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.distributed.transform""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.distributed.transform", __name__) +tvm_ffi.init_ffi_api("relax.distributed.transform", __name__) diff --git a/python/tvm/relax/dpl/_ffi.py b/python/tvm/relax/dpl/_ffi.py index 72bf073bedfc..b03e5800e8fc 100644 --- a/python/tvm/relax/dpl/_ffi.py +++ b/python/tvm/relax/dpl/_ffi.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """DataFlow Pattern Language FFI bindings.""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.dpl", __name__) +tvm_ffi.init_ffi_api("relax.dpl", __name__) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index eca885e03acb..ef7516f31f46 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -22,8 +22,9 @@ import typing from typing import Dict, List, Optional, Tuple, Union +import tvm_ffi + import tvm -import tvm.ffi as tvm_ffi from tvm.ir.container import Array from tvm.ir.expr import PrimExpr from tvm.ir.op import Op diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index a9782057c8fb..6dd730e83147 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -20,7 +20,7 @@ from tvm.ir import IRModule from tvm.runtime import Object -from tvm.ffi import register_object +from tvm_ffi import register_object from .pattern import DFPattern from .context import PatternContext diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 4c5647daf756..50d6c0679eca 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -19,7 +19,7 @@ from enum import IntEnum from typing import Optional, Union, List import tvm -from tvm.runtime import Object +import tvm_ffi from tvm.runtime.container import ShapeTuple from .vm_build import VMExecutable from . import _ffi_api @@ -56,8 +56,8 @@ def __exit__(self, ptype, value, trace): self.exit_callback() -@tvm.ffi.register_object("relax.ExecBuilder") -class ExecBuilder(Object): +@tvm_ffi.register_object("relax.ExecBuilder") +class ExecBuilder(tvm_ffi.core.Object): """A builder to emit instructions and build executable for the virtual machine.""" def __init__(self) -> None: @@ -106,7 +106,7 @@ def convert_constant(self, const: object) -> int: def emit_call( self, name: str, - args: Optional[List[Union[tvm.nd.NDArray, tvm.DataType]]] = None, + args: Optional[List[Union[tvm.runtime.Tensor, tvm.DataType]]] = None, dst: int = None, ) -> None: """emit a call instruction which calls a packed function.""" @@ -120,7 +120,7 @@ def emit_call( shape_tuple = ShapeTuple(arg) new_arg = self.convert_constant(shape_tuple) args_.append(new_arg) - elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)): + elif isinstance(arg, (tvm.runtime.Tensor, tvm.DataType, ShapeTuple)): new_arg = self.convert_constant(arg) args_.append(new_arg) else: diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index ee9caf3a835b..e9bc9a7a3e98 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -21,13 +21,13 @@ import numpy as _np # type: ignore -import tvm -import tvm.ffi +import tvm_ffi + import tvm.ir import tvm.relax from tvm import DataType +import tvm.runtime from tvm.runtime import Object -from tvm.runtime import ndarray as _nd from ..ir import BaseFunc, Node, Span from ..runtime import Scriptable, String @@ -42,7 +42,7 @@ GlobalVar = Union[tvm.ir.GlobalVar] -@tvm.ffi.register_object("relax.Id") +@tvm_ffi.register_object("relax.Id") class Id(Object): """Unique identifier(name) used in Var. Guaranteed to be stable across all passes. @@ -56,7 +56,7 @@ def __init__(self): # NOTE: place base struct info in expr to avoid cyclic dep # from expr to struct info. -@tvm.ffi.register_object("ir.StructInfo") +@tvm_ffi.register_object("ir.StructInfo") class StructInfo(Node, Scriptable): """The base class of all StructInfo. @@ -185,8 +185,7 @@ def __rfloordiv__(self, other: Expr) -> "ExprWithOp": return _binary_rhs_helper(other) def __mod__(self, other: Expr) -> "ExprWithOp": - # TODO(siyuan): Support it after mod operator is supported in relax - raise ValueError("relax.mod is not supported yet.") + return _binary_op_helper(self, other, _op_ffi_api.mod) # type: ignore def __rmod__(self, other: Expr) -> "ExprWithOp": return _binary_rhs_helper(other) @@ -308,7 +307,7 @@ def elem_offset(self) -> "Expr": return tvm.relax.Call(op, [self]) -class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric): +class _DLTensorDTypeProxy(tvm.runtime.ObjectConvertible): """A proxy object for unpacking DLDatatype from DLTensor Exposes accessors for `DLDataType` fields `type_code`, `lanes`, @@ -388,7 +387,7 @@ def bits(self) -> Expr: return tvm.relax.Call(op, [self.tensor]) -class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric): +class _DLTensorShapeProxy(tvm.runtime.ObjectConvertible): """A proxy object for unpacking the shape from DLTensor Exposes accessors for the `DLTensor::shape` field. Accessing @@ -458,7 +457,7 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: return tvm.relax.Call(op, [self.tensor, axis]) -class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric): +class _DLTensorStrideProxy(tvm.runtime.ObjectConvertible): """A proxy object for unpacking the strides from DLTensor Exposes accessors for the `DLTensor::strides` field. Accessing @@ -528,7 +527,7 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: return tvm.relax.Call(op, [self.tensor, axis]) -@tvm.ffi.register_object("relax.expr.Call") +@tvm_ffi.register_object("relax.expr.Call") class Call(ExprWithOp): """Function call node in Relax. @@ -577,7 +576,7 @@ def __init__( ) -@tvm.ffi.register_object("relax.expr.If") +@tvm_ffi.register_object("relax.expr.If") class If(ExprWithOp): """A conditional expression in Relax. @@ -609,7 +608,7 @@ def __init__( ) -@tvm.ffi.register_object("relax.expr.Tuple") +@tvm_ffi.register_object("relax.expr.Tuple") class Tuple(ExprWithOp): """Tuple expression that groups several fields together. @@ -644,7 +643,7 @@ def __len__(self) -> int: return len(self.fields) -@tvm.ffi.register_object("relax.expr.TupleGetItem") +@tvm_ffi.register_object("relax.expr.TupleGetItem") class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. @@ -670,7 +669,7 @@ def __init__(self, tuple_value: Expr, index: int, span: Optional[Span] = None): ) -@tvm.ffi.register_object("relax.expr.ShapeExpr") +@tvm_ffi.register_object("relax.expr.ShapeExpr") class ShapeExpr(ExprWithOp): """A shape expression which allows users to construct a shape containing PrimExpr. @@ -708,13 +707,13 @@ def make_shape(shape: Union[List[Any], typing.Tuple[Any, ...]]) -> ShapeExpr: raise ValueError("Wrong type") -@tvm.ffi.register_object("relax.expr.Constant") +@tvm_ffi.register_object("relax.expr.Constant") class Constant(ExprWithOp): """Constant Tensor Parameters ---------- - data: tvm.nd.NDArray + data: tvm.runtime.Tensor The data of the constant tensor. struct_info: Optional[StructInfo] @@ -728,12 +727,12 @@ class Constant(ExprWithOp): Scalar constants are represented by ndim-0 constant tensors. """ - data: tvm.nd.NDArray + data: tvm.runtime.Tensor span: Optional[Span] def __init__( self, - data: tvm.nd.NDArray, + data: tvm.runtime.Tensor, struct_info: Optional[StructInfo] = None, span: Optional[Span] = None, ) -> None: @@ -742,7 +741,7 @@ def __init__( ) -@tvm.ffi.register_object("relax.expr.Var") +@tvm_ffi.register_object("relax.expr.Var") class Var(ExprWithOp): """The variable class for all Relax bindings. @@ -789,7 +788,7 @@ def name_hint(self) -> str: return name -@tvm.ffi.register_object("relax.expr.DataflowVar") +@tvm_ffi.register_object("relax.expr.DataflowVar") class DataflowVar(Var): """A sub-type of the variable node used to mark dataflow variables from normal visible "function local" bindings. @@ -838,7 +837,7 @@ def __init__( ) -@tvm.ffi.register_object("relax.expr.PrimValue") +@tvm_ffi.register_object("relax.expr.PrimValue") class PrimValue(Expr, Scriptable): """The prim expr representing the value.""" @@ -850,7 +849,7 @@ def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore -@tvm.ffi.register_object("relax.expr.StringImm") +@tvm_ffi.register_object("relax.expr.StringImm") class StringImm(Expr, Scriptable): """Represent a string literal constant.""" @@ -861,7 +860,7 @@ def __init__(self, value: str, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore -@tvm.ffi.register_object("relax.expr.DataTypeImm") +@tvm_ffi.register_object("relax.expr.DataTypeImm") class DataTypeImm(Expr, Scriptable): """Represent a data type constant.""" @@ -872,7 +871,7 @@ def __init__(self, value: Union[DataType, str], span: Optional[Span] = None) -> self.__init_handle_by_constructor__(_ffi_api.DataTypeImm, value, span) # type: ignore -@tvm.ffi.register_object("relax.expr.Binding") +@tvm_ffi.register_object("relax.expr.Binding") class Binding(Node, Scriptable): """The base class of a binding in Relax.""" @@ -880,7 +879,7 @@ class Binding(Node, Scriptable): span: Optional[Span] -@tvm.ffi.register_object("relax.expr.MatchCast") +@tvm_ffi.register_object("relax.expr.MatchCast") class MatchCast(Binding): """Runtime-match the value to the struct info. @@ -912,7 +911,7 @@ def __init__( ) -@tvm.ffi.register_object("relax.expr.VarBinding") +@tvm_ffi.register_object("relax.expr.VarBinding") class VarBinding(Binding): """Variable binding, bind he variable of the lhs with the rhs. @@ -934,7 +933,7 @@ def __init__(self, var: Var, value: Expr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) # type: ignore -@tvm.ffi.register_object("relax.expr.BindingBlock") +@tvm_ffi.register_object("relax.expr.BindingBlock") class BindingBlock(Node, Scriptable): """base class of binding block, bindings inside can be impure (with side effect or control flow)""" @@ -946,7 +945,7 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) # type: ignore -@tvm.ffi.register_object("relax.expr.DataflowBlock") +@tvm_ffi.register_object("relax.expr.DataflowBlock") class DataflowBlock(BindingBlock): """dataflow block, bindings inside are pure (no side effect and no control flow)""" @@ -958,7 +957,7 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None self.__init_handle_by_constructor__(_ffi_api.DataflowBlock, bindings, span) # type: ignore -@tvm.ffi.register_object("relax.expr.SeqExpr") +@tvm_ffi.register_object("relax.expr.SeqExpr") class SeqExpr(ExprWithOp): """A sequence of binding blocks followed by an expression.""" @@ -970,7 +969,7 @@ def __init__(self, blocks: List[BindingBlock], body: Expr, span: Optional[Span] self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) # type: ignore -@tvm.ffi.register_object("relax.expr.Function") +@tvm_ffi.register_object("relax.expr.Function") class Function(BaseFunc, Scriptable): """A Relax function.""" @@ -1057,7 +1056,7 @@ def bind_params( self, binding_map: Mapping[ Union[str, Var], - Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr], + Union[int, float, PrimExpr, tvm.runtime.Tensor, _np.ndarray, Expr], ], ) -> "Function": """Return a new function with updated symbolic variable @@ -1066,7 +1065,7 @@ def bind_params( ---------- binding_map: Mapping[ Union[str, Var], - Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr], + Union[int, float, PrimExpr, tvm.runtime.Tensor, _np.ndarray, Expr], ] The mapping of values to be replaced. @@ -1094,7 +1093,7 @@ def _normalize_value(value): # Relax uses int64 for symbolic variables, but the FFI # converts python integers into int32. return tvm.tir.const(value, "int64") - elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)): + elif isinstance(value, (_np.ndarray, tvm.runtime.Tensor)): return tvm.relax.const(value) else: return value @@ -1109,7 +1108,7 @@ def inline_functions( return _ffi_api.FunctionInlineFunctions(self, function_map) # type: ignore -@tvm.ffi.register_object("relax.expr.ExternFunc") +@tvm_ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc, ExprWithOp): """extern function, which represents a PackedFunc.""" @@ -1133,13 +1132,13 @@ def extern(name: str, struct_info: Optional[StructInfo] = None, span: Optional[S def const( - value: Union[bool, int, float, _np.ndarray, tvm.nd.NDArray], dtype: Optional[str] = None + value: Union[bool, int, float, _np.ndarray, tvm.runtime.Tensor], dtype: Optional[str] = None ) -> Constant: """Create a constant value. Parameters ---------- - value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] + value: Union[bool, int, float, numpy.ndarray, tvm.runtime.Tensor] The constant value. dtype: Optional[str] @@ -1154,6 +1153,9 @@ def const( - bool maps to "bool" - other using the same default rule as numpy. """ + # Needed for bf16 and fp8 support (does not come with numpy) + import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel + if isinstance(value, (Number, (bool, list))): value = _np.array(value, dtype=dtype) @@ -1169,15 +1171,15 @@ def const( if isinstance(value, (_np.ndarray, _np.generic)): if dtype is not None: value = value.astype(dtype) - value = _nd.array(value) + value = tvm.runtime.tensor(value) - if not isinstance(value, _nd.NDArray): - raise ValueError("value has to be scalar or NDArray") + if not isinstance(value, tvm.runtime.Tensor): + raise ValueError("value has to be scalar or Tensor") return Constant(value) -@tvm.ffi.register_object("relax.TEPlaceholderOp") +@tvm_ffi.register_object("relax.TEPlaceholderOp") class TEPlaceholderOp(tvm.te.tensor.Operation): """The placeholder op that represents a relax expression.""" diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index a40b81c233ef..e5e77251c66d 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -18,7 +18,7 @@ """The expression functor of Relax.""" from typing import Callable, Optional -import tvm +import tvm_ffi from tvm.ir import Op from tvm.runtime import Object from tvm.runtime.support import derived_object @@ -261,8 +261,8 @@ def visit_var_def(self, var: Var): raise TypeError("Invalid type: {0}".format(type(var))) -@tvm.ffi.register_object("expr_functor.PyExprVisitor") -class _PyExprVisitor(Object): +@tvm_ffi.register_object("expr_functor.PyExprVisitor") +class _PyExprVisitor(tvm_ffi.core.Object): """ A TVM object to support customization of ExprVisitor on the python side. This is the decorated result returned from visitor decorator. @@ -781,7 +781,7 @@ def visit_span(self, span: Span) -> None: return _ffi_api.ExprVisitorVisitSpan(self._outer(), span) # type: ignore -@tvm.ffi.register_object("expr_functor.PyExprMutator") +@tvm_ffi.register_object("expr_functor.PyExprMutator") class _PyExprMutator(Object): """ A TVM object to support customization of ExprMutator on the python side. diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index ba2960c159fc..5b18d5e27d9b 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -23,7 +23,7 @@ from tvm import topi -def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]: +def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.runtime.Tensor]]]: """Detach the attribute "params" in the functions of the input IRModule as separate dictionary of params. @@ -37,7 +37,7 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n detached_mod : tvm.IRModule The IRModule after the detachment. - params_dict : Dict[str, List[tvm.nd.NDArray]] + params_dict : Dict[str, List[tvm.runtime.Tensor]] The detached params. The dict keys corresponds to the names of the functions in the input IRModule that have attribute "params". """ @@ -46,10 +46,8 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n for gv, func in mod.functions_items(): if "params" in func.attrs: params = list(func.attrs["params"]) - if not all([isinstance(param, tvm.nd.NDArray) for param in params]): - raise ValueError( - 'The value "params" attribute is expected to be a list of NDArray.' - ) + if not all([isinstance(param, tvm.runtime.Tensor) for param in params]): + raise ValueError('The value "params" attribute is expected to be a list of Tensor.') params_dict[gv.name_hint] = params detached_mod[gv] = func.without_attr("params") else: @@ -125,5 +123,5 @@ def autopad( topi.nn.mirror_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), "REFLECT" ) else: - # TODO(gigiblender) Support edge mode. - raise NotImplementedError("Pad mode {} not implemented".format(pad_type)) + # edge mode - replicate border values + return bb.emit_te(topi.nn.replicate_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist()) diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index f490af7062b0..d9036348835a 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -17,7 +17,7 @@ """A PyTorch-like API to build IRModules.""" # pylint: disable=redefined-builtin from . import op, spec -from .core import Effect, Module, ModuleList, Object, Parameter, Tensor +from .core import Effect, Module, ModuleDict, ModuleList, Object, Parameter, Tensor from .exporter import add_extern from .extern import ExternModule, ObjectModule, SourceModule from .modules import ( diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 068b2090db5b..b15ba685b76d 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -42,9 +42,9 @@ from tvm import tir from tvm.ir import IRModule from tvm.ir.transform import Pass -from tvm.runtime import Device, NDArray +import tvm.runtime +from tvm.runtime import Device from tvm.runtime import device as as_device -from tvm.runtime import ndarray from tvm.runtime.vm import VirtualMachine from tvm.target import Target @@ -225,7 +225,7 @@ class Parameter(Tensor): it is called a bound parameter, otherwise it is called an unbound parameter. """ - _data: Optional[NDArray] + _data: Optional[Tensor] attrs: Dict[str, Any] def __init__( @@ -251,16 +251,16 @@ def __init__( self.attrs = OrderedDict() @property - def data(self) -> Optional[NDArray]: + def data(self) -> Optional[Tensor]: """Returns the concrete value of the parameter if it is bound to a concrete value, - otherwise returns None. The returned value is a tvm.runtime.NDArray.""" + otherwise returns None. The returned value is a tvm.runtime.Tensor.""" return self._data @data.setter - def data(self, data: Union[None, NDArray, np.ndarray, "torch.Tensor"]) -> None: + def data(self, data: Union[None, tvm.runtime.Tensor, np.ndarray, "torch.Tensor"]) -> None: """Set the concrete value of the parameter. The data should be one of the following: - None: unbind the parameter to concrete values - - tvm.runtime.NDArray + - tvm.runtime.Tensor - numpy.ndarray - torch.Tensor and any other DLPack-compliant tensors """ @@ -268,10 +268,10 @@ def data(self, data: Union[None, NDArray, np.ndarray, "torch.Tensor"]) -> None: self._data = data return # Try to do zero-copy if possible - if isinstance(data, NDArray): + if isinstance(data, tvm.runtime.Tensor): pass elif isinstance(data, np.ndarray): - data = ndarray.array(data) + data = tvm.runtime.tensor(data) elif hasattr(data, "__dlpack__"): data = _from_dlpack(data) else: @@ -526,7 +526,7 @@ def _compile(spec, device, pipeline, debug): ), device, ) - params = _param_to_ndarray(params, device) + params = _param_to_tensor(params, device) return spec, vm, params device = as_device(device) @@ -540,6 +540,56 @@ def _compile(spec, device, pipeline, debug): raise ValueError(f"Unknown out_format: {out_format}") +class ModuleDict(Module): + """Holds submodules in a dict.""" + + def __init__(self, modules: Optional[OrderedDict[str, Module]] = None): + if modules is None: + self.modules = OrderedDict() + else: + self.modules = OrderedDict(modules) + + def __iter__(self): + return iter(self.modules.values()) + + def __getitem__(self, key: str) -> Module: + return self.modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.modules[key] = module + + def __len__(self) -> int: + return len(self.modules) + + def keys(self) -> Iterator[str]: + return self.modules.keys() + + def values(self) -> Iterator[Module]: + return self.modules.values() + + def items(self) -> Iterator[Tuple[str, Module]]: + return self.modules.items() + + def get(self, key: str, default: Optional[Module] = None) -> Optional[Module]: + return self.modules.get(key, default) + + def update(self, modules: Dict[str, Module]) -> None: + self.modules.update(modules) + + def clear(self) -> None: + self.modules.clear() + + def pop(self, key: str) -> Module: + return self.modules.pop(key) + + def __contains__(self, key: str) -> bool: + return key in self.modules + + def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name + for module in self.modules.values(): + module.to(dtype=dtype) + + class ModuleList(Module): """Holds submodules in a list.""" @@ -611,6 +661,10 @@ def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any] for i, subitem in enumerate(root): yield from _attribute_finder(subitem, prefix + f"{i}.", condition_yield) return + elif isinstance(root, ModuleDict): + for name, subitem in root.items(): + yield from _attribute_finder(subitem, prefix + f"{name}.", condition_yield) + return for name, item in root.__dict__.items(): if condition_yield(item): yield prefix + name, item @@ -620,6 +674,13 @@ def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any] prefix + name + ".", condition_yield, ) + elif isinstance(item, ModuleDict): + for sub_name, sub_item in item.items(): + yield from _attribute_finder( + sub_item, + prefix + name + f".{sub_name}.", + condition_yield, + ) elif isinstance(item, Module): yield from _attribute_finder( item, @@ -628,24 +689,26 @@ def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any] ) -def _from_dlpack(tensor) -> NDArray: +def _from_dlpack(tensor) -> tvm.runtime.Tensor: try: - return ndarray.from_dlpack(tensor) + return tvm.runtime.from_dlpack(tensor) except RuntimeError: pass # special logic for PyTorch device_type = tensor.device.type device_id = tensor.device.index or 0 - return ndarray.array( + return tvm.runtime.tensor( tensor.numpy(), device=Device( - Device.DEVICE_NAME_TO_TYPE[device_type], + Device._DEVICE_NAME_TO_TYPE[device_type], device_id, ), ) -def _param_to_ndarray(params: List[Tuple[str, Parameter]], device: Device) -> List[NDArray]: +def _param_to_tensor( + params: List[Tuple[str, Parameter]], device: Device +) -> List[tvm.runtime.Tensor]: results = [] missing = [] for name, param in params: diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index e7248b0f4b27..b35f6e0d220c 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -310,8 +310,8 @@ def get_includes(tvm_pkg: Optional[List[str]] = None) -> List[Path]: results = [ tvm_home / "include", tvm_home / "3rdparty/dmlc-core/include", - tvm_home / "ffi/include", - tvm_home / "ffi/3rdparty/dlpack/include", + tvm_home / "3rdparty/tvm-ffi/include", + tvm_home / "3rdparty/tvm-ffi/3rdparty/dlpack/include", ] if tvm_pkg: for relative in tvm_pkg: @@ -387,12 +387,14 @@ def compile(self, output_path: Path) -> None: options=self.compile_options, cc=self.compiler, cwd=temp_dir, - ccache_env={ - "CCACHE_COMPILERCHECK": "content", - "CCACHE_NOHASHDIR": "1", - } - if shutil.which("ccache") - else None, + ccache_env=( + { + "CCACHE_COMPILERCHECK": "content", + "CCACHE_NOHASHDIR": "1", + } + if shutil.which("ccache") + else None + ), ) shutil.move(str(object_path), str(output_path)) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index e6e171da9903..e94d5c42957b 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -371,8 +371,7 @@ def __init__( # pylint: disable=too-many-locals enable_disaggregation : bool Whether to enable disaggregation in the KV cache. """ - if rope_mode == RopeMode.INLINE: - assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not support partial rotary dim." + assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support inline mode." attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind if attn_kind_single == "mha_sliding": @@ -383,8 +382,8 @@ def __init__( # pylint: disable=too-many-locals dtype_o=dtype, qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim), v_head_dim=(v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim), - target=target, - enable_inline_rope=rope_mode == RopeMode.INLINE, + enable_inline_rope=False, + return_static_libs=True, ) flashinfer_decode_mods = ( rx.backend.cuda.flashinfer.gen_flashinfer_decode_module( @@ -393,7 +392,8 @@ def __init__( # pylint: disable=too-many-locals dtype_o=dtype, qk_head_dim=qk_head_dim, v_head_dim=v_head_dim, - target=target, + enable_inline_rope=False, + return_static_libs=True, ) if attn_kind_single == "mha" else [] @@ -405,7 +405,7 @@ def __init__( # pylint: disable=too-many-locals dtype_o=dtype, head_dim_ckv=v_head_dim, head_dim_kpe=qk_head_dim - v_head_dim, - target=target, + return_static_libs=True, ) if attn_kind_single == "mla" else [] @@ -417,8 +417,8 @@ def __init__( # pylint: disable=too-many-locals bb = rx.BlockBuilder.current() mha_functions = ( [ - rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_with_paged_kv_cache_run"), rx.ExternFunc("batch_prefill_with_kv_cache_plan")]), - rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_decode_with_paged_kv_cache_run"), rx.ExternFunc("batch_decode_with_paged_kv_cache_plan")]), + rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_paged_run"), rx.ExternFunc("batch_prefill_plan")]), + rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_decode_run"), rx.ExternFunc("batch_decode_plan")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), @@ -427,7 +427,8 @@ def __init__( # pylint: disable=too-many-locals if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] ) - mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" else []) + ragged_prefill_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_ragged_run"), rx.ExternFunc("batch_prefill_plan")]) if attn_kind_single == "mha" else rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_ragged_run"), rx.ExternFunc("batch_prefill_plan"), rx.PrimValue(mla_original_qk_head_dim), rx.PrimValue(mla_original_v_head_dim)]) + mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_run"), rx.ExternFunc("batch_mla_plan")] if attn_kind_single == "mla" else []) attn_merge_functions = [ bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), ] @@ -463,7 +464,7 @@ def __init__( # pylint: disable=too-many-locals rx.op.zeros((), dtype), bb.add_func(_kv_cache_transpose_append(num_key_value_heads, qk_head_dim, dtype), "kv_cache_transpose_append"), bb.add_func(_kv_cache_transpose_append_mla(qk_head_dim, dtype), "kv_cache_transpose_append_mla"), - rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_with_ragged_kv_cache_run"), rx.ExternFunc("batch_prefill_with_kv_cache_plan")]), + ragged_prefill_function, *mha_functions, mla_function, rx.Tuple(attn_merge_functions), diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 1a1659b29e18..35eeb4f5f32f 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st return cos_freq, sin_freq, {freq_var: freq} +def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals + s: tir.Var, + d: tir.Var, + d_range: int, + theta: float, + dtype: str, + factor: float, + low_freq_factor: float, + high_freq_factor: float, + original_max_position_embeddings: float, +): + """Compute the inverse frequency of RoPE for llama4 RoPE scaling.""" + orig_freq = tir.const(1, "float32") / tir.power( + theta, 2 * (d // 2) / tir.const(d_range, "float32") + ) + orig_freq_var = tir.Var("orig_freq", "float32") + + llama4_inv_scaling_factor = 1.0 / factor + + if high_freq_factor == low_freq_factor: + wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var + threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") + + scaled_freq = tir.if_then_else( + wavelength > threshold_wavelen, orig_freq_var / factor, orig_freq_var + ) + smoothed_freq = s * scaled_freq + + else: + # Original smooth interpolation logic + inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) + + llama4_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor + llama4_beta = low_freq_factor * inv_diff_freq_factor + smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var - llama4_beta)) + smoothed_freq = s * ( + (1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor + smooth * orig_freq_var + ) + + smoothed_freq_var = tir.Var("smoothed_freq", "float32") + cos_freq = tir.cos(smoothed_freq_var).astype(dtype) + sin_freq = tir.sin(smoothed_freq_var).astype(dtype) + return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq} + + def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals s: tir.Var, d: tir.Var, @@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable: high_freq_factor=rope_scaling["high_freq_factor"], original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], ) + if rope_scaling["rope_type"] == "llama4": + return partial( + rope_freq_llama4, + factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + ) if rope_scaling["rope_type"] == "longrope": return partial( rope_freq_longrope, @@ -411,6 +464,10 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments rotary_dim = head_dim scale = tir.const(scale, "float32") is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" + if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling: + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + else: + original_max_position_embeddings = 0 def _rope( # pylint: disable=too-many-arguments x: T.Buffer, @@ -486,6 +543,226 @@ def fused_rope( # pylint: disable=too-many-locals else: v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + @T.prim_func + def fused_rope_longrope_scaling( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": True, + } + ) + seq_len = T.int64() + position_map_elem_offset = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + # long factors is the first half, short factors is the second half + long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data) + short_factors = T.Buffer( + (rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2) + ) + + if seq_len > original_max_position_embeddings: + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + long_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + long_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + else: + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + short_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + short_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + + if is_longrope_scaling: + return fused_rope_longrope_scaling + return fused_rope + + +def llama4_rope_with_position_map( # pylint: disable=too-many-arguments + theta: float, + scale: float, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + dtype: str, + rope_scaling: Dict[str, Any], + rotary_dim: Optional[int] = None, +): + """Return the TIR function that computes Llama-style RoPE with q position map. + + Parameters + ---------- + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + head_dim : int + The number of features on each head. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + dtype : str + The dtype of qkv data. + + rope_scaling : Dict + The configuration of RoPE scaling. + + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + """ + fused_heads = num_q_heads + num_kv_heads * 2 + if rotary_dim is None: + rotary_dim = head_dim + scale = tir.const(scale, "float32") + is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" + + def _rope( # pylint: disable=too-many-arguments + x: T.Buffer, + s: tir.Var, + h: tir.Var, + d: tir.Var, + pos: tir.Var, + ext_factors: Optional[T.Buffer] = None, + ): + kwargs = {} + if ext_factors: + kwargs["ext_factors"] = ext_factors + cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( + pos * scale, d, rotary_dim, theta, "float32", **kwargs + ) + cos = cos_freq * x[s, h, d].astype("float32") + if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj": + sin = sin_freq * tir.if_then_else( + d % 2 == 0, + -x[s, h, d + 1], + x[s, h, d - 1], + ).astype("float32") + else: + # Data layout is different for llama4 vs llama3 + sin = sin_freq * tir.if_then_else( + d % 2 == 0, + -x[s, h, d + 1], + x[s, h, d - 1], + ).astype("float32") + expr = (cos + sin).astype(dtype) + for var, value in var_map.items(): + expr = tir.Let(var, value, expr) + return expr + + @T.prim_func(private=True) + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + apply_rope: T.int64, + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": True, + } + ) + seq_len = T.int32() + position_map_elem_offset = T.int32() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + @T.prim_func def fused_rope_longrope_scaling( # pylint: disable=too-many-locals var_qkv: T.handle, diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index b61656a2e6bd..5ca5f72787b7 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -27,7 +27,7 @@ class IOEffect(Effect): """ - Modeling IO side effect, for example, printing the content of NDArrays on screen, inserting + Modeling IO side effect, for example, printing the content of Tensors on screen, inserting debug breakpoints, etc. """ diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 1e42c862fee6..50d4772d8ca1 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1174,6 +1174,92 @@ def exp(x: Tensor, name: str = "exp") -> Tensor: return wrap_nested(_op.exp(x._expr), name) +def log(x: Tensor, name: str = "log") -> Tensor: + r"""Applies the natural logarithm function. + + .. math:: + \text{Log}(x) = \log(x) + + Parameters + ---------- + x : Tensor + The input data to the operator. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.log(x._expr), name) + + +def floor(x: Tensor, name: str = "floor") -> Tensor: + r"""Computes the floor of the input tensor. + + .. math:: + \text{Floor}(x) = \floor(x) + + Parameters + ---------- + x : Tensor + The input data to the operator. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.floor(x._expr), name) + + +def arange( + start: int, + end: Optional[int] = None, + step: int = 1, + dtype: Optional[str] = "float32", + name: str = "arange", +) -> Tensor: + r"""Construct a tensor with evenly spaced elements. + + Parameters + ---------- + start : int + The start of the interval. + + end : Optional[int] + The end of the interval. If not given, it will be set to start, + and start will be set to 0. + + step : int + The step size. + + dtype : Optional[str] + The data type of the created tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.arange(start, end, step, dtype), name) + + def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor: """Permutes the dimensions of the input tensor. @@ -2087,7 +2173,7 @@ def extern( out: OutType, ) -> OutType: """Invoke an extern function during runtime. The extern function must be registered with the " - TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_func` (Python). + TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_global_func` (Python). Parameters ---------- @@ -2144,7 +2230,7 @@ def debug_func( .. code-block:: python - @tvm.register_func(name_of_debug_func) + @tvm.register_global_func(name_of_debug_func) def debug_func(lineno: str, arg_0, arg_1, ...) -> None: ... diff --git a/python/tvm/relax/frontend/nn/torch.py b/python/tvm/relax/frontend/nn/torch.py index ae98868dae09..183cb11731e3 100644 --- a/python/tvm/relax/frontend/nn/torch.py +++ b/python/tvm/relax/frontend/nn/torch.py @@ -21,7 +21,7 @@ import torch from tvm.ir import Array -from tvm.runtime import NDArray, ShapeTuple, ndarray +from tvm.runtime import Tensor, ShapeTuple, _tensor from tvm.runtime.vm import VirtualMachine from . import core @@ -34,14 +34,14 @@ class TorchModule: # pylint: disable=too-few-public-methods spec: _spec.ModuleSpec vm: VirtualMachine # pylint: disable=invalid-name - params: List[NDArray] + params: List[Tensor] effects: List[Any] def __init__( # pylint: disable=invalid-name self, spec: _spec.ModuleSpec, vm: VirtualMachine, - params: List[NDArray], + params: List[Tensor], ): try: self.effects = vm["_initialize_effect"]() @@ -87,7 +87,7 @@ def _closure(*args): def _tvm_to_torch(arg): if isinstance(arg, (list, tuple, Array)): return [_tvm_to_torch(i) for i in arg] - if isinstance(arg, ndarray.NDArray): + if isinstance(arg, _tensor.Tensor): return torch.utils.dlpack.from_dlpack(arg) if isinstance(arg, ShapeTuple): return list(arg) diff --git a/python/tvm/relax/frontend/nn/visitor.py b/python/tvm/relax/frontend/nn/visitor.py index 82f301006697..d2467a2bf81d 100644 --- a/python/tvm/relax/frontend/nn/visitor.py +++ b/python/tvm/relax/frontend/nn/visitor.py @@ -79,6 +79,24 @@ def visit_param(self, name: str, node: nn.Effect) -> Any: """ return self.visit(name, node) + def visit_moduledict(self, name: str, node: nn.ModuleDict) -> Any: + """The base visiting method for mutation of nn.ModuleDict nodes. + + Parameters + ---------- + name : str + The name of the current node in parent's attribute. + + node : nn.ModuleDict + The current node of nn.ModuleDict to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ + return self.visit(name, node) + def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any: """The base visiting method for mutation of nn.ModuleList nodes. @@ -88,7 +106,7 @@ def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any: The name of the current node in parent's attribute. node : nn.ModuleList - The current node of nn.MoModuleListdule to mutate. + The current node of nn.ModuleList to mutate. Returns ------ @@ -124,7 +142,9 @@ def _get_child_name(parent: str, child: str) -> str: if isinstance(node, nn.ModuleList): for i in range(len(node)): - if isinstance(node[i], nn.ModuleList): + if isinstance(node[i], nn.ModuleDict): + node[i] = self.visit_moduledict(f"{name}.{i}", node[i]) + elif isinstance(node[i], nn.ModuleList): node[i] = self.visit_modulelist(f"{name}.{i}", node[i]) elif isinstance(node[i], nn.Module): node[i] = self.visit_module(f"{name}.{i}", node[i]) @@ -132,9 +152,23 @@ def _get_child_name(parent: str, child: str) -> str: node[i] = self.visit_effect(f"{name}.{i}", node[i]) elif isinstance(node[i], nn.Parameter): node[i] = self.visit_param(f"{name}.{i}", node[i]) + elif isinstance(node, nn.ModuleDict): + for k, v in node.items(): + if isinstance(v, nn.ModuleDict): + node[k] = self.visit_moduledict(_get_child_name(name, k), v) + elif isinstance(v, nn.ModuleList): + node[k] = self.visit_modulelist(_get_child_name(name, k), v) + elif isinstance(v, nn.Module): + node[k] = self.visit_module(_get_child_name(name, k), v) + elif isinstance(v, nn.Effect): + node[k] = self.visit_effect(_get_child_name(name, k), v) + elif isinstance(v, nn.Parameter): + node[k] = self.visit_param(_get_child_name(name, k), v) else: for key, value in node.__dict__.items(): - if isinstance(value, nn.ModuleList): + if isinstance(value, nn.ModuleDict): + setattr(node, key, self.visit_moduledict(_get_child_name(name, key), value)) + elif isinstance(value, nn.ModuleList): setattr(node, key, self.visit_modulelist(_get_child_name(name, key), value)) elif isinstance(value, nn.Module): setattr(node, key, self.visit_module(_get_child_name(name, key), value)) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b91106e64a91..24a4014f840a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -340,6 +340,8 @@ def base_impl(cls, bb, inputs, attr, params): x = _to_numpy(inputs[0]) y = _to_numpy(inputs[1]) output = cls.numpy_op(x, y) # pylint: disable=not-callable + if isinstance(x, relax.PrimValue) and isinstance(y, relax.PrimValue): + return relax.PrimValue(output.item()) if x.dtype == y.dtype: # no numpy precision widening output = output.astype(x.dtype) @@ -643,11 +645,34 @@ class Transpose(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): + data = inputs[0] axes = attr.get("perm", None) - if isinstance(inputs[0], relax.Constant): - output = _np.transpose(inputs[0].data.numpy(), axes) + + if hasattr(data.struct_info, "ndim"): + input_ndim = data.struct_info.ndim + elif hasattr(data.struct_info, "shape") and data.struct_info.shape: + input_ndim = len(data.struct_info.shape) + else: + if isinstance(data, relax.Constant): + input_ndim = data.data.numpy().ndim + else: + input_ndim = None + + if input_ndim == 0: + return data + + if input_ndim is not None and axes is not None: + if len(axes) != input_ndim: + raise ValueError( + f"Transpose: number of axes in perm attribute ({len(axes)}) " + f"must equal the number of input tensor dimensions ({input_ndim})" + ) + + if isinstance(data, relax.Constant): + output = _np.transpose(data.data.numpy(), axes) return relax.const(output, output.dtype) - return relax.op.permute_dims(inputs[0], axes) + + return relax.op.permute_dims(data, axes) class Unsqueeze(OnnxOpConverter): @@ -1155,11 +1180,12 @@ class FastGelu(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - if inputs[1]: + x = inputs[0] + if len(inputs) > 1 and inputs[1] is not None: bias = inputs[1] bias_shape = bias.struct_info.shape assert len(bias_shape) == 1, "bias term must be a 1D tensor" - x += bias + x = bb.emit(relax.op.add(x, bias)) # Declare consts const_dtype = x.struct_info.dtype @@ -1169,11 +1195,13 @@ def _impl_v1(cls, bb, inputs, attr, params): const2 = relax.const(0.044715 * math.sqrt(2 / math.pi), dtype=const_dtype) # Compute FastGelu - term1 = relax.op.multiply(half, x) - term2 = relax.op.multiply(const1, x) - term3 = relax.op.multiply(const2, relax.op.power(x, relax.const(3, const_dtype))) - tanh = relax.op.tanh(relax.op.add(term2, term3)) - return relax.op.multiply(term1, relax.op.add(one, tanh)) + term1 = bb.emit(relax.op.multiply(half, x)) + term2 = bb.emit(relax.op.multiply(const1, x)) + # use x^3 = x * x * x instead of pow(x, 3) for better performance + x_cubed = bb.emit(relax.op.multiply(relax.op.multiply(x, x), x)) + term3 = bb.emit(relax.op.multiply(const2, x_cubed)) + tanh = bb.emit(relax.op.tanh(relax.op.add(term2, term3))) + return bb.emit(relax.op.multiply(term1, relax.op.add(one, tanh))) class BiasGelu(OnnxOpConverter): @@ -1698,7 +1726,8 @@ class Softplus(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): dtype = inputs[0].struct_info.dtype - return relax.op.log(relax.op.exp(inputs[0]) + relax.const(1, dtype=dtype)) + threshold = 10.0 if dtype == "float16" else 20.0 + return relax.op.nn.softplus(inputs[0], threshold=threshold) class Softsign(OnnxOpConverter): @@ -1909,15 +1938,47 @@ def _impl_v13(cls, bb, inputs, attr, params): if isinstance(shape, relax.ShapeExpr): data_shape = list(data.struct_info.shape) target_shape = list(shape.values) + original_data_shape = [ + dim.value if hasattr(dim, "value") else str(dim) for dim in data_shape + ] + original_target_shape = [ + dim.value if hasattr(dim, "value") else str(dim) for dim in target_shape + ] data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape assert len(data_shape) == len(target_shape) - # Fix small target shapes or target shapes assigned to -1 + # Apply ONNX v13 Expand broadcasting rules for i, s in enumerate(target_shape): - if isinstance(s, tvm.tir.IntImm) and ( - (isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i]) - or s.value == -1 - ): - target_shape[i] = data_shape[i] + if isinstance(s, tvm.tir.IntImm): + if s.value == -1: + # -1 means preserve the input dimension + target_shape[i] = data_shape[i] + elif isinstance(data_shape[i], tvm.tir.IntImm) and data_shape[i].value == 1: + # Input dimension is 1, can broadcast to any target dimension >= 1 + if s.value < 1: + raise ValueError( + f"ONNX Expand: Invalid target dimension {s.value} " + f"at possition {i}. Target dimensions must be >= 1." + ) + elif ( + isinstance(data_shape[i], tvm.tir.IntImm) and s.value == data_shape[i].value + ): + # Dimensions match, no change needed + pass + elif s.value == 1: + # Target dimension is 1 but input dimension is not 1 + # This would "squeeze" the dimension - preserve input for safety + target_shape[i] = data_shape[i] + else: + if isinstance(data_shape[i], tvm.tir.IntImm): + raise ValueError( + f"ONNX Expand: Cannot broadcast input shape {original_data_shape} " + f"to target shape {original_target_shape}. " + f"At dimension {i}: input size {data_shape[i].value} is " + f"incompatible with target size {s.value}. " + f"ONNX broadcasting requires corresponding dimensions to have " + f"the same value or one of them to be 1." + ) + # For dynamic shapes, let broadcast_to handle it if target_shape == data_shape: return data return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape)) @@ -1928,6 +1989,8 @@ def _impl_v13(cls, bb, inputs, attr, params): # ONNX Expand operator requires preserving target rank and broadcasting # according to standard rules. Dimensions are right-aligned. data_shape = [dim.value for dim in data.struct_info.shape] + original_data_shape = data_shape.copy() + original_new_shape = new_shape.copy() # Right-align the shapes if len(new_shape) > len(data_shape): @@ -1937,8 +2000,32 @@ def _impl_v13(cls, bb, inputs, attr, params): # Fix small target shapes - if target dim is smaller than input dim # use the input dim (ONNX-specific behavior). for i in range(len(new_shape)): - if new_shape[i] < data_shape[i]: + if new_shape[i] == -1: + # -1 means preserve the input dimension + new_shape[i] = data_shape[i] + elif data_shape[i] == 1: + # Input dimension is 1, can broadcast to any target dimension >= 1 + if new_shape[i] < 1: + raise ValueError( + f"ONNX Expand: Invalid target dimension {new_shape[i]} " + f"at possition {i}. Target dimensions must be >= 1." + ) + elif new_shape[i] == data_shape[i]: + # Dimensions match, no change needed + pass + elif new_shape[i] == 1: + # Target dimension is 1 but input dimension is not 1 + # This would "squeeze" the dimension - preserve input for safety new_shape[i] = data_shape[i] + else: + raise ValueError( + f"ONNX Expand: Cannot broadcast input shape {original_data_shape} " + f"to target shape {original_new_shape}. " + f"At dimension {i}: input size {data_shape[i]} is incompatible " + f"with target size {new_shape[i]}. " + f"ONNX broadcasting requires corresponding dimensions to have the same " + f"value or one of them to be 1." + ) return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape)) # Otherwise handle dynamic shapes. @@ -1955,7 +2042,18 @@ def _impl_v13(cls, bb, inputs, attr, params): for i in range(shape_ndim): shape_vars.append(tvm.tir.Var("x_%d" % i, "int64")) bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) - return bb.normalize(relax.op.broadcast_to(data, relax.ShapeExpr(shape_vars))) + + # Applying broadcasting rules for dynamic shapes + data_shape = list(data.struct_info.shape) + data_ndim = len(data_shape) + target_ndim = shape_ndim + padded_data = data + + if target_ndim > data_ndim: + padded_data_shape = [tir.IntImm("int64", 1)] * (target_ndim - data_ndim) + data_shape + padded_data = bb.normalize(relax.op.reshape(data, relax.ShapeExpr(padded_data_shape))) + + return bb.normalize(relax.op.broadcast_to(padded_data, relax.ShapeExpr(shape_vars))) class Attention(OnnxOpConverter): @@ -3385,6 +3483,182 @@ def _impl_v11(cls, bb, inputs, attr, params): return input_sequence[position] +class NonMaxSuppression(OnnxOpConverter): + """Converts an onnx NonMaxSuppression node into an equivalent Relax expression.""" + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + """ + NonMaxSuppression performs non-maximum suppression (NMS) on all classes. + + Inputs: + - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2] + - scores: (N, C) tensor of scores for each box and class + - max_output_boxes_per_class: maximum number of boxes to keep per class + - iou_threshold: IoU threshold for NMS + - score_threshold: score threshold for filtering + + Outputs: + - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx] + """ + boxes = inputs[0] + scores = inputs[1] + max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None + iou_threshold = inputs[3] if len(inputs) > 3 else None + score_threshold = inputs[4] if len(inputs) > 4 else None + + center_point_box = attr.get("center_point_box", 0) + + if max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Constant + ): + max_output_boxes_per_class = int(max_output_boxes_per_class.data.numpy()) + elif max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Var + ): + var_name = max_output_boxes_per_class.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + max_output_boxes_per_class = int(param_value.numpy().item()) + else: + max_output_boxes_per_class = 100 # Default value + else: + max_output_boxes_per_class = 100 # Default value + + if iou_threshold is not None and isinstance(iou_threshold, relax.Constant): + iou_threshold = float(iou_threshold.data.numpy()) + else: + iou_threshold = 0.5 # Default value + + if score_threshold is not None and isinstance(score_threshold, relax.Constant): + score_threshold = float(score_threshold.data.numpy()) + elif score_threshold is not None and isinstance(score_threshold, relax.Var): + var_name = score_threshold.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + score_threshold = float(param_value.numpy().item()) + else: + score_threshold = 0.0 # Default value + else: + score_threshold = 0.0 # Default value + + if center_point_box != 0: + split_result = relax.op.split(boxes, 4, axis=2) + xc = split_result[0] + yc = split_result[1] + w = split_result[2] + h = split_result[3] + half_w = w / relax.const(2.0, boxes.struct_info.dtype) + half_h = h / relax.const(2.0, boxes.struct_info.dtype) + x1 = xc - half_w + x2 = xc + half_w + y1 = yc - half_h + y2 = yc + half_h + boxes = relax.op.concat([y1, x1, y2, x2], axis=2) + + nms_out = bb.normalize( + relax.op.vision.all_class_non_max_suppression( + boxes, + scores, + relax.const(max_output_boxes_per_class, dtype="int64"), + relax.const(iou_threshold, dtype="float32"), + relax.const(score_threshold, dtype="float32"), + output_format="onnx", + ) + ) + + selected_indices = bb.emit(relax.TupleGetItem(nms_out, 0)) + + return selected_indices + + +class AllClassNMS(OnnxOpConverter): + """Converts an onnx AllClassNMS node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + """ + AllClassNMS performs non-maximum suppression (NMS) on all classes. + + Inputs: + - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2] + - scores: (N, C) tensor of scores for each box and class + - max_output_boxes_per_class: maximum number of boxes to keep per class + - iou_threshold: IoU threshold for NMS + - score_threshold: score threshold for filtering + + Outputs: + - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx] + """ + boxes = inputs[0] + scores = inputs[1] + max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None + iou_threshold = inputs[3] if len(inputs) > 3 else None + score_threshold = inputs[4] if len(inputs) > 4 else None + + center_point_box = attr.get("center_point_box", 0) + + if max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Constant + ): + max_output_boxes_per_class = int(max_output_boxes_per_class.data.numpy()) + elif max_output_boxes_per_class is not None and isinstance( + max_output_boxes_per_class, relax.Var + ): + var_name = max_output_boxes_per_class.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + max_output_boxes_per_class = int(param_value.numpy().item()) + else: + max_output_boxes_per_class = 100 # Default value + else: + max_output_boxes_per_class = 100 # Default value + + if iou_threshold is not None and isinstance(iou_threshold, relax.Constant): + iou_threshold = float(iou_threshold.data.numpy()) + else: + iou_threshold = 0.5 # Default value + + if score_threshold is not None and isinstance(score_threshold, relax.Constant): + score_threshold = float(score_threshold.data.numpy()) + elif score_threshold is not None and isinstance(score_threshold, relax.Var): + var_name = score_threshold.name_hint + if var_name in params[1]: + _, param_value = params[1][var_name] + score_threshold = float(param_value.numpy().item()) + else: + score_threshold = 0.0 # Default value + else: + score_threshold = 0.0 # Default value + + if center_point_box != 0: + split_result = relax.op.split(boxes, 4, axis=2) + xc = split_result[0] + yc = split_result[1] + w = split_result[2] + h = split_result[3] + half_w = w / relax.const(2.0, boxes.struct_info.dtype) + half_h = h / relax.const(2.0, boxes.struct_info.dtype) + x1 = xc - half_w + x2 = xc + half_w + y1 = yc - half_h + y2 = yc + half_h + boxes = relax.op.concat([y1, x1, y2, x2], axis=2) + + nms_out = bb.normalize( + relax.op.vision.all_class_non_max_suppression( + boxes, + scores, + relax.const(max_output_boxes_per_class, dtype="int64"), + relax.const(iou_threshold, dtype="float32"), + relax.const(score_threshold, dtype="float32"), + output_format="onnx", + ) + ) + + return nms_out + + def _get_convert_map(): return { # defs/experimental @@ -3535,7 +3809,8 @@ def _get_convert_map(): # "LRN": LRN, # "MaxRoiPool": MaxRoiPool, # "RoiAlign": RoiAlign, - # "NonMaxSuppression": NonMaxSuppression, + "NonMaxSuppression": NonMaxSuppression, + "AllClassNMS": AllClassNMS, # "GridSample": GridSample, "Upsample": Upsample, # others @@ -3829,9 +4104,9 @@ def _parse_value_proto(self, value_proto: onnx.onnx_ml_pb2.GraphProto): name = value_proto return name - def _parse_array(self, tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> tvm.nd.array: + def _parse_array(self, tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> tvm.runtime.tensor: np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims)) - return tvm.nd.array(np_array) + return tvm.runtime.tensor(np_array) def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> Dict[str, Any]: """Convert a list of AttributeProto to a dict, with names as keys.""" diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1895119e79f4..47eb66621008 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -23,6 +23,7 @@ import math from typing import Callable, Dict, Optional, Tuple, Union, List +import tvm from tvm import relax, tir @@ -45,6 +46,17 @@ def __init__(self) -> None: ########## Utilities ########## + def update_convert_map(self, custom_convert_map: Dict[str, Callable]): + """Update self.convert_map with custom convert map + + Parameters + ---------- + custom_convert_map : Dict[str, Callable] + A custom op conversion map in the same format as self.convert_map + """ + + self.convert_map.update(custom_convert_map) + @staticmethod def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): """converts the PyTorch scalar type input_type to a TVM dtype.""" @@ -54,16 +66,34 @@ def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] input_type = env[input_type] input_type = input_type.lower() if isinstance(input_type, str) else input_type - if input_type in ["float", "float32", "torch.float32", torch.float32]: - return "float32" - elif input_type in ["float16", "torch.float16", torch.float16]: + # Float types + if input_type in ["float16", "torch.float16", torch.float16]: return "float16" + elif input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float64", "double", "torch.float64", torch.float64]: + return "float64" elif input_type in ["bfloat16", "torch.bfloat16", torch.bfloat16]: return "bfloat16" - elif input_type in ["int64", "torch.int64", torch.int64]: - return "int64" + # Signed integer types + elif input_type in ["int8", "torch.int8", torch.int8]: + return "int8" + elif input_type in ["int16", "torch.int16", torch.int16]: + return "int16" elif input_type in ["int32", "torch.int32", torch.int32]: return "int32" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + # Unsigned integer types + elif input_type in ["uint8", "torch.uint8", torch.uint8]: + return "uint8" + elif input_type in ["uint16", "torch.uint16", torch.uint16]: + return "uint16" + elif input_type in ["uint32", "torch.uint32", torch.uint32]: + return "uint32" + elif input_type in ["uint64", "torch.uint64", torch.uint64]: + return "uint64" + # Boolean elif input_type in ["bool", "torch.bool", torch.bool]: return "bool" else: @@ -88,6 +118,52 @@ def shape_of(tensor): return tensor.shape raise ValueError("Unsupported type: {}".format(type(tensor))) + @staticmethod + def _promote_common_dtype(lhs_dtype: Optional[str], rhs_dtype: Optional[str]) -> Optional[str]: + """Return the promoted dtype following PyTorch rules, or None if unsupported.""" + import torch # type: ignore + + if lhs_dtype is None or rhs_dtype is None or lhs_dtype == rhs_dtype: + return None + + tvm_to_torch = { + "float64": torch.float64, + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "int64": torch.int64, + "int32": torch.int32, + "int16": torch.int16, + "int8": torch.int8, + "uint8": torch.uint8, + "bool": torch.bool, + } + torch_to_tvm = {v: k for k, v in tvm_to_torch.items()} + + lhs_torch = tvm_to_torch.get(lhs_dtype) + rhs_torch = tvm_to_torch.get(rhs_dtype) + if lhs_torch is None or rhs_torch is None: + return None + + promoted = torch.promote_types(lhs_torch, rhs_torch) + return torch_to_tvm.get(promoted, None) + + @staticmethod + def _is_no_bias(bias): + """Check if bias represents 'no bias' condition. + + This handles both Python None and relax.op.null_value() expressions + that might be used to represent missing bias parameters. + """ + if bias is None: + return True + + # Check if this is a null_value expression + if isinstance(bias, relax.Call) and bias.op.name == "relax.null_value": + return True + + return False + def retrieve_args(self, node: fx.Node): return self._retrieve_args(node.args) @@ -102,6 +178,8 @@ def _retrieve_args(self, node): return [self._retrieve_args(x) for x in node] elif isinstance(node, dict): return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + elif node is None: + return None else: return node @@ -316,10 +394,19 @@ def _prelu(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) def _round(self, node: fx.Node) -> relax.Expr: - if node.kwargs.get("decimals", 0) != 0: - raise ValueError("specifying decimals for round is not supported yet") arg = self.env[node.args[0]] - return self.block_builder.emit(relax.op.round(arg)) + decimals = node.kwargs.get("decimals", 0) + + if decimals == 0: + return self.block_builder.emit(relax.op.round(arg)) + + # For decimals != 0, use: round(x * 10^decimals) / 10^decimals + dtype = arg.struct_info.dtype + scale = relax.const(10**decimals, dtype) + scaled = relax.op.multiply(arg, scale) + rounded = relax.op.round(scaled) + result = relax.op.divide(rounded, scale) + return self.block_builder.emit(result) def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -390,6 +477,17 @@ def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: def convert(node: fx.Node) -> relax.Var: def promote_binary_op_args(lhs, rhs): if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + lhs_si = getattr(lhs, "struct_info", None) + rhs_si = getattr(rhs, "struct_info", None) + if isinstance(lhs_si, relax.TensorStructInfo) and isinstance( + rhs_si, relax.TensorStructInfo + ): + target_dtype = self._promote_common_dtype(lhs_si.dtype, rhs_si.dtype) + if target_dtype is not None: + if lhs_si.dtype != target_dtype: + lhs = self.block_builder.emit(relax.op.astype(lhs, target_dtype)) + if rhs_si.dtype != target_dtype: + rhs = self.block_builder.emit(relax.op.astype(rhs, target_dtype)) return lhs, rhs elif isinstance(lhs, relax.Expr): assert isinstance(lhs.struct_info, relax.TensorStructInfo) @@ -635,6 +733,7 @@ def _avg_pool2d_impl( stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Optional[int] = 0, ceil_mode: Optional[bool] = False, + count_include_pad: Optional[bool] = True, ) -> relax.Var: # Expand to 4D by adding batch dim if input is 3D x_ndim = x.struct_info.ndim @@ -649,6 +748,7 @@ def _avg_pool2d_impl( strides=stride, padding=padding, ceil_mode=ceil_mode, + count_include_pad=count_include_pad, layout="NCHW", ) ) @@ -664,7 +764,8 @@ def _avg_pool2d(self, node: fx.Node) -> relax.Var: stride = args[2] if len(args) > 2 else kwargs.get("stride", None) padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) - return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + count_include_pad = args[5] if len(args) > 5 else kwargs.get("count_include_pad", True) + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode, count_include_pad) def _avg_pool3d_impl( self, @@ -756,7 +857,7 @@ def _conv_transpose1d_impl( ) ) - if bias is None: + if self._is_no_bias(bias): return conv1d_transpose assert len(self.shape_of(bias)) == 1 @@ -810,7 +911,7 @@ def _conv_transpose2d_impl( ) ) - if bias is None: + if self._is_no_bias(bias): return conv2d_transpose assert len(self.shape_of(bias)) == 1 @@ -862,7 +963,7 @@ def _conv1d_impl( ) ) - if bias is None: + if self._is_no_bias(bias): return conv1d assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1)) @@ -911,7 +1012,7 @@ def _conv2d_impl( ) ) - if bias is None: + if self._is_no_bias(bias): return conv2d assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1, 1)) @@ -960,7 +1061,7 @@ def _conv3d_impl( ) ) - if bias is None: + if self._is_no_bias(bias): return conv3d assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) @@ -985,6 +1086,80 @@ def _conv3d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _convolution(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + transposed = args[6] if len(args) > 6 else False + output_padding = args[7] if len(args) > 7 else 0 + groups = args[8] if len(args) > 8 else 1 + + input_shape = self.shape_of(x) + ndim = len(input_shape) + + if transposed: + if ndim == 3: # 1D convolution (N, C, W) + return self._conv_transpose1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + output_padding=output_padding, + ) + elif ndim == 4: # 2D convolution (N, C, H, W) + return self._conv_transpose2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + output_padding=output_padding, + ) + else: + raise ValueError(f"Unsupported transposed convolution dimensionality: {ndim}") + else: + if ndim == 3: # 1D convolution (N, C, W) + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + elif ndim == 4: # 2D convolution (N, C, H, W) + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + elif ndim == 5: # 3D convolution (N, C, D, H, W) + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + else: + raise ValueError(f"Unsupported convolution dimensionality: {ndim}") + def _cross_entropy_loss( self, preds: relax.Expr, @@ -1221,6 +1396,54 @@ def _max_pool3d(self, node: fx.Node) -> relax.Var: ceil_mode = args[5] if len(args) > 5 else False return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool1d_with_indices(self, node: fx.Node) -> relax.Var: + # max_pool1d_with_indices returns (output, indices) + # We only compute the output and create a placeholder for indices + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + output = self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + # Create a placeholder for indices (empty tensor with same shape as output) + indices = relax.op.zeros_like(output) + return self.block_builder.emit(relax.Tuple([output, indices])) + + def _max_pool2d_with_indices(self, node: fx.Node) -> relax.Var: + # max_pool2d_with_indices returns (output, indices) + # We only compute the output and create a placeholder for indices + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + output = self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + # Create a placeholder for indices (empty tensor with same shape as output) + indices = relax.op.zeros_like(output) + return self.block_builder.emit(relax.Tuple([output, indices])) + + def _max_pool3d_with_indices(self, node: fx.Node) -> relax.Var: + # max_pool3d_with_indices returns (output, indices) + # We only compute the output and create a placeholder for indices + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + output = self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + # Create a placeholder for indices (empty tensor with same shape as output) + indices = relax.op.zeros_like(output) + return self.block_builder.emit(relax.Tuple([output, indices])) + def _pad(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] pad = node.args[1] @@ -1239,6 +1462,23 @@ def _pad(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value)) + def _constant_pad_nd(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + pad = node.args[1] + value = node.args[2] if len(node.args) > 2 else node.kwargs.get("value", 0.0) + value = 0.0 if value is None else value + + # Calculate symmetric padding width for each dimension + # and applying them in reverse order to match the input dimensions. + input_ndim = x.struct_info.ndim + pad_width = [0] * (input_ndim * 2) + pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)] + reversed_pairs = list(reversed(pad_pairs)) + flattened = [v for pair in reversed_pairs for v in pair] + pad_width[-len(flattened) :] = flattened + + return self.block_builder.emit(relax.op.nn.pad(x, pad_width, "constant", value)) + def _pixel_shuffle(self, node: fx.Node) -> relax.Var: data = self.env[node.args[0]] upscale_factor = node.args[1] @@ -1249,10 +1489,49 @@ def _pixel_shuffle(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) + query_tensor = self.env[node.args[0]] + key_tensor = self.env[node.args[1]] + value_tensor = self.env[node.args[2]] + + # Check the dimensionality of the input tensors + query_ndim = len(query_tensor.struct_info.shape) + + # TVM's nn.attention requires 4D inputs in format (batch, num_heads, seq_len, head_dim) + # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first + if query_ndim == 2: + # 2D input: (seq_len, head_dim) -> expand to (1, 1, seq_len, head_dim) + # Add batch dimension at axis 0 + query_3d = self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0)) + key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor, axis=0)) + value_3d = self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0)) + # Add num_heads dimension at axis 1 + query = self.block_builder.emit(relax.op.expand_dims(query_3d, axis=1)) + key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=1)) + value = self.block_builder.emit(relax.op.expand_dims(value_3d, axis=1)) + + # No permutation needed for 2D inputs after expanding to 4D + # After attention, squeeze back to 2D: (1, 1, seq_len, head_dim) -> (seq_len, head_dim) + def transpose_and_reshape_back(tensor): + # Squeeze batch and num_heads dimensions + return self.block_builder.emit(relax.op.squeeze(tensor, axis=[0, 1])) + + elif query_ndim == 4: + # 4D input: (batch, seq_len, num_heads, head_dim) + # -> (batch, num_heads, seq_len, head_dim) + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = self.block_builder.emit(transpose_S_H(query_tensor)) + key = self.block_builder.emit(transpose_S_H(key_tensor)) + value = self.block_builder.emit(transpose_S_H(value_tensor)) + + # For 4D, transpose back after attention + def transpose_and_reshape_back(tensor): + return self.block_builder.emit(transpose_S_H(tensor)) + + else: + raise ValueError( + f"scaled_dot_product_attention expects 2D or 4D inputs, but got {query_ndim}D input" + ) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) assert dropout_p == 0.0, "Dropout is not supported" @@ -1264,20 +1543,24 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: msg = "Only a float mask is supported for the attn_mask input." assert "float" in attn_mask.struct_info.dtype, msg - return self.block_builder.emit( - transpose_S_H( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) - ) + attention_output = self.block_builder.emit( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) ) + return transpose_and_reshape_back(attention_output) + def _unbind(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) assert isinstance(dim, int), "Expected 2nd argument of unbind as int" selections = self.shape_of(x)[dim].value - ret, split = [], self.block_builder.emit(relax.op.split(x, selections, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + ret = [] + if selections == 1: + ret.append(self.block_builder.emit(relax.op.squeeze(x, axis=dim))) + else: + split = self.block_builder.emit(relax.op.split(x, selections, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) ########## Statistical ########## @@ -1357,6 +1640,21 @@ def _var(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.variance(x, dim, keepdims=keepdim)) + def _any(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + + # max doesn't support boolean tensors directly, so we compute it in int8 and cast back + if x.struct_info.dtype == "bool": + x = relax.op.astype(x, "int8") + ret = relax.op.max(x, dim, keepdims=keepdim) + return self.block_builder.emit(relax.op.astype(ret, "bool")) + + # For boolean tensors, any is equivalent to max (checking if any element is True) + return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim)) + ########## Search ########## def _argmax_argmin(self, op: Callable) -> Callable: @@ -1508,13 +1806,92 @@ def _index_put(self, node: fx.Node) -> relax.Var: raise TypeError("'accumulate' must be a boolean value, got {}".format(type(accumulate))) if isinstance(indices, (list, tuple)): - indices = relax.Tuple(indices) + # In PyTorch index_put, None means "select all elements" for that dimension + non_none_indices = [(i, idx) for i, idx in enumerate(indices) if idx is not None] + + if len(non_none_indices) < len(indices): + data_shape = self.shape_of(tensor) + processed_indices = [] + + max_ndim = max((idx.struct_info.ndim for _, idx in non_none_indices), default=1) + + for i, idx in enumerate(indices): + if idx is None: + # Replace None with arange for full dimension indexing + arange_idx = self.block_builder.emit( + relax.op.arange( + relax.PrimValue(0), data_shape[i], relax.PrimValue(1), "int64" + ) + ) + # Reshape to [dim_size, 1, 1, ...] for broadcasting + # Add an extra dimension so it broadcasts with other indices + arange_idx = self.block_builder.emit( + relax.op.reshape(arange_idx, [data_shape[i]] + [1] * max_ndim) + ) + processed_indices.append(arange_idx) + else: + processed_indices.append(idx) + + indices = relax.Tuple(processed_indices) + else: + indices = relax.Tuple(indices) return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) def _index_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) + data = args[0] indices = args[1] - return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) + + # In PyTorch's aten.index.Tensor, None means "select all elements" for that dimension + non_none_indices = [(i, idx) for i, idx in enumerate(indices) if idx is not None] + + # Special case: if there's only one non-None index, use take operation + if len(non_none_indices) == 1: + axis, index_tensor = non_none_indices[0] + return self.block_builder.emit(relax.op.take(data, index_tensor, axis=axis)) + + # Check if all indices can be squeezed to 1D for sequential take + def is_squeezable(idx): + if idx.struct_info.ndim == 1: + return True + if idx.struct_info.ndim == 2: + shape = idx.struct_info.shape + for d in shape: + if isinstance(d, int) and d == 1: + return True + # Check for tir.IntImm + if hasattr(d, "value") and d.value == 1: + return True + return False + + all_squeezable = all(is_squeezable(idx) for _, idx in non_none_indices) + if all_squeezable: + result = data + for axis, idx in reversed(non_none_indices): + if idx.struct_info.ndim > 1: + idx = self.block_builder.emit(relax.op.squeeze(idx)) + result = self.block_builder.emit(relax.op.take(result, idx, axis=axis)) + return result + + # General case: replace None with arange, reshaped for broadcasting + max_ndim = max((idx.struct_info.ndim for _, idx in non_none_indices), default=1) + processed_indices = [] + data_shape = self.shape_of(data) + + for i, idx in enumerate(indices): + if idx is None: + arange_idx = self.block_builder.emit( + relax.op.arange(relax.PrimValue(0), data_shape[i], relax.PrimValue(1), "int64") + ) + # Reshape to [dim_size, 1, 1, ...] for broadcasting + arange_idx = self.block_builder.emit( + relax.op.reshape(arange_idx, [data_shape[i]] + [1] * (max_ndim - 1)) + ) + processed_indices.append(arange_idx) + else: + processed_indices.append(idx) + + return self.block_builder.emit(relax.op.index_tensor(data, processed_indices)) def _meshgrid(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) @@ -1545,6 +1922,41 @@ def _slice_scatter(self, node: fx.Node) -> relax.Var: end = args[4] if len(args) > 4 else node.kwargs.get("end", self.shape_of(input_tensor)[dim]) step = args[5] if len(args) > 5 else node.kwargs.get("step", 1) + # Normalize bounds to match PyTorch behavior (negative and open-ended slices). + input_shape = self.shape_of(input_tensor) + axis = dim if dim >= 0 else dim + len(input_shape) + + def _normalize_bound(bound): + # PyTorch uses a large positive value (2^63-1) to mean "len". + max_index_val = 9223372036854775807 + + def _adjust(val): + if isinstance(val, (int, tir.IntImm)): + int_val = int(val) + if int_val >= max_index_val: + return input_shape[axis] + if int_val < 0: + return input_shape[axis] + int_val + if isinstance(input_shape[axis], (int, tir.IntImm)) and int_val > int( + input_shape[axis] + ): + return input_shape[axis] + return val + + if isinstance(bound, relax.PrimValue): + value = _adjust(bound.value) + return relax.PrimValue(value) + + bound = _adjust(bound) + if not isinstance(bound, relax.PrimValue): + bound = relax.PrimValue(bound) + return bound + + start = _normalize_bound(start) + end = _normalize_bound(end) + if not isinstance(step, relax.PrimValue): + step = relax.PrimValue(step) + return self.block_builder.emit( relax.op.slice_scatter(input_tensor, src, start, end, step, axis=dim) ) @@ -1650,6 +2062,12 @@ def _reshape(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + + # Skip identity reshape + current_shape = self.shape_of(x) + if list(current_shape) == list(dims): + return x + return self.block_builder.emit(relax.op.reshape(x, dims)) def _reshape_as(self, node: fx.Node) -> relax.Var: @@ -1700,6 +2118,23 @@ def _split(self, node: fx.Node) -> relax.Var: def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + # Support both "dim" and "dims" parameters + if dim is None: + dim = node.kwargs.get("dims", None) + + # If dims is a list, filter out axes where dimension is not 1 + # This is needed because PyTorch decomposition may pass all axes + if isinstance(dim, (list, tuple)) and len(dim) > 0: + shape = self.shape_of(x) + # Filter to only include axes where the dimension is 1 + valid_dims = [] + for d in dim: + axis = d if d >= 0 else len(shape) + d + if axis < len(shape): + valid_dims.append(d) + # If no valid dims, use None to squeeze all size-1 dimensions + dim = valid_dims if valid_dims else None + return self.block_builder.emit(relax.op.squeeze(x, dim)) def _stack(self, node: fx.Node) -> relax.Var: @@ -1753,9 +2188,21 @@ def _detach(self, node: fx.Node) -> relax.Var: return self.env[node.args[0]] def _copy_(self, node: fx.Node) -> relax.Var: - # Copies the source tensor's into the destination tensor - # In TVM, that means simply returning the source tensor - return self.env[node.args[1]] + dest = self.env[node.args[0]] + src = self.env[node.args[1]] + + # Match PyTorch semantics: cast to destination dtype and broadcast to destination shape. + if src.struct_info.dtype != dest.struct_info.dtype: + src = self.block_builder.emit(relax.op.astype(src, dest.struct_info.dtype)) + + dest_shape = self.shape_of(dest) + src_shape = self.shape_of(src) + if dest_shape != src_shape: + src = self.block_builder.emit(relax.op.broadcast_to(src, dest_shape)) + + # copy_ writes into the destination tensor, so update env accordingly + self.env[node.args[0]] = src + return src def _to_copy(self, node: fx.Node) -> relax.Var: # Returns a copy of the input tensor @@ -1815,7 +2262,11 @@ def _arange(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) def _empty(self, node: fx.Node) -> relax.Var: - dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + import torch + + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) def _empty_like(self, node: fx.Node) -> relax.Var: @@ -1864,8 +2315,16 @@ def _full(self, node: fx.Node) -> relax.Var: def _full_like(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - fill_value = relax.const(node.args[1]) - return self.block_builder.emit(relax.op.full_like(x, fill_value)) + value = node.args[1] + fill_value = relax.const(value) + + x_dtype = x.struct_info.dtype + fill_dtype = None + if isinstance(value, (int, float)) and (math.isinf(value) or math.isnan(value)): + if not ("float" in x_dtype or "bfloat16" in x_dtype): + fill_dtype = "float32" + + return self.block_builder.emit(relax.op.full_like(x, fill_value, dtype=fill_dtype)) def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1878,7 +2337,19 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: mask = self.env[node.args[1]] value = node.args[2] rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + + x_dtype = x.struct_info.dtype + fill_dtype = None + if isinstance(value, (int, float)) and (math.isinf(value) or math.isnan(value)): + if not ("float" in x_dtype or "bfloat16" in x_dtype): + fill_dtype = "float32" + + values = self.block_builder.emit(relax.op.full_like(x, rx_value, dtype=fill_dtype)) + + # Cast x to match values dtype if necessary + if fill_dtype is not None: + x = self.block_builder.emit(relax.op.astype(x, fill_dtype)) + output = self.block_builder.emit(relax.op.where(mask, values, x)) self.env[node.args[0]] = output return output @@ -1909,10 +2380,43 @@ def _linspace(self, node: fx.Node) -> relax.Var: def _masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] - rx_value = relax.const(node.args[2]) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + value = node.args[2] + rx_value = relax.const(value) + + x_dtype = x.struct_info.dtype + fill_dtype = None + if isinstance(value, (int, float)) and (math.isinf(value) or math.isnan(value)): + if not ("float" in x_dtype or "bfloat16" in x_dtype): + fill_dtype = "float32" + + values = self.block_builder.emit(relax.op.full_like(x, rx_value, dtype=fill_dtype)) + + # Cast x to match values dtype if necessary + if fill_dtype is not None: + x = self.block_builder.emit(relax.op.astype(x, fill_dtype)) + return self.block_builder.emit(relax.op.where(mask, values, x)) + def _masked_select(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + mask = self.env[node.args[1]] + + data_shape = self.shape_of(data) + mask_shape = self.shape_of(mask) + shapes_equal = tvm.ir.structural_equal(data_shape, mask_shape) + + if not shapes_equal: + mask = self.block_builder.emit(relax.op.broadcast_to(mask, data_shape)) + + data_flat = self.block_builder.emit(relax.op.reshape(data, [-1])) + mask_flat = self.block_builder.emit(relax.op.reshape(mask, [-1])) + indices = self.block_builder.emit(relax.op.nonzero(mask_flat)) + indices_1d = self.block_builder.emit(relax.op.squeeze(indices, axis=[0])) + + result = self.block_builder.emit(relax.op.take(data_flat, indices_1d, axis=0)) + + return result + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] @@ -1997,6 +2501,12 @@ def _getitem(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) assert isinstance(x.struct_info, relax.TensorStructInfo) + if isinstance(node.args[1], int): + return x + if not isinstance(node.args[1], (list, tuple)): + indices = [node.args[1]] + else: + indices = node.args[1] take_indices = [] take_axes = [] stride_begin = [] @@ -2007,10 +2517,10 @@ def _getitem(self, node: fx.Node) -> relax.Var: i = 0 shape = self.shape_of(x) non_ellipsis_cnt = 0 - for index in node.args[1]: + for index in indices: if isinstance(index, (int, slice, torch.fx.Node)): non_ellipsis_cnt += 1 - for index in node.args[1]: + for index in indices: if isinstance(index, int): stride_begin.append(index) stride_end.append(index + 1) @@ -2071,6 +2581,21 @@ def _item(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.take(x, relax.const(0, "int64"), axis=0)) + def _sym_size_int(self, node: fx.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + + # Handle case where shape is unknown (None) - this can happen for operations + # with dynamic output shapes. + if shape is None: + return self.block_builder.emit(relax.const(0, "int64")) + + shape_dim = shape[dim] + if hasattr(shape_dim, "value"): + return self.block_builder.emit(relax.const(shape_dim.value, dtype="int32")) + return shape_dim + def _zeros_inplace(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output = self.block_builder.emit(relax.op.zeros_like(x)) diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index c10019454015..8837d9683511 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -55,8 +55,8 @@ def _relax_backend(graph_module, example_inputs): assert isinstance(graph_module, torch.fx.GraphModule) def to_torch_tensor(nd_tensor): - """A helper function to transfer a NDArray to torch.tensor.""" - if isinstance(nd_tensor, tvm.nd.NDArray): + """A helper function to transfer a Tensor to torch.tensor.""" + if isinstance(nd_tensor, tvm.runtime.Tensor): return torch.from_numpy(nd_tensor.numpy()) elif isinstance(nd_tensor, tvm.ir.Array): return tuple(to_torch_tensor(x) for x in nd_tensor) @@ -64,12 +64,12 @@ def to_torch_tensor(nd_tensor): raise ValueError(f"Unsupported type {type(nd_tensor)}") def to_tvm_tensor(torch_tensor): - """A helper function to transfer a torch.tensor to NDArray.""" + """A helper function to transfer a torch.tensor to Tensor.""" if not isinstance(torch_tensor, torch._subclasses.fake_tensor.FakeTensor): - return tvm.nd.array(torch_tensor.numpy()) + return tvm.runtime.tensor(torch_tensor.numpy()) # Fake Tensor real_tensor = torch.randn(torch_tensor.shape, dtype=torch_tensor.dtype) - return tvm.nd.array(real_tensor.numpy()) + return tvm.runtime.tensor(real_tensor.numpy()) graph_module.graph.eliminate_dead_code() diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1a53a0cbdc72..3d6a632fb20f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -20,9 +20,10 @@ """PyTorch ExportedProgram of Relax.""" from collections import ChainMap, OrderedDict from functools import partial -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch +from torch import fx import tvm from tvm import relax @@ -32,7 +33,35 @@ class ExportedProgramImporter(BaseFXGraphImporter): """An importer from ExportedProgram to Relax.""" - from torch import fx + @staticmethod + def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Tensor: + """Convert a PyTorch tensor to TVM tensor, handling sparse tensors. + + Parameters + ---------- + tensor_value : torch.Tensor + The PyTorch tensor to convert. + + Returns + ------- + tvm.runtime.Tensor + The converted TVM tensor. + """ + # PyTorch sparse tensors (layout != torch.strided) must be converted to dense. + if tensor_value.layout != torch.strided: + tensor_to_convert = tensor_value.to_dense() + else: + tensor_to_convert = tensor_value + tensor_detached = tensor_to_convert.detach() + + # Try DLPack conversion first (faster) + try: + return tvm.runtime.from_dlpack(tensor_detached) + except (RuntimeError, BufferError): + # Fallback: convert to numpy and then to TVM tensor + # This handles cases where DLPack conversion fails + tensor_cpu = tensor_detached.cpu().contiguous() + return tvm.runtime.tensor(tensor_cpu.numpy()) ########## Unary Ops ########## @@ -64,9 +93,29 @@ def _reciprocal(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + def _sqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"): + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.sqrt(x)) + + def _rsqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"): + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.rsqrt(x)) + ########## Neural Network ########## - def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: + def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool = False) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -76,17 +125,30 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) - ignore_running_stats = ( - node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) - ) - track_running_stats = not ignore_running_stats - momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) - eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) - if track_running_stats: + # After torch.export decomposition, batch_norm shows up as + # _native_batch_norm_legit_* with signature (x, weight, bias, mean, var, momentum, eps). + target_name = getattr(node.target, "__name__", "") + if target_name.startswith("_native_batch_norm_legit_no_training"): + momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) + eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + training = False + elif target_name.startswith("_native_batch_norm_legit_functional"): + momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) + eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) training = True + else: + ignore_running_stats = ( + node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) + ) + track_running_stats = not ignore_running_stats + momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) + eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) - return self.block_builder.emit( + if track_running_stats: + training = True + + bn_result = self.block_builder.emit( relax.op.nn.batch_norm( data=x, gamma=weight, @@ -97,21 +159,58 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: epsilon=eps, momentum=momentum, training=training, - )[0] + ) ) + if return_tuple: + return bn_result + else: + # Return only the output tensor (for backward compatibility) + return self.block_builder.emit(bn_result[0]) + def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: # This method is called for batch_norm in training mode - # TODO does not have correctness! - # TODO we need to store the running mean and variance returned by the - # previous call to batch_norm and pass it again - training = True - return self._batch_norm(node, training) + bn_tuple = self._batch_norm(node, training=True, return_tuple=True) + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + + output = self.block_builder.emit(bn_tuple[0]) + new_running_mean = self.block_builder.emit(bn_tuple[1]) + reserve = self.block_builder.emit(relax.op.zeros(relax.ShapeExpr([channel]), dtype)) + + return self.block_builder.emit( + relax.Tuple([output, new_running_mean, reserve, reserve, reserve]) + ) def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: - # This method is called for batch_norm in eval mode - training = False - return self._batch_norm(node, training) + return self._batch_norm(node, training=False, return_tuple=False) + + def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var: + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + eps = node.args[5] if len(node.args) > 5 else node.kwargs.get("eps", 1e-05) + + # Determine axes for instance norm (all spatial dimensions after channel) + dim = len(self.shape_of(x)) + axes = list(range(2, dim)) + + return self.block_builder.emit( + relax.op.nn.instance_norm( + x, + weight, + bias, + channel_axis=1, + axes=axes, + epsilon=eps, + ) + ) def _cross_entropy_default(self, node: fx.Node) -> relax.Expr: preds = self.env[node.args[0]] @@ -141,6 +240,37 @@ def _group_norm(self, node: fx.Node) -> relax.Var: ) ) + def _native_group_norm(self, node: fx.Node) -> relax.Var: + # native_group_norm signature: (input, weight, bias, N, C, HxW, group, eps) + x = self.env[node.args[0]] + gamma = self.env.get(node.args[1], None) if len(node.args) > 1 else None + beta = self.env.get(node.args[2], None) if len(node.args) > 2 else None + # args[3] = N (batch size), args[4] = C (channels), args[5] = HxW (spatial size) + num_groups = node.args[6] if len(node.args) > 6 else 1 + eps = node.args[7] if len(node.args) > 7 else 1e-05 + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + + def _native_layer_norm(self, node: fx.Node) -> relax.Var: + # native_layer_norm signature: (input, normalized_shape, weight, bias, eps) + x = self.env[node.args[0]] + normalized_shape = node.args[1] + gamma = self.env.get(node.args[2], None) if len(node.args) > 2 else None + beta = self.env.get(node.args[3], None) if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + def _upsample_impl( self, x: relax.Expr, @@ -179,6 +309,22 @@ def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: x, size=size, scale_factor=scale_factor, method="linear", align_corners=align_corners ) + def _upsample_bilinear2d_aa(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("output_size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", False) + ) + scale_factor = ( + node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factors", None) + ) + + # Note: TVM's resize2d doesn't have explicit antialias support. + # For upsampling, antialiasing has minimal effect, so we use regular bilinear. + return self._upsample_impl( + x, size=size, scale_factor=scale_factor, method="linear", align_corners=align_corners + ) + def _upsample_nearest2d(self, node: fx.node) -> relax.Var: x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) @@ -190,11 +336,11 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: ) else: - # TODO figure out why pytorch export passes a list such as - # [scale_factor,scale_factor] instead of just an int for - # scale_factor. Using first element for now + # PyTorch export passes scale_factor as either a scalar or a list/tuple + # (e.g., [2.0, 3.0] for different H and W scaling). + # Pass it as-is to _upsample_impl which handles both cases correctly. scale_factor = ( - node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) + node.args[2] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) ) align_corners = ( node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None) @@ -217,11 +363,11 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: if size is not None: scale_factor = None else: - scale_arg = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) - if isinstance(scale_arg, (list, tuple)): - scale_factor = scale_arg[0] - else: - scale_factor = scale_arg + # PyTorch export passes scale_factor as either a scalar or a list/tuple. + # Pass it as-is to _upsample_impl which handles both cases correctly. + scale_factor = ( + node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) + ) return self._upsample_impl( x, @@ -231,6 +377,527 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) + def _lstm_cell_unroll( + self, + input_reshaped, + weight_ih, + weight_hh, + bias_ih, + bias_hh, + h_prev, + c_prev, + seq_len, + hidden_size, + reverse=False, + ): + """Unroll LSTM cells for a single direction.""" + weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0])) + weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0])) + outputs = [] + time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len) + + for t in time_steps: + x_t = self.block_builder.emit( + relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") + ) + ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t)) + hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t)) + + gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) + if bias_ih is not None: + gates = self.block_builder.emit(relax.op.add(gates, bias_ih)) + if bias_hh is not None: + gates = self.block_builder.emit(relax.op.add(gates, bias_hh)) + + i_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[0], end=[hidden_size]) + ) + f_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[hidden_size], end=[2 * hidden_size]) + ) + g_gate = self.block_builder.emit( + relax.op.strided_slice( + gates, axes=[1], begin=[2 * hidden_size], end=[3 * hidden_size] + ) + ) + o_gate = self.block_builder.emit( + relax.op.strided_slice( + gates, axes=[1], begin=[3 * hidden_size], end=[4 * hidden_size] + ) + ) + + i_t = self.block_builder.emit(relax.op.sigmoid(i_gate)) + f_t = self.block_builder.emit(relax.op.sigmoid(f_gate)) + g_t = self.block_builder.emit(relax.op.tanh(g_gate)) + o_t = self.block_builder.emit(relax.op.sigmoid(o_gate)) + + c_t = self.block_builder.emit( + relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t)) + ) + h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t))) + + outputs.append(h_t) + h_prev = h_t + c_prev = c_t + + if reverse: + outputs = outputs[::-1] + + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + return output + + def _lstm(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + hx = args[1] if len(args) > 1 else None + params = args[2] if len(args) > 2 else None + has_biases = args[3] if len(args) > 3 else True + num_layers = args[4] if len(args) > 4 else 1 + bidirectional = args[7] if len(args) > 7 else False + batch_first = args[8] if len(args) > 8 else False + + if num_layers > 1: + raise NotImplementedError("Multi-layer LSTM is not yet supported") + + input_shape = self.shape_of(input_tensor) + if batch_first: + batch_size, seq_len, input_size = input_shape + else: + seq_len, batch_size, input_size = input_shape + + seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else seq_len + batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) else batch_size + input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) else input_size + # Extract hidden size from the LSTM parameters + # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh] + # weight_ih shape: (4 * hidden_size, input_size) + # weight_hh shape: (4 * hidden_size, hidden_size) + if params and len(params) >= 2: + # Extract hidden size from weight dimensions + # weight_ih has shape (4 * hidden_size, input_size) + weight_ih_shape = self.shape_of(params[0]) + hidden_size = weight_ih_shape[0] // 4 + else: + # Fallback to a default hidden size + hidden_size = 16 + # Implement actual LSTM computation using Relax operations + # LSTM equations: + # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi) + # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf) + # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg) + # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho) + # c_t = f_t * c_{t-1} + i_t * g_t + # h_t = o_t * tanh(c_t) + dtype = input_tensor.struct_info.dtype + params_per_direction = 4 if has_biases else 2 + + # Extract or create forward direction weights + if params and len(params) >= 2: + weight_ih_fwd = params[0] + weight_hh_fwd = params[1] + bias_ih_fwd = params[2] if has_biases and len(params) > 2 else None + bias_hh_fwd = params[3] if has_biases and len(params) > 3 else None + else: + # Fallback: create zero weights + weight_ih_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) + ) + weight_hh_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) + ) + bias_ih_fwd = None + bias_hh_fwd = None + + # Extract or create backward direction weights if bidirectional + if bidirectional: + if params and len(params) >= params_per_direction * 2: + weight_ih_bwd = params[params_per_direction] + weight_hh_bwd = params[params_per_direction + 1] + bias_ih_bwd = params[params_per_direction + 2] if has_biases else None + bias_hh_bwd = params[params_per_direction + 3] if has_biases else None + else: + # Fallback: create zero weights + weight_ih_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) + ) + weight_hh_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) + ) + bias_ih_bwd = None + bias_hh_bwd = None + else: + weight_ih_bwd = None + weight_hh_bwd = None + bias_ih_bwd = None + bias_hh_bwd = None + + if hx is not None and len(hx) >= 2: + h_0, c_0 = hx[0], hx[1] + h_prev_fwd = self.block_builder.emit( + relax.op.take(h_0, relax.const(0, "int64"), axis=0, mode="clip") + ) + c_prev_fwd = self.block_builder.emit( + relax.op.take(c_0, relax.const(0, "int64"), axis=0, mode="clip") + ) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.take(h_0, relax.const(1, "int64"), axis=0, mode="clip") + ) + c_prev_bwd = self.block_builder.emit( + relax.op.take(c_0, relax.const(1, "int64"), axis=0, mode="clip") + ) + else: + h_prev_bwd = None + c_prev_bwd = None + else: + h_prev_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + c_prev_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + c_prev_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + else: + h_prev_bwd = None + c_prev_bwd = None + + input_reshaped = ( + self.block_builder.emit(relax.op.permute_dims(input_tensor, axes=[1, 0, 2])) + if batch_first + else input_tensor + ) + + output_fwd = self._lstm_cell_unroll( + input_reshaped, + weight_ih_fwd, + weight_hh_fwd, + bias_ih_fwd, + bias_hh_fwd, + h_prev_fwd, + c_prev_fwd, + seq_len, + hidden_size, + reverse=False, + ) + + if bidirectional: + output_bwd = self._lstm_cell_unroll( + input_reshaped, + weight_ih_bwd, + weight_hh_bwd, + bias_ih_bwd, + bias_hh_bwd, + h_prev_bwd, + c_prev_bwd, + seq_len, + hidden_size, + reverse=True, + ) + output = self.block_builder.emit(relax.op.concat([output_fwd, output_bwd], axis=2)) + else: + output = output_fwd + + if batch_first: + # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) + output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) + return output + + def _gru_cell_unroll( + self, + input_reshaped, + weight_ih, + weight_hh, + bias_ih, + bias_hh, + h_prev, + seq_len, + hidden_size, + dtype, + reverse=False, + ): + """Unroll GRU cells for a single direction.""" + gate_size = hidden_size + + # Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n) + # Reset gate weights + weight_ih_r = self.block_builder.emit( + relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size]) + ) + weight_hh_r = self.block_builder.emit( + relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size]) + ) + + # Update gate weights + weight_ih_z = self.block_builder.emit( + relax.op.strided_slice(weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size]) + ) + weight_hh_z = self.block_builder.emit( + relax.op.strided_slice(weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size]) + ) + + # New gate weights + weight_ih_n = self.block_builder.emit( + relax.op.strided_slice(weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]) + ) + weight_hh_n = self.block_builder.emit( + relax.op.strided_slice(weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]) + ) + + # Transpose weights for matmul + weight_ih_r_t = self.block_builder.emit(relax.op.permute_dims(weight_ih_r, axes=[1, 0])) + weight_hh_r_t = self.block_builder.emit(relax.op.permute_dims(weight_hh_r, axes=[1, 0])) + weight_ih_z_t = self.block_builder.emit(relax.op.permute_dims(weight_ih_z, axes=[1, 0])) + weight_hh_z_t = self.block_builder.emit(relax.op.permute_dims(weight_hh_z, axes=[1, 0])) + weight_ih_n_t = self.block_builder.emit(relax.op.permute_dims(weight_ih_n, axes=[1, 0])) + weight_hh_n_t = self.block_builder.emit(relax.op.permute_dims(weight_hh_n, axes=[1, 0])) + + outputs = [] + time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len) + + for t in time_steps: + # Get input at time t: (batch_size, input_size) + x_t = self.block_builder.emit( + relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") + ) + + # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) + r_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_r_t)) + r_hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)) + if bias_ih is not None and bias_hh is not None: + bias_ih_r = self.block_builder.emit( + relax.op.strided_slice(bias_ih, axes=[0], begin=[0], end=[gate_size]) + ) + bias_hh_r = self.block_builder.emit( + relax.op.strided_slice(bias_hh, axes=[0], begin=[0], end=[gate_size]) + ) + r_t = self.block_builder.emit( + relax.op.sigmoid( + relax.op.add(relax.op.add(relax.op.add(r_ih, bias_ih_r), r_hh), bias_hh_r) + ) + ) + else: + r_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh))) + + # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) + z_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_z_t)) + z_hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)) + if bias_ih is not None and bias_hh is not None: + bias_ih_z = self.block_builder.emit( + relax.op.strided_slice( + bias_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] + ) + ) + bias_hh_z = self.block_builder.emit( + relax.op.strided_slice( + bias_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] + ) + ) + z_t = self.block_builder.emit( + relax.op.sigmoid( + relax.op.add(relax.op.add(relax.op.add(z_ih, bias_ih_z), z_hh), bias_hh_z) + ) + ) + else: + z_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh))) + + # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) + n_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_n_t)) + n_hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)) + if bias_ih is not None and bias_hh is not None: + bias_ih_n = self.block_builder.emit( + relax.op.strided_slice( + bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + ) + ) + bias_hh_n = self.block_builder.emit( + relax.op.strided_slice( + bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + ) + ) + n_t = self.block_builder.emit( + relax.op.tanh( + relax.op.add( + relax.op.add(n_ih, bias_ih_n), + relax.op.multiply(r_t, relax.op.add(n_hh, bias_hh_n)), + ) + ) + ) + else: + n_t = self.block_builder.emit( + relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, n_hh))) + ) + + # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1} + one_minus_z = self.block_builder.emit(relax.op.subtract(relax.const(1.0, dtype), z_t)) + h_t = self.block_builder.emit( + relax.op.add(relax.op.multiply(one_minus_z, n_t), relax.op.multiply(z_t, h_prev)) + ) + + outputs.append(h_t) + h_prev = h_t + + if reverse: + outputs = outputs[::-1] + + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + return output + + def _gru(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + hx = args[1] if len(args) > 1 else None + params = args[2] if len(args) > 2 else None + has_biases = args[3] if len(args) > 3 else True + num_layers = args[4] if len(args) > 4 else 1 + _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference + _train = args[6] if len(args) > 6 else False # Not used in inference + bidirectional = args[7] if len(args) > 7 else False + batch_first = args[8] if len(args) > 8 else False + + if num_layers > 1: + raise NotImplementedError("Multi-layer GRU is not yet supported") + + input_shape = self.shape_of(input_tensor) + if batch_first: + batch_size, seq_len, input_size = input_shape + else: + seq_len, batch_size, input_size = input_shape + + seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else seq_len + batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) else batch_size + input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) else input_size + + # Extract hidden size from parameters + # For bidirectional: params has weights for both directions + # params_per_direction = 4 if has_biases else 2 (weight_ih, weight_hh, [bias_ih, bias_hh]) + params_per_direction = 4 if has_biases else 2 + + if params and len(params) >= 2: + # Extract hidden size from weight dimensions + # weight_ih has shape (3 * hidden_size, input_size) + weight_ih_shape = self.shape_of(params[0]) + hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update, new + else: + # Fallback to a default hidden size + hidden_size = 16 + + dtype = input_tensor.struct_info.dtype + + # Extract forward direction weights + if params and len(params) >= params_per_direction: + weight_ih_fwd = params[0] + weight_hh_fwd = params[1] + bias_ih_fwd = params[2] if has_biases else None + bias_hh_fwd = params[3] if has_biases else None + else: + # Fallback: create zero weights + weight_ih_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, input_size)), dtype) + ) + weight_hh_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype) + ) + bias_ih_fwd = None + bias_hh_fwd = None + + # Extract or create backward direction weights if bidirectional + if bidirectional: + if params and len(params) >= params_per_direction * 2: + weight_ih_bwd = params[params_per_direction] + weight_hh_bwd = params[params_per_direction + 1] + bias_ih_bwd = params[params_per_direction + 2] if has_biases else None + bias_hh_bwd = params[params_per_direction + 3] if has_biases else None + else: + # Fallback: create zero weights + weight_ih_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, input_size)), dtype) + ) + weight_hh_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype) + ) + bias_ih_bwd = None + bias_hh_bwd = None + else: + weight_ih_bwd = None + weight_hh_bwd = None + bias_ih_bwd = None + bias_hh_bwd = None + + # Initialize hidden states + if hx is not None: + h_prev_fwd = self.block_builder.emit( + relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip") + ) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.take(hx, relax.const(1, "int64"), axis=0, mode="clip") + ) + else: + h_prev_bwd = None + else: + h_prev_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + else: + h_prev_bwd = None + + # Reshape input for processing + input_reshaped = ( + self.block_builder.emit(relax.op.permute_dims(input_tensor, axes=[1, 0, 2])) + if batch_first + else input_tensor + ) + + # Process forward direction + output_fwd = self._gru_cell_unroll( + input_reshaped, + weight_ih_fwd, + weight_hh_fwd, + bias_ih_fwd, + bias_hh_fwd, + h_prev_fwd, + seq_len, + hidden_size, + dtype, + reverse=False, + ) + + # Process backward direction if bidirectional + if bidirectional: + output_bwd = self._gru_cell_unroll( + input_reshaped, + weight_ih_bwd, + weight_hh_bwd, + bias_ih_bwd, + bias_hh_bwd, + h_prev_bwd, + seq_len, + hidden_size, + dtype, + reverse=True, + ) + # Concatenate forward and backward outputs along feature dimension + output = self.block_builder.emit(relax.op.concat([output_fwd, output_bwd], axis=2)) + else: + output = output_fwd + + # Reshape back to batch_first if needed + if batch_first: + # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) + output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) + + return output + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -247,11 +914,23 @@ def _select(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.take(x, index, dim)) def _slice(self, node: fx.Node) -> relax.Var: + import sys + x = self.env[node.args[0]] - axes = [node.args[1]] - begin = [node.args[2]] - end = [node.args[3]] - stride = [node.args[4] if len(node.args) > 4 else 1] + dim = node.args[1] if len(node.args) > 1 else 0 + start = node.args[2] if len(node.args) > 2 else None + end_val = node.args[3] if len(node.args) > 3 else None + step = node.args[4] if len(node.args) > 4 else 1 + + if start is None: + start = 0 + if end_val is None: + end_val = sys.maxsize + + axes = [dim] + begin = [start] + end = [end_val] + stride = [step] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) def _unflatten(self, node: fx.Node) -> relax.Var: @@ -306,6 +985,83 @@ def _zeros(self, node: fx.Node) -> relax.Var: ) return self.block_builder.emit(relax.op.zeros(size, dtype)) + def _sparse_mm(self, node: fx.Node) -> relax.Var: + """Handle sparse matrix multiplication by converting sparse tensor to dense.""" + args = self.retrieve_args(node) + sparse_input = args[0] + dense_input = args[1] + # Convert sparse tensor to dense if needed + # Note: sparse_input should already be converted to dense in _convert_pytorch_tensor_to_tvm + # Use regular matrix multiplication + return self.block_builder.emit( + relax.op.linear_algebra.matmul(sparse_input, dense_input, out_dtype="float32") + ) + + def _sparse_addmm(self, node: fx.Node) -> relax.Var: + """Handle sparse addmm (beta * input + alpha * sparse_mm(mat1, mat2)).""" + args = self.retrieve_args(node) + input_tensor = args[0] # beta * input + sparse_mat1 = args[1] # sparse matrix + dense_mat2 = args[2] # dense matrix + alpha = node.kwargs.get("alpha", 1.0) + beta = node.kwargs.get("beta", 1.0) + + # Convert sparse tensor to dense if needed + # Note: sparse_mat1 should already be converted to dense in _convert_pytorch_tensor_to_tvm + # Compute alpha * sparse_mm(mat1, mat2) + matmul_result = self.block_builder.emit( + relax.op.linear_algebra.matmul(sparse_mat1, dense_mat2, out_dtype="float32") + ) + + if alpha != 1.0: + alpha_const = relax.const(alpha, matmul_result.struct_info.dtype) + matmul_result = self.block_builder.emit(relax.op.multiply(matmul_result, alpha_const)) + + # Compute beta * input + alpha * matmul_result + if beta != 0.0: + if beta != 1.0: + beta_const = relax.const(beta, input_tensor.struct_info.dtype) + input_scaled = self.block_builder.emit(relax.op.multiply(input_tensor, beta_const)) + else: + input_scaled = input_tensor + return self.block_builder.emit(relax.op.add(input_scaled, matmul_result)) + else: + return matmul_result + + def _grid_sampler_2d(self, node: fx.Node) -> relax.Var: + """Convert torch.nn.functional.grid_sample to relax.op.image.grid_sample.""" + args = self.retrieve_args(node) + data = args[0] + grid = args[1] + interp_mode = args[2] if len(args) > 2 else 0 + pad_mode = args[3] if len(args) > 3 else 0 + align_corners = args[4] if len(args) > 4 else False + + interp_map = {0: "bilinear", 1: "nearest", 2: "bicubic"} + pad_map = {0: "zeros", 1: "border", 2: "reflection"} + + method = interp_map.get(interp_mode, "bilinear") + padding_mode = pad_map.get(pad_mode, "zeros") + + return self.block_builder.emit( + relax.op.image.grid_sample( + data, + grid, + method=method, + layout="NCHW", + padding_mode=padding_mode, + align_corners=align_corners, + ) + ) + + def _scalar_tensor(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + scalar_value = args[0] + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit(relax.const(scalar_value, dtype)) + def _instance_norm(self, node: fx.Node): import numpy as np @@ -329,6 +1085,78 @@ def _instance_norm(self, node: fx.Node): ) ) + def _exponential(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.zeros_like(x)) + + def _max_dim(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + + topk_res = self.block_builder.emit( + relax.op.topk(x, k=1, axis=dim, largest=True, ret_type="both", dtype="int64") + ) + + values = topk_res[0] + indices = topk_res[1] + + if not keepdim: + values = self.block_builder.emit(relax.op.squeeze(values, axis=[dim])) + indices = self.block_builder.emit(relax.op.squeeze(indices, axis=[dim])) + + return self.block_builder.emit(relax.Tuple([values, indices])) + + def _alias(self, node: fx.Node) -> relax.Var: + return self.env[node.args[0]] + + def _scatter_value(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + value = node.args[3] + + value_const = relax.const(value, x.struct_info.dtype) + src = self.block_builder.emit(relax.op.broadcast_to(value_const, self.shape_of(index))) + + return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) + + def _as_strided(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + size = args[1] + stride = args[2] + storage_offset = args[3] if len(args) > 3 else node.kwargs.get("storage_offset", 0) + + assert storage_offset == 0, "as_strided with non-zero storage_offset is not supported yet" + + # Only handle view-like cases where the provided strides align with a contiguous layout. + can_check = all(isinstance(dim, (int, tvm.tir.IntImm)) for dim in size) and all( + isinstance(st, (int, tvm.tir.IntImm)) for st in stride + ) + if can_check: + expected_stride = [] + running = 1 + for dim in reversed(size): + dim_int = int(dim) + expected_stride.insert(0, running) + running *= dim_int + + for dim, st, exp in zip(size, stride, expected_stride): + dim_int = int(dim) + if dim_int != 1 and int(st) != exp: + raise AssertionError( + f"as_strided with non-contiguous stride {stride} for" + f"size {size} is not supported" + ) + + return self.block_builder.emit(relax.op.reshape(x, size)) + + ########## Symbolic Shape Constraints ########## + + def _symbolic_comparison(self, _: fx.Node) -> relax.Expr: + return self.block_builder.emit(relax.const(True, dtype="bool")) + ########## Others ########## def create_convert_map( @@ -355,9 +1183,17 @@ def create_convert_map( "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], "dropout_.default": lambda node: self.env[node.args[0]], + "native_dropout.default": lambda node: self.env[node.args[0]], "elu.default": self._elu, "erf.default": self._unary_op(relax.op.erf), "exp.default": self._unary_op(relax.op.exp), + "exponential.default": self._exponential, + "expm1.default": lambda node: self.block_builder.emit( + relax.op.subtract( + relax.op.exp(self.env[node.args[0]]), + relax.const(1.0, self.env[node.args[0]].struct_info.dtype), + ) + ), "floor.default": self._unary_op(relax.op.floor), "gelu.default": self._gelu, "hardsigmoid.default": self._hardsigmoid, @@ -376,9 +1212,13 @@ def create_convert_map( "log10.default": self._log10, "log1p.default": self._log1p, "logical_not.default": self._unary_op(relax.op.logical_not), + "logical_and.default": self._binary_op(relax.op.logical_and, operator.and_), "log_softmax.int": self._log_softmax, + "_log_softmax.default": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "pad.default": self._pad, + "constant_pad_nd.default": self._constant_pad_nd, + "copy.default": self._copy_, "pixel_shuffle.default": self._pixel_shuffle, "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, @@ -387,7 +1227,9 @@ def create_convert_map( "relu6.default": self._unary_op(relax.op.nn.relu6), "relu6_.default": self._unary_op(relax.op.nn.relu6), "round.default": self._round, - "rsqrt.default": self._unary_op(relax.op.rsqrt), + "rsqrt.default": self._rsqrt, + "scalar_tensor.default": self._scalar_tensor, + "scatter.value": self._scatter_value, "rsub.Tensor": self._rsub, "rsub.Scalar": self._rsub, "selu.default": self._unary_op(relax.op.nn.selu), @@ -398,10 +1240,11 @@ def create_convert_map( "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, + "_softmax.default": self._softmax, "softplus.default": self._softplus, "softshrink.default": self._softshrink, "softsign.default": self._softsign, - "sqrt.default": self._unary_op(relax.op.sqrt), + "sqrt.default": self._sqrt, "square.default": self._unary_op(relax.op.square), "tan.default": self._unary_op(relax.op.tan), "tanh.default": self._unary_op(relax.op.tanh), @@ -410,11 +1253,17 @@ def create_convert_map( "trunc.default": self._unary_op(relax.op.trunc), # binary "add.Tensor": self._binary_op(relax.op.add, operator.add), + "add.Scalar": self._binary_op(relax.op.add, operator.add), "add_.Tensor": self._binary_op(relax.op.add, operator.add), + "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_), + "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_), "bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), "bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), + "bitwise_xor.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor), + "bitwise_xor.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor), "bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), "bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), + "div.Scalar": self._binary_op(relax.op.divide, operator.truediv), "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), "div.Tensor_mode": self._div, "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), @@ -434,13 +1283,20 @@ def create_convert_map( "matmul.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), + "mm.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), "max.other": self._binary_op(relax.op.maximum, max), "min.other": self._binary_op(relax.op.minimum, min), "max.default": self._unary_op(relax.op.max), "min.default": self._unary_op(relax.op.min), + "maximum.default": self._binary_op(relax.op.maximum, torch.maximum), + "minimum.default": self._binary_op(relax.op.minimum, torch.minimum), "remainder.Tensor": self._binary_op(relax.op.floor_mod, operator.mod), "remainder.Scalar": self._binary_op(relax.op.floor_mod, operator.mod), + "mul": self._binary_op(relax.op.multiply, operator.mul), "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul), + "mul.Scalar": self._binary_op(relax.op.multiply, operator.mul), "mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul), "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne), "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne), @@ -451,6 +1307,7 @@ def create_convert_map( "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), + "sub.Scalar": self._binary_op(relax.op.subtract, operator.sub), "__and__.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_), "__and__.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_), "__or__.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), @@ -460,13 +1317,19 @@ def create_convert_map( # linear algebra "linalg_vector_norm.default": self._norm, # neural network + "_adaptive_avg_pool1d.default": self._adaptive_avg_pool1d, + "_adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "_adaptive_avg_pool3d.default": self._adaptive_avg_pool3d, "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, + "_native_batch_norm_legit.no_stats": self._batch_norm_legit_no_stats, "batch_norm.default": self._batch_norm_legit_no_training, "adaptive_avg_pool1d.default": self._adaptive_avg_pool1d, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "adaptive_avg_pool3d.default": self._adaptive_avg_pool3d, "addmm.default": self._addmm, + "_sparse_mm.default": self._sparse_mm, + "_sparse_addmm.default": self._sparse_addmm, "avg_pool1d.default": self._avg_pool1d, "avg_pool2d.default": self._avg_pool2d, "avg_pool3d.default": self._avg_pool3d, @@ -479,6 +1342,7 @@ def create_convert_map( "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, "conv3d.default": self._conv3d, + "convolution.default": self._convolution, "cross_entropy_loss.default": self._cross_entropy_default, "einsum.default": self._einsum, "embedding.default": lambda node: self._embedding_impl( @@ -486,23 +1350,33 @@ def create_convert_map( ), "group_norm.default": self._group_norm, "instance_norm.default": self._instance_norm, + "native_group_norm.default": self._native_group_norm, "layer_norm.default": self._layer_norm, + "native_layer_norm.default": self._native_layer_norm, "linear.default": self._linear, + "lstm.input": self._lstm, + "gru.input": self._gru, "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, + "max_pool2d_with_indices.default": self._max_pool2d_with_indices, "max_pool3d.default": self._max_pool3d, + "max_pool3d_with_indices.default": self._max_pool3d_with_indices, "scaled_dot_product_attention.default": self._scaled_dot_product_attention, "unbind.int": self._unbind, "upsample_bilinear2d.vec": self._upsample_bilinear2d, + "_upsample_bilinear2d_aa.default": self._upsample_bilinear2d_aa, "upsample_nearest2d.vec": self._upsample_nearest2d, "upsample_bicubic2d.vec": self._upsample_bicubic2d, # statistical + "any.dim": self._any, + "any.dims": self._any, "mean.dim": self._mean, "prod.default": self._prod, "std.correction": self._std, "sum.default": self._sum, "sum.dim_IntList": self._sum, "var.correction": self._var, + "max.dim": self._max_dim, # search "argmax.default": self._argmax_argmin(relax.op.argmax), "argmin.default": self._argmax_argmin(relax.op.argmin), @@ -510,6 +1384,7 @@ def create_convert_map( "bucketize.Tensor": self._bucketize, # tensor manipulation "argsort.default": self._argsort, + "alias.default": self._alias, "broadcast_to.default": self._broadcast_to, "cat.default": self._cat, "chunk.default": self._chunk, @@ -524,6 +1399,7 @@ def create_convert_map( "flip.default": self._flip, "gather.default": self._gather, "index.Tensor": self._index_tensor, + "index_put.default": self._index_put, "index_put_.default": self._index_put, "meshgrid.indexing": self._meshgrid, "meshgrid.default": self._meshgrid, @@ -539,6 +1415,7 @@ def create_convert_map( "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, + "squeeze.dims": self._squeeze, "stack.default": self._stack, "take.default": self._take, "tile.default": self._tile, @@ -551,6 +1428,7 @@ def create_convert_map( "view.default": self._reshape, "reshape.default": self._reshape, "reshape_as.default": self._reshape_as, + "as_strided.default": self._as_strided, # tensor creation "_to_copy.default": self._to_copy, "arange.default": self._arange, @@ -560,7 +1438,13 @@ def create_convert_map( "detach_.default": self._detach, "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], + "bernoulli.p": lambda node: self.env[node.args[0]], # Dropout: just return input + "_assert_tensor_metadata.default": lambda node: self.env[ + node.args[0] + ], # metadata assertion: no-op + "empty.default": self._empty, "empty.memory_format": self._empty, + "empty_permuted.default": self._empty, # Similar to empty with permuted layout "empty_like.default": self._empty_like, "eye.default": self._eye, "eye.m": self._eye, @@ -577,6 +1461,7 @@ def create_convert_map( "linspace.default": self._linspace, "masked_fill.Scalar": self._masked_fill, "masked_fill_.Scalar": self._inplace_masked_fill, + "masked_select.default": self._masked_select, "new_ones.default": self._new_ones, "new_zeros.default": self._new_zeros, "one_hot.default": self._one_hot, @@ -587,6 +1472,7 @@ def create_convert_map( "zero_.default": self._zeros_inplace, "zeros.default": self._zeros, "zeros_like.default": self._zeros_like, + "grid_sampler_2d.default": self._grid_sampler_2d, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, @@ -594,16 +1480,90 @@ def create_convert_map( # other "getitem": self._getitem, "item.default": self._item, + "sym_size.int": self._sym_size_int, + "_local_scalar_dense.default": self._item, + # symbolic shape constraints (no-ops for compilation) + "sym_constrain_range_for_size.default": lambda node: self.env[node.args[0]], + "_assert_scalar.default": lambda node: self.env[node.args[0]], + "ge": self._symbolic_comparison, + "le": self._symbolic_comparison, } + def _process_derived_symbol( + self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] + ) -> Tuple[str, Optional[tvm.tir.PrimExpr]]: + """Process a sympy symbol to generate a descriptive name and TIR expression.""" + import sympy + + if isinstance(symbol, sympy.Symbol): + return str(symbol), None + + if not isinstance(symbol, (sympy.Add, sympy.Mul)): + return str(symbol), None + + tir_expr = None + for arg in symbol.args: + if isinstance(arg, sympy.Integer): + term = tvm.tir.IntImm("int64", int(arg)) + elif isinstance(arg, sympy.Symbol): + term = torch_symbol_to_relax_var.setdefault( + str(arg), tvm.tir.SizeVar(str(arg), "int64") + ) + else: + _, term = self._process_derived_symbol(arg, torch_symbol_to_relax_var) + + if term is None: + return str(symbol), None + + if tir_expr is None: + tir_expr = term + elif isinstance(symbol, sympy.Mul): + tir_expr = tir_expr * term + elif isinstance(symbol, sympy.Add): + tir_expr = tir_expr + term + + if isinstance(tir_expr, tvm.tir.Add): + for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]: + if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var): + return f"{var.name}___{const.value}", tir_expr + + if isinstance(tir_expr, tvm.tir.Mul): + for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]: + if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var): + return f"{var.name}_{const.value}", tir_expr + + return str(symbol), tir_expr + def create_input_vars( self, exported_program: torch.export.ExportedProgram - ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: + ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, Optional[int]]]]: """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {} + range_constraints = {} + + if hasattr(exported_program, "range_constraints"): + import math + + for symbol, value_range in exported_program.range_constraints.items(): + if hasattr(value_range, "lower") and hasattr(value_range, "upper"): + try: + # PyTorch uses int_oo (IntInfinity) for unbounded constraints + lower = int(value_range.lower) + upper = ( + None if math.isinf(float(value_range.upper)) else int(value_range.upper) + ) + + symbol_name, _ = self._process_derived_symbol( + symbol, torch_symbol_to_relax_var + ) + range_constraints[symbol_name] = (lower, upper) + except (OverflowError, AttributeError, TypeError): + continue + + named_buffers = OrderedDict(exported_program.named_buffers()) for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: @@ -615,18 +1575,29 @@ def create_input_vars( torch_shape = node.meta["tensor_meta"].shape torch_dtype = node.meta["tensor_meta"].dtype break - else: - # PARAMETER or BUFFER + elif spec.kind is torch.export.graph_signature.InputKind.BUFFER: + torch_shape = named_buffers[spec.target].shape + torch_dtype = named_buffers[spec.target].dtype + elif spec.kind is torch.export.graph_signature.InputKind.PARAMETER: torch_shape = exported_program.state_dict[spec.target].shape torch_dtype = exported_program.state_dict[spec.target].dtype - - # TODO(mshr-h): Support range constraints - relax_shape = [ - torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) - if isinstance(s, torch.SymInt) - else s - for s in torch_shape - ] + else: + raise ValueError(f"Unsupported input kind: {spec.kind}") + + relax_shape = [] + for s in torch_shape: + if isinstance(s, torch.SymInt): + sympy_node = s.node.expr if hasattr(s.node, "expr") else s.node + symbol_name, _ = self._process_derived_symbol( + sympy_node, torch_symbol_to_relax_var + ) + + size_var = torch_symbol_to_relax_var.setdefault( + symbol_name, tvm.tir.SizeVar(symbol_name, "int64") + ) + relax_shape.append(size_var) + else: + relax_shape.append(s) dtype = self._convert_data_type(torch_dtype) relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) @@ -635,7 +1606,7 @@ def create_input_vars( else: parameters_buffers_constants[name_hint] = relax_var - return parameters_buffers_constants, user_inputs + return parameters_buffers_constants, user_inputs, range_constraints def from_exported_program( self, @@ -643,19 +1614,45 @@ def from_exported_program( keep_params_as_input: bool, unwrap_unit_return_tuple: bool, no_bind_return_tuple: bool, + custom_convert_map: Optional[ + Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] + ], ) -> tvm.IRModule: """Convert a PyTorch ExportedProgram to a Relax program.""" - from torch import fx # type: ignore + + # Update the conversion map with custom ops if provided. + if custom_convert_map: + custom_ops = set(custom_convert_map.keys()) + self.update_convert_map(custom_convert_map) + else: + custom_ops = set() # Create input variables. - parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) + ( + parameter_buffer_constant_vars, + user_input_vars, + range_constraints, + ) = self.create_input_vars(exported_program) inputs_vars = user_input_vars.copy() inputs_vars.update(parameter_buffer_constant_vars) # Initialize the block builder with a function and a dataflow block. self.block_builder = relax.BlockBuilder() func_name = "main" - func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None + func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else {} + if range_constraints: + func_attrs["tir_var_lower_bound"] = { + var_name: lower for var_name, (lower, _) in range_constraints.items() + } + + upper_bounds = { + var_name: upper + for var_name, (_, upper) in range_constraints.items() + if upper is not None + } + + if upper_bounds: + func_attrs["tir_var_upper_bound"] = upper_bounds nodes: List[fx.Node] = exported_program.graph.nodes @@ -667,7 +1664,6 @@ def from_exported_program( ): output = None with self.block_builder.dataflow(): - # Translate the model. for node in nodes: if node.op == "placeholder": @@ -694,7 +1690,10 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ - self.env[node] = self.convert_map[func_name](node) + if func_name in custom_ops: + self.env[node] = self.convert_map[func_name](node, self) + else: + self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") assert output is not None @@ -715,14 +1714,14 @@ def from_exported_program( if tensor_name == spec.target: bind_name = spec.arg.name break - binding[bind_name] = tvm.nd.from_dlpack(tensor_value.detach()) + binding[bind_name] = self._convert_pytorch_tensor_to_tvm(tensor_value) mod = self.block_builder.get() mod = relax.transform.BindParams("main", binding)(mod) if keep_params_as_input: parameters = dict(exported_program.named_parameters()) - params = [tvm.nd.from_dlpack(p.detach()) for p in parameters.values()] + params = [self._convert_pytorch_tensor_to_tvm(p) for p in parameters.values()] mod["main"] = mod["main"].with_attr("params", params) return mod @@ -734,6 +1733,10 @@ def from_exported_program( keep_params_as_input: bool = False, unwrap_unit_return_tuple: bool = False, no_bind_return_tuple: bool = False, + custom_convert_map: Optional[ + Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] + ] = None, + run_ep_decomposition: bool = True, ) -> tvm.IRModule: """Convert a PyTorch ExportedProgram to a Relax program @@ -753,6 +1756,14 @@ def from_exported_program( A boolean flag indicating whether to bind the return tuple as a relax var. If the flag is true and the return value is a tuple, it will not bind it to a var. + custom_convert_map : Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] + A custom op conversion map in the same format as ExportedProgramImporter.convert_map above + + run_ep_decomposition : bool + A boolean flag indicating whether to run PyTorch's decomposition on the + exported program before translation. When True, high-level operators will + be decomposed into their constituent parts. Defaults to True. + Returns ------- output : tvm.IRModule @@ -792,12 +1803,14 @@ def forward(self, input): # Use the importer to import the ExportedProgram to Relax. mod: tvm.IRModule = from_exported_program(exported_program) """ - # decompose into Core ATen operators - exported_program.run_decompositions() + # Conditionally decompose into Core ATen operators + if run_ep_decomposition: + exported_program = exported_program.run_decompositions() return ExportedProgramImporter().from_exported_program( exported_program, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple, + custom_convert_map, ) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 754129ffdeb8..f2a6c9e6546b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -33,11 +33,12 @@ class TorchFXImporter(BaseFXGraphImporter): import torch # type: ignore from torch import fx - def __init__(self) -> None: + def __init__(self, default_image_layout: str = "NCHW") -> None: import torch # type: ignore super().__init__() self.named_modules: Dict[str, torch.Module] = None + self.default_image_layout = default_image_layout ########## Utilities ########## @@ -96,6 +97,26 @@ def _log1p(self, node: fx.Node) -> relax.Var: one = relax.const(1, x.struct_info.dtype) return self.block_builder.emit(relax.op.log(relax.op.add(x, one))) + def _sqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]: + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.sqrt(x)) + + def _rsqrt(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = x.struct_info.dtype + + # Check if input is integer type and convert to float32 if needed + if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]: + x = self.block_builder.emit(relax.op.astype(x, "float32")) + + return self.block_builder.emit(relax.op.rsqrt(x)) + def _log_softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -460,7 +481,6 @@ def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, # recompute_scale_factor=None, antialias=False) - # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout data = self.env[node.args[0]] size = ( node.args[1] @@ -503,13 +523,26 @@ def _interpolate(self, node: fx.Node) -> relax.Var: if size is None: shape = self.shape_of(data) assert isinstance(shape, relax.ShapeExpr) + # Determine spatial dimension indices based on layout + # NCHW: spatial dims are [2, 3, ...] (skip batch and channel) + # NHWC: spatial dims are [1, 2, ...] (skip batch, before channel) + if self.default_image_layout == "NHWC": + spatial_start = 1 + spatial_end = len(shape) - 1 + else: # NCHW or other layouts + spatial_start = 2 + spatial_end = len(shape) + if isinstance(scale_factor, tuple): - assert len(scale_factor) == len(shape) - 2 + assert len(scale_factor) == spatial_end - spatial_start size = tuple( - int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + int(shape[i].value * scale_factor[i - spatial_start]) + for i in range(spatial_start, spatial_end) ) else: - size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + size = tuple( + int(shape[i].value * scale_factor) for i in range(spatial_start, spatial_end) + ) if method.startswith("nearest"): method = "nearest_neighbor" @@ -525,7 +558,11 @@ def _interpolate(self, node: fx.Node) -> relax.Var: return self.block_builder.emit( relax.op.image.resize2d( - data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + data, + size, + layout=self.default_image_layout, + method=method, + coordinate_transformation_mode=coord_trans, ) ) @@ -632,7 +669,17 @@ def _size(self, node: fx.Node) -> relax.Expr: ########## Creation ########## def _inplace_copy(self, node: fx.Node) -> relax.Var: + dest = self.env[node.args[0]] src = self.env[node.args[1]] + + if src.struct_info.dtype != dest.struct_info.dtype: + src = self.block_builder.emit(relax.op.astype(src, dest.struct_info.dtype)) + + dest_shape = self.shape_of(dest) + src_shape = self.shape_of(src) + if dest_shape != src_shape: + src = self.block_builder.emit(relax.op.broadcast_to(src, dest_shape)) + self.env[node.args[0]] = src return src @@ -710,12 +757,6 @@ def _getattr(self, node: fx.Node) -> relax.Var: return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) - def _sym_size_int(self, node: fx.Node) -> relax.Expr: - x = self.env[node.args[0]] - shape = self.shape_of(x) - idx = node.args[1] - return self.block_builder.emit(relax.const(shape[idx].value, "int32")) - def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) -> List[relax.Var]: inputs = list() for idx, (shape, dtype) in enumerate(input_info): @@ -825,7 +866,7 @@ def create_convert_map( "relu": self._unary_op(relax.op.nn.relu), "relu6": self._unary_op(relax.op.nn.relu6), "round": self._round, - "rsqrt": self._unary_op(relax.op.rsqrt), + "rsqrt": self._rsqrt, "selu": self._unary_op(relax.op.nn.selu), "sigmoid": self._unary_op(relax.op.sigmoid), "sign": self._unary_op(relax.op.sign), @@ -834,7 +875,7 @@ def create_convert_map( "sinh": self._unary_op(relax.op.sinh), "softmax": self._softmax, "softplus": self._softplus, - "sqrt": self._unary_op(relax.op.sqrt), + "sqrt": self._sqrt, "square": self._unary_op(relax.op.square), "tan": self._unary_op(relax.op.tan), "tanh": self._unary_op(relax.op.tanh), @@ -996,17 +1037,6 @@ def create_convert_map( "item": self._item, } - def update_convert_map(self, custom_convert_map: dict): - """Update self.convert_map with custom convert map - - Parameters - ---------- - custom_convert_map : Dictionary of str to Relax op - A custom op conversion map in the same format as self.convert_map - """ - - self.convert_map.update(custom_convert_map) - def from_fx( self, model, @@ -1042,7 +1072,7 @@ def from_fx( dtype = self._convert_data_type(str(param.data.dtype)) inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype))) self.params[param] = inputs[-1] - params.append(tvm.nd.array(param.data.cpu().numpy())) + params.append(tvm.runtime.tensor(param.data.cpu().numpy())) else: func_attrs = None @@ -1126,6 +1156,7 @@ def from_fx( unwrap_unit_return_tuple: bool = False, no_bind_return_tuple: bool = False, custom_convert_map: dict = None, + default_image_layout: str = "NCHW", ) -> tvm.IRModule: """Convert a PyTorch FX GraphModule to a Relax program @@ -1151,6 +1182,10 @@ def from_fx( custom_convert_map : Dictionary of str to Relax op A custom op conversion map in the same format as TorchFXImporter.convert_map + default_image_layout : str + The default layout for image operations (e.g., "NCHW" or "NHWC"). + Default is "NCHW" which is the standard PyTorch layout. + Returns ------- output : tvm.IRModule @@ -1218,7 +1253,7 @@ def forward(self, input): to print out the tabular representation of the PyTorch module, and then check the placeholder rows in the beginning of the tabular. """ - return TorchFXImporter().from_fx( + return TorchFXImporter(default_image_layout=default_image_layout).from_fx( model, input_info, keep_params_as_input, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index fd3672368b68..19096decd932 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -27,6 +27,7 @@ call_dps_packed, call_inplace_packed, call_pure_packed, + call_py_func, call_tir, call_tir_inplace, call_tir_with_grad, @@ -154,6 +155,7 @@ tanh, trunc, ) +from .vision import all_class_non_max_suppression def _register_op_make(): diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py index 1d16a024d1d4..867c43e4d85b 100644 --- a/python/tvm/relax/op/_ffi_api.py +++ b/python/tvm/relax/op/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op", __name__) +tvm_ffi.init_ffi_api("relax.op", __name__) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index b0570344e5a0..ffa19fbaa060 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -22,7 +22,7 @@ import tvm import tvm.runtime from tvm.runtime.object import Object -from tvm.runtime import ObjectGeneric +from tvm.runtime import ObjectConvertible from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var @@ -304,6 +304,42 @@ def call_dps_packed( return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore +@args_converter.auto +def call_py_func( + func_name: str, + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], +) -> Call: + """ + Call a Python function and return the output. + + Parameters + ---------- + func_name : str + The name of the Python function to call. This should correspond to a function + in the IRModule's pyfuncs attribute. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_py_func output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + ret: Call + A call node for the call_py_func operator. + """ + args = _wrap_inline_arg_tuple(args) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore + + @args_converter.auto def call_builtin_with_ctx( func: Union[str, Expr], @@ -414,7 +450,7 @@ def render_object(val: tvm.Object) -> str: ret: str A string representing the value, ideally human-readable """ - if isinstance(val, tvm.nd.NDArray): + if isinstance(val, tvm.runtime.Tensor): return str(val) if isinstance(val, tvm.ir.Array): fields = ", ".join([render_object(val[i]) for i in range(len(val))]) @@ -422,20 +458,20 @@ def render_object(val: tvm.Object) -> str: return str(val) -@tvm.register_func("relax.run.shape_to_tensor") -def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) -> tvm.nd.NDArray: +@tvm.register_global_func("relax.run.shape_to_tensor") +def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) -> tvm.runtime.Tensor: """ - Takes a ShapeTuple and convert it to NDArray. + Takes a ShapeTuple and convert it to Tensor. Parameters ---------- shape_tuple: tvm.runtime.ShapeTuple - Shape tuple that we want to convert to NDArray at runtime + Shape tuple that we want to convert to Tensor at runtime """ - return tvm.nd.array([int(v) for v in shape_tuple]) + return tvm.runtime.tensor([int(v) for v in shape_tuple]) -@tvm.register_func("relax.run.print") +@tvm.register_global_func("relax.run.print") def relax_print(format_str: str, *format_args: tvm.Object) -> None: """ Takes a list of values to print, formats with the given format string. @@ -483,7 +519,7 @@ def print(*values: List[Expr], format: Union[str, Expr] = "") -> Expr: return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member -@tvm.register_func("relax.run.assert_op") +@tvm.register_global_func("relax.run.assert_op") def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None: """ A variadic function. The first value serves as the assertion condition: @@ -514,7 +550,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob if isinstance(condition, (bool, int)): val = condition - elif isinstance(condition, tvm.nd.NDArray): + elif isinstance(condition, tvm.runtime.Tensor): # may happen if the original program had unknown shape or dtype for the tensor's type dtype = condition.dtype if dtype != "bool": @@ -528,7 +564,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob else: # should be guaranteed by the type system raise ValueError( - f"The condition for relax assert must be a bool, int, or NDArray, " + f"The condition for relax assert must be a bool, int, or Tensor, " f"but received a {type(condition)}." ) @@ -744,7 +780,7 @@ def call_pure_packed( sinfo() if callable(sinfo) else sinfo.asobject() - if isinstance(sinfo, ObjectGeneric) + if isinstance(sinfo, ObjectConvertible) else sinfo for sinfo in sinfo_args ] @@ -813,7 +849,7 @@ def to_vdevice(data, dst_vdevice) -> Expr: return _ffi_api.to_vdevice(data, dst_vdevice) # type: ignore -def hint_on_device(data, dst_vdevice) -> Expr: +def hint_on_device(data, dst_vdevice, memory_scope="global") -> Expr: """It provides a hint specifying the device on which the input data should be executed. This hint is utilized by RealizeVDevice to propagate the virtual device." @@ -822,12 +858,15 @@ def hint_on_device(data, dst_vdevice) -> Expr: data : Expr The tensor to be copied. - dst_device : VDevice + dst_device : Device The destination device where the data is supposed to be executed. + memory_scope: String + Memory scope of buffer on target device. + Returns ------- result : Expr The result. """ - return _ffi_api.hint_on_device(data, dst_vdevice) # type: ignore + return _ffi_api.hint_on_device(data, dst_vdevice, memory_scope) # type: ignore diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py index a7f48af57697..0e5955f6e47d 100644 --- a/python/tvm/relax/op/builtin/_ffi_api.py +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.builtin""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.builtin", __name__) +tvm_ffi.init_ffi_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/ccl/_ffi_api.py b/python/tvm/relax/op/ccl/_ffi_api.py index bf605aae6ab0..f0796d3da318 100644 --- a/python/tvm/relax/op/ccl/_ffi_api.py +++ b/python/tvm/relax/op/ccl/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Operators serving for Collective Communications Library (CCL) operators""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.ccl", __name__) +tvm_ffi.init_ffi_api("relax.op.ccl", __name__) diff --git a/python/tvm/relax/op/distributed/_ffi_api.py b/python/tvm/relax/op/distributed/_ffi_api.py index 394cb8c262b2..fa1c163794b9 100644 --- a/python/tvm/relax/op/distributed/_ffi_api.py +++ b/python/tvm/relax/op/distributed/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.distributed""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.dist", __name__) +tvm_ffi.init_ffi_api("relax.op.dist", __name__) diff --git a/python/tvm/relax/op/grad/_ffi_api.py b/python/tvm/relax/op/grad/_ffi_api.py index 415d590f01f0..1a8ebb09aa8d 100644 --- a/python/tvm/relax/op/grad/_ffi_api.py +++ b/python/tvm/relax/op/grad/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.grad""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.grad", __name__) +tvm_ffi.init_ffi_api("relax.op.grad", __name__) diff --git a/python/tvm/relax/op/image/__init__.py b/python/tvm/relax/op/image/__init__.py index 10ef635cbfd3..15c1847b28d6 100644 --- a/python/tvm/relax/op/image/__init__.py +++ b/python/tvm/relax/op/image/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Image operators.""" -from .image import resize2d +from .image import grid_sample, resize2d diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py index 8c813231f9a0..8147a155cb76 100644 --- a/python/tvm/relax/op/image/_ffi_api.py +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.image", __name__) +tvm_ffi.init_ffi_api("relax.op.image", __name__) diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py index afadbf35fb6b..893f7af90fb7 100644 --- a/python/tvm/relax/op/image/image.py +++ b/python/tvm/relax/op/image/image.py @@ -130,3 +130,52 @@ def resize2d( extrapolation_value, out_dtype, ) + + +def grid_sample( + data: Expr, + grid: Expr, + method: str = "bilinear", + layout: str = "NCHW", + padding_mode: str = "zeros", + align_corners: bool = False, +) -> Expr: + """Applies grid sampling to input feature map. + + Given data and grid, the output is computed by sampling from data using + the grid coordinates. + + Parameters + ---------- + data : relax.Expr + The input data tensor with shape [N, C, H, W] for NCHW layout. + + grid : relax.Expr + The grid tensor with shape [N, H_out, W_out, 2]. The values are normalized + to [-1, 1], where (-1, -1) is the top-left corner and (1, 1) is the bottom-right. + + method : str + Interpolation method. Can be 'nearest', 'bilinear', or 'bicubic'. + + layout : str + Layout of the input data. Default is 'NCHW'. + + padding_mode : str + Padding mode for outside grid values. Can be 'zeros', 'border', or 'reflection'. + + align_corners : bool + If True, the corner pixels of the input and output tensors are aligned. + + Returns + ------- + result : relax.Expr + The sampled output tensor with shape [N, C, H_out, W_out]. + """ + return _ffi_api.grid_sample( # type: ignore + data, + grid, + method, + layout, + padding_mode, + align_corners, + ) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index bb134f114855..ee486b0ab69c 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -642,7 +642,7 @@ def index_put( [0.0, 3.0, 0.0], ] """ - if not isinstance(indices, (list, tuple)): + if isinstance(indices, (list, tuple)): indices = RxTuple(indices) return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore diff --git a/python/tvm/relax/op/memory/_ffi_api.py b/python/tvm/relax/op/memory/_ffi_api.py index fb829b7db953..05dbf534c7f5 100644 --- a/python/tvm/relax/op/memory/_ffi_api.py +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.memory""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.memory", __name__) +tvm_ffi.init_ffi_api("relax.op.memory", __name__) diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index 95adc782092f..a7f6f91e182a 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -70,7 +70,7 @@ def view( relative_byte_offset: Optional[Expr] - The offset of the output NDArray, relative to the byte offset + The offset of the output Tensor, relative to the byte offset of `data`. If `None`, the offset of the view is the same as the offset of `data`. diff --git a/python/tvm/relax/op/nn/_ffi_api.py b/python/tvm/relax/op/nn/_ffi_api.py index b5f735127ec2..d58fa186fc7c 100644 --- a/python/tvm/relax/op/nn/_ffi_api.py +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.nn", __name__) +tvm_ffi.init_ffi_api("relax.op.nn", __name__) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 9c15cdd96613..229a789a45ef 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -16,344 +16,349 @@ # under the License. """The attributes node used for Relax operators""" from tvm.ir import Attrs -import tvm.ffi +import tvm_ffi -@tvm.ffi.register_object("relax.attrs.CallTIRWithGradAttrs") +@tvm_ffi.register_object("relax.attrs.CallTIRWithGradAttrs") class CallTIRWithGradAttrs(Attrs): """Attributes used in call_tir_with_grad operator""" -@tvm.ffi.register_object("relax.attrs.InitAttrs") +@tvm_ffi.register_object("relax.attrs.InitAttrs") class InitAttrs(Attrs): """Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator""" -@tvm.ffi.register_object("relax.attrs.TriluAttrs") +@tvm_ffi.register_object("relax.attrs.TriluAttrs") class TriluAttrs(Attrs): """Attributes used in tril and triu operator""" -@tvm.ffi.register_object("relax.attrs.AstypeAttrs") +@tvm_ffi.register_object("relax.attrs.AstypeAttrs") class AstypeAttrs(Attrs): """Attributes used in astype operator""" -@tvm.ffi.register_object("relax.attrs.TakeAttrs") +@tvm_ffi.register_object("relax.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes used in take operator""" -@tvm.ffi.register_object("relax.attrs.StridedSliceAttrs") +@tvm_ffi.register_object("relax.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" -@tvm.ffi.register_object("relax.attrs.MatmulAttrs") +@tvm_ffi.register_object("relax.attrs.MatmulAttrs") class MatmulAttrs(Attrs): """Attributes for matmul operator""" -@tvm.ffi.register_object("relax.attrs.Conv2DAttrs") +@tvm_ffi.register_object("relax.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" -@tvm.ffi.register_object("relax.attrs.Conv3DAttrs") +@tvm_ffi.register_object("relax.attrs.Conv3DAttrs") class Conv3DAttrs(Attrs): """Attributes for nn.conv3d""" -@tvm.ffi.register_object("relax.attrs.Conv2DTransposeAttrs") +@tvm_ffi.register_object("relax.attrs.Conv2DTransposeAttrs") class Conv2DTransposeAttrs(Attrs): """Attributes for nn.conv2d_transpose""" -@tvm.ffi.register_object("relax.attrs.Pool2DAttrs") +@tvm_ffi.register_object("relax.attrs.Pool2DAttrs") class Pool2DAttrs(Attrs): """Attributes for nn.max_pool2d""" -@tvm.ffi.register_object("relax.attrs.AdaptivePool2DAttrs") +@tvm_ffi.register_object("relax.attrs.AdaptivePool2DAttrs") class AdaptivePool2DAttrs(Attrs): """Attributes for 2d adaptive pool operator""" -@tvm.ffi.register_object("relax.attrs.SoftmaxAttrs") +@tvm_ffi.register_object("relax.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" -@tvm.ffi.register_object("relax.attrs.BatchNormAttrs") +@tvm_ffi.register_object("relax.attrs.BatchNormAttrs") class BatchNormAttrs(Attrs): """Attributes used in batch_norm operator""" -@tvm.ffi.register_object("relax.attrs.LayerNormAttrs") +@tvm_ffi.register_object("relax.attrs.LayerNormAttrs") class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" -@tvm.ffi.register_object("relax.attrs.InstanceNormAttrs") +@tvm_ffi.register_object("relax.attrs.InstanceNormAttrs") class InstanceNormAttrs(Attrs): """Attributes used in instance_norm operator""" -@tvm.ffi.register_object("relax.attrs.DropoutAttrs") +@tvm_ffi.register_object("relax.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for dropout operator""" -@tvm.ffi.register_object("relax.attrs.StatisticalAttrs") +@tvm_ffi.register_object("relax.attrs.StatisticalAttrs") class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" -@tvm.ffi.register_object("relax.attrs.ConcatAttrs") +@tvm_ffi.register_object("relax.attrs.ConcatAttrs") class ConcatAttrs(Attrs): """Attributes for concat operator""" -@tvm.ffi.register_object("relax.attrs.ExpandDimsAttrs") +@tvm_ffi.register_object("relax.attrs.ExpandDimsAttrs") class ExpandDimsAttrs(Attrs): """Attributes for expand_dims operator""" -@tvm.ffi.register_object("relax.attrs.PermuteDimsAttrs") +@tvm_ffi.register_object("relax.attrs.PermuteDimsAttrs") class PermuteDimsAttrs(Attrs): """Attributes for permute_dims operator""" -@tvm.ffi.register_object("relax.attrs.SortAttrs") +@tvm_ffi.register_object("relax.attrs.SortAttrs") class SortAttrs(Attrs): """Attributes for sort operator""" -@tvm.ffi.register_object("relax.attrs.ArgsortAttrs") +@tvm_ffi.register_object("relax.attrs.ArgsortAttrs") class ArgsortAttrs(Attrs): """Attributes for argsort operator""" -@tvm.ffi.register_object("relax.attrs.SplitAttrs") +@tvm_ffi.register_object("relax.attrs.SplitAttrs") class SplitAttrs(Attrs): """Attributes used in split operator""" -@tvm.ffi.register_object("relax.attrs.SqueezeAttrs") +@tvm_ffi.register_object("relax.attrs.SqueezeAttrs") class SqueezeAttrs(Attrs): """Attributes for squeeze operator""" -@tvm.ffi.register_object("relax.attrs.StackAttrs") +@tvm_ffi.register_object("relax.attrs.StackAttrs") class StackAttrs(Attrs): """Attributes for concat operator""" -@tvm.ffi.register_object("relax.attrs.IndexPutAttrs") +@tvm_ffi.register_object("relax.attrs.IndexPutAttrs") class IndexPutAttrs(Attrs): """Attributes for index_put operator""" -@tvm.ffi.register_object("relax.attrs.LayoutTransformAttrs") +@tvm_ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" -@tvm.ffi.register_object("relax.attrs.Resize2DAttrs") +@tvm_ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" -@tvm.ffi.register_object("relax.attrs.ArgmaxArgminAttrs") +@tvm_ffi.register_object("relax.attrs.ArgmaxArgminAttrs") class ArgmaxArgminAttrs(Attrs): """Attributes for argmax/argmin operator""" -@tvm.ffi.register_object("relax.attrs.RepeatAttrs") +@tvm_ffi.register_object("relax.attrs.RepeatAttrs") class RepeatAttrs(Attrs): """Attributes for repeat operator""" -@tvm.ffi.register_object("relax.attrs.TileAttrs") +@tvm_ffi.register_object("relax.attrs.TileAttrs") class TileAttrs(Attrs): """Attributes for tile operator""" -@tvm.ffi.register_object("relax.attrs.ScanopAttrs") +@tvm_ffi.register_object("relax.attrs.ScanopAttrs") class ScanopAttrs(Attrs): """Attributes for scan operators""" -@tvm.ffi.register_object("relax.attrs.TopKAttrs") +@tvm_ffi.register_object("relax.attrs.TopKAttrs") class TopKAttrs(Attrs): """Attributes for topk operators""" -@tvm.ffi.register_object("relax.attrs.EinsumAttrs") +@tvm_ffi.register_object("relax.attrs.EinsumAttrs") class EinsumAttrs(Attrs): """Attributes for einsum operator""" -@tvm.ffi.register_object("relax.attrs.FlipAttrs") +@tvm_ffi.register_object("relax.attrs.FlipAttrs") class FlipAttrs(Attrs): """Attributes for flip operator""" -@tvm.ffi.register_object("relax.attrs.PadAttrs") +@tvm_ffi.register_object("relax.attrs.PadAttrs") class PadAttrs(Attrs): """Attributes used in pad operator""" -@tvm.ffi.register_object("relax.attrs.MultinomialFromUniformAttrs") +@tvm_ffi.register_object("relax.attrs.MultinomialFromUniformAttrs") class MultinomialFromUniformAttrs(Attrs): """Attributes for multinomial_from_uniform operator""" -@tvm.ffi.register_object("relax.attrs.CallInplacePackedAttrs") +@tvm_ffi.register_object("relax.attrs.CallInplacePackedAttrs") class CallInplacePackedAttrs(Attrs): """Attributes used in call_inplace_packed operator""" -@tvm.ffi.register_object("relax.attrs.CallTIRInplaceAttrs") +@tvm_ffi.register_object("relax.attrs.CallTIRInplaceAttrs") class CallTIRInplaceAttrs(Attrs): """Attributes used in call_tir_inplace operator""" -@tvm.ffi.register_object("relax.attrs.ToVDeviceAttrs") +@tvm_ffi.register_object("relax.attrs.ToVDeviceAttrs") class ToVDeviceAttrs(Attrs): """Attributes used in to_vdevice operator""" -@tvm.ffi.register_object("relax.attrs.HintOnDeviceAttrs") +@tvm_ffi.register_object("relax.attrs.HintOnDeviceAttrs") class HintOnDeviceAttrs(Attrs): """Attributes used in hint_on_device operator""" -@tvm.ffi.register_object("relax.attrs.ScatterCollectiveAttrs") +@tvm_ffi.register_object("relax.attrs.ScatterCollectiveAttrs") class ScatterCollectiveAttrs(Attrs): """Attributes used in scatter collective operators""" -@tvm.ffi.register_object("relax.attrs.AttentionAttrs") +@tvm_ffi.register_object("relax.attrs.AttentionAttrs") class AttentionAttrs(Attrs): """Attributes used in attention operator""" -@tvm.ffi.register_object("relax.attrs.Conv1DAttrs") +@tvm_ffi.register_object("relax.attrs.AllClassNonMaximumSuppressionAttrs") +class AllClassNonMaximumSuppressionAttrs(Attrs): + """Attributes for vision.all_class_non_max_suppression""" + + +@tvm_ffi.register_object("relax.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" -@tvm.ffi.register_object("relax.attrs.Conv1DTransposeAttrs") +@tvm_ffi.register_object("relax.attrs.Conv1DTransposeAttrs") class Conv1DTransposeAttrs(Attrs): """Attributes for nn.conv1d_transpose""" -@tvm.ffi.register_object("relax.attrs.Pool1DAttrs") +@tvm_ffi.register_object("relax.attrs.Pool1DAttrs") class Pool1DAttrs(Attrs): """Attributes for nn.max_pool1d and nn.avg_pool1d""" -@tvm.ffi.register_object("relax.attrs.Pool3DAttrs") +@tvm_ffi.register_object("relax.attrs.Pool3DAttrs") class Pool3DAttrs(Attrs): """Attributes for nn.max_pool3d and nn.avg_pool3d""" -@tvm.ffi.register_object("relax.attrs.AdaptivePool1DAttrs") +@tvm_ffi.register_object("relax.attrs.AdaptivePool1DAttrs") class AdaptivePool1DAttrs(Attrs): """Attributes for 1d adaptive pool operator""" -@tvm.ffi.register_object("relax.attrs.AdaptivePool3DAttrs") +@tvm_ffi.register_object("relax.attrs.AdaptivePool3DAttrs") class AdaptivePool3DAttrs(Attrs): """Attributes for 3d adaptive pool operator""" -@tvm.ffi.register_object("relax.attrs.LeakyReluAttrs") +@tvm_ffi.register_object("relax.attrs.LeakyReluAttrs") class LeakyReluAttrs(Attrs): """Attributes used in leaky_relu operator""" -@tvm.ffi.register_object("relax.attrs.SoftplusAttrs") +@tvm_ffi.register_object("relax.attrs.SoftplusAttrs") class SoftplusAttrs(Attrs): """Attributes used in softplus operator""" -@tvm.ffi.register_object("relax.attrs.PReluAttrs") +@tvm_ffi.register_object("relax.attrs.PReluAttrs") class PReluAttrs(Attrs): """Attributes used in prelu operator""" -@tvm.ffi.register_object("relax.attrs.PixelShuffleAttrs") +@tvm_ffi.register_object("relax.attrs.PixelShuffleAttrs") class PixelShuffleAttrs(Attrs): """Attributes used in pixel_shuffle operator""" -@tvm.ffi.register_object("relax.attrs.GroupNormAttrs") +@tvm_ffi.register_object("relax.attrs.GroupNormAttrs") class GroupNormAttrs(Attrs): """Attributes used in group_norm operator""" -@tvm.ffi.register_object("relax.attrs.RMSNormAttrs") +@tvm_ffi.register_object("relax.attrs.RMSNormAttrs") class RMSNormAttrs(Attrs): """Attributes used in rms_norm operator""" -@tvm.ffi.register_object("relax.attrs.NLLLossAttrs") +@tvm_ffi.register_object("relax.attrs.NLLLossAttrs") class NLLLossAttrs(Attrs): """Attributes used in nll_loss operator""" -@tvm.ffi.register_object("relax.attrs.AllReduceAttrs") +@tvm_ffi.register_object("relax.attrs.AllReduceAttrs") class AllReduceAttrs(Attrs): """Attributes used in allreduce operator""" -@tvm.ffi.register_object("relax.attrs.AllGatherAttrs") +@tvm_ffi.register_object("relax.attrs.AllGatherAttrs") class AllGatherAttrs(Attrs): """Attributes used in allgather operator""" -@tvm.ffi.register_object("relax.attrs.WrapParamAttrs") +@tvm_ffi.register_object("relax.attrs.WrapParamAttrs") class WrapParamAttrs(Attrs): """Attributes used in wrap_param operator""" -@tvm.ffi.register_object("relax.attrs.QuantizeAttrs") +@tvm_ffi.register_object("relax.attrs.QuantizeAttrs") class QuantizeAttrs(Attrs): """Attributes used in quantize/dequantize operators""" -@tvm.ffi.register_object("relax.attrs.GatherElementsAttrs") +@tvm_ffi.register_object("relax.attrs.GatherElementsAttrs") class GatherElementsAttrs(Attrs): """Attributes for gather_elements operator""" -@tvm.ffi.register_object("relax.attrs.GatherNDAttrs") +@tvm_ffi.register_object("relax.attrs.GatherNDAttrs") class GatherNDAttrs(Attrs): """Attributes for gather_nd operator""" -@tvm.ffi.register_object("relax.attrs.MeshgridAttrs") +@tvm_ffi.register_object("relax.attrs.MeshgridAttrs") class MeshgridAttrs(Attrs): """Attributes for meshgrid operator""" -@tvm.ffi.register_object("relax.attrs.ScatterElementsAttrs") +@tvm_ffi.register_object("relax.attrs.ScatterElementsAttrs") class ScatterElementsAttrs(Attrs): """Attributes for scatter_elements operator""" -@tvm.ffi.register_object("relax.attrs.ScatterNDAttrs") +@tvm_ffi.register_object("relax.attrs.ScatterNDAttrs") class ScatterNDAttrs(Attrs): """Attributes for scatter_nd operator""" -@tvm.ffi.register_object("relax.attrs.SliceScatterAttrs") +@tvm_ffi.register_object("relax.attrs.SliceScatterAttrs") class SliceScatterAttrs(Attrs): """Attributes for slice_scatter operator""" -@tvm.ffi.register_object("relax.attrs.OneHotAttrs") +@tvm_ffi.register_object("relax.attrs.OneHotAttrs") class OneHotAttrs(Attrs): """Attributes for one_hot operator""" diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index ed4b2e2ff928..87fd067e5d1e 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -84,15 +84,15 @@ def unique( ) -@tvm.register_func("relax.run.unique") +@tvm.register_global_func("relax.run.unique") def numpy_unique( - x: tvm.nd.array, + x: tvm.runtime.tensor, sorted: int, return_index: int, return_inverse: int, return_counts: int, axis: Optional[int] = None, -) -> tvm.nd.array: +) -> tvm.runtime.tensor: """Returns the unique elements of the input tensor. Uses numpy.unique to compute unique elements. @@ -107,9 +107,9 @@ def numpy_unique( output_sorted_numpy, indices = np.unique(x_numpy, return_index=True, axis=axis) if sorted: - return tvm.nd.array(output_sorted_numpy) + return tvm.runtime.tensor(output_sorted_numpy) output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis) - return tvm.nd.array(output_numpy) + return tvm.runtime.tensor(output_numpy) def nonzero(x: Expr) -> Expr: @@ -143,7 +143,7 @@ def nonzero(x: Expr) -> Expr: return _ffi_api.nonzero(x) # type: ignore -@tvm.register_func("relax.run.nonzero") -def numpy_nonzero(x: tvm.nd.array) -> tvm.nd.array: +@tvm.register_global_func("relax.run.nonzero") +def numpy_nonzero(x: tvm.runtime.tensor) -> tvm.runtime.tensor: np_result = np.atleast_1d(x.numpy()).nonzero() - return tvm.nd.array(np.stack(np_result, axis=0)) + return tvm.runtime.tensor(np.stack(np_result, axis=0)) diff --git a/python/tvm/relax/op/vision/__init__.py b/python/tvm/relax/op/vision/__init__.py new file mode 100644 index 000000000000..be45458d3647 --- /dev/null +++ b/python/tvm/relax/op/vision/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""VISION operators.""" +from .nms import * diff --git a/ffi/scripts/run_tests.sh b/python/tvm/relax/op/vision/_ffi_api.py old mode 100755 new mode 100644 similarity index 64% rename from ffi/scripts/run_tests.sh rename to python/tvm/relax/op/vision/_ffi_api.py index 8fc9eb95d005..8af761dc5a00 --- a/ffi/scripts/run_tests.sh +++ b/python/tvm/relax/op/vision/_ffi_api.py @@ -1,4 +1,3 @@ -#!/bin/bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,12 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -set -euxo pipefail +"""Constructor APIs""" +import tvm_ffi -BUILD_TYPE=RelWithDebugInfo - -rm -rf build/CMakeFiles build/CMakeCache.txt -cmake -G Ninja -S . -B build -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -cmake --build build --parallel 16 --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests -GTEST_COLOR=1 ctest -V -C ${BUILD_TYPE} --test-dir build --output-on-failure +tvm_ffi.init_ffi_api("relax.op.vision", __name__) diff --git a/python/tvm/relax/op/vision/nms.py b/python/tvm/relax/op/vision/nms.py new file mode 100644 index 000000000000..3714b00b01e2 --- /dev/null +++ b/python/tvm/relax/op/vision/nms.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Non-maximum suppression operator""" +# from tvm import relax # Unused import +from . import _ffi_api + + +def all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", +): + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : relax.Expr + 3-D tensor with shape (batch_size, num_boxes, 4) + scores: relax.Expr + 3-D tensor with shape (batch_size, num_classes, num_boxes) + max_output_boxes_per_class : relax.Expr + The maxinum number of output selected boxes per class + iou_threshold : relax.Expr + IoU test threshold + score_threshold : relax.Expr + Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below. + + Returns + ------- + out : relax.Expr + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. + + TODO: Implement true dynamic output shapes to match ONNX Runtime behavior exactly. + This would eliminate the need for manual trimming and improve memory efficiency. + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. + """ + return _ffi_api.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format + ) diff --git a/python/tvm/relax/op/vm/_ffi_api.py b/python/tvm/relax/op/vm/_ffi_api.py index f3b6cea13b67..eed64e53f036 100644 --- a/python/tvm/relax/op/vm/_ffi_api.py +++ b/python/tvm/relax/op/vm/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.vm""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.op.vm", __name__) +tvm_ffi.init_ffi_api("relax.op.vm", __name__) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 37ef6156e4e7..388f9dbb43cd 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -21,7 +21,7 @@ as it is or serves as a basis to do further composition. """ # pylint: disable=unused-argument -from typing import Union +from typing import Union, Optional import tvm from tvm import meta_schedule as ms @@ -111,6 +111,7 @@ def static_shape_tuning_pipeline( target: Union[str, tvm.target.Target], work_dir: str = "tuning_logs", cpu_weight_prepack: bool = False, + max_trials_per_task: Optional[int] = None, ): """Tune the static shape model and store the log to database. @@ -128,6 +129,16 @@ def static_shape_tuning_pipeline( cpu_weight_prepack : bool Whether to enable the cpu weight prepack feature. + max_trials_per_task : Optional[int] + The maximum number of trials to run per task. + If not specified, it defaults to the value of `total_trials`, and this + may lead to undersubscribed tuning, potentially skipping some tasks + entirely. Explicitly setting both parameters avoids this issue and + provides deterministic resource allocation across all tasks. + For optimal tuning, set `total_trials` to at least + `max_trials_per_task * number_of_tuning_tasks` to ensure + each task receives adequate tuning resources in one iteration. + Note ---- `cpu_weight_prepack` is expected to be `True` when running on CPU for @@ -142,6 +153,7 @@ def static_shape_tuning_pipeline( target="llvm -num-cores 16", work_dir="tuning_logs", cpu_weight_prepack=True, + max_trials_per_task=64, )(mod) ex = tvm.compile(mod, target=target) @@ -151,7 +163,7 @@ def static_shape_tuning_pipeline( # the name should be f"{func_name}_transform_params" params = vm["main_transform_params"](params["main"]) - input_data = tvm.nd.array(np.random.randn(1, 3, 224, 224).astype("float32")) + input_data = tvm.runtime.tensor(np.random.randn(1, 3, 224, 224).astype("float32")) out = vm["main"](input_data, *params).numpy() """ @@ -177,7 +189,12 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I *pre_tuning_layout_rewrite, # Skip tuning if total_trials is 0 ( - transform.MetaScheduleTuneIRMod({}, work_dir, total_trials) + transform.MetaScheduleTuneIRMod( + params={}, + work_dir=work_dir, + max_trials_global=total_trials, + max_trials_per_task=max_trials_per_task, + ) if total_trials > 0 else tvm.transform.Sequential([]) ), diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py b/python/tvm/relax/relax_to_pyfunc_converter.py new file mode 100644 index 000000000000..e527e3f73bac --- /dev/null +++ b/python/tvm/relax/relax_to_pyfunc_converter.py @@ -0,0 +1,1214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax to Python Function Converter. + +This module provides functionality to convert Relax functions to Python functions +that can be executed directly in Python/PyTorch environment. +""" + +import traceback +from typing import Any, Dict, List, Optional, Union + +import numpy # pylint: disable=unused-import +import torch +import torch.nn.functional as F + +import tvm +from tvm import relax +from tvm import runtime +from tvm.ir import IRModule, Op + + +class RelaxToPyFuncConverter: + """Converter that works with IRModule to convert Relax functions to Python functions. + + This converter transforms Relax functions into Python functions that can be executed + directly in Python/PyTorch environment. The conversion maps Relax operators to + corresponding PyTorch APIs and handles special cases like call_tir and call_dps_packed. + """ + + def __init__(self, ir_module: IRModule): + """Initialize the converter with an IRModule. + + Args: + ir_module: The IRModule containing Relax functions to convert + """ + self.ir_module = ir_module + self.operator_map = self._get_op_map() + # Cache for RelaxExpressionConverter instances to avoid recreating them + self._converter_cache = {} + # Cache for operator mappings to avoid repeated lookups + self._op_cache = {} + + def _create_fallback_tensor( + self, shape_hint: Optional[List[int]] = None, dtype: str = "float32" + ) -> torch.Tensor: + """Create a fallback tensor with reasonable default shape.""" + if shape_hint: + # Use the provided shape hint + return torch.zeros(shape_hint, dtype=getattr(torch, dtype)) + else: + # Use a small default shape + return torch.zeros(1, dtype=getattr(torch, dtype)) + + def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: + """Convert specified Relax functions to Python functions. + + Args: + relax_function_names: Name(s) of Relax functions to convert + + Returns: + Updated IRModule with converted Python functions stored in pyfuncs + + Example: + >>> converter = RelaxToPyFuncConverter(ir_mod) + >>> # Convert a single function + >>> converted_ir_mod = converter.convert("my_relax_func") + >>> # Convert multiple functions + >>> converted_ir_mod = converter.convert(["func1", "func2"]) + """ + if isinstance(relax_function_names, str): + relax_function_names = [relax_function_names] + + # Create a copy of the current IRModule + new_ir_mod = self.ir_module.clone() + + # Initialize pyfuncs if not exists + if not hasattr(new_ir_mod, "pyfuncs"): + new_ir_mod.pyfuncs = {} + + # Get Relax function names from IRModule + relax_func_names = [] + for global_var, func in self.ir_module.functions_items(): + if isinstance(func, relax.Function): + relax_func_names.append(global_var.name_hint) + + # Convert each Relax function + for func_name in relax_function_names: + if func_name not in relax_func_names: + raise ValueError(f"Relax function '{func_name}' not found in IRModule") + + # Get the Relax function + relax_func = None + for global_var, func in self.ir_module.functions_items(): + if global_var.name_hint == func_name and isinstance(func, relax.Function): + relax_func = func + break + + if relax_func is None: + raise ValueError(f"Could not find Relax function '{func_name}'") + + # Convert to Python function + py_func = self._convert_relax_func_to_python(relax_func, func_name) + + # Store in pyfuncs + new_ir_mod.pyfuncs[func_name] = py_func + + return new_ir_mod + + def _convert_relax_func_to_python(self, relax_func: relax.Function, func_name: str) -> callable: + """Convert a single Relax function to a Python function with caching.""" + # Get function parameters + params = relax_func.params + + # Create the Python function + def converted_function(*args, **_kwargs): + """Converted Python function from Relax function.""" + # Handle arguments + if len(args) != len(params): + raise ValueError(f"Expected {len(params)} arguments, got {len(args)}") + + # Use cached converter or create new one + if func_name not in self._converter_cache: + self._converter_cache[func_name] = RelaxExpressionConverter( + self.operator_map, self.ir_module, self._op_cache + ) + + # Execute the converted function body + converter = self._converter_cache[func_name] + converter.current_params = params + return converter.convert_expr(relax_func.body, args) + + # Set function metadata + converted_function.__name__ = func_name + converted_function.__doc__ = f"Converted Python function from Relax function: {func_name}" + + return converted_function + + @staticmethod + def _get_op_map() -> Dict[str, str]: + """Get the mapping from Relax operators to PyTorch operators.""" + return { + # Binary operations + "relax.add": "torch.add", + "relax.subtract": "torch.sub", + "relax.multiply": "torch.mul", + "relax.divide": "torch.div", + "relax.power": "torch.pow", + "relax.maximum": "torch.maximum", + "relax.minimum": "torch.minimum", + "relax.floor_divide": "torch.floor_divide", + "relax.mod": "torch.fmod", + "relax.floor_mod": "torch.remainder", + "relax.log_add_exp": "torch.logaddexp", + # Bitwise operations + "relax.bitwise_and": "torch.bitwise_and", + "relax.bitwise_or": "torch.bitwise_or", + "relax.bitwise_xor": "torch.bitwise_xor", + "relax.left_shift": "torch.left_shift", + "relax.right_shift": "torch.right_shift", + # Unary operations + "relax.abs": "torch.abs", + "relax.negative": "torch.neg", + "relax.exp": "torch.exp", + "relax.log": "torch.log", + "relax.sqrt": "torch.sqrt", + "relax.rsqrt": "torch.rsqrt", + "relax.sin": "torch.sin", + "relax.cos": "torch.cos", + "relax.tanh": "torch.tanh", + "relax.sigmoid": "torch.sigmoid", + "relax.square": "torch.square", + "relax.sign": "torch.sign", + "relax.floor": "torch.floor", + "relax.ceil": "torch.ceil", + "relax.round": "torch.round", + "relax.trunc": "torch.trunc", + "relax.clip": "torch.clamp", + "relax.bitwise_not": "torch.bitwise_not", + # Trigonometric functions + "relax.acos": "torch.acos", + "relax.asin": "torch.asin", + "relax.atan": "torch.atan", + "relax.cosh": "torch.cosh", + "relax.sinh": "torch.sinh", + "relax.tan": "torch.tan", + "relax.acosh": "torch.acosh", + "relax.asinh": "torch.asinh", + "relax.atanh": "torch.atanh", + # Special functions + "relax.erf": "torch.erf", + "relax.isfinite": "torch.isfinite", + "relax.isinf": "torch.isinf", + "relax.isnan": "torch.isnan", + # Neural network operations + "relax.nn.relu": "F.relu", + "relax.nn.relu6": "F.relu6", + "relax.nn.gelu": "F.gelu", + "relax.nn.gelu_tanh": "F.gelu", + "relax.nn.softmax": "F.softmax", + "relax.nn.log_softmax": "F.log_softmax", + "relax.nn.dropout": "F.dropout", + "relax.nn.batch_norm": "F.batch_norm", + "relax.nn.layer_norm": "F.layer_norm", + "relax.nn.group_norm": "F.group_norm", + "relax.nn.instance_norm": "F.instance_norm", + "relax.nn.rms_norm": "F.layer_norm", # Approximate mapping + "relax.nn.linear": "F.linear", + "relax.nn.conv1d": "F.conv1d", + "relax.nn.conv2d": "F.conv2d", + "relax.nn.conv3d": "F.conv3d", + "relax.nn.conv1d_transpose": "F.conv_transpose1d", + "relax.nn.conv2d_transpose": "F.conv_transpose2d", + "relax.nn.conv3d_transpose": "F.conv_transpose3d", + "relax.nn.max_pool1d": "F.max_pool1d", + "relax.nn.max_pool2d": "F.max_pool2d", + "relax.nn.max_pool3d": "F.max_pool3d", + "relax.nn.avg_pool1d": "F.avg_pool1d", + "relax.nn.avg_pool2d": "F.avg_pool2d", + "relax.nn.avg_pool3d": "F.avg_pool3d", + "relax.nn.adaptive_avg_pool1d": "F.adaptive_avg_pool1d", + "relax.nn.adaptive_avg_pool2d": "F.adaptive_avg_pool2d", + "relax.nn.adaptive_avg_pool3d": "F.adaptive_avg_pool3d", + "relax.nn.leakyrelu": "F.leaky_relu", + "relax.nn.prelu": "F.prelu", + "relax.nn.selu": "F.selu", + "relax.nn.silu": "F.silu", + "relax.nn.softplus": "F.softplus", + "relax.nn.attention": "F.scaled_dot_product_attention", # Approximate mapping + "relax.nn.cross_entropy_with_logits": "F.cross_entropy", + "relax.nn.nll_loss": "F.nll_loss", + "relax.nn.pad": "F.pad", + "relax.nn.pixel_shuffle": "F.pixel_shuffle", + # Tensor operations + "relax.matmul": "torch.matmul", + "relax.linear": "F.linear", + "relax.einsum": "torch.einsum", + "relax.outer": "torch.outer", + "relax.reshape": "reshape", # Special handling needed + "relax.permute_dims": "permute_dims", # Special handling needed + "relax.expand_dims": "expand_dims", # Special handling needed + "relax.squeeze": "squeeze", # Special handling needed + "relax.concat": "concat", # Special handling needed + "relax.split": "split", # Special handling needed + "relax.stack": "stack", # Special handling needed + "relax.tile": "tile", # Special handling needed + "relax.repeat": "repeat", # Special handling needed + "relax.broadcast_to": "torch.broadcast_to", + "relax.flatten": "torch.flatten", + "relax.flip": "flip", # Special handling needed + "relax.roll": "torch.roll", + "relax.rot90": "torch.rot90", + "relax.meshgrid": "torch.meshgrid", + "relax.one_hot": "F.one_hot", + "relax.layout_transform": "torch.permute", # Approximate mapping + # Indexing operations + "relax.take": "take", # Special handling needed + "relax.gather_elements": "torch.gather", + "relax.gather_nd": "torch.gather", + "relax.scatter_elements": "torch.scatter", + "relax.scatter_nd": "torch.scatter", + "relax.index_put": "torch.index_put", + "relax.index_tensor": "torch.index_select", + "relax.strided_slice": "torch.slice", + "relax.dynamic_strided_slice": "torch.slice", + "relax.slice_scatter": "torch.scatter", + # Reduction operations + "relax.sum": "sum", # Special handling needed + "relax.mean": "mean", # Special handling needed + "relax.max": "max", # Special handling needed + "relax.min": "min", # Special handling needed + "relax.prod": "torch.prod", + "relax.std": "std", # Special handling needed + "relax.variance": "variance", # Special handling needed + "relax.cumsum": "torch.cumsum", + "relax.cumprod": "torch.cumprod", + "relax.argmax": "torch.argmax", + "relax.argmin": "torch.argmin", + # Comparison operations + "relax.equal": "torch.eq", + "relax.not_equal": "torch.ne", + "relax.greater": "torch.gt", + "relax.greater_equal": "torch.ge", + "relax.less": "torch.lt", + "relax.less_equal": "torch.le", + # Logical operations + "relax.logical_and": "torch.logical_and", + "relax.logical_or": "torch.logical_or", + "relax.logical_not": "torch.logical_not", + "relax.logical_xor": "torch.logical_xor", + # Creation operations + "relax.zeros": "torch.zeros", + "relax.ones": "torch.ones", + "relax.full": "torch.full", + "relax.full_like": "torch.full_like", + "relax.zeros_like": "torch.zeros_like", + "relax.ones_like": "torch.ones_like", + "relax.arange": "torch.arange", + "relax.eye": "torch.eye", + "relax.eye_like": "torch.eye", + "relax.tril": "torch.tril", + "relax.triu": "torch.triu", + "relax.hamming_window": "torch.hamming_window", + # Search operations + "relax.where": "torch.where", + "relax.bucketize": "torch.bucketize", + "relax.nonzero": "torch.nonzero", + "relax.unique": "torch.unique", + # Sorting operations + "relax.sort": "torch.sort", + "relax.argsort": "torch.argsort", + "relax.topk": "torch.topk", + # Sampling operations + "relax.multinomial_from_uniform": "torch.multinomial", + # Ternary operations + "relax.ewise_fma": "torch.fma", # Approximate mapping + # Data type operations + "relax.astype": "torch.to", + "relax.wrap_param": "torch.tensor", + # Mask operations + "relax.masked_fill": "torch.masked_fill", + # Quantization operations + "relax.quantize": "torch.quantize_per_tensor", # Approximate mapping + "relax.dequantize": "torch.dequantize", # Approximate mapping + # Special operations (handled separately) + "relax.call_tir": "call_tir", + "relax.call_tir_inplace": "call_tir_inplace", + "relax.call_dps_packed": "call_dps_packed", + "relax.call_pure_packed": "call_pure_packed", + "relax.call_tir_with_grad": "call_tir_with_grad", + "relax.call_builtin_with_ctx": "call_builtin_with_ctx", + "relax.call_inplace_packed": "call_inplace_packed", + "relax.invoke_closure": "invoke_closure", + "relax.invoke_pure_closure": "invoke_pure_closure", + "relax.make_closure": "make_closure", + "relax.null_value": "null_value", + "relax.print": "print", + "relax.shape_of": "shape_of", + "relax.shape_to_tensor": "shape_to_tensor", + "relax.tensor_to_shape": "tensor_to_shape", + "relax.to_vdevice": "to_vdevice", + "relax.hint_on_device": "hint_on_device", + "relax.assert_op": "assert_op", + } + + +class RelaxExpressionConverter: + """Converter that transforms Relax expressions to Python/PyTorch code.""" + + def __init__( + self, + operator_map: Dict[str, str], + ir_module: IRModule = None, + op_cache: Dict[str, str] = None, + ): + """Initialize the expression converter. + + Args: + operator_map: Mapping from Relax operators to PyTorch operators + ir_module: The IRModule containing TIR functions to compile + op_cache: Shared cache for operator mappings to avoid repeated lookups + """ + self.operator_map = operator_map + self.variable_map: Dict[str, Any] = {} + self.current_params: List[relax.Var] = [] + self.ir_module = ir_module + # Use shared operator cache or create new one + self._op_cache = op_cache if op_cache is not None else {} + + def _create_fallback_tensor( + self, shape_hint: Optional[List[int]] = None, dtype: str = "float32" + ) -> torch.Tensor: + """Create a fallback tensor with reasonable default shape.""" + if shape_hint: + return torch.zeros(shape_hint, dtype=getattr(torch, dtype)) + else: + return torch.zeros(1, dtype=getattr(torch, dtype)) + + def convert_expr(self, expr: relax.Expr, args: List[Any]) -> Any: + """Convert a Relax expression to Python/PyTorch equivalent.""" + if isinstance(expr, relax.Var): + return self._convert_var(expr, args) + elif isinstance(expr, relax.Call): + return self._convert_call(expr, args) + elif isinstance(expr, relax.Constant): + return self._convert_constant(expr) + elif isinstance(expr, relax.SeqExpr): + return self._convert_seq_expr(expr, args) + elif isinstance(expr, relax.Tuple): + return self._convert_tuple(expr, args) + elif isinstance(expr, relax.TupleGetItem): + return self._convert_tuple_get_item(expr, args) + elif isinstance(expr, relax.If): + return self._convert_if(expr, args) + elif isinstance(expr, relax.ShapeExpr): + return self._convert_shape_expr(expr) + else: + # Fallback for unknown expression types + return f"" + + def _convert_var(self, var: relax.Var, args: List[Any]) -> Any: + """Convert a Relax variable to Python equivalent.""" + if hasattr(var, "name_hint"): + var_name = var.name_hint + + # Check if it's a function parameter + for i, param in enumerate(self.current_params): + if hasattr(param, "name_hint") and param.name_hint == var_name: + return args[i] + + # Check if it's a bound variable + if var_name in self.variable_map: + return self.variable_map[var_name] + + # Try to infer shape from var's type annotation + if hasattr(var, "struct_info") and hasattr(var.struct_info, "shape"): + shape = var.struct_info.shape + if shape and len(shape) > 0: + # Convert symbolic shapes to concrete values + concrete_shape = [] + for dim in shape: + if isinstance(dim, int): + concrete_shape.append(dim) + else: + # For symbolic dimensions, use a reasonable default + concrete_shape.append(1) + return torch.zeros(concrete_shape, dtype=torch.float32) + + if args and isinstance(args[0], torch.Tensor): + return torch.zeros_like(args[0]) + # Use fallback tensor with shape inference + return self._create_fallback_tensor() + return self._create_fallback_tensor() + + def _convert_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax call to Python/PyTorch equivalent.""" + op = call.op + + # Handle different types of calls + if isinstance(op, relax.GlobalVar): + # Function call + return self._convert_function_call(call, args) + elif isinstance(op, Op): + # Operator call + return self._convert_operator_call(call, args) + elif isinstance(op, relax.ExternFunc): + # External function call (like call_tir, call_dps_packed) + return self._convert_extern_func_call(call, args) + else: + return self._create_fallback_tensor() + + def _convert_function_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax function call.""" + func_name = call.op.name_hint + call_args = [self.convert_expr(arg, args) for arg in call.args] + + # Handle special cases + if func_name in ["call_tir", "call_tir_inplace"]: + return self._convert_call_tir(call, args) + elif func_name in ["call_dps_packed", "call_pure_packed"]: + return self._convert_call_dps_packed(call, args) + else: + # Regular function call - return first argument as fallback + return call_args[0] if call_args else self._create_fallback_tensor() + + def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax operator call to PyTorch equivalent.""" + op_name = call.op.name + call_args = [self.convert_expr(arg, args) for arg in call.args] + + # Use cached operator mapping or look it up + if op_name not in self._op_cache: + self._op_cache[op_name] = self.operator_map.get(op_name) + pytorch_op = self._op_cache[op_name] + if pytorch_op: + try: + # Handle special operations + if pytorch_op == "call_tir": + return self._convert_call_tir(call, args) + elif pytorch_op == "call_tir_inplace": + return self._convert_call_tir(call, args) + elif pytorch_op == "call_dps_packed": + return self._convert_call_dps_packed(call, args) + elif pytorch_op == "call_pure_packed": + return self._convert_call_dps_packed(call, args) + elif pytorch_op == "expand_dims": + return self._convert_expand_dims(call, args) + elif pytorch_op in ["sum", "mean", "max", "min", "std", "variance"]: + return self._convert_reduction_op(call, args, pytorch_op) + elif pytorch_op == "squeeze": + return self._convert_squeeze(call, args) + elif pytorch_op in ["concat", "split", "stack"]: + return self._convert_tensor_ops(call, args, pytorch_op) + elif pytorch_op == "reshape": + return self._convert_reshape(call, args) + elif pytorch_op == "permute_dims": + return self._convert_permute_dims(call, args) + elif pytorch_op == "take": + return self._convert_take(call, args) + elif pytorch_op == "flip": + return self._convert_flip(call, args) + elif pytorch_op == "tile": + return self._convert_tile(call, args) + elif pytorch_op == "repeat": + return self._convert_repeat(call, args) + # Handle special cases for PyTorch operations + elif pytorch_op.startswith("F."): + return self._handle_functional_operation(pytorch_op, call, call_args) + elif pytorch_op.startswith("torch."): + # Regular PyTorch operation + func_name = pytorch_op[6:] # Remove "torch." prefix + func = getattr(torch, func_name) + return func(*call_args) + else: + # Direct function reference - use getattr for safer access + if pytorch_op.startswith("torch."): + module = torch + func_name = pytorch_op[6:] # Remove "torch." prefix + elif pytorch_op.startswith("F."): + module = F + func_name = pytorch_op[2:] # Remove "F." prefix + else: + return ( + f"" + ) + + func = getattr(module, func_name, None) + if func is None: + return ( + f"" + ) + return func(*call_args) + except (AttributeError, TypeError, ValueError) as error: + # This allows the test framework to catch and handle the errors appropriately + if pytorch_op.startswith("torch.") or pytorch_op.startswith("F."): + raise error + # Fallback to string representation for non-PyTorch operations + return f"" + else: + # Unknown operator + return f"" + + def _handle_functional_operation( + self, pytorch_op: str, call: relax.Call, call_args: List[Any] + ) -> Any: + """Handle PyTorch functional operations with special parameter handling.""" + # Neural network function + func_name = pytorch_op[2:] # Remove "F." prefix + func = getattr(F, func_name) + + # Special handling for functions that need dim parameter + if func_name in ["softmax", "log_softmax"]: + # Extract axis from call.attrs and convert to dim + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + return func(call_args[0], dim=axis) + else: + # Default to last dimension if no axis specified + return func(call_args[0], dim=-1) + else: + return func(*call_args) + + def _convert_extern_func_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert an external function call.""" + func_name = call.op.global_symbol + call_args = [self.convert_expr(arg, args) for arg in call.args] + + if func_name in ["call_tir", "call_tir_inplace"]: + return self._convert_call_tir(call, args) + elif func_name in ["call_dps_packed", "call_pure_packed"]: + return self._convert_call_dps_packed(call, args) + else: + return call_args[0] if call_args else self._create_fallback_tensor() + + def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: + """Convert call_tir to Python equivalent with DLPack conversion.""" + # Extract TIR function name and arguments + tir_func = call.args[0] + tir_args = call.args[1] if len(call.args) > 1 else [] + out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + + # Get function name + if isinstance(tir_func, relax.GlobalVar): + func_name = tir_func.name_hint + else: + # Convert the GlobalVar expression + func_name = self.convert_expr(tir_func, args) + if isinstance(func_name, str) and func_name.startswith("<"): + # If it's a placeholder, extract the name + func_name = str(tir_func) + + # Convert arguments to PyTorch tensors + converted_args = [self.convert_expr(arg, args) for arg in tir_args] + + try: + # First, try to get the TIR function from the current IRModule + tir_function = None + if self.ir_module: + # Look for the TIR function in the current IRModule + for global_var, func in self.ir_module.functions.items(): + if global_var.name_hint == func_name and hasattr(func, "body"): + try: + # Compile the TIR function + target = tvm.target.Target("llvm") + with tvm.target.Target(target): + tir_function = tvm.compile(func, target=target) + break + except (RuntimeError, ValueError, TypeError) as compile_e: + print( + f"Warning: Failed to compile TIR function {func_name}: {compile_e}" + ) + continue + + # If not found in current module, try global registry + if tir_function is None: + tir_function = tvm.get_global_func(func_name) + + if tir_function is None: + if len(converted_args) >= 2: + # Simple fallback: just add the tensors + return torch.add(converted_args[0], converted_args[1]) + else: + return converted_args[0] if converted_args else torch.tensor([]) + + # Convert PyTorch tensors to TVM NDArrays via DLPack + tvm_args = [] + for arg in converted_args: + try: + if isinstance(arg, torch.Tensor): + # Convert PyTorch tensor to TVM NDArray via DLPack + tvm_arg = runtime.from_dlpack(torch.to_dlpack(arg)) + tvm_args.append(tvm_arg) + else: + tvm_args.append(arg) + except (AttributeError, TypeError, ValueError): + traceback.print_exc() + tvm_args.append(arg) + + # For call_tir, we need to allocate output tensor + output_shape = None + if out_sinfo and hasattr(out_sinfo, "shape"): + output_shape = out_sinfo.shape + elif converted_args: + # Use the shape of the first input tensor + first_arg = converted_args[0] + if isinstance(first_arg, torch.Tensor): + output_shape = first_arg.shape + + if output_shape is None: + if converted_args and isinstance(converted_args[0], torch.Tensor): + output_shape = converted_args[0].shape + else: + output_shape = (1,) # Default shape + + # Allocate output tensor + output_tensor = runtime.empty(output_shape, dtype="float32") + tvm_args.append(output_tensor) + + # Call the TIR function + try: + tir_function(*tvm_args) + # The result is in the output_tensor we allocated + # Convert result back to PyTorch tensor via DLPack + try: + result = torch.from_dlpack(output_tensor.to_dlpack()) + return result + except AttributeError: + # Fallback: convert to numpy then to PyTorch + numpy_result = output_tensor.numpy() + result = torch.from_numpy(numpy_result) + return result + except (RuntimeError, ValueError, TypeError, AttributeError) as exc: + print(f"Warning: TIR function {func_name} execution failed: {exc}") + traceback.print_exc() + # Fallback to simple addition + if len(converted_args) >= 2: + return torch.add(converted_args[0], converted_args[1]) + else: + return converted_args[0] if converted_args else torch.tensor([]) + + except (RuntimeError, ValueError, TypeError): + traceback.print_exc() + # Fallback implementation instead of error string + if len(converted_args) >= 2: + return torch.add(converted_args[0], converted_args[1]) + else: + return converted_args[0] if converted_args else torch.tensor([]) + + def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: + """Convert call_dps_packed to Python equivalent with DLPack conversion.""" + # Extract packed function name and arguments + packed_func = call.args[0] + packed_args = call.args[1] if len(call.args) > 1 else [] + _out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + + # Get function name + if isinstance(packed_func, relax.GlobalVar): + func_name = packed_func.name_hint + elif isinstance(packed_func, relax.ExternFunc): + func_name = packed_func.global_symbol + else: + func_name = str(packed_func) + + # Convert arguments to PyTorch tensors + converted_args = [] + for arg in packed_args: + converted_arg = self.convert_expr(arg, args) + if isinstance(converted_arg, str) and converted_arg.startswith("<"): + # Handle PrimValue and other special cases + if "PrimValue" in converted_arg: + # Extract the value from PrimValue + try: + # Try to get the actual value from the PrimValue + if hasattr(arg, "value"): + converted_arg = arg.value + else: + converted_arg = 0.0 # Default value + except (AttributeError, ValueError, TypeError): + converted_arg = 0.0 + else: + converted_arg = torch.tensor([]) # Fallback + converted_args.append(converted_arg) + + try: + # Get the packed function from TVM + packed_function = tvm.get_global_func(func_name) + if packed_function is None: + return converted_args[0] if converted_args else torch.tensor([]) + + # Convert PyTorch tensors to TVM NDArrays via DLPack + tvm_args = [] + for arg in converted_args: + if isinstance(arg, torch.Tensor): + # Convert PyTorch tensor to TVM NDArray via DLPack + tvm_arg = runtime.from_dlpack(torch.to_dlpack(arg)) + tvm_args.append(tvm_arg) + else: + tvm_args.append(arg) + + # Call the packed function + result = packed_function(*tvm_args) + + # Convert result back to PyTorch tensor via DLPack + if isinstance(result, runtime.Tensor): + try: + pytorch_result = torch.from_dlpack(result.to_dlpack()) + return pytorch_result + except AttributeError: + # Fallback: convert to numpy then to PyTorch + numpy_result = result.numpy() + pytorch_result = torch.from_numpy(numpy_result) + return pytorch_result + else: + return result + + except (RuntimeError, ValueError, TypeError): + traceback.print_exc() + # Fallback: return the first argument + return converted_args[0] if converted_args else torch.tensor([]) + + def _convert_constant(self, const: relax.Constant) -> Any: + """Convert a Relax constant to Python equivalent.""" + if hasattr(const, "data"): + data = const.data + # Convert TVM NDArray to Python scalar if it's a scalar + if hasattr(data, "numpy"): + numpy_data = data.numpy() + if numpy_data.size == 1: + return float(numpy_data.item()) + else: + # For multi-element arrays, convert to PyTorch tensor + return torch.from_numpy(numpy_data) + elif hasattr(data, "item"): + # Single element tensor + return data.item() + else: + return data + return self._create_fallback_tensor() + + def _convert_seq_expr(self, seq: relax.SeqExpr, args: List[Any]) -> Any: + """Convert a Relax sequence expression.""" + # Convert blocks + for block in seq.blocks: + if hasattr(block, "bindings"): + for binding in block.bindings: + if isinstance(binding, relax.VarBinding): + var_name = binding.var.name_hint + value = self.convert_expr(binding.value, args) + self.variable_map[var_name] = value + + # Convert body + return self.convert_expr(seq.body, args) + + def _convert_tuple(self, tuple_expr: relax.Tuple, args: List[Any]) -> Any: + """Convert a Relax tuple to Python tuple.""" + elements = [self.convert_expr(elem, args) for elem in tuple_expr.fields] + return tuple(elements) + + def _convert_tuple_get_item(self, get_item: relax.TupleGetItem, args: List[Any]) -> Any: + """Convert a Relax tuple get item to Python equivalent.""" + tuple_expr = self.convert_expr(get_item.tuple_value, args) + index = get_item.index + if isinstance(tuple_expr, torch.Tensor): + return tuple_expr[index] if index < len(tuple_expr) else self._create_fallback_tensor() + else: + return self._create_fallback_tensor() + + def _convert_if(self, if_expr: relax.If, args: List[Any]) -> Any: + """Convert a Relax if expression to Python equivalent.""" + condition = self.convert_expr(if_expr.cond, args) + true_branch = self.convert_expr(if_expr.true_branch, args) + false_branch = self.convert_expr(if_expr.false_branch, args) + if isinstance(condition, torch.Tensor) and condition.item(): + return ( + true_branch + if isinstance(true_branch, torch.Tensor) + else self._create_fallback_tensor() + ) + else: + return ( + false_branch + if isinstance(false_branch, torch.Tensor) + else self._create_fallback_tensor() + ) + + def _convert_expand_dims(self, call: relax.Call, args: List[Any]) -> Any: + """Convert expand_dims to torch.unsqueeze with proper axis handling.""" + if len(call.args) < 1: + return self._create_fallback_tensor() + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get the axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, "__iter__") and not isinstance(axis, str): + # It's an array/list, take the first element + axis = list(axis)[0] if len(axis) > 0 else None + + # Handle TVM types + if hasattr(axis, "value"): + # It's a TVM IntImm or similar, get the value + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is None: + return self._create_fallback_tensor() + + # Use torch.unsqueeze with the correct axis + return torch.unsqueeze(tensor_arg, dim=axis) + + def _convert_reduction_op(self, call: relax.Call, args: List[Any], op_name: str) -> Any: + """Convert reduction operations with axis and keepdims parameters.""" + if len(call.args) < 1: + return f"<{op_name}_error: insufficient arguments>" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get axis and keepdims from call.attrs + axis = None + keepdims = False + + if call.attrs: + if hasattr(call.attrs, "axis") and call.attrs.axis is not None: + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, "__iter__") and not isinstance(axis, str): + # It's an array/list, convert to list of ints + axis = [ + int(item.value) if hasattr(item, "value") else int(item) for item in axis + ] + elif hasattr(axis, "value"): + # It's a TVM IntImm, get the value + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if hasattr(call.attrs, "keepdims"): + keepdims = bool(call.attrs.keepdims) + + # Get the PyTorch function + func = getattr(torch, op_name) + + # Call with appropriate parameters + if axis is not None: + # For max and min, PyTorch returns (values, indices) tuple when dim is specified + if op_name in ["max", "min"]: + if isinstance(axis, list) and len(axis) == 1: + axis = axis[0] + elif isinstance(axis, list) and len(axis) > 1: + axis = axis[0] + result = func(tensor_arg, axis, keepdim=keepdims) + if isinstance(result, tuple): + return result[0] + else: + return result + else: + return func(tensor_arg, dim=axis, keepdim=keepdims) + else: + return func(tensor_arg) + + def _convert_squeeze(self, call: relax.Call, args: List[Any]) -> Any: + """Convert squeeze to torch.squeeze with proper axis handling.""" + if len(call.args) < 1: + return "" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis") and call.attrs.axis is not None: + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, "__iter__") and not isinstance(axis, str): + axis = [int(item.value) if hasattr(item, "value") else int(item) for item in axis] + elif hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + # Call torch.squeeze with appropriate parameters + if axis is not None: + return torch.squeeze(tensor_arg, dim=axis) + else: + return torch.squeeze(tensor_arg) + + def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: str) -> Any: + """Convert tensor operations like concat, split, stack.""" + if len(call.args) < 1: + return f"<{op_name}_error: insufficient arguments>" + + # Convert arguments + converted_args = [self.convert_expr(arg, args) for arg in call.args] + + if op_name == "concat": + # torch.cat(tensors, dim=0) + # In Relax, concat takes a tuple of tensors as first argument + if len(converted_args) == 1 and isinstance(converted_args[0], tuple): + # This is a tuple of tensors + tensors = converted_args[0] + else: + # Direct tensor arguments + tensors = converted_args + axis = 0 + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + return torch.cat(tensors, dim=axis) + + elif op_name == "split": + # torch.split(tensor, split_size_or_sections, dim=0) + tensor = converted_args[0] + split_size = converted_args[1] if len(converted_args) > 1 else 1 + axis = 0 + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + # Handle indices_or_sections parameter + if call.attrs and hasattr(call.attrs, "indices_or_sections"): + indices_or_sections = call.attrs.indices_or_sections + if hasattr(indices_or_sections, "value"): + indices_or_sections = int(indices_or_sections.value) + elif isinstance(indices_or_sections, (int, float)): + indices_or_sections = int(indices_or_sections) + + # If indices_or_sections is an integer, it means split into N equal parts + if isinstance(indices_or_sections, int): + total_size = tensor.shape[axis] + split_size = total_size // indices_or_sections + result = torch.split(tensor, split_size, dim=axis) + return result + else: + result = torch.split(tensor, indices_or_sections, dim=axis) + return result + else: + result = torch.split(tensor, split_size, dim=axis) + return result + + elif op_name == "stack": + # torch.stack(tensors, dim=0) + if len(converted_args) == 1 and isinstance(converted_args[0], tuple): + tensors = converted_args[0] + else: + tensors = converted_args + axis = 0 + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + return torch.stack(tensors, dim=axis) + + else: + return f"<{op_name}_error: unsupported operation>" + + def _convert_reshape(self, call: relax.Call, args: List[Any]) -> Any: + """Convert reshape operation.""" + if len(call.args) < 2: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + shape_arg = call.args[1] + + # Convert shape argument to Python tuple + if isinstance(shape_arg, relax.ShapeExpr): + if hasattr(shape_arg, "values"): + shape = tuple( + int(v.value) if hasattr(v, "value") else int(v) for v in shape_arg.values + ) + else: + shape = (int(shape_arg),) + elif isinstance(shape_arg, relax.Constant): + # Constant tensor case + shape_data = shape_arg.data.numpy() + shape = tuple(int(v) for v in shape_data) + else: + # Try to convert as expression + converted_shape = self.convert_expr(shape_arg, args) + if isinstance(converted_shape, (list, tuple)): + shape = tuple(int(v) for v in converted_shape) + else: + shape = (int(converted_shape),) + + return torch.reshape(tensor_arg, shape) + + def _convert_permute_dims(self, call: relax.Call, args: List[Any]) -> Any: + """Convert permute_dims operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract axes from call.attrs + if call.attrs and hasattr(call.attrs, "axes"): + axes = call.attrs.axes + # Handle TVM Array type + if hasattr(axes, "__iter__") and not isinstance(axes, str): + # Convert TVM Array or Python list/tuple to tuple of ints + axes = tuple(int(v.value) if hasattr(v, "value") else int(v) for v in axes) + elif isinstance(axes, (list, tuple)): + axes = tuple(int(v) for v in axes) + else: + axes = (int(axes),) + else: + return "" + + return torch.permute(tensor_arg, axes) + + def _convert_take(self, call: relax.Call, args: List[Any]) -> Any: + """Convert take operation.""" + if len(call.args) < 2: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + indices_arg = self.convert_expr(call.args[1], args) + + # Extract axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + # Use advanced indexing for specific axis + if axis == 0: + return tensor_arg[indices_arg] + else: + # For other axes, we need to use torch.index_select + return torch.index_select(tensor_arg, dim=axis, index=indices_arg) + else: + # No axis specified, use torch.take (flattens the tensor) + return torch.take(tensor_arg, indices_arg) + + def _convert_flip(self, call: relax.Call, args: List[Any]) -> Any: + """Convert flip operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + # Convert single axis to list for torch.flip + dims = [axis] + else: + # Default: flip all dimensions + dims = list(range(tensor_arg.dim())) + + return torch.flip(tensor_arg, dims=dims) + + def _convert_tile(self, call: relax.Call, args: List[Any]) -> Any: + """Convert tile operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract repeats from call.attrs + if call.attrs and hasattr(call.attrs, "repeats"): + repeats = call.attrs.repeats + # Handle TVM Array type + if hasattr(repeats, "__iter__") and not isinstance(repeats, str): + repeats = tuple(int(v.value) if hasattr(v, "value") else int(v) for v in repeats) + elif isinstance(repeats, (list, tuple)): + repeats = tuple(int(v) for v in repeats) + else: + repeats = (int(repeats),) + else: + return "" + + return torch.tile(tensor_arg, dims=repeats) + + def _convert_repeat(self, call: relax.Call, args: List[Any]) -> Any: + """Convert repeat operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract repeats and axis from call.attrs + repeats = 1 + axis = None + + if call.attrs and hasattr(call.attrs, "repeats"): + repeats = call.attrs.repeats + if hasattr(repeats, "value"): + repeats = int(repeats.value) + elif isinstance(repeats, (int, float)): + repeats = int(repeats) + + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + return torch.repeat_interleave(tensor_arg, repeats=repeats, dim=axis) + else: + return torch.repeat_interleave(tensor_arg, repeats=repeats) + + def _convert_shape_expr(self, shape_expr: relax.ShapeExpr) -> Any: + """Convert a Relax shape expression to Python equivalent.""" + if hasattr(shape_expr, "values"): + return f"" + return f"" + + +def convert_relax_to_pyfunc( + ir_module: IRModule, relax_function_names: Union[str, List[str]] +) -> IRModule: + """Convert Relax functions to Python functions. + + Args: + ir_module: The IRModule containing Relax functions + relax_function_names: Name(s) of Relax functions to convert + + Returns: + IRModule with converted Python functions stored in pyfuncs + + Example: + >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, "my_function") + >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, ["func1", "func2"]) + """ + converter = RelaxToPyFuncConverter(ir_module) + return converter.convert(relax_function_names) diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index c143f098328c..e8f6c42435da 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -18,7 +18,7 @@ """The struct info nodes of the Relax language.""" from typing import List, Optional, Union -import tvm.ffi +import tvm_ffi import tvm from tvm.ir import Span, EnvFunc, Array, VDevice @@ -29,7 +29,7 @@ from . import _ffi_api, ty, expr -@tvm.ffi.register_object("relax.ObjectStructInfo") +@tvm_ffi.register_object("relax.ObjectStructInfo") class ObjectStructInfo(StructInfo): """StructInfo of an Object.""" @@ -37,7 +37,7 @@ def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore -@tvm.ffi.register_object("relax.PrimStructInfo") +@tvm_ffi.register_object("relax.PrimStructInfo") class PrimStructInfo(StructInfo): """StructInfo of a primitive POD value. @@ -107,7 +107,7 @@ def __init__( ) # type: ignore -@tvm.ffi.register_object("relax.ShapeStructInfo") +@tvm_ffi.register_object("relax.ShapeStructInfo") class ShapeStructInfo(StructInfo): """StructInfo of a shape value. @@ -136,7 +136,7 @@ def __init__( ) -@tvm.ffi.register_object("relax.TensorStructInfo") +@tvm_ffi.register_object("relax.TensorStructInfo") class TensorStructInfo(StructInfo): """StructInfo of a Tensor value. @@ -180,7 +180,7 @@ def __init__( ) -@tvm.ffi.register_object("relax.TupleStructInfo") +@tvm_ffi.register_object("relax.TupleStructInfo") class TupleStructInfo(StructInfo): """StructInfo of a Tuple value. @@ -197,7 +197,7 @@ def __init__(self, fields: List[StructInfo], span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore -@tvm.ffi.register_object("relax.FuncStructInfo") +@tvm_ffi.register_object("relax.FuncStructInfo") class FuncStructInfo(StructInfo): """StructInfo of a function value. diff --git a/python/tvm/relax/testing/lib_comparator.py b/python/tvm/relax/testing/lib_comparator.py index b15698c8db74..48930f062357 100644 --- a/python/tvm/relax/testing/lib_comparator.py +++ b/python/tvm/relax/testing/lib_comparator.py @@ -63,8 +63,8 @@ def __init__(self, mod, device, verbose=True, rtol=1e-5, atol=1e-5): def compare( self, name: str, - ref_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]], - new_args: Union[List[tvm.nd.NDArray], Tuple[tvm.nd.NDArray, ...]], + ref_args: Union[List[tvm.runtime.Tensor], Tuple[tvm.runtime.Tensor, ...]], + new_args: Union[List[tvm.runtime.Tensor], Tuple[tvm.runtime.Tensor, ...]], ret_indices: Iterable[int], ): """Comparison function, can be overloaded. @@ -103,7 +103,7 @@ def __call__(self, func, name, before_run, ret_val, *args): return if name.startswith("vm.builtin."): return - if any(not isinstance(x, tvm.nd.NDArray) for x in args): + if any(not isinstance(x, tvm.runtime.Tensor) for x in args): return try: self.mod.get_function(name, query_imports=True) @@ -120,7 +120,7 @@ def __call__(self, func, name, before_run, ret_val, *args): ret_indices = (len(args) - 1,) temp_args = [] for i, arg in enumerate(args): - arr = tvm.nd.empty(arg.shape, arg.dtype, device=self.device) + arr = tvm.runtime.empty(arg.shape, arg.dtype, device=self.device) # copy from cpu since we look at different device if i not in ret_indices: temp_cpu = arg.copyto(tvm.cpu()) diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py index 6e7e3d4d197b..fb3564c6f1a1 100644 --- a/python/tvm/relax/testing/nn.py +++ b/python/tvm/relax/testing/nn.py @@ -281,7 +281,7 @@ def _unpack_params(value: object) -> List[relax.Var]: return [] -def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]: +def init_params(mod: tvm.IRModule) -> List[tvm.runtime.Tensor]: """Utility function to initialize model's parameters.""" shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params} params = [] @@ -295,7 +295,7 @@ def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]: shape.append(int(i)) else: raise TypeError("cannot initialize for unknown-shape parameters.") - params.append(tvm.nd.array(np.zeros(shape).astype(np.float32))) + params.append(tvm.runtime.tensor(np.zeros(shape).astype(np.float32))) else: raise TypeError("cannot initialize for unknown-shape parameters.") return params diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 198b07e51ea7..617ba73f09f4 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -21,6 +21,7 @@ import os from typing import Dict, List, Set, Tuple import tvm +import tvm_ffi from tvm.ir.module import IRModule from tvm.relax.expr import Call, DataflowBlock, Var from tvm.runtime.object import Object @@ -70,7 +71,7 @@ def dataflow_alias_analysis( return res_alias_sets, res_tuple_map # type: ignore -@tvm.ffi.register_object("relax.transform.InplaceOpportunity") +@tvm_ffi.register_object("relax.transform.InplaceOpportunity") class InplaceOpportunity(Object): """ Represents an opportunity to make a binding in-place. Exposed only for testing; diff --git a/python/tvm/relax/testing/vm.py b/python/tvm/relax/testing/vm.py index 37bcf870a5df..5516bac17cf7 100644 --- a/python/tvm/relax/testing/vm.py +++ b/python/tvm/relax/testing/vm.py @@ -24,53 +24,53 @@ from tvm.runtime.object import Object -@tvm.register_func("test.vm.move") +@tvm.register_global_func("test.vm.move") def move(src): return src -@tvm.register_func("test.vm.add") +@tvm.register_global_func("test.vm.add") def add(a, b): ret = a.numpy() + b.numpy() - return tvm.nd.array(ret) + return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.mul") +@tvm.register_global_func("test.vm.mul") def mul(a, b): ret = a.numpy() * b.numpy() - return tvm.nd.array(ret) + return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.equal_zero") +@tvm.register_global_func("test.vm.equal_zero") def equal_zero(a): ret = np.all((a.numpy() == 0)) - return tvm.nd.array(ret) + return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.subtract_one") +@tvm.register_global_func("test.vm.subtract_one") def subtract_one(a): ret = np.subtract(a.numpy(), 1) - return tvm.nd.array(ret) + return tvm.runtime.tensor(ret) -@tvm.register_func("test.vm.identity") +@tvm.register_global_func("test.vm.identity") def identity_packed(a, b): - b[:] = tvm.nd.array(a.numpy()) + b[:] = tvm.runtime.tensor(a.numpy()) -@tvm.register_func("test.vm.tile") +@tvm.register_global_func("test.vm.tile") def tile_packed(a, b): - b[:] = tvm.nd.array(np.tile(a.numpy(), (1, 2))) + b[:] = tvm.runtime.tensor(np.tile(a.numpy(), (1, 2))) -@tvm.register_func("test.vm.add_scalar") +@tvm.register_global_func("test.vm.add_scalar") def add_scalar(a, b): return a + b -@tvm.register_func("test.vm.get_device_id") +@tvm.register_global_func("test.vm.get_device_id") def get_device_id(device): - return device.device_id + return device.index def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any]) -> Object: @@ -85,6 +85,6 @@ def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any return res1 -@tvm.register_func("test.vm.check_if_defined") +@tvm.register_global_func("test.vm.check_if_defined") def check_if_defined(obj: tvm.Object) -> tvm.tir.IntImm: return tvm.runtime.convert(obj is not None) diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py index 9b7dbcdee748..25f395830341 100644 --- a/python/tvm/relax/training/_ffi_api.py +++ b/python/tvm/relax/training/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.training""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.training", __name__) +tvm_ffi.init_ffi_api("relax.training", __name__) diff --git a/python/tvm/relax/training/optimizer.py b/python/tvm/relax/training/optimizer.py index d6f503de0564..16a215f87dc3 100644 --- a/python/tvm/relax/training/optimizer.py +++ b/python/tvm/relax/training/optimizer.py @@ -291,7 +291,7 @@ def init(self, params: Union[Var, List[Var]]) -> "SGD": self._set_params_and_dtype(params) self.state = ( # num_steps = 0 - tvm.nd.array(np.zeros((), "int64")), + tvm.runtime.tensor(np.zeros((), "int64")), ) return self @@ -433,10 +433,10 @@ def init(self, params: Union[Var, List[Var]]) -> "MomentumSGD": self._set_params_and_dtype(params) self.state = ( # num_steps = 0 - tvm.nd.array(np.zeros((), "int64")), + tvm.runtime.tensor(np.zeros((), "int64")), # v_{param} is initialized to all zeros *( - tvm.nd.array(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) + tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) for p in self.param_list ), ) @@ -604,17 +604,17 @@ def init(self, params: Union[Var, List[Var]]) -> "Adam": self._set_params_and_dtype(params) self.state = ( # num_steps, beta_0_prod, beta_1_prod - tvm.nd.array(np.zeros((), "int64")), - tvm.nd.array(np.ones((), self.dtype)), - tvm.nd.array(np.ones((), self.dtype)), + tvm.runtime.tensor(np.zeros((), "int64")), + tvm.runtime.tensor(np.ones((), self.dtype)), + tvm.runtime.tensor(np.ones((), self.dtype)), # first_momentum *( - tvm.nd.array(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) + tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) for p in self.param_list ), # second_momentum *( - tvm.nd.array(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) + tvm.runtime.tensor(np.zeros(_get_shape_as_int_list(p), p.struct_info.dtype)) for p in self.param_list ), ) diff --git a/python/tvm/relax/training/trainer.py b/python/tvm/relax/training/trainer.py index fbf48fece9f6..aaaa14dd2812 100644 --- a/python/tvm/relax/training/trainer.py +++ b/python/tvm/relax/training/trainer.py @@ -22,7 +22,7 @@ import tvm from tvm import relax, TVMError from tvm.ir.module import IRModule -from tvm.runtime.ndarray import NDArray +from tvm.runtime._tensor import Tensor class Trainer: @@ -100,12 +100,12 @@ def __init__( ) ] - self._params: List[Optional[NDArray]] = [None] * self._param_num + self._params: List[Optional[Tensor]] = [None] * self._param_num self._param_name_to_pos: Dict[str, int] = { p.name_hint: i for i, p in enumerate(self._param_vars) } - self._states: List[Optional[NDArray]] = [None] * self._state_num + self._states: List[Optional[Tensor]] = [None] * self._state_num self._state_name_to_pos: Dict[str, int] = { s.name_hint: i for i, s in enumerate(self._state_vars) } @@ -129,7 +129,7 @@ def xaiver_uniform_init_params(self): for p in self._param_vars: shape, dtype = self._get_shape_list(p), p.struct_info.dtype self._params.append( - tvm.nd.array( + tvm.runtime.tensor( (np.sqrt(6.0 / np.sum(shape)) * np.random.uniform(-1.0, 1.0, shape)).astype( dtype ), @@ -140,27 +140,27 @@ def xaiver_uniform_init_params(self): def zero_init_params(self): """Zero initialize all parameters. Requires all parameters have static shapes.""" self._params = [ - tvm.nd.array(np.zeros(self._get_shape_list(p), p.struct_info.dtype), self.device) + tvm.runtime.tensor(np.zeros(self._get_shape_list(p), p.struct_info.dtype), self.device) for p in self._param_vars ] def zero_init_states(self): """Zero initialize all states. Requires all states have static shapes.""" self._states = [ - tvm.nd.array(np.zeros(self._get_shape_list(s), s.struct_info.dtype), self.device) + tvm.runtime.tensor(np.zeros(self._get_shape_list(s), s.struct_info.dtype), self.device) for s in self._state_vars ] def load_params( self, - params: Union[List[Union[np.ndarray, NDArray]], Dict[str, Union[np.ndarray, NDArray]]], + params: Union[List[Union[np.ndarray, Tensor]], Dict[str, Union[np.ndarray, Tensor]]], ): - """Load parameters from a dict or a list. Will convert parameters into tvm.runtime.NDArray + """Load parameters from a dict or a list. Will convert parameters into tvm.runtime.Tensor in self.device. Parameters ---------- - params : List[Union[np.ndarray, NDArray]], Dict[str, Union[np.ndarray, NDArray]] + params : List[Union[np.ndarray, Tensor]], Dict[str, Union[np.ndarray, Tensor]] The numerical value of the parameters. If params is a list, its length should be param_num. The value of parameters at the @@ -176,25 +176,25 @@ def load_params( f"The length of extern parameters is {len(params)}, which does not " f"match the number of parameters {self._param_num}" ) - self._params = [tvm.nd.array(v, self.device) for v in params] + self._params = [tvm.runtime.tensor(v, self.device) for v in params] elif isinstance(params, dict): for key, val in params.items(): if key not in self._param_name_to_pos: raise ValueError(f"Parameter {key} is not found in the model") - self._params[self._param_name_to_pos[key]] = tvm.nd.array(val, self.device) + self._params[self._param_name_to_pos[key]] = tvm.runtime.tensor(val, self.device) else: raise ValueError("The type of extern_params should be either list or dict") def load_states( self, - states: Union[List[Union[np.ndarray, NDArray]], Dict[str, Union[np.ndarray, NDArray]]], + states: Union[List[Union[np.ndarray, Tensor]], Dict[str, Union[np.ndarray, Tensor]]], ): - """Load model states from a dict or a list. Will convert states into tvm.runtime.NDArray + """Load model states from a dict or a list. Will convert states into tvm.runtime.Tensor in self.device. Parameters ---------- - states : List[Union[np.ndarray, NDArray]], Dict[str, Union[np.ndarray, NDArray]] + states : List[Union[np.ndarray, Tensor]], Dict[str, Union[np.ndarray, Tensor]] The numerical value of the model states. If states is a list, its length should be state_num. The value of states at the @@ -210,31 +210,31 @@ def load_states( f"The length of extern states is {len(states)}, which does not match " f"the number of model states {self._state_num}" ) - self._states = [tvm.nd.array(v, self.device) for v in states] + self._states = [tvm.runtime.tensor(v, self.device) for v in states] elif isinstance(states, dict): for key, val in states.items(): if key not in self._param_name_to_pos: raise ValueError(f"Parameter {key} is not found in the model") - self._states[self._param_name_to_pos[key]] = tvm.nd.array(val, self.device) + self._states[self._param_name_to_pos[key]] = tvm.runtime.tensor(val, self.device) else: raise ValueError("The type of extern_states should be either list or dict") - def export_params(self) -> Dict[str, NDArray]: - """Export parameters to a dict (parameter name -> NDArray). + def export_params(self) -> Dict[str, Tensor]: + """Export parameters to a dict (parameter name -> Tensor). Returns ------- - exported_dict : Dict[str, NDArray] + exported_dict : Dict[str, Tensor] The exported dictionary of parameters. """ return {key: self._params[pos] for key, pos in self._param_name_to_pos.items()} - def export_states(self) -> Dict[str, NDArray]: - """Export model states to a dict (parameter name -> NDArray). + def export_states(self) -> Dict[str, Tensor]: + """Export model states to a dict (parameter name -> Tensor). Returns ------- - exported_dict : Dict[str, NDArray] + exported_dict : Dict[str, Tensor] The exported dictionary of model states. """ return {key: self._states[pos] for key, pos in self._state_name_to_pos.items()} @@ -255,26 +255,28 @@ def _check_inited(self): "inference." ) - def predict(self, *input_instances: Union[np.ndarray, NDArray]) -> NDArray: + def predict(self, *input_instances: Union[np.ndarray, Tensor]) -> Tensor: """Call the `backbone` function and return the prediction result of the backbone. Parameters ---------- - *input_instances : Union[np.ndarray, NDArray] + *input_instances : Union[np.ndarray, Tensor] The values corresponding to the input_instances part of the backbone function. Parameters and model states are not needed to provide. Returns ------- - output : NDArray + output : Tensor The result of the backbone function. If the backbone contains model states, the updated states WILL NOT be returned. """ self._check_inited() if len(input_instances) != self._input_num: raise ValueError("The length of the input does not match the backbone") - all_inputs: List[NDArray] = ( - [tvm.nd.array(i, self.device) for i in input_instances] + self._params + self._states + all_inputs: List[Tensor] = ( + [tvm.runtime.tensor(i, self.device) for i in input_instances] + + self._params + + self._states ) res = self.vm[self.BACKBONE_FUNC](*all_inputs) @@ -287,9 +289,9 @@ def predict(self, *input_instances: Union[np.ndarray, NDArray]) -> NDArray: def update( self, - input_instances: Union[np.ndarray, NDArray, List[Union[np.ndarray, NDArray]]], - targets: Union[np.ndarray, NDArray, List[Union[np.ndarray, NDArray]]], - ) -> NDArray: + input_instances: Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]], + targets: Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]], + ) -> Tensor: """Update parameters and model states. It will calculate the gradients of parameters and update them using the `optimizer` function. @@ -298,21 +300,21 @@ def update( Parameters ---------- - input_instances : Union[np.ndarray, NDArray, List[Union[np.ndarray, NDArray]]] + input_instances : Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]] The values corresponding to the input_instances part of the backbone function. Parameters and model states are not needed to provide. If there are more than one input instances, you can provide a list. - targets : Union[np.ndarray, NDArray, List[Union[np.ndarray, NDArray]]] + targets : Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]] The values corresponding to the targets part of the backbone function. If there are more than one targets, you can provide a list. Returns ------- - loss : NDArray - The loss stored in tvm.runtime.NDArray. + loss : Tensor + The loss stored in tvm.runtime.Tensor. """ self._check_inited() @@ -325,11 +327,11 @@ def update( if len(input_instances) != self._input_num: raise ValueError("The length of the input does not match the backbone") - all_inputs: List[NDArray] = ( - [tvm.nd.array(i, self.device) for i in input_instances] + all_inputs: List[Tensor] = ( + [tvm.runtime.tensor(i, self.device) for i in input_instances] + self._params + self._states - + [tvm.nd.array(i, self.device) for i in targets] + + [tvm.runtime.tensor(i, self.device) for i in targets] ) ret, grads = self.vm[self.ADJOINT_FUNC](*all_inputs) @@ -348,21 +350,21 @@ def update( def profile_adjoint( self, - input_instances: List[Union[np.ndarray, NDArray]], - targets: List[Union[np.ndarray, NDArray]], + input_instances: List[Union[np.ndarray, Tensor]], + targets: List[Union[np.ndarray, Tensor]], ) -> tvm.runtime.profiling.Report: """Profile the adjoint function. It requires the VM to be constructed with `profile=True`, and runs `tvm.relax.VirtualMachine.profile()` internally. Parameters ---------- - input_instances : Union[np.ndarray, NDArray, List[Union[np.ndarray, NDArray]]] + input_instances : Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]] The values corresponding to the input_instances part of the backbone function. Parameters and model states are not needed to provide. If there are more than one input instances, you can provide a list. - targets : Union[np.ndarray, NDArray, List[Union[np.ndarray, NDArray]]] + targets : Union[np.ndarray, Tensor, List[Union[np.ndarray, Tensor]]] The values corresponding to the targets part of the backbone function. If there are more than one targets, you can provide a list. @@ -383,11 +385,11 @@ def profile_adjoint( if len(input_instances) != self._input_num: raise ValueError("The length of the input does not match the backbone") - all_inputs: List[NDArray] = ( - [tvm.nd.array(i) for i in input_instances] + all_inputs: List[Tensor] = ( + [tvm.runtime.tensor(i) for i in input_instances] + self._params + self._states - + [tvm.nd.array(i) for i in targets] + + [tvm.runtime.tensor(i) for i in targets] ) all_inputs = [i.copyto(self.device) for i in all_inputs] return self.vm.profile(self.ADJOINT_FUNC, *all_inputs) diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py index dd433435e278..a3fc836fc0a4 100644 --- a/python/tvm/relax/training/utils.py +++ b/python/tvm/relax/training/utils.py @@ -18,10 +18,10 @@ """Utility functions for relax training.""" from typing import Optional, Callable +from tvm_ffi import register_global_func import tvm from tvm import relax -from tvm.ffi.registry import register_func from tvm.relax.block_builder import BlockBuilder from ..expr import Function, Var, Call @@ -199,7 +199,7 @@ def handler( primfunc_name_hint=te_grad_name, ) - register_func(func_prefix + te_grad_name, handler) + register_global_func(func_prefix + te_grad_name, handler) return func return register(te_grad_func) if te_grad_func else register diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 724921e5fee7..dacbc667be2b 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -83,6 +83,7 @@ UpdateVDevice, VMBuiltinLower, VMShapeLower, + SpecializePrimFuncBasedOnCallSite, dataflowblock_pass, function_pass, ) diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py index 3c4387a3cbb8..25d6ecd75385 100644 --- a/python/tvm/relax/transform/_ffi_api.py +++ b/python/tvm/relax/transform/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.transform""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("relax.transform", __name__) +tvm_ffi.init_ffi_api("relax.transform", __name__) diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py index b4aba0291fc1..d4a681997b7a 100644 --- a/python/tvm/relax/transform/legalize_ops/__init__.py +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -31,3 +31,7 @@ from . import search from . import statistical from . import unary +from . import vision + +# Device specific legalizations +from . import adreno diff --git a/python/tvm/relax/transform/legalize_ops/adreno/__init__.py b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py new file mode 100644 index 000000000000..f2b3f4a781d2 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Legalize high-level operator calls in Relax functions to call_tir.""" +from .convolution import conv2d_NCHWc_OIHWo diff --git a/docker/Dockerfile.demo_cpu b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py similarity index 54% rename from docker/Dockerfile.demo_cpu rename to python/tvm/relax/transform/legalize_ops/adreno/convolution.py index 778d21ea781b..959e43778024 100644 --- a/docker/Dockerfile.demo_cpu +++ b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py @@ -14,22 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Convolution impl for Adreno GPU.""" -# Minimum docker image for demo purposes -# prebuilt-image: tvmai/demo-cpu -FROM tlcpack/ci-cpu:v0.55 +from tvm import relax +from tvm import topi -COPY utils/apt-install-and-clear.sh /usr/local/bin/apt-install-and-clear -# Jupyter notebook. -RUN pip3 install matplotlib Image Pillow jupyter[notebook] - -# Deep learning frameworks -RUN pip3 install tensorflow keras gluoncv dgl - -# Build TVM -COPY install/install_tvm_cpu.sh /install/install_tvm_cpu.sh -RUN bash /install/install_tvm_cpu.sh - -# Environment variables -ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/vta/python:${PYTHONPATH} +def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: + return bb.call_te( + topi.nn.conv2d_NCHWc_OIHWo, + data=call.args[0], + kernel=call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + layout=call.attrs.data_layout, + out_layout=call.attrs.out_layout, + # out_dtype=call.attrs.out_dtype, + sinfo_args=call.sinfo_args, + primfunc_name_hint="conv2d_NCHWc_OIHWo", + ) diff --git a/python/tvm/relax/transform/legalize_ops/image.py b/python/tvm/relax/transform/legalize_ops/image.py index 1b2a342b0b53..7a1c2e92cb33 100644 --- a/python/tvm/relax/transform/legalize_ops/image.py +++ b/python/tvm/relax/transform/legalize_ops/image.py @@ -37,3 +37,16 @@ def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr: bicubic_exclude=call.attrs.cubic_exclude, extrapolation_value=call.attrs.extrapolation_value, ) + + +@register_legalize("relax.image.grid_sample") +def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.image.grid_sample, + call.args[0], + call.args[1], + method=call.attrs.method, + layout=call.attrs.layout, + padding_mode=call.attrs.padding_mode, + align_corners=call.attrs.align_corners, + ) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py new file mode 100644 index 000000000000..f910f62cec64 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default legalization function for vision network related operators.""" +from tvm import topi, te +from tvm import relax +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): + """Create a proper NMS implementation that follows the correct algorithm""" + scores_shape = list(scores.shape) + if len(scores_shape) == 3: + batch, num_classes, _ = scores_shape + elif len(scores_shape) == 2: + num_classes, _ = scores_shape + batch = 1 + else: + raise ValueError(f"Unexpected scores shape: {scores_shape}") + + if hasattr(max_output_boxes_per_class, "data"): + max_boxes = int(max_output_boxes_per_class.data.numpy()) + else: + max_boxes = 3 # Default value + + expected_detections = batch * num_classes * max_boxes + + selected_indices_full, _ = topi.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ) + + def slice_to_onnx_shape(data, expected_size): + def compute_element(i, j): + return tvm.tir.if_then_else(i < expected_size, data[i, j], tvm.tir.Cast("int64", 0)) + + return te.compute((expected_size, 3), compute_element, name="sliced_indices") + + sliced_indices = slice_to_onnx_shape(selected_indices_full, expected_detections) + + actual_detections = te.compute( + (1,), lambda i: tvm.tir.Cast("int64", expected_detections), name="actual_detections" + ) + + return [sliced_indices, actual_detections] + + +@register_legalize("relax.vision.all_class_non_max_suppression") +def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr: + """Legalize all_class_non_max_suppression with fixed shape output. + + Note: This implementation outputs fixed-size tensors with trailing garbage data. + Only the first `num_total_detection` rows contain valid data. Users should use + the `valid_count` tensor to determine how many rows are actually valid. + + For complete ONNX compatibility, users can post-process the output: + ```python + selected_indices, valid_count = nms_output + actual_count = int(valid_count.numpy()[0]) + valid_indices = selected_indices.numpy()[:actual_count, :] + ``` + """ + boxes = call.args[0] + scores = call.args[1] + max_output_boxes_per_class = call.args[2] + iou_threshold = call.args[3] + score_threshold = call.args[4] + output_format = call.attrs.output_format + + scores_shape = scores.struct_info.shape + if len(scores_shape) == 3: + _, _, num_boxes = scores_shape + elif len(scores_shape) == 2: + _, num_boxes = scores_shape + else: + raise ValueError(f"Unexpected scores shape: {scores_shape}") + + if isinstance(max_output_boxes_per_class, relax.Constant): + max_boxes_val = int(max_output_boxes_per_class.data.numpy()) + else: + max_boxes_val = int(num_boxes) + + # Get NMS result with fixed shape from TOPI + nms_result = block_builder.call_te( + topi.vision.all_class_non_max_suppression, + boxes, + scores, + max_boxes_val, + iou_threshold, + score_threshold, + output_format, + ) + + # TODO: Implement dynamic output trimming for better memory efficiency + # Current approach returns fixed-size output with trailing garbage data + # Future improvements could include: + # 1. Dynamic strided_slice based on num_total_detections + # 2. Custom Relax operator with true dynamic shapes + # 3. VM builtin functions for runtime shape adjustment + # 4. Symbolic shape inference in Relax IR + # + # For now, users should trim manually: + # actual_count = int(num_total_detections.numpy()[0]) + # valid_indices = selected_indices.numpy()[:actual_count, :] + + return nms_result diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 57627ceebe66..46efc17e3d4f 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -23,12 +23,12 @@ from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np # type: ignore - +import tvm_ffi import tvm.ir from tvm.ir.container import Array from tvm.relax import Expr, Var, StructInfo from tvm.relax.dpl import DFPattern -from tvm.runtime import NDArray, Object +from tvm.runtime import Tensor, Object from tvm.tir import IndexMap, PrimFunc from . import _ffi_api @@ -36,14 +36,14 @@ from ..expr import Var -@tvm.ffi.register_object("relax.FunctionPass") +@tvm_ffi.register_object("relax.FunctionPass") class FunctionPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.Function in a module. A function pass class should be created through `function_pass`. """ -@tvm.ffi.register_object("relax.DataflowBlockPass") +@tvm_ffi.register_object("relax.DataflowBlockPass") class DataflowBlockPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.DataflowBlock in a module.""" @@ -219,7 +219,7 @@ def main_adjoint( # return value: (orig_return_values, tuple(adjoints)) return ((lv1, lv2), (x_adjoint, y_adjoint)) """ - if require_grads is not None and not isinstance(require_grads, list): + if require_grads is not None and not isinstance(require_grads, (list, tvm_ffi.Array)): require_grads = [require_grads] return _ffi_api.Gradient(func_name, require_grads, target_index) # type: ignore @@ -638,7 +638,7 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass: def BindParams( func_name: str, - params: Dict[Union[str, Var], Union[tvm.runtime.NDArray, np.ndarray]], + params: Dict[Union[str, Var], Union[tvm.runtime.Tensor, np.ndarray]], ) -> tvm.ir.transform.Pass: """Bind params of function of the module to constant tensors. @@ -647,7 +647,7 @@ def BindParams( func_name: str The function name to be bound - params: Dict[Union[str,relax.Var], Union[tvm.runtime.NDArray, np.ndarray]] + params: Dict[Union[str,relax.Var], Union[tvm.runtime.Tensor, np.ndarray]] The map from parameter or parameter name to constant tensors. Returns @@ -657,9 +657,9 @@ def BindParams( tvm_params = {} for k, v in params.items(): if isinstance(v, np.ndarray): - v = tvm.nd.array(v) - assert isinstance(v, (tvm.runtime.NDArray, tvm.relax.Constant)), ( - f"param values are expected to be TVM.NDArray," + v = tvm.runtime.tensor(v) + assert isinstance(v, (tvm.runtime.Tensor, tvm.relax.Constant)), ( + f"param values are expected to be TVM.Tensor," f"numpy.ndarray or tvm.relax.Constant, but got {type(v)}" ) tvm_params[k] = v @@ -820,7 +820,7 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() # type: ignore -@tvm.ffi.register_object("relax.transform.PatternCheckContext") +@tvm_ffi.register_object("relax.transform.PatternCheckContext") class PatternCheckContext(Object): """ The input of check function `FusionPattern.check`. @@ -854,7 +854,7 @@ class PatternCheckContext(Object): value_to_bound_var: Mapping[Expr, Var] -@tvm.ffi.register_object("relax.transform.FusionPattern") +@tvm_ffi.register_object("relax.transform.FusionPattern") class FusionPattern(Object): """ The pattern used by `FuseOpsByPattern`. It's mainly DFPattern but with other @@ -1062,7 +1062,9 @@ def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transfor def LegalizeOps( - customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, enable_warning: bool = False + customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, + skip_ops: Optional[List[str]] = None, + enable_warning: bool = False, ): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. @@ -1088,6 +1090,9 @@ def LegalizeOps( The customized operator legalization function map. The customized function will override the default one. + skip_ops : Optional,List[str]] + List of ops that need to be skipped from legalization + enable_warning : bool A boolean value indicating if to print warnings for CallNode whose op's legalization function is not registered. By default we don't print @@ -1167,7 +1172,7 @@ def multiply( T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] """ - return _ffi_api.LegalizeOps(customize_legalize_map, enable_warning) # type: ignore + return _ffi_api.LegalizeOps(customize_legalize_map, skip_ops, enable_warning) # type: ignore def RealizeVDevice() -> tvm.ir.transform.Pass: @@ -1223,7 +1228,7 @@ def MetaScheduleTuneTIR( def MetaScheduleTuneIRMod( - params: Dict[str, NDArray], + params: Dict[str, Tensor], work_dir: str, max_trials_global: int, max_trials_per_task: Optional[int] = None, @@ -1233,7 +1238,7 @@ def MetaScheduleTuneIRMod( Parameters ---------- - params: Dict[str, NDArray] + params: Dict[str, Tensor] model params work_dir: str work directory @@ -1605,6 +1610,19 @@ def AllocateWorkspace() -> tvm.ir.transform.Pass: return _ffi_api.AllocateWorkspace() # type: ignore +def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: + """This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + Primarily used to update the VDevice information if any changes occured from the caller. + This pass recreates the buffers and updates the map. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index 426695c9f1fe..ebf757f38136 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -16,13 +16,13 @@ # under the License. # pylint: disable=invalid-name, unused-import """The type nodes of the Relax language.""" -import tvm.ffi +import tvm_ffi from tvm.ir import Type, TupleType, FuncType, Span from . import _ffi_api -@tvm.ffi.register_object("relax.ShapeType") +@tvm_ffi.register_object("relax.ShapeType") class ShapeType(Type): """The type of shape in Relax. @@ -37,7 +37,7 @@ def __init__(self, ndim: int = -1, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore -@tvm.ffi.register_object("relax.ObjectType") +@tvm_ffi.register_object("relax.ObjectType") class ObjectType(Type): """A type that corresponds to tvm::runtime::Object, is base of all possible object values in TVM.""" @@ -46,7 +46,7 @@ def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore -@tvm.ffi.register_object("relax.DynTensorType") +@tvm_ffi.register_object("relax.DynTensorType") class TensorType(Type): """A dynamic tensor type in Relax. @@ -65,7 +65,7 @@ def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.TensorType, ndim, dtype, span) # type: ignore -@tvm.ffi.register_object("relax.PackedFuncType") +@tvm_ffi.register_object("relax.PackedFuncType") class PackedFuncType(Type): """The type of ExternFunc in Relax.""" diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 192235d595d0..76897eefd707 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -26,6 +26,7 @@ from typing import Any, Callable, List, Dict, Optional import tvm +import tvm_ffi from .. import tir from ..tir import PrimExpr from . import _ffi_api @@ -99,7 +100,7 @@ def convert_to_expr(value: Any) -> Expr: if isinstance(value, float): return PrimValue(tir.FloatImm("float64", value)) - tvm_value = tvm.ffi.convert(value) + tvm_value = tvm_ffi.convert(value) # Case 1 if isinstance(tvm_value, Expr): # type: ignore return tvm_value @@ -346,6 +347,7 @@ def _shape_with_old_tir_var( ) primfunc_attrs = kwargs.pop("primfunc_attrs", None) + custom_out_sinfo = kwargs.pop("sinfo_args", []) te_args = _convert_te_arg(args) te_kwargs = _convert_te_arg(kwargs) @@ -370,14 +372,17 @@ def _shape_with_old_tir_var( # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - output_sinfo = [ - TensorStructInfo( - _shape_with_old_tir_var(out.shape, tir_var_inverse_map), - out.dtype, - _get_vdevice(args), - ) - for out in outs - ] + if len(custom_out_sinfo) == 1: + output_sinfo = custom_out_sinfo[0] + else: + output_sinfo = [ + TensorStructInfo( + _shape_with_old_tir_var(out.shape, tir_var_inverse_map), + out.dtype, + _get_vdevice(args), + ) + for out in outs + ] tir_vars = None if len(unbound_tir_vars) > 0: diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py index 3b77e7a552e3..80fd79e31348 100644 --- a/python/tvm/rpc/_ffi_api.py +++ b/python/tvm/rpc/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.rpc""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("rpc", __name__) +tvm_ffi.init_ffi_api("rpc", __name__) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 0bb4e8cb7d29..73e9db3d5b60 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -22,11 +22,12 @@ import struct import time -import tvm.ffi +import tvm_ffi +from tvm_ffi import DLDeviceType + +import tvm.runtime from tvm.base import TVMError from tvm.contrib import utils -from tvm.runtime import ndarray as nd -from tvm.runtime import Device from . import _ffi_api, base, server @@ -86,9 +87,9 @@ def device(self, dev_type, dev_id=0): dev: Device The corresponding encoded remote device. """ - dev = nd.device(dev_type, dev_id) + dev = tvm.runtime.device(dev_type, dev_id) encode = (self._tbl_index + 1) * base.RPC_SESS_MASK - dev = nd.device(dev.device_type + encode, dev.device_id) + dev = tvm.runtime.device(dev.dlpack_device_type() + encode, dev.index) dev._rpc_sess = self return dev @@ -216,39 +217,39 @@ def download_linked_module(self, path): def cpu(self, dev_id=0): """Construct CPU device.""" - return self.device(Device.kDLCPU, dev_id) + return self.device(DLDeviceType.kDLCPU, dev_id) def cuda(self, dev_id=0): """Construct CUDA GPU device.""" - return self.device(Device.kDLCUDA, dev_id) + return self.device(DLDeviceType.kDLCUDA, dev_id) def cl(self, dev_id=0): """Construct OpenCL device.""" - return self.device(Device.kDLOpenCL, dev_id) + return self.device(DLDeviceType.kDLOpenCL, dev_id) def vulkan(self, dev_id=0): """Construct Vulkan device.""" - return self.device(Device.kDLVulkan, dev_id) + return self.device(DLDeviceType.kDLVulkan, dev_id) def metal(self, dev_id=0): """Construct Metal device.""" - return self.device(Device.kDLMetal, dev_id) + return self.device(DLDeviceType.kDLMetal, dev_id) def rocm(self, dev_id=0): """Construct ROCm device.""" - return self.device(Device.kDLROCM, dev_id) + return self.device(DLDeviceType.kDLROCM, dev_id) def ext_dev(self, dev_id=0): """Construct extension device.""" - return self.device(Device.kDLExtDev, dev_id) + return self.device(DLDeviceType.kDLExtDev, dev_id) def hexagon(self, dev_id=0): """Construct Hexagon device.""" - return self.device(Device.kDLHexagon, dev_id) + return self.device(DLDeviceType.kDLHexagon, dev_id) def webgpu(self, dev_id=0): """Construct WebGPU device.""" - return self.device(Device.kDLWebGPU, dev_id) + return self.device(DLDeviceType.kDLWebGPU, dev_id) class LocalSession(RPCSession): @@ -263,7 +264,7 @@ def __init__(self): RPCSession.__init__(self, _ffi_api.LocalSession()) -@tvm.ffi.register_func("rpc.PopenSession") +@tvm_ffi.register_global_func("rpc.PopenSession") def _popen_session(binary): temp = utils.tempdir() diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py index 2e46965a2050..5dcaffba0b4b 100644 --- a/python/tvm/rpc/minrpc.py +++ b/python/tvm/rpc/minrpc.py @@ -16,6 +16,7 @@ # under the License. """Utils to path.""" import os +import tvm_ffi from tvm import libinfo from tvm.contrib import cc @@ -65,17 +66,20 @@ def with_minrpc(compile_func, server="posix_popen_server", runtime="libtvm"): """ minrpc_dir, server_path = find_minrpc_server_libpath(server) runtime_path = libinfo.find_lib_path([runtime, runtime + ".so", runtime + ".dylib"])[0] + tvm_ffi_path = tvm_ffi.libinfo.find_libtvm_ffi() runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) + tvm_ffi_dir = os.path.abspath(os.path.dirname(tvm_ffi_path)) options = ["-std=c++17"] # Make sure the rpath to the libtvm is set so we can do local tests. # Note that however, this approach won't work on remote. # Always recommend to link statically. options += ["-Wl,-rpath=" + runtime_dir] + options += ["-Wl,-rpath=" + tvm_ffi_dir] options += ["-I" + path for path in libinfo.find_include_path()] options += ["-I" + minrpc_dir] fcompile = cc.cross_compiler( - compile_func, options=options, add_files=[server_path, runtime_path] + compile_func, options=options, add_files=[server_path, runtime_path, tvm_ffi_path] ) fcompile.__name__ = "with_minrpc" fcompile.need_system_lib = True diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index eb345260e300..3ed512e9dd04 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -36,7 +36,7 @@ import time import errno import sys -import tvm.ffi +import tvm_ffi from tvm.base import py_str from tvm.libinfo import find_lib_path @@ -70,11 +70,11 @@ def _server_env(load_library, work_path=None): temp = utils.tempdir() # pylint: disable=unused-variable - @tvm.ffi.register_func("tvm.rpc.server.workpath", override=True) + @tvm_ffi.register_global_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) - @tvm.ffi.register_func("tvm.rpc.server.load_module", override=True) + @tvm_ffi.register_global_func("tvm.rpc.server.load_module", override=True) def load_module(file_name): """Load module from remote side.""" path = temp.relpath(file_name) @@ -82,7 +82,7 @@ def load_module(file_name): logger.info("load_module %s", path) return m - @tvm.ffi.register_func("tvm.rpc.server.download_linked_module", override=True) + @tvm_ffi.register_global_func("tvm.rpc.server.download_linked_module", override=True) def download_linked_module(file_name): """Load module from remote side.""" # pylint: disable=import-outside-toplevel @@ -488,7 +488,7 @@ def server_init_callback(): # must import mypackage here import mypackage - tvm.register_func("function", mypackage.func) + tvm.register_global_func("function", mypackage.func) server = rpc.Server(host, server_init_callback=server_init_callback) """ diff --git a/python/tvm/rpc/testing.py b/python/tvm/rpc/testing.py index ba88c2048443..e3f216563863 100644 --- a/python/tvm/rpc/testing.py +++ b/python/tvm/rpc/testing.py @@ -22,41 +22,41 @@ # RPC test functions to be registered for unit-tests purposes -@tvm.register_func("rpc.test.addone") +@tvm.register_global_func("rpc.test.addone") def _addone(x): return x + 1 -@tvm.register_func("rpc.test.strcat") +@tvm.register_global_func("rpc.test.strcat") def _strcat(name, x): return f"{name}:{x}" -@tvm.register_func("rpc.test.except") +@tvm.register_global_func("rpc.test.except") def _remotethrow(name): raise ValueError(f"{name}") -@tvm.register_func("rpc.test.runtime_str_concat") +@tvm.register_global_func("rpc.test.runtime_str_concat") def _strcat(x, y): return x + y -@tvm.register_func("rpc.test.remote_array_func") -def _remote_array_func(y): +@tvm.register_global_func("rpc.test.remote_tensor_func") +def _remote_tensor_func(y): x = np.ones((3, 4)) np.testing.assert_equal(y.numpy(), x) -@tvm.register_func("rpc.test.add_to_lhs") +@tvm.register_global_func("rpc.test.add_to_lhs") def _add_to_lhs(x): return lambda y: x + y -@tvm.register_func("rpc.test.remote_return_nd") +@tvm.register_global_func("rpc.test.remote_return_nd") def _my_module(name): # Use closure to check the ref counter correctness - nd = tvm.nd.array(np.zeros(10).astype("float32")) + nd = tvm.runtime.tensor(np.zeros(10).astype("float32")) if name == "get_arr": return lambda: nd diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index ca70cf0f45a7..4c61e2e06b3a 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -16,20 +16,23 @@ # under the License. """TVM runtime namespace.""" +from tvm_ffi import convert +from tvm_ffi._dtype import dtype as DataType, DataTypeCode + # class exposures from .packed_func import PackedFunc from .object import Object from .script_printer import Scriptable -from .object_generic import ObjectGeneric +from .object_generic import ObjectConvertible from .device import Device -from .ndarray import NDArray +from ._tensor import Tensor, tensor, empty from .module import Module from .profiling import Report from .executable import Executable # function exposures -from .ndarray import device, cpu, cuda, opencl, vulkan, metal -from .ndarray import vpi, rocm, ext_dev +from ._tensor import device, cpu, cuda, opencl, vulkan, metal +from ._tensor import vpi, rocm, ext_dev, from_dlpack from .module import load_module, enabled, system_lib, load_static_library, num_threads from .container import String, ShapeTuple from .object_generic import const @@ -43,4 +46,3 @@ from . import disco from .support import _regex_match -from ..ffi import convert, dtype as DataType, DataTypeCode diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py index 88a49f3a63d9..c713b379c384 100644 --- a/python/tvm/runtime/_ffi_api.py +++ b/python/tvm/runtime/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.runtime""" -import tvm.ffi +import tvm_ffi # Exports functions registered in runtime namespace. -tvm.ffi._init_api("runtime", __name__) +tvm_ffi.init_ffi_api("runtime", __name__) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 4a0edd449c24..a4f74864aa2d 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -17,13 +17,13 @@ # pylint: disable=invalid-name, unused-argument """FFI for tvm.node""" -import tvm.ffi -import tvm.ffi.core +import tvm_ffi +import tvm_ffi.core # The implementations below are default ones when the corresponding # functions are not available in the runtime only mode. -# They will be overriden via _init_api to the ones registered +# They will be overriden via tvm_ffi.init_ffi_api to the ones registered def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" @@ -37,4 +37,4 @@ def LoadJSON(json_str): # Exports functions registered in node namespace. -tvm.ffi._init_api("node", __name__) +tvm_ffi.init_ffi_api("node", __name__) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/_tensor.py similarity index 69% rename from python/tvm/runtime/ndarray.py rename to python/tvm/runtime/_tensor.py index 1d960d5dda4a..3affbf55d563 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/_tensor.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-import, redefined-outer-name -"""Runtime NDArray API""" +"""Runtime Tensor API""" import ctypes import warnings from typing import Optional @@ -27,52 +27,41 @@ except ImportError: ml_dtypes = None -from tvm.runtime import Device +import tvm_ffi +from tvm_ffi import device, DLDeviceType -import tvm.ffi +import tvm +from tvm.runtime import Device from . import _ffi_api -from ..ffi import ( - device, - cpu, - cuda, - rocm, - opencl, - metal, - vpi, - vulkan, - ext_dev, - hexagon, - webgpu, -) - - def from_dlpack(ext_tensor): """ - Convert an external tensor to an NDArray. + Convert an external tensor to an Tensor. Parameters ---------- ext_tensor : object The external tensor to convert. - required_alignment : int + require_alignment : int The minimum required alignment to check for the tensor. - required_contiguous : bool + require_contiguous : bool Whether to check for contiguous memory. """ - return tvm.ffi.from_dlpack( + # TODO(tvm-team): change to require_alignment=0 and require_contiguous=False + # once we update the compiler generated code to guard against misaligned access. + return tvm_ffi.from_dlpack( ext_tensor, - required_alignment=64, - required_contiguous=True, + require_alignment=64, + require_contiguous=True, ) -@tvm.ffi.register_object("ffi.NDArray") -class NDArray(tvm.ffi.core.NDArray): - """Lightweight NDArray class of TVM runtime. +@tvm_ffi.register_object("ffi.Tensor") +class Tensor(tvm_ffi.core.Tensor): + """Lightweight Tensor class of TVM runtime. Strictly this is only an Array Container (a buffer object) No arthimetic operations are defined. @@ -91,7 +80,7 @@ def __setitem__(self, in_slice, value): or in_slice.stop is not None ): raise ValueError("Array only support set from numpy array") - if isinstance(value, NDArray): + if isinstance(value, Tensor): if not value.same_as(self): value.copyto(self) elif isinstance(value, (np.ndarray, np.generic)): @@ -109,10 +98,10 @@ def copyfrom(self, source_array): Returns ------- - arr : NDArray + arr : Tensor Reference to self. """ - if isinstance(source_array, NDArray): + if isinstance(source_array, Tensor): source_array.copyto(self) return self @@ -124,7 +113,7 @@ def copyfrom(self, source_array): f"array must be an array_like data, type {type(source_array)} is not supported" ) - t = tvm.ffi.dtype(self.dtype) + t = tvm_ffi.dtype(self.dtype) shape, dtype = self.shape, self.dtype if t.lanes > 1: shape = shape + (t.lanes,) @@ -133,9 +122,9 @@ def copyfrom(self, source_array): if source_array.shape != shape: raise ValueError( - f"array shape do not match the shape of NDArray {source_array.shape} vs {shape}" + f"array shape do not match the shape of Tensor {source_array.shape} vs {shape}" ) - numpy_str_map = tvm.ffi.dtype.NUMPY_DTYPE_TO_STR + numpy_str_map = tvm_ffi.dtype._NUMPY_DTYPE_TO_STR np_dtype_str = ( numpy_str_map[source_array.dtype] if source_array.dtype in numpy_str_map @@ -160,14 +149,14 @@ def copyfrom(self, source_array): assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = source_array.size * source_array.dtype.itemsize - _ffi_api.TVMArrayCopyFromBytes(self, data, nbytes) + _ffi_api.TVMTensorCopyFromBytes(self, data, nbytes) return self def __repr__(self): # exception safety handling for chandle=None if self.__chandle__() == 0: return type(self).__name__ + "(chandle=None)" - res = f"\n" + res = f"\n" res += self.numpy().__repr__() return res @@ -182,7 +171,7 @@ def numpy(self): np_arr : numpy.ndarray The corresponding numpy array. """ - t = tvm.ffi.dtype(self.dtype) + t = tvm_ffi.dtype(self.dtype) shape, dtype = self.shape, self.dtype old_dtype = dtype if t.lanes > 1: @@ -219,7 +208,7 @@ def numpy(self): # TODO(kathy): revisit and get a mirrored function of ffi::GetDataSize # in Python to replace line below nbytes = np_arr.size if dtype == "bool" else (np_arr.size * old_dtype.bits + 7) // 8 - _ffi_api.TVMArrayCopyToBytes(self, data, nbytes) + _ffi_api.TVMTensorCopyToBytes(self, data, nbytes) if old_dtype == "int4" or old_dtype.startswith("float4_e2m1fn"): length = np_arr.size @@ -239,22 +228,22 @@ def copyto(self, target, mem_scope=None): Parameters ---------- - target : NDArray + target : Tensor The target array to be copied, must have same shape as this array. mem_scope : Optional[str] The memory scope of the array. """ - if isinstance(target, NDArray): + if isinstance(target, Tensor): return self._copyto(target) - if isinstance(target, tvm.ffi.core.Device): + if isinstance(target, tvm_ffi.core.Device): res = empty(self.shape, self.dtype, target, mem_scope) return self._copyto(res) raise ValueError(f"Unsupported target type {type(target)}") def _copyto(self, target_nd): """Internal function that implements copy to target ndarray.""" - _ffi_api.TVMArrayCopyFromTo(self, target_nd) + _ffi_api.TVMTensorCopyFromTo(self, target_nd) return target_nd def _create_view(self, shape, dtype: Optional[str] = None, relative_byte_offset: int = 0): @@ -302,7 +291,7 @@ def _create_view(self, shape, dtype: Optional[str] = None, relative_byte_offset: if dtype is None: dtype = self.dtype - return _ffi_api.TVMArrayCreateView(self, shape, dtype, relative_byte_offset) + return _ffi_api.TVMTensorCreateView(self, shape, dtype, relative_byte_offset) def empty(shape, dtype="float32", device=None, mem_scope=None): @@ -324,19 +313,19 @@ def empty(shape, dtype="float32", device=None, mem_scope=None): Returns ------- - arr : tvm.nd.NDArray + arr : tvm.runtime.Tensor The array tvm supported. """ device = device or cpu() if not isinstance(shape, tvm.runtime.ShapeTuple): shape = tvm.runtime.ShapeTuple([int(dim) for dim in shape]) - dtype = tvm.ffi.dtype(dtype) - arr = _ffi_api.TVMArrayAllocWithScope(shape, dtype, device, mem_scope) + dtype = tvm_ffi.dtype(dtype) + arr = _ffi_api.TVMTensorAllocWithScope(shape, dtype, device, mem_scope) return arr -def array(arr, device=None, mem_scope=None): - """Create an array from source arr. +def tensor(arr, device=None, mem_scope=None): + """Create an tensor from source arr. Parameters ---------- @@ -351,15 +340,180 @@ def array(arr, device=None, mem_scope=None): Returns ------- - ret : NDArray + ret : Tensor The created array """ device = device or cpu() - if not isinstance(arr, (np.ndarray, NDArray)): + if not isinstance(arr, (np.ndarray, Tensor)): arr = np.array(arr) return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr) +def cpu(dev_id=0): + """Construct a CPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLCPU, dev_id) + + +def cuda(dev_id=0): + """Construct a CUDA GPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLCUDA, dev_id) + + +def rocm(dev_id=0): + """Construct a ROCM device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLROCM, dev_id) + + +def opencl(dev_id=0): + """Construct a OpenCL device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLOpenCL, dev_id) + + +def metal(dev_id=0): + """Construct a metal device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLMetal, dev_id) + + +def vpi(dev_id=0): + """Construct a VPI simulated device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLVPI, dev_id) + + +def vulkan(dev_id=0): + """Construct a Vulkan device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLVulkan, dev_id) + + +def ext_dev(dev_id=0): + """Construct a extension device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + + Note + ---- + This API is reserved for quick testing of new + device by plugin device API as ext_dev. + """ + return device(DLDeviceType.kDLExtDev, dev_id) + + +def hexagon(dev_id=0): + """Construct a Hexagon device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLHexagon, dev_id) + + +def webgpu(dev_id=0): + """Construct a webgpu device. + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + dev : Device + The created device + """ + return device(DLDeviceType.kDLWebGPU, dev_id) + + # Register back to FFI -tvm.ffi.core._set_class_ndarray(NDArray) +tvm_ffi.core._set_class_tensor(Tensor) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 3bf149d6b2af..f9ddb5e51206 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. """Runtime container structures.""" -from tvm.ffi import String, Shape as ShapeTuple +from tvm_ffi.core import String +from tvm_ffi import Shape as ShapeTuple __all__ = ["ShapeTuple", "String"] diff --git a/python/tvm/runtime/device.py b/python/tvm/runtime/device.py index d9d6abce50fa..b8a3db15f30e 100644 --- a/python/tvm/runtime/device.py +++ b/python/tvm/runtime/device.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name import json -import tvm.ffi +import tvm_ffi from . import _ffi_api @@ -26,7 +26,7 @@ RPC_SESS_MASK = 128 -class Device(tvm.ffi.core.Device): +class Device(tvm_ffi.core.Device): """TVM device strucure.""" def _GetDeviceAttr(self, device_type, device_id, attr_id): @@ -48,7 +48,7 @@ def exist(self): True if the device exists """ - return self._GetDeviceAttr(self.device_type, self.device_id, 0) != 0 + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 0) != 0 @property def max_threads_per_block(self): @@ -64,7 +64,7 @@ def max_threads_per_block(self): The number of threads on each block """ - return self._GetDeviceAttr(self.device_type, self.device_id, 1) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 1) @property def warp_size(self): @@ -81,7 +81,7 @@ def warp_size(self): Number of threads that execute concurrently """ - return self._GetDeviceAttr(self.device_type, self.device_id, 2) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 2) @property def max_shared_memory_per_block(self): @@ -97,7 +97,7 @@ def max_shared_memory_per_block(self): Total amount of shared memory per block in bytes """ - return self._GetDeviceAttr(self.device_type, self.device_id, 3) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 3) @property def compute_version(self): @@ -116,7 +116,7 @@ def compute_version(self): The version string in `major.minor` format. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 4) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 4) @property def device_name(self): @@ -132,7 +132,7 @@ def device_name(self): The name of the device. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 5) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 5) @property def max_clock_rate(self): @@ -148,7 +148,7 @@ def max_clock_rate(self): The maximum clock frequency of the device (kHz) """ - return self._GetDeviceAttr(self.device_type, self.device_id, 6) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 6) @property def multi_processor_count(self): @@ -164,7 +164,7 @@ def multi_processor_count(self): Thee number of compute units in the device """ - return self._GetDeviceAttr(self.device_type, self.device_id, 7) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 7) @property def max_thread_dimensions(self): @@ -180,7 +180,7 @@ def max_thread_dimensions(self): The maximum length of threadIdx.x, threadIdx.y, threadIdx.z """ - return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) + return json.loads(self._GetDeviceAttr(self.dlpack_device_type(), self.index, 8)) @property def api_version(self): @@ -199,7 +199,7 @@ def api_version(self): The version of the SDK """ - return self._GetDeviceAttr(self.device_type, self.device_id, 11) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 11) @property def driver_version(self): @@ -218,7 +218,7 @@ def driver_version(self): The version string in `major.minor.patch` format. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 12) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 12) @property def l2_cache_size_bytes(self): @@ -236,7 +236,7 @@ def l2_cache_size_bytes(self): ---- The value returned by opencl's API is smaller than actual device L2 cache size. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 13) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 13) @property def total_global_memory(self): @@ -250,7 +250,7 @@ def total_global_memory(self): Return the total size of global memory on device in bytes. Return None if the device does not support this feature. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 14) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 14) @property def available_global_memory(self): @@ -264,7 +264,7 @@ def available_global_memory(self): Return the amount of unallocated global memory on device in bytes. Return None if the device does not support this feature. """ - return self._GetDeviceAttr(self.device_type, self.device_id, 15) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 15) def texture_spatial_limit(self): """Returns limits for textures by spatial dimensions @@ -275,7 +275,7 @@ def texture_spatial_limit(self): Maximum size of the texture by spatial dimensions """ - return self._GetDeviceAttr(self.device_type, self.device_id, 12) + return self._GetDeviceAttr(self.dlpack_device_type(), self.index, 12) def create_raw_stream(self): """Create a new runtime stream at the context. @@ -319,19 +319,12 @@ def sync(self, stream=None): """ _ffi_api.Device_StreamSync(self, stream or 0) - def _device_type_name_(self): - if self.device_type >= RPC_SESS_MASK: - tbl_id = self.device_type / RPC_SESS_MASK - 1 - dev_type = self.device_type % RPC_SESS_MASK - return f"remote[{tbl_id}]:{Device.DEVICE_TYPE_TO_NAME[dev_type]}" - return Device.DEVICE_TYPE_TO_NAME[self.device_type] - def __device_type_name__(self): - if self.device_type >= RPC_SESS_MASK: - tbl_id = self.device_type / RPC_SESS_MASK - 1 - dev_type = self.device_type % RPC_SESS_MASK - return f"remote[{tbl_id}]:{Device.DEVICE_TYPE_TO_NAME[dev_type]}" - return Device.DEVICE_TYPE_TO_NAME[self.device_type] + if self.dlpack_device_type() >= RPC_SESS_MASK: + tbl_id = self.dlpack_device_type() / RPC_SESS_MASK - 1 + dev_type = self.dlpack_device_type() % RPC_SESS_MASK + return f"remote[{tbl_id}]:{Device._DEVICE_TYPE_TO_NAME[dev_type]}" + return Device._DEVICE_TYPE_TO_NAME[self.dlpack_device_type()] -tvm.ffi.core._set_class_device(Device) +tvm_ffi.core._set_class_device(Device) diff --git a/python/tvm/runtime/disco/_ffi_api.py b/python/tvm/runtime/disco/_ffi_api.py index 79e1a52ad44e..2caeef293ea5 100644 --- a/python/tvm/runtime/disco/_ffi_api.py +++ b/python/tvm/runtime/disco/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs from C++""" -from ...ffi import _init_api +import tvm_ffi -_init_api("runtime.disco", __name__) +tvm_ffi.init_ffi_api("runtime.disco", __name__) diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py index 8f05f28e9158..975c26fb922f 100644 --- a/python/tvm/runtime/disco/process_pool.py +++ b/python/tvm/runtime/disco/process_pool.py @@ -20,7 +20,7 @@ import subprocess import sys -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.runtime import ShapeTuple @@ -177,7 +177,7 @@ def _kill_child_processes(pid): pass -@register_func("runtime.disco.create_process_pool") +@register_global_func("runtime.disco.create_process_pool") def _create_process_pool(num_workers: int, num_groups: int, entrypoint: str): """Create a process pool where the workers' are [1, num_workers).""" pool = [DiscoPopenWorker(i, num_workers, num_groups, entrypoint) for i in range(1, num_workers)] diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 49449a451a12..f2c2dfc791ab 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -25,11 +25,11 @@ import numpy as np -from ...ffi import get_global_func, register_func, register_object +from tvm_ffi import get_global_func, register_global_func, register_object from ..device import Device from ..container import ShapeTuple -from ..ndarray import NDArray -from ..ndarray import array as _as_NDArray +from .._tensor import Tensor +from .._tensor import tensor as _as_Tensor from ..object import Object from . import _ffi_api, process_pool # pylint: disable=unused-import @@ -58,20 +58,20 @@ def debug_get_from_remote(self, worker_id: int) -> Any: def debug_copy_from( self, worker_id: int, - value: Union[np.ndarray, NDArray], + value: Union[np.ndarray, Tensor], ) -> None: - """Copy an NDArray value to remote for debugging purposes. + """Copy an Tensor value to remote for debugging purposes. Parameters ---------- worker_id : int The id of the worker to be copied to. - value : Union[numpy.ndarray, NDArray] + value : Union[numpy.ndarray, Tensor] The value to be copied. """ - if not isinstance(value, NDArray): - value = _as_NDArray(value) + if not isinstance(value, Tensor): + value = _as_Tensor(value) return _ffi_api.DRefDebugCopyFrom(self, worker_id, value) # type: ignore # pylint: disable=no-member @@ -122,18 +122,18 @@ def empty( worker0_only: bool = False, in_group: bool = True, ) -> DRef: - """Create an empty NDArray on all workers and attach them to a DRef. + """Create an empty Tensor on all workers and attach them to a DRef. Parameters ---------- shape : tuple of int - The shape of the NDArray. + The shape of the Tensor. dtype : str - The data type of the NDArray. + The data type of the Tensor. device : Optional[Device] = None - The device of the NDArray. + The device of the Tensor. worker0_only: bool If False (default), allocate an array on each worker. If @@ -147,7 +147,7 @@ def empty( Returns ------- array : DRef - The created NDArray. + The created Tensor. """ func = self._get_cached_method("runtime.disco.empty") @@ -217,7 +217,7 @@ def call_packed(self, func: DRef, *args) -> DRef: Notes ----- Examples of unsupported types: - - NDArray, DLTensor,; + - Tensor, DLTensor,; - TVM Objects, including PackedFunc, Module and String. """ return _ffi_api.SessionCallPacked(self, 0, 0, func, *args) # type: ignore # pylint: disable=no-member @@ -246,29 +246,29 @@ def sync_worker_0(self) -> None: executing all the existing instructions.""" return self._sync_worker(0) - def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: - """Copy an NDArray from worker-0 to the controller-side NDArray. + def copy_from_worker_0(self, host_array: Tensor, remote_array: DRef) -> None: + """Copy an Tensor from worker-0 to the controller-side Tensor. Parameters ---------- host_array : numpy.ndarray The array to be copied to worker-0. - remote_array : NDArray - The NDArray on worker-0. + remote_array : Tensor + The Tensor on worker-0. """ return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member - def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = None) -> DRef: - """Copy the controller-side NDArray to worker-0. + def copy_to_worker_0(self, host_array: Tensor, remote_array: Optional[DRef] = None) -> DRef: + """Copy the controller-side Tensor to worker-0. Parameters ---------- - host_array : NDArray + host_array : Tensor The array to be copied to worker-0. remote_array : Optiona[DRef] - The destination NDArray on worker-0. + The destination Tensor on worker-0. Returns ------- @@ -329,7 +329,7 @@ def init_ccl(self, ccl: str, *device_ids): def broadcast( self, - src: Union[np.ndarray, NDArray], + src: Union[np.ndarray, Tensor], dst: Optional[DRef] = None, in_group: bool = True, ) -> DRef: @@ -337,7 +337,7 @@ def broadcast( Parameters ---------- - src: Union[np.ndarray, NDArray] + src: Union[np.ndarray, Tensor] The array to be broadcasted. dst: Optional[DRef] @@ -356,8 +356,8 @@ def broadcast( `dst`. Otherwise, it is the newly allocated space. """ - if not isinstance(src, NDArray): - src = _as_NDArray(src) + if not isinstance(src, Tensor): + src = _as_Tensor(src) if dst is None: dst = self.empty(src.shape, src.dtype) @@ -372,7 +372,7 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool = True) -> Parameters ---------- - src: Union[np.ndarray, NDArray] + src: Union[np.ndarray, Tensor] The array to be broadcasted. dst: Optional[DRef] @@ -387,7 +387,7 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef, in_group: bool = True) -> def scatter( self, - src: Union[np.ndarray, NDArray], + src: Union[np.ndarray, Tensor], dst: Optional[DRef] = None, in_group: bool = True, ) -> DRef: @@ -395,7 +395,7 @@ def scatter( Parameters ---------- - src: Union[np.ndarray, NDArray] + src: Union[np.ndarray, Tensor] The array to be scattered. The first dimension of this array, `src.shape[0]`, must be equal to the number of workers. @@ -419,8 +419,8 @@ def scatter( """ assert src.shape[0] == self.num_workers - if not isinstance(src, NDArray): - src = _as_NDArray(src) + if not isinstance(src, Tensor): + src = _as_Tensor(src) if dst is None: dst = self.empty(src.shape[1:], src.dtype) @@ -435,7 +435,7 @@ def scatter_from_worker0(self, from_array: DRef, to_array: DRef, in_group: bool Parameters ---------- - src: Union[np.ndarray, NDArray] + src: Union[np.ndarray, Tensor] The array to be scattered. The first dimension of this array, `src.shape[0]`, must be equal to the number of workers. @@ -583,7 +583,7 @@ def _configure_structlog(self) -> None: func(config, os.getpid()) -@register_func("runtime.disco.create_socket_session_local_workers") +@register_global_func("runtime.disco.create_socket_session_local_workers") def _create_socket_session_local_workers(num_workers) -> Session: """Create the local session for each distributed node over socket session.""" return ProcessSession(num_workers) @@ -611,7 +611,7 @@ def __init__( ) -@register_func("runtime.disco._configure_structlog") +@register_global_func("runtime.disco._configure_structlog") def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: """Configure structlog for all disco workers @@ -646,7 +646,7 @@ def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: structlog.configure(**structlog_config) -@register_func("runtime.disco._import_python_module") +@register_global_func("runtime.disco._import_python_module") def _import_python_module(module_name: str) -> None: __import__(module_name) diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py index 47c46959be28..a57c1b623183 100644 --- a/python/tvm/runtime/executable.py +++ b/python/tvm/runtime/executable.py @@ -39,7 +39,7 @@ def __getitem__(self, name: str) -> PackedFunc: def __call__(self, *args, **kwargs) -> Any: """Call the executable.""" - return self.jit().entry_func(*args, **kwargs) + return self.jit().main(*args, **kwargs) def jit( self, diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 9cbc06708bd0..71b3bdd94b64 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -22,17 +22,18 @@ from typing import Sequence import numpy as np -from tvm.base import _RUNTIME_ONLY -from tvm.libinfo import find_include_path - -from . import _ffi_api -from ..ffi import ( +from tvm_ffi import ( Module as _Module, load_module as _load_module, register_object as _register_object, system_lib, ) +from tvm.base import _RUNTIME_ONLY +from tvm.libinfo import find_include_path + +from . import _ffi_api + class BenchmarkResult: """Runtimes from benchmarking""" @@ -376,8 +377,8 @@ def time_evaluator( feval = _ffi_api.RPCTimeEvaluator( self, func_name, - dev.device_type, - dev.device_id, + dev.dlpack_device_type(), + dev.index, number, repeat, min_repeat_ms, diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index b2fcddc40ad6..c9dcf2d1a8ed 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -17,11 +17,11 @@ # pylint: disable=invalid-name, unused-import """Runtime Object API""" -from tvm.ffi.core import Object -import tvm.ffi.core +from tvm_ffi.core import Object +import tvm_ffi.core from . import _ffi_node_api -tvm.ffi.core._set_class_object(Object) -# override the default repr function for tvm.ffi.core.Object -tvm.ffi.core.__object_repr__ = _ffi_node_api.AsRepr +tvm_ffi.core._set_class_object(Object) +# override the default repr function for tvm_ffi.core.Object +tvm_ffi.core.__object_repr__ = _ffi_node_api.AsRepr diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 4ffea01a3cef..340df0fcea55 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -16,7 +16,7 @@ # under the License. """Common implementation of object generic related logic""" # pylint: disable=unused-import, invalid-name -from tvm.ffi import ObjectGeneric +from tvm_ffi import ObjectConvertible from . import _ffi_node_api diff --git a/python/tvm/runtime/packed_func.py b/python/tvm/runtime/packed_func.py index 71a0ba081658..68940103f32a 100644 --- a/python/tvm/runtime/packed_func.py +++ b/python/tvm/runtime/packed_func.py @@ -17,6 +17,6 @@ # pylint: disable=invalid-name, unused-import """Packed Function namespace.""" -from tvm.ffi import Function as PackedFunc +from tvm_ffi import Function as PackedFunc __all__ = ["PackedFunc"] diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py index af0b4a26173a..f1ea7bda242d 100644 --- a/python/tvm/runtime/params.py +++ b/python/tvm/runtime/params.py @@ -16,15 +16,15 @@ # under the License. # pylint: disable=invalid-name """Helper utility to save and load parameter dicts.""" -from . import _ffi_api, ndarray, NDArray +from . import _ffi_api, tensor, Tensor -def _to_ndarray(params): +def _to_tensor(params): transformed = {} for k, v in params.items(): - if not isinstance(v, NDArray): - transformed[k] = ndarray.array(v) + if not isinstance(v, Tensor): + transformed[k] = tensor(v) else: transformed[k] = v @@ -39,7 +39,7 @@ def save_param_dict(params): Parameters ---------- - params : dict of str to NDArray + params : dict of str to Tensor The parameter dictionary. Returns @@ -59,7 +59,7 @@ def save_param_dict(params): # Pass in byte array to module to directly set parameters tvm.runtime.load_param_dict(param_bytes) """ - return _ffi_api.SaveParams(_to_ndarray(params)) + return _ffi_api.SaveParams(_to_tensor(params)) def save_param_dict_to_file(params, path): @@ -67,13 +67,13 @@ def save_param_dict_to_file(params, path): Parameters ---------- - params : dict of str to NDArray + params : dict of str to Tensor The parameter dictionary. path: str The path to the parameter file. """ - return _ffi_api.SaveParamsToFile(_to_ndarray(params), path) + return _ffi_api.SaveParamsToFile(_to_tensor(params), path) def load_param_dict(param_bytes): @@ -86,7 +86,7 @@ def load_param_dict(param_bytes): Returns ------- - params : dict of str to NDArray + params : dict of str to Tensor The parameter dictionary. """ if isinstance(param_bytes, (bytes, str)): @@ -104,7 +104,7 @@ def load_param_dict_from_file(path): Returns ------- - params : dict of str to NDArray + params : dict of str to Tensor The parameter dictionary. """ return _ffi_api.LoadParamsFromFile(path) diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 45189a008495..3ca831ac4200 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -266,7 +266,7 @@ def profile_function(mod, dev, collectors, func_name=None, warmup_iters=10): if func_name is None: func_name = mod.entry_name return _ffi_api.ProfileFunction( - mod, func_name, dev.device_type, dev.device_id, warmup_iters, collectors + mod, func_name, dev.dlpack_device_type(), dev.index, warmup_iters, collectors ) diff --git a/python/tvm/runtime/profiling/_ffi_api.py b/python/tvm/runtime/profiling/_ffi_api.py index 85e5d4ca020c..883e3ca6e778 100644 --- a/python/tvm/runtime/profiling/_ffi_api.py +++ b/python/tvm/runtime/profiling/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for profiling""" -from ...ffi import _init_api +import tvm_ffi -_init_api("runtime.profiling", __name__) +tvm_ffi.init_ffi_api("runtime.profiling", __name__) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index a00281b435ef..7442cd99172f 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -18,8 +18,8 @@ import os from typing import Dict, List, Optional, Sequence -from tvm.ffi import get_global_func, register_object -from tvm.ffi.access_path import AccessPath +from tvm_ffi import get_global_func, register_object +from tvm_ffi.access_path import AccessPath from tvm.runtime import Object from . import _ffi_node_api diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index 2669459d71a7..20b7159ed535 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -18,12 +18,12 @@ """Runtime support infra of TVM.""" import re -from typing import TypeVar +from typing import TypeVar, Type -import tvm.ffi +import tvm_ffi -@tvm.ffi.register_func("tvm.runtime.regex_match") +@tvm_ffi.register_global_func("tvm.runtime.regex_match") def _regex_match(regex_pattern: str, match_against: str) -> bool: """Check if a pattern matches a regular expression @@ -73,7 +73,7 @@ def _regex_match(regex_pattern: str, match_against: str) -> bool: T = TypeVar("T") -def derived_object(cls: type[T]) -> type[T]: +def derived_object(cls: Type[T]) -> Type[T]: """A decorator to register derived subclasses for TVM objects. Parameters @@ -147,10 +147,17 @@ def method(*args, **kwargs): metadata = getattr(base, "_tvm_metadata") fields = metadata.get("fields", []) methods = metadata.get("methods", []) - - class TVMDerivedObject(metadata["cls"]): # type: ignore + base_cls = metadata["cls"] + derived_slots = ( + ("_inst",) + if hasattr(base_cls, "__weakref__") or getattr(base_cls, "__weakrefoffset__", 0) + else ("_inst", "__weakref__") + ) + + class TVMDerivedObject(base_cls): # type: ignore """The derived object to avoid cyclic dependency.""" + __slots__ = derived_slots _cls = cls _type = "TVMDerivedObject" diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index d69c3308fad4..b188c6ca70c7 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -23,7 +23,7 @@ import numpy as np # type: ignore import tvm -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.runtime import Device, Object, PackedFunc from tvm.runtime.profiling import Report @@ -99,7 +99,7 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) devs = [dev] # CPU is required for executing shape functions - if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type: + if devs[-1].dlpack_device_type() % RPC_SESS_MASK != tvm.cpu().dlpack_device_type(): devs.append(tvm.cpu()) default_alloc_type = VirtualMachine.POOLED_ALLOCATOR @@ -117,8 +117,8 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) ) init_args = [] for device in devs: - init_args.append(device.device_type % RPC_SESS_MASK) - init_args.append(device.device_id) + init_args.append(device.dlpack_device_type() % RPC_SESS_MASK) + init_args.append(device.index) alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type init_args.append(alloc_type) self.module["vm_initialization"](*init_args) @@ -134,7 +134,7 @@ def invoke_closure(self, closure: Object, *args: Any) -> Object: closure : Object The VMClosure Object. - args : list[tvm.runtime.NDArray] or list[np.ndarray] + args : list[tvm.runtime.Tensor] or list[np.ndarray] The arguments to the closure. Returns @@ -206,9 +206,9 @@ def _gettype(arg): if isinstance(arg, Object): cargs.append(arg) elif isinstance(arg, np.ndarray): - nd_arr = tvm.nd.array(arg, device=tvm.cpu(0)) + nd_arr = tvm.runtime.tensor(arg, device=tvm.cpu(0)) cargs.append(nd_arr) - elif isinstance(arg, tvm.runtime.NDArray): + elif isinstance(arg, tvm.runtime.Tensor): cargs.append(arg) elif isinstance(arg, (tuple, list)): field_args: List[Any] = [] @@ -217,7 +217,7 @@ def _gettype(arg): cargs.append(tuple(field_args)) elif isinstance(arg, (Number, bool)): dtype = _gettype(arg) - value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) + value = tvm.runtime.tensor(np.array(arg, dtype=dtype), device=tvm.cpu(0)) cargs.append(value) elif isinstance(arg, str): cargs.append(arg) @@ -252,7 +252,7 @@ def _convert_func_named_args(self, func_name: str, args: Any, **kwargs: Any) -> def set_input(self, func_name: str, *args: Any, **kwargs: Any) -> None: """Set the inputs to a function. - This interface works when using VM over RPC by internally converting NDArray in + This interface works when using VM over RPC by internally converting Tensor in the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C runtime. @@ -263,9 +263,9 @@ def set_input(self, func_name: str, *args: Any, **kwargs: Any) -> None: ---------- func_name : str The name of the function. - args: List[tvm.runtime.NDArray] or List[np.ndarray] + args: List[tvm.runtime.Tensor] or List[np.ndarray] The arguments to the function. - kwargs: dict of str to tvm.runtime.NDArray or np.ndarray + kwargs: dict of str to tvm.runtime.Tensor or np.ndarray Named arguments to the function. """ cargs: List[Any] = [] @@ -482,7 +482,7 @@ def profile(self, func_name: str, *args): func_name : str The name of the function. - args: List of NDArray or other objects supported by PackedFunc. + args: List of Tensor or other objects supported by PackedFunc. The arguments to the function. Returns @@ -499,6 +499,6 @@ def profile(self, func_name: str, *args): return Report.from_json(report_json) -@register_func("vm.builtin.debug_print") +@register_global_func("vm.builtin.debug_print") def _print(lineo: str, array) -> None: print(f"{lineo}: shape = {array.shape}, dtype = {array.dtype}, data =\n{array}") diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/_ffi_api.py index 8ae8f7b7f9a5..1354d3f2ec2c 100644 --- a/python/tvm/script/_ffi_api.py +++ b/python/tvm/script/_ffi_api.py @@ -14,7 +14,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("script", __name__) +tvm_ffi.init_ffi_api("script", __name__) diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py index 8ee223051986..c8a9597d5292 100644 --- a/python/tvm/script/ir_builder/_ffi_api.py +++ b/python/tvm/script/ir_builder/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("script.ir_builder", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 95b5c5002558..a6bb68e2507c 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -17,7 +17,7 @@ """A generic IRBuilder across the TVM stack""" from typing import Any, Callable, List -from tvm.ffi import register_object as _register_object +from tvm_ffi import register_object as _register_object from tvm.runtime import Object as _Object from . import _ffi_api diff --git a/python/tvm/script/ir_builder/ir/_ffi_api.py b/python/tvm/script/ir_builder/ir/_ffi_api.py index 5b9d801a6ed3..e319c3d4612e 100644 --- a/python/tvm/script/ir_builder/ir/_ffi_api.py +++ b/python/tvm/script/ir_builder/ir/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/ir/frame.py b/python/tvm/script/ir_builder/ir/frame.py index d2737fde59a6..45b49221e34b 100644 --- a/python/tvm/script/ir_builder/ir/frame.py +++ b/python/tvm/script/ir_builder/ir/frame.py @@ -16,7 +16,7 @@ # under the License. """Package tvm.script.ir_builder.ir.frame""" -from tvm.ffi import register_object as _register_object +from tvm_ffi import register_object as _register_object from ..base import IRBuilderFrame diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py index 1c767bacc4c5..f6c53336ff4c 100644 --- a/python/tvm/script/ir_builder/relax/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py index 4d2ba60c2002..b82fa37e8f3f 100644 --- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax.distributed""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api( +tvm_ffi.init_ffi_api( "script.ir_builder.relax.distributed", __name__ ) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py b/python/tvm/script/ir_builder/relax/distributed/ir.py index 159ad5aea169..465cf6313eb1 100644 --- a/python/tvm/script/ir_builder/relax/distributed/ir.py +++ b/python/tvm/script/ir_builder/relax/distributed/ir.py @@ -29,7 +29,7 @@ from tvm.relax.distributed import DTensorStructInfo from tvm.relax.utils import args_converter from tvm import base as _base -from tvm.runtime import ndarray as _nd +from tvm.runtime import _tensor from tvm.relax.op.distributed import ( redistribute as _redistribute, annotate_sharding as _annotate_sharding, @@ -89,14 +89,14 @@ def call_tir( def const( - value: Union[bool, int, float, _np.ndarray, tvm.nd.NDArray], + value: Union[bool, int, float, _np.ndarray, tvm.runtime.Tensor], struct_info: DTensorStructInfo, ) -> Constant: """Create a constant value. Parameters ---------- - value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] + value: Union[bool, int, float, numpy.ndarray, tvm.runtime.Tensor] The constant value. dtype: Optional[str] @@ -121,10 +121,10 @@ def const( if isinstance(value, (_np.ndarray, _np.generic)): if dtype is not None: value = value.astype(dtype) - value = _nd.array(value) + value = _tensor.tensor(value) - if not isinstance(value, _nd.NDArray): - raise ValueError("value has to be scalar or NDArray") + if not isinstance(value, _tensor.Tensor): + raise ValueError("value has to be scalar or Tensor") return Constant(value, struct_info) diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py index 181f62ec4f39..ed4d948ff972 100644 --- a/python/tvm/script/ir_builder/relax/frame.py +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """IR Builder Frame for Relax dialect""" -from tvm.ffi import register_object as _register_object +from tvm_ffi import register_object as _register_object from ..base import IRBuilderFrame diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e61e563b706b..141361a729c4 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -30,6 +30,7 @@ Expr, ExternFunc, ShapeExpr, + StringImm, TupleGetItem, Var, VarBinding, @@ -64,6 +65,7 @@ call_dps_packed, call_inplace_packed, call_pure_packed, + call_py_func as _call_py_func, call_tir, call_tir_inplace, call_tir_with_grad, @@ -135,6 +137,7 @@ multiply, negative, nn, + nonzero, not_equal, null_value, ones, @@ -186,13 +189,14 @@ wrap_param, zeros, zeros_like, + vision, ) from tvm.relax.op.builtin import stop_lift_params from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter, gen_call_tir_inputs from tvm.runtime import Object as tvm_Object -from tvm.runtime import ObjectGeneric -from tvm.runtime.ndarray import ( +from tvm.runtime import ObjectConvertible +from tvm.runtime._tensor import ( cpu, cuda, device, @@ -431,7 +435,7 @@ def call_packed( sinfo() if callable(sinfo) else sinfo.asobject() - if isinstance(sinfo, ObjectGeneric) + if isinstance(sinfo, ObjectConvertible) else sinfo ) for sinfo in sinfo_args @@ -451,6 +455,57 @@ def call_packed( return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) +@args_converter.auto +def call_py_func( + py_func_name: py_str, + *args: Expr, + out_sinfo: Union[StructInfo, List[StructInfo]], +) -> Call: + """Create a relax Call, which calls a Python function. + + Parameters + ---------- + py_func_name: str + The name of the Python function to call. This should correspond to a function + in the IRModule's pyfuncs attribute. + *args : Expr + The arguments. + out_sinfo: Union[StructInfo, List[StructInfo]] + The structure info of the call_py_func output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + call: Call + The created Relax Call for call_py_func operator. + """ + if isinstance(out_sinfo, py_tuple): # type: ignore + out_sinfo = list(out_sinfo) + elif not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + out_sinfo = [ + ( + sinfo() + if callable(sinfo) + else sinfo.asobject() + if isinstance(sinfo, ObjectConvertible) + else sinfo + ) + for sinfo in out_sinfo + ] + + # Convert string to StringImm + try: + func_name_imm = ( + StringImm(py_func_name) if isinstance(py_func_name, py_str) else py_func_name + ) + except (TypeError, ValueError, AttributeError): + func_name_imm = StringImm(py_func_name) + return _call_py_func(func_name_imm, args, out_sinfo) + + def _sinfo_arg_wrapper(func): """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" @@ -462,7 +517,7 @@ def _convert_tensor_type(args): return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} if inspect.isfunction(args): args = args() - if isinstance(args, ObjectGeneric): + if isinstance(args, ObjectConvertible): args = args.asobject() return args @@ -743,6 +798,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_tir_inplace", "call_tir_with_grad", "call_dps_packed", + "call_py_func", "call_builtin_with_ctx", "ceil", "clip", @@ -827,6 +883,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "multinomial_from_uniform", "multiply", "negative", + "nonzero", "not_equal", "null_value", "ones", @@ -896,4 +953,5 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "nn", "ccl", "erf", + "vision", ] diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tir/_ffi_api.py index 69797f986afd..4385b2ec13d0 100644 --- a/python/tvm/script/ir_builder/tir/_ffi_api.py +++ b/python/tvm/script/ir_builder/tir/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index e3ce2e6e2eb1..f43b4cf6ed67 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -17,7 +17,7 @@ """IRBuilder for TIR""" from typing import List, Union -from tvm.ffi import register_object as _register_object +from tvm_ffi import register_object as _register_object from tvm.tir import Buffer, Var from ..base import IRBuilderFrame diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c6549ad104c3..84143e05891f 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -16,11 +16,13 @@ # under the License. """IRBuilder for TIR""" +import contextlib import functools import inspect import sys +import threading from numbers import Integral -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union # isort: off from typing_extensions import Literal @@ -32,7 +34,7 @@ from tvm import ir, tir from tvm.ir import Type from tvm.ir.base import deprecated -from tvm.runtime import String, convert, ndarray +from tvm.runtime import String, convert, tensor from tvm.target import Target # pylint: disable=unused-import @@ -87,6 +89,35 @@ # pylint: enable=unused-import +_block_name_suffix = threading.local() + + +def _get_block_name_suffix() -> str: + """Get the current block name suffix for macro expansion.""" + return getattr(_block_name_suffix, "value", "") + + +@contextlib.contextmanager +def block_name_suffix_context(block_suffix: str): + """Context manager to set block name suffix during macro expansion. + + Parameters + ---------- + block_suffix : str + The suffix to append to block names (e.g., "_1", "_2"). + + Yields + ------ + None + """ + old_suffix = getattr(_block_name_suffix, "value", "") + _block_name_suffix.value = block_suffix + try: + yield + finally: + _block_name_suffix.value = old_suffix + + def buffer( shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], dtype: str = "float32", @@ -352,6 +383,9 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: res : frame.BlockFrame The BlockFrame. """ + block_suffix = _get_block_name_suffix() + if block_suffix and name: + name = name + block_suffix return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] # pylint: disable=no-member @@ -677,7 +711,11 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L def serial( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The serial For statement. @@ -692,6 +730,9 @@ def serial( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -703,11 +744,15 @@ def serial( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def parallel( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The parallel For statement. @@ -722,6 +767,9 @@ def parallel( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -733,11 +781,15 @@ def parallel( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def vectorized( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The vectorized For statement. @@ -752,6 +804,9 @@ def vectorized( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -763,11 +818,15 @@ def vectorized( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def unroll( - start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None + start: PrimExpr, + stop: PrimExpr = None, + *, + annotations: Dict[str, Any] = None, + step: Optional[PrimExpr] = None, ) -> frame.ForFrame: """The unrolled For statement. @@ -782,6 +841,9 @@ def unroll( annotations : Dict[str, Any] The optional annotations of the For statement. + step : PrimExpr + The optional step value of iteration. + Returns ------- res : frame.ForFrame @@ -793,7 +855,7 @@ def unroll( start = IntImm(start.dtype, 0) else: start = 0 - return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def thread_binding( @@ -1054,7 +1116,7 @@ def allocate_const( np_data = np_data.reshape(extents) return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member - ndarray.array(np_data), dtype, extents, annotations + tensor(np_data), dtype, extents, annotations ) @@ -1316,6 +1378,17 @@ def buffer_store( ) +def customized_code(code: str): + """Add a customized code block. + + Parameters + ---------- + code : str + The code block to be added. + """ + return _ffi_api.CustomizedCode(code) # type: ignore[attr-defined] # pylint: disable=no-member + + def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -1357,173 +1430,348 @@ def func( return func - -# pylint: disable=invalid-name -int8 = func_gen(("Int8")) -int16 = func_gen(("Int16")) -int32 = func_gen(("Int32")) -int64 = func_gen(("Int64")) -int8x4 = func_gen(("Int8x4")) -int16x4 = func_gen(("Int16x4")) -int32x4 = func_gen(("Int32x4")) -int64x4 = func_gen(("Int64x4")) -int8x8 = func_gen(("Int8x8")) -int16x8 = func_gen(("Int16x8")) -int32x8 = func_gen(("Int32x8")) -int64x8 = func_gen(("Int64x8")) -int8x16 = func_gen(("Int8x16")) -int16x16 = func_gen(("Int16x16")) -int32x16 = func_gen(("Int32x16")) -int64x16 = func_gen(("Int64x16")) -int8x32 = func_gen(("Int8x32")) -int16x32 = func_gen(("Int16x32")) -int32x32 = func_gen(("Int32x32")) -int64x32 = func_gen(("Int64x32")) -int8x64 = func_gen(("Int8x64")) -int16x64 = func_gen(("Int16x64")) -int32x64 = func_gen(("Int32x64")) -int64x64 = func_gen(("Int64x64")) - -uint8 = func_gen(("UInt8")) -uint16 = func_gen(("UInt16")) -uint32 = func_gen(("UInt32")) -uint64 = func_gen(("UInt64")) -uint8x4 = func_gen(("UInt8x4")) -uint16x4 = func_gen(("UInt16x4")) -uint32x4 = func_gen(("UInt32x4")) -uint64x4 = func_gen(("UInt64x4")) -uint8x8 = func_gen(("UInt8x8")) -uint16x8 = func_gen(("UInt16x8")) -uint32x8 = func_gen(("UInt32x8")) -uint64x8 = func_gen(("UInt64x8")) -uint8x16 = func_gen(("UInt8x16")) -uint16x16 = func_gen(("UInt16x16")) -uint32x16 = func_gen(("UInt32x16")) -uint64x16 = func_gen(("UInt64x16")) -uint8x32 = func_gen(("UInt8x32")) -uint16x32 = func_gen(("UInt16x32")) -uint32x32 = func_gen(("UInt32x32")) -uint64x32 = func_gen(("UInt64x32")) -uint8x64 = func_gen(("UInt8x64")) -uint16x64 = func_gen(("UInt16x64")) -uint32x64 = func_gen(("UInt32x64")) -uint64x64 = func_gen(("UInt64x64")) - -float16 = func_gen(("Float16")) -float32 = func_gen(("Float32")) -float64 = func_gen(("Float64")) -float16x2 = func_gen(("Float16x2")) -float32x2 = func_gen(("Float32x2")) -float64x2 = func_gen(("Float64x2")) -float16x4 = func_gen(("Float16x4")) -float32x4 = func_gen(("Float32x4")) -float64x4 = func_gen(("Float64x4")) -float16x8 = func_gen(("Float16x8")) -float32x8 = func_gen(("Float32x8")) -float64x8 = func_gen(("Float64x8")) -float16x16 = func_gen(("Float16x16")) -float32x16 = func_gen(("Float32x16")) -float64x16 = func_gen(("Float64x16")) -float16x32 = func_gen(("Float16x32")) -float32x32 = func_gen(("Float32x32")) -float64x32 = func_gen(("Float64x32")) -float16x64 = func_gen(("Float16x64")) -float32x64 = func_gen(("Float32x64")) -float64x64 = func_gen(("Float64x64")) - -# Float8 variants -float8_e3m4 = func_gen(("Float8E3M4")) -float8_e3m4x2 = func_gen(("Float8E3M4x2")) -float8_e3m4x4 = func_gen(("Float8E3M4x4")) -float8_e3m4x8 = func_gen(("Float8E3M4x8")) -float8_e3m4x16 = func_gen(("Float8E3M4x16")) -float8_e3m4x32 = func_gen(("Float8E3M4x32")) -float8_e3m4x64 = func_gen(("Float8E3M4x64")) - -float8_e4m3 = func_gen(("Float8E4M3")) -float8_e4m3x2 = func_gen(("Float8E4M3x2")) -float8_e4m3x4 = func_gen(("Float8E4M3x4")) -float8_e4m3x8 = func_gen(("Float8E4M3x8")) -float8_e4m3x16 = func_gen(("Float8E4M3x16")) -float8_e4m3x32 = func_gen(("Float8E4M3x32")) -float8_e4m3x64 = func_gen(("Float8E4M3x64")) - -float8_e4m3b11fnuz = func_gen(("Float8E4M3B11FNUZ")) -float8_e4m3b11fnuzx2 = func_gen(("Float8E4M3B11FNUZx2")) -float8_e4m3b11fnuzx4 = func_gen(("Float8E4M3B11FNUZx4")) -float8_e4m3b11fnuzx8 = func_gen(("Float8E4M3B11FNUZx8")) -float8_e4m3b11fnuzx16 = func_gen(("Float8E4M3B11FNUZx16")) -float8_e4m3b11fnuzx32 = func_gen(("Float8E4M3B11FNUZx32")) -float8_e4m3b11fnuzx64 = func_gen(("Float8E4M3B11FNUZx64")) - -float8_e4m3fn = func_gen(("Float8E4M3FN")) -float8_e4m3fnx2 = func_gen(("Float8E4M3FNx2")) -float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4")) -float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8")) -float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16")) -float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32")) -float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64")) - -float8_e4m3fnuz = func_gen(("Float8E4M3FNUZ")) -float8_e4m3fnuzx2 = func_gen(("Float8E4M3FNUZx2")) -float8_e4m3fnuzx4 = func_gen(("Float8E4M3FNUZx4")) -float8_e4m3fnuzx8 = func_gen(("Float8E4M3FNUZx8")) -float8_e4m3fnuzx16 = func_gen(("Float8E4M3FNUZx16")) -float8_e4m3fnuzx32 = func_gen(("Float8E4M3FNUZx32")) -float8_e4m3fnuzx64 = func_gen(("Float8E4M3FNUZx64")) - -float8_e5m2 = func_gen(("Float8E5M2")) -float8_e5m2x2 = func_gen(("Float8E5M2x2")) -float8_e5m2x4 = func_gen(("Float8E5M2x4")) -float8_e5m2x8 = func_gen(("Float8E5M2x8")) -float8_e5m2x16 = func_gen(("Float8E5M2x16")) -float8_e5m2x32 = func_gen(("Float8E5M2x32")) -float8_e5m2x64 = func_gen(("Float8E5M2x64")) - -float8_e5m2fnuz = func_gen(("Float8E5M2FNUZ")) -float8_e5m2fnuzx2 = func_gen(("Float8E5M2FNUZx2")) -float8_e5m2fnuzx4 = func_gen(("Float8E5M2FNUZx4")) -float8_e5m2fnuzx8 = func_gen(("Float8E5M2FNUZx8")) -float8_e5m2fnuzx16 = func_gen(("Float8E5M2FNUZx16")) -float8_e5m2fnuzx32 = func_gen(("Float8E5M2FNUZx32")) -float8_e5m2fnuzx64 = func_gen(("Float8E5M2FNUZx64")) - -float8_e8m0fnu = func_gen(("Float8E8M0FNU")) -float8_e8m0fnux2 = func_gen(("Float8E8M0FNUx2")) -float8_e8m0fnux4 = func_gen(("Float8E8M0FNUx4")) -float8_e8m0fnux8 = func_gen(("Float8E8M0FNUx8")) -float8_e8m0fnux16 = func_gen(("Float8E8M0FNUx16")) -float8_e8m0fnux32 = func_gen(("Float8E8M0FNUx32")) -float8_e8m0fnux64 = func_gen(("Float8E8M0FNUx64")) - -# Float6 variants -float6_e2m3fn = func_gen(("Float6E2M3FN")) -float6_e2m3fnx2 = func_gen(("Float6E2M3FNx2")) -float6_e2m3fnx4 = func_gen(("Float6E2M3FNx4")) -float6_e2m3fnx8 = func_gen(("Float6E2M3FNx8")) -float6_e2m3fnx16 = func_gen(("Float6E2M3FNx16")) -float6_e2m3fnx32 = func_gen(("Float6E2M3FNx32")) -float6_e2m3fnx64 = func_gen(("Float6E2M3FNx64")) - -float6_e3m2fn = func_gen(("Float6E3M2FN")) -float6_e3m2fnx2 = func_gen(("Float6E3M2FNx2")) -float6_e3m2fnx4 = func_gen(("Float6E3M2FNx4")) -float6_e3m2fnx8 = func_gen(("Float6E3M2FNx8")) -float6_e3m2fnx16 = func_gen(("Float6E3M2FNx16")) -float6_e3m2fnx32 = func_gen(("Float6E3M2FNx32")) -float6_e3m2fnx64 = func_gen(("Float6E3M2FNx64")) - -# Float4 variants -float4_e2m1fn = func_gen(("Float4E2M1FN")) -float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2")) -float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4")) -float4_e2m1fnx8 = func_gen(("Float4E2M1FNx8")) -float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16")) -float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32")) -float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64")) - -bfloat16 = func_gen(("BFloat16")) -# pylint: enable=invalid-name +if TYPE_CHECKING: + class int8: ... + class int16: ... + class int32: ... + class int64: ... + class int8x4: ... + class int16x4: ... + class int32x4: ... + class int64x4: ... + class int8x8: ... + class int16x8: ... + class int32x8: ... + class int64x8: ... + class int8x16: ... + class int16x16: ... + class int32x16: ... + class int64x16: ... + class int8x32: ... + class int16x32: ... + class int32x32: ... + class int64x32: ... + class int8x64: ... + class int16x64: ... + class int32x64: ... + class int64x64: ... + class uint8: ... + class uint16: ... + class uint32: ... + class uint64: ... + class uint8x4: ... + class uint16x4: ... + class uint32x4: ... + class uint64x4: ... + class uint8x8: ... + class uint16x8: ... + class uint32x8: ... + class uint64x8: ... + class uint8x16: ... + class uint16x16: ... + class uint32x16: ... + class uint64x16: ... + class uint8x32: ... + class uint16x32: ... + class uint32x32: ... + class uint64x32: ... + class uint8x64: ... + class uint16x64: ... + class uint32x64: ... + class uint64x64: ... + class float16: ... + class float32: ... + class float64: ... + class float16x2: ... + class float32x2: ... + class float64x2: ... + class float16x4: ... + class float32x4: ... + class float64x4: ... + class float16x8: ... + class float32x8: ... + class float64x8: ... + class float16x16: ... + class float32x16: ... + class float64x16: ... + class float16x32: ... + class float32x32: ... + class float64x32: ... + class float16x64: ... + class float32x64: ... + class float64x64: ... + class float8_e3m4: ... + class float8_e3m4x2: ... + class float8_e3m4x4: ... + class float8_e3m4x8: ... + class float8_e3m4x16: ... + class float8_e3m4x32: ... + class float8_e3m4x64: ... + class float8_e4m3: ... + class float8_e4m3x2: ... + class float8_e4m3x4: ... + class float8_e4m3x8: ... + class float8_e4m3x16: ... + class float8_e4m3x32: ... + class float8_e4m3x64: ... + class float8_e4m3b11fnuz: ... + class float8_e4m3b11fnuzx2: ... + class float8_e4m3b11fnuzx4: ... + class float8_e4m3b11fnuzx8: ... + class float8_e4m3b11fnuzx16: ... + class float8_e4m3b11fnuzx32: ... + class float8_e4m3b11fnuzx64: ... + class float8_e4m3fn: ... + class float8_e4m3fnx2: ... + class float8_e4m3fnx4: ... + class float8_e4m3fnx8: ... + class float8_e4m3fnx16: ... + class float8_e4m3fnx32: ... + class float8_e4m3fnx64: ... + class float8_e4m3fnuz: ... + class float8_e4m3fnuzx2: ... + class float8_e4m3fnuzx4: ... + class float8_e4m3fnuzx8: ... + class float8_e4m3fnuzx16: ... + class float8_e4m3fnuzx32: ... + class float8_e4m3fnuzx64: ... + class float8_e5m2: ... + class float8_e5m2x2: ... + class float8_e5m2x4: ... + class float8_e5m2x8: ... + class float8_e5m2x16: ... + class float8_e5m2x32: ... + class float8_e5m2x64: ... + class float8_e5m2fnuz: ... + class float8_e5m2fnuzx2: ... + class float8_e5m2fnuzx4: ... + class float8_e5m2fnuzx8: ... + class float8_e5m2fnuzx16: ... + class float8_e5m2fnuzx32: ... + class float8_e5m2fnuzx64: ... + class float8_e8m0fnu: ... + class float8_e8m0fnux2: ... + class float8_e8m0fnux4: ... + class float8_e8m0fnux8: ... + class float8_e8m0fnux16: ... + class float8_e8m0fnux32: ... + class float8_e8m0fnux64: ... + class float6_e2m3fn: ... + class float6_e2m3fnx2: ... + class float6_e2m3fnx4: ... + class float6_e2m3fnx8: ... + class float6_e2m3fnx16: ... + class float6_e2m3fnx32: ... + class float6_e2m3fnx64: ... + class float6_e3m2fn: ... + class float6_e3m2fnx2: ... + class float6_e3m2fnx4: ... + class float6_e3m2fnx8: ... + class float6_e3m2fnx16: ... + class float6_e3m2fnx32: ... + class float6_e3m2fnx64: ... + class float4_e2m1fn: ... + class float4_e2m1fnx2: ... + class float4_e2m1fnx4: ... + class float4_e2m1fnx8: ... + class float4_e2m1fnx16: ... + class float4_e2m1fnx32: ... + class float4_e2m1fnx64: ... + class bfloat16: ... + class bfloat16x2: ... + class bfloat16x4: ... + class bfloat16x8: ... + class bfloat16x16: ... + class bfloat16x32: ... + class bfloat16x64: ... + class tfloat32: ... + class tfloat32x2: ... + class tfloat32x4: ... + class tfloat32x8: ... + class tfloat32x16: ... + class tfloat32x32: ... + class tfloat32x64: ... +else: + # pylint: disable=invalid-name + int8 = func_gen(("Int8")) + int16 = func_gen(("Int16")) + int32 = func_gen(("Int32")) + int64 = func_gen(("Int64")) + int8x4 = func_gen(("Int8x4")) + int16x4 = func_gen(("Int16x4")) + int32x4 = func_gen(("Int32x4")) + int64x4 = func_gen(("Int64x4")) + int8x8 = func_gen(("Int8x8")) + int16x8 = func_gen(("Int16x8")) + int32x8 = func_gen(("Int32x8")) + int64x8 = func_gen(("Int64x8")) + int8x16 = func_gen(("Int8x16")) + int16x16 = func_gen(("Int16x16")) + int32x16 = func_gen(("Int32x16")) + int64x16 = func_gen(("Int64x16")) + int8x32 = func_gen(("Int8x32")) + int16x32 = func_gen(("Int16x32")) + int32x32 = func_gen(("Int32x32")) + int64x32 = func_gen(("Int64x32")) + int8x64 = func_gen(("Int8x64")) + int16x64 = func_gen(("Int16x64")) + int32x64 = func_gen(("Int32x64")) + int64x64 = func_gen(("Int64x64")) + + uint8 = func_gen(("UInt8")) + uint16 = func_gen(("UInt16")) + uint32 = func_gen(("UInt32")) + uint64 = func_gen(("UInt64")) + uint8x4 = func_gen(("UInt8x4")) + uint16x4 = func_gen(("UInt16x4")) + uint32x4 = func_gen(("UInt32x4")) + uint64x4 = func_gen(("UInt64x4")) + uint8x8 = func_gen(("UInt8x8")) + uint16x8 = func_gen(("UInt16x8")) + uint32x8 = func_gen(("UInt32x8")) + uint64x8 = func_gen(("UInt64x8")) + uint8x16 = func_gen(("UInt8x16")) + uint16x16 = func_gen(("UInt16x16")) + uint32x16 = func_gen(("UInt32x16")) + uint64x16 = func_gen(("UInt64x16")) + uint8x32 = func_gen(("UInt8x32")) + uint16x32 = func_gen(("UInt16x32")) + uint32x32 = func_gen(("UInt32x32")) + uint64x32 = func_gen(("UInt64x32")) + uint8x64 = func_gen(("UInt8x64")) + uint16x64 = func_gen(("UInt16x64")) + uint32x64 = func_gen(("UInt32x64")) + uint64x64 = func_gen(("UInt64x64")) + + float16 = func_gen(("Float16")) + float32 = func_gen(("Float32")) + float64 = func_gen(("Float64")) + float16x2 = func_gen(("Float16x2")) + float32x2 = func_gen(("Float32x2")) + float64x2 = func_gen(("Float64x2")) + float16x4 = func_gen(("Float16x4")) + float32x4 = func_gen(("Float32x4")) + float64x4 = func_gen(("Float64x4")) + float16x8 = func_gen(("Float16x8")) + float32x8 = func_gen(("Float32x8")) + float64x8 = func_gen(("Float64x8")) + float16x16 = func_gen(("Float16x16")) + float32x16 = func_gen(("Float32x16")) + float64x16 = func_gen(("Float64x16")) + float16x32 = func_gen(("Float16x32")) + float32x32 = func_gen(("Float32x32")) + float64x32 = func_gen(("Float64x32")) + float16x64 = func_gen(("Float16x64")) + float32x64 = func_gen(("Float32x64")) + float64x64 = func_gen(("Float64x64")) + + # Float8 variants + float8_e3m4 = func_gen(("Float8E3M4")) + float8_e3m4x2 = func_gen(("Float8E3M4x2")) + float8_e3m4x4 = func_gen(("Float8E3M4x4")) + float8_e3m4x8 = func_gen(("Float8E3M4x8")) + float8_e3m4x16 = func_gen(("Float8E3M4x16")) + float8_e3m4x32 = func_gen(("Float8E3M4x32")) + float8_e3m4x64 = func_gen(("Float8E3M4x64")) + + float8_e4m3 = func_gen(("Float8E4M3")) + float8_e4m3x2 = func_gen(("Float8E4M3x2")) + float8_e4m3x4 = func_gen(("Float8E4M3x4")) + float8_e4m3x8 = func_gen(("Float8E4M3x8")) + float8_e4m3x16 = func_gen(("Float8E4M3x16")) + float8_e4m3x32 = func_gen(("Float8E4M3x32")) + float8_e4m3x64 = func_gen(("Float8E4M3x64")) + + float8_e4m3b11fnuz = func_gen(("Float8E4M3B11FNUZ")) + float8_e4m3b11fnuzx2 = func_gen(("Float8E4M3B11FNUZx2")) + float8_e4m3b11fnuzx4 = func_gen(("Float8E4M3B11FNUZx4")) + float8_e4m3b11fnuzx8 = func_gen(("Float8E4M3B11FNUZx8")) + float8_e4m3b11fnuzx16 = func_gen(("Float8E4M3B11FNUZx16")) + float8_e4m3b11fnuzx32 = func_gen(("Float8E4M3B11FNUZx32")) + float8_e4m3b11fnuzx64 = func_gen(("Float8E4M3B11FNUZx64")) + + float8_e4m3fn = func_gen(("Float8E4M3FN")) + float8_e4m3fnx2 = func_gen(("Float8E4M3FNx2")) + float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4")) + float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8")) + float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16")) + float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32")) + float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64")) + + float8_e4m3fnuz = func_gen(("Float8E4M3FNUZ")) + float8_e4m3fnuzx2 = func_gen(("Float8E4M3FNUZx2")) + float8_e4m3fnuzx4 = func_gen(("Float8E4M3FNUZx4")) + float8_e4m3fnuzx8 = func_gen(("Float8E4M3FNUZx8")) + float8_e4m3fnuzx16 = func_gen(("Float8E4M3FNUZx16")) + float8_e4m3fnuzx32 = func_gen(("Float8E4M3FNUZx32")) + float8_e4m3fnuzx64 = func_gen(("Float8E4M3FNUZx64")) + + float8_e5m2 = func_gen(("Float8E5M2")) + float8_e5m2x2 = func_gen(("Float8E5M2x2")) + float8_e5m2x4 = func_gen(("Float8E5M2x4")) + float8_e5m2x8 = func_gen(("Float8E5M2x8")) + float8_e5m2x16 = func_gen(("Float8E5M2x16")) + float8_e5m2x32 = func_gen(("Float8E5M2x32")) + float8_e5m2x64 = func_gen(("Float8E5M2x64")) + + float8_e5m2fnuz = func_gen(("Float8E5M2FNUZ")) + float8_e5m2fnuzx2 = func_gen(("Float8E5M2FNUZx2")) + float8_e5m2fnuzx4 = func_gen(("Float8E5M2FNUZx4")) + float8_e5m2fnuzx8 = func_gen(("Float8E5M2FNUZx8")) + float8_e5m2fnuzx16 = func_gen(("Float8E5M2FNUZx16")) + float8_e5m2fnuzx32 = func_gen(("Float8E5M2FNUZx32")) + float8_e5m2fnuzx64 = func_gen(("Float8E5M2FNUZx64")) + + float8_e8m0fnu = func_gen(("Float8E8M0FNU")) + float8_e8m0fnux2 = func_gen(("Float8E8M0FNUx2")) + float8_e8m0fnux4 = func_gen(("Float8E8M0FNUx4")) + float8_e8m0fnux8 = func_gen(("Float8E8M0FNUx8")) + float8_e8m0fnux16 = func_gen(("Float8E8M0FNUx16")) + float8_e8m0fnux32 = func_gen(("Float8E8M0FNUx32")) + float8_e8m0fnux64 = func_gen(("Float8E8M0FNUx64")) + + # Float6 variants + float6_e2m3fn = func_gen(("Float6E2M3FN")) + float6_e2m3fnx2 = func_gen(("Float6E2M3FNx2")) + float6_e2m3fnx4 = func_gen(("Float6E2M3FNx4")) + float6_e2m3fnx8 = func_gen(("Float6E2M3FNx8")) + float6_e2m3fnx16 = func_gen(("Float6E2M3FNx16")) + float6_e2m3fnx32 = func_gen(("Float6E2M3FNx32")) + float6_e2m3fnx64 = func_gen(("Float6E2M3FNx64")) + + float6_e3m2fn = func_gen(("Float6E3M2FN")) + float6_e3m2fnx2 = func_gen(("Float6E3M2FNx2")) + float6_e3m2fnx4 = func_gen(("Float6E3M2FNx4")) + float6_e3m2fnx8 = func_gen(("Float6E3M2FNx8")) + float6_e3m2fnx16 = func_gen(("Float6E3M2FNx16")) + float6_e3m2fnx32 = func_gen(("Float6E3M2FNx32")) + float6_e3m2fnx64 = func_gen(("Float6E3M2FNx64")) + + # Float4 variants + float4_e2m1fn = func_gen(("Float4E2M1FN")) + float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2")) + float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4")) + float4_e2m1fnx8 = func_gen(("Float4E2M1FNx8")) + float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16")) + float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32")) + float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64")) + + bfloat16 = func_gen(("BFloat16")) + bfloat16x2 = func_gen(("BFloat16x2")) + bfloat16x4 = func_gen(("BFloat16x4")) + bfloat16x8 = func_gen(("BFloat16x8")) + bfloat16x16 = func_gen(("BFloat16x16")) + bfloat16x32 = func_gen(("BFloat16x32")) + bfloat16x64 = func_gen(("BFloat16x64")) + + tfloat32 = func_gen(("TensorFloat32")) + tfloat32x2 = func_gen(("TensorFloat32x2")) + tfloat32x4 = func_gen(("TensorFloat32x4")) + tfloat32x8 = func_gen(("TensorFloat32x8")) + tfloat32x16 = func_gen(("TensorFloat32x16")) + tfloat32x32 = func_gen(("TensorFloat32x32")) + tfloat32x64 = func_gen(("TensorFloat32x64")) + # pylint: enable=invalid-name def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimExpr: @@ -1917,6 +2165,8 @@ def wrapped(*args, **kwargs): q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) +continue_loop = _op_wrapper(_tir_op.continue_loop) +break_loop = _op_wrapper(_tir_op.break_loop) round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin rsqrt = _op_wrapper(_tir_op.rsqrt) shift_left = _op_wrapper(_tir_op.shift_left) @@ -2096,6 +2346,19 @@ def wrapped(*args, **kwargs): "uint32x64", "uint64x64", "bfloat16", + "bfloat16x2", + "bfloat16x4", + "bfloat16x8", + "bfloat16x16", + "bfloat16x32", + "bfloat16x64", + "tfloat32", + "tfloat32x2", + "tfloat32x4", + "tfloat32x8", + "tfloat32x16", + "tfloat32x32", + "tfloat32x64", "buffer", "buffer_decl", "prim_func", @@ -2105,6 +2368,7 @@ def wrapped(*args, **kwargs): "func_ret", "match_buffer", "block", + "block_name_suffix_context", "init", "where", "reads", @@ -2195,6 +2459,8 @@ def wrapped(*args, **kwargs): "q_multiply_shift", "q_multiply_shift_per_axis", "ret", + "continue_loop", + "break_loop", "reinterpret", "round", "rsqrt", diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py index 74174f066727..f8c400ad1667 100644 --- a/python/tvm/script/parser/core/doc.py +++ b/python/tvm/script/parser/core/doc.py @@ -18,6 +18,7 @@ import ast import inspect +import sys import typing from collections import defaultdict @@ -318,4 +319,150 @@ def __call__(self, node): ) + +def _py_version() -> typing.Tuple[int, int]: + return (sys.version_info.major, sys.version_info.minor) + + +def _register_constant_handling(): + if _py_version() not in [(3, 6), (3, 7)]: + return + + def as_constant(f) -> doc.Constant: + def to_doc_func(x: ast.AST) -> doc.Constant: + return doc.Constant( + value=getattr(x, f) if isinstance(f, str) else f(x), + kind=None, + lineno=x.lineno, + col_offset=x.col_offset, + end_lineno=x.lineno, + end_col_offset=x.col_offset, + ) + + return to_doc_func + + register_to_doc("Str")(as_constant("s")) + register_to_doc("NameConstant")(as_constant("value")) + register_to_doc("Num")(as_constant("n")) + register_to_doc("Bytes")(as_constant("s")) + register_to_doc("Ellipsis")(as_constant(lambda _: ...)) + + +def _register_subscription_handling(): + if _py_version() >= (3, 9): + return + + def subscript_to_doc(x: ast.Subscript) -> doc.Subscript: + if isinstance(x.slice, ast.Slice): + return doc.Subscript( + value=to_doc(x.value), + slice=doc.Slice( + lower=to_doc(x.slice.lower), + upper=to_doc(x.slice.upper), + step=to_doc(x.slice.step), + lineno=getattr(x.slice, "lineno", None), + col_offset=getattr(x.slice, "col_offset", None), + end_lineno=getattr(x.slice, "end_lineno", None), + end_col_offset=getattr(x.slice, "end_col_offset", None), + ), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + if isinstance(x.slice, ast.ExtSlice): + return doc.Subscript( + value=to_doc(x.value), + slice=doc.Tuple( + elts=[to_doc(i) for i in x.slice.dims], + ctx=doc.Load( + lineno=None, + col_offset=None, + end_lineno=None, + end_col_offset=None, + ), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + if isinstance(x.slice, ast.Index): + return doc.Subscript( + value=to_doc(x.value), + slice=to_doc(x.slice.value), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + raise TypeError(f"Unknown subscript type: {type(x.slice)}") + + def subscript_from_doc(x: doc.Subscript) -> ast.Subscript: + if isinstance(x.slice, doc.Slice): + result = ast.Subscript( + value=from_doc(x.value), + slice=from_doc(x.slice), + ctx=from_doc(x.ctx), + ) + elif isinstance(x.slice, doc.Tuple): + + def remap_dim(doc_item: doc.Expr) -> ast.Expr: + ast_item = from_doc(doc_item) + if isinstance(ast_item, (ast.Index, ast.Slice)): + return ast_item + return ast.Index(value=ast_item) + + # ast.ExtSlice requires a non-empty list of dims, and each dim must be either + # a Slice or an Index. + if x.slice.elts: + ast_slice = ast.ExtSlice(dims=[*map(remap_dim, x.slice.elts)]) + else: + ast_slice = ast.Index(value=ast.Tuple(elts=[], ctx=from_doc(x.ctx))) + result = ast.Subscript(value=from_doc(x.value), slice=ast_slice, ctx=from_doc(x.ctx)) + else: + result = ast.Subscript( + value=from_doc(x.value), + slice=ast.Index(value=from_doc(x.slice)), + ctx=from_doc(x.ctx), + ) + result.lineno = x.lineno + result.col_offset = x.col_offset + result.end_lineno = x.end_lineno + result.end_col_offset = x.end_col_offset + return result + + register_to_doc("Subscript")(subscript_to_doc) + register_from_doc("Subscript")(subscript_from_doc) + + +def _register_index_handling(): + if _py_version() >= (3, 9): + return + + def index_to_doc(x: ast.Index) -> doc.Expr: + return to_doc(x.value) + + def index_from_doc(x: doc.Expr) -> ast.Index: + result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx)) + result.lineno = x.lineno + result.col_offset = x.col_offset + result.end_lineno = x.end_lineno + result.end_col_offset = x.end_col_offset + return result + + register_to_doc("Index")(index_to_doc) + register_from_doc("Index")(index_from_doc) + + _register_default() +_register_constant_handling() +_register_subscription_handling() +_register_index_handling() diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index e7a7f98b7651..a6be751b0de8 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Union import tvm +from tvm.relax import ExternFunc from ....ir.module import IRModule from ...ir_builder import IRBuilder from . import doc @@ -86,12 +87,14 @@ def parse( extra_vars = _default_globals() ann = {} + all_pyfuncs = {} if inspect.isfunction(program): ann = {program.__name__: program.__annotations__} elif inspect.isclass(program): for name, func in program.__dict__.items(): if inspect.isfunction(func): ann[name] = func.__annotations__ + all_pyfuncs[name] = func source = Source(program) parser = Parser(source, ann) @@ -101,6 +104,10 @@ def parse( except ParserError as err: parser.report_error(err.node, err.args[0]) ret = builder.get() + # Attach pyfuncs to the IRModule + if inspect.isclass(program) and isinstance(ret, IRModule): + _attach_pyfuncs_to_irmodule(ret, all_pyfuncs) + # check well-formedness in both Relax and TIR if check_well_formed: check_ret = ret @@ -122,3 +129,65 @@ def parse( err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}", ) return ret + + +def _create_python_packed_func(pyfunc): + """Create a PackedFunc wrapper for a Python function. + + This function creates a PackedFunc that can be called from TVM runtime + and will execute the original Python function. + + Parameters + ---------- + pyfunc : Callable + The Python function to wrap. + + Returns + ------- + PackedFunc + A PackedFunc that wraps the Python function. + """ + + def packed_func_wrapper(*args, **kwargs): + """Wrapper function that calls the original Python function.""" + try: + result = pyfunc(*args, **kwargs) + return result + except Exception as error: + print(f"Error calling Python function {pyfunc.__name__}: {error}") + raise + + return packed_func_wrapper + + +def _attach_pyfuncs_to_irmodule(irmodule, all_pyfuncs): + """Attach Python functions to IRModule with reduced nesting.""" + if not all_pyfuncs: + return + + if not hasattr(irmodule, "pyfuncs"): + irmodule.pyfuncs = {} + + for global_var, func in irmodule.functions_items(): + if not isinstance(func, ExternFunc): + continue + if not func.attrs.get("is_pyfunc", False): + continue + + pyfunc_name = global_var.name_hint + if pyfunc_name not in all_pyfuncs: + continue + + pyfunc = all_pyfuncs[pyfunc_name] + irmodule.pyfuncs[pyfunc_name] = pyfunc + + try: + source_code = inspect.getsource(pyfunc) + func = func.with_attr("python_source", source_code) + except (OSError, TypeError): + func = func.with_attr("python_source", f"# Source unavailable for {pyfunc_name}") + + packed_func = _create_python_packed_func(pyfunc) + func = func.with_attr("python_packed_func", packed_func) + + irmodule[global_var] = func diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 9d09df3d8e5f..f23c69824bde 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -19,6 +19,8 @@ import ast from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +import tvm + from . import dispatch, doc from .error import ParserError @@ -55,6 +57,7 @@ doc.Not: lambda a: not a, doc.UAdd: lambda a: +a, doc.USub: lambda a: -a, + doc.IfExp: tvm.tir.op.if_then_else, } @@ -172,7 +175,7 @@ def _visit(self, node: doc.AST) -> Any: if ( isinstance(node, doc.Call) and hasattr(node.func, "attr") - and node.func.attr not in ["reads", "writes", "match_buffer", "realize"] + and node.func.attr not in ["reads", "writes", "match_buffer", "realize", "copy"] ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)): if isinstance(node, doc.BinOp): args = [node.left, node.right] @@ -180,11 +183,12 @@ def _visit(self, node: doc.AST) -> Any: args = [node.operand] elif isinstance(node, doc.Compare): args = [node.left, *node.comparators] - else: - if isinstance(node, doc.Call): - args = node.args - elif isinstance(node, doc.BoolOp): - args = node.values + elif isinstance(node, doc.IfExp): + args = [node.test, node.body, node.orelse] + elif isinstance(node, doc.Call): + args = node.args + elif isinstance(node, doc.BoolOp): + args = node.values for arg in args: if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)): if isinstance(arg.slice, doc.Slice): @@ -256,6 +260,8 @@ def _visit(self, node: doc.AST) -> Any: value = self._eval_unary_op(fields) elif isinstance(node, doc.BinOp): value = self._eval_bin_op(fields) + elif isinstance(node, doc.IfExp): + value = self._eval_if_exp(fields) elif isinstance(node, doc.Slice): value = self._eval_slice(fields) else: @@ -319,10 +325,18 @@ def _eval_compare(self, fields: Dict[str, Any]) -> Any: res : Any The evaluation result. """ - value = self._eval_expr(fields["left"]) - for op, rhs in zip(fields["ops"], fields["comparators"]): - value = _eval_op(op, values=[value, self._eval_expr(rhs)]) - return value + values = [self._eval_expr(fields["left"])] + values.extend([self._eval_expr(rhs) for rhs in fields["comparators"]]) + result = None + assert len(fields["ops"]) == len(values) - 1 + + for index, op in enumerate(fields["ops"]): + sub_result = _eval_op(op, values=[values[index], values[index + 1]]) + if result is None: + result = sub_result + else: + result = _eval_op(doc.And(), values=[result, sub_result]) + return result def _eval_unary_op(self, fields: Dict[str, Any]) -> Any: """The doc AST unary operation node evaluating method. @@ -364,6 +378,30 @@ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any: ], ) + def _eval_if_exp(self, fields: Dict[str, Any]) -> Any: + """The doc AST if-else expression node evaluating method. + + Parameters + ---------- + fields : Dict[str, Any] + The dictionary of if-else expression information, + e.g., test, body, orelse. + + Returns + ------- + res : Any + The evaluation result. + """ + test = self._eval_expr(fields["test"]) + body = self._eval_expr(fields["body"]) + orelse = self._eval_expr(fields["orelse"]) + if isinstance(test, bool): + return body if test else orelse + elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool": + return tvm.tir.op.if_then_else(test, body, orelse) + else: + raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}") + def _eval_slice(self, fields: Dict[str, Any]) -> slice: """The doc AST slice node evaluating method. @@ -492,14 +530,14 @@ def _eval_expr( def _eval_op( - op: doc.AST, + op_or_type: Union[doc.AST, Type], values: List[Any], ): """Operation expression evaluation implementation for TVMScript parser. Parameters ---------- - op : doc.AST + op_or_type : Union[doc.AST, Type] The root node of AST tree node of operation expression to evaluate. values : List[Any] @@ -510,7 +548,9 @@ def _eval_op( res : Any The evaluation result. """ - op_type = type(op) # pylint: disable=protected-access + op_type = ( + type(op_or_type) if isinstance(op_or_type, doc.AST) else op_or_type + ) # pylint: disable=protected-access for i, v in enumerate(values): v_type = getattr(type(v), "_dispatch_type", None) if v_type is None: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 78da15ca1f27..e81ff0657f8b 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -343,6 +343,8 @@ class Parser(doc.NodeVisitor): function_annotations: Optional[Dict[str, Dict[str, Any]]] var_table: VarTable inside_function: bool # whether we are within a function + current_class: Optional[str] = None # current class being parsed + base_py_module_context: bool = False # whether current class inherits from BasePyModule def __init__( self, @@ -414,6 +416,39 @@ def pop_token(): return _deferred(pop_token) + def set_class_context(self, class_name: str, is_base_py_module: bool = False): + """Set the current class context for parsing. + + Parameters + ---------- + class_name : str + The name of the current class being parsed. + is_base_py_module : bool + Whether the current class inherits from BasePyModule. + """ + self.current_class = class_name + self.base_py_module_context = is_base_py_module + + def _get_current_class_context(self) -> Optional[str]: + """Get the current class context. + + Returns + ------- + Optional[str] + The name of the current class, or None if not in a class context. + """ + return self.current_class + + def _is_base_py_module_context(self) -> bool: + """Check if the current class context allows Python functions. + + Returns + ------- + bool + True if Python functions are allowed in the current context. + """ + return self.base_py_module_context + def with_diag_source(self, source: Source): """Add a new source as with statement. @@ -837,6 +872,36 @@ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name """ return _dispatch(self, "Return")(self, node) + def visit_Continue(self, node: doc.Continue) -> Any: # pylint: disable=invalid-name + """The general continue visiting method. + + Parameters + ---------- + node : doc.Continue + The doc AST continue node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Continue")(self, node) + + def visit_Break(self, node: doc.Break) -> Any: # pylint: disable=invalid-name + """The general break visiting method. + + Parameters + ---------- + node : doc.Break + The doc AST break node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Break")(self, node) + def visit_Nonlocal(self, node: doc.Nonlocal) -> Any: # pylint: disable=invalid-name """The general nonlocal visiting method. diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index 3a8196288df1..3cc015a405d3 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -18,7 +18,7 @@ from tvm.ir import Range from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser -from .entry import ir_module +from .entry import ir_module, pyfunc __all__ = [ @@ -28,5 +28,6 @@ "dummy_global_info", "Range", "lookup_vdevice", + "pyfunc", "vdevice", ] diff --git a/python/tvm/script/parser/ir/entry.py b/python/tvm/script/parser/ir/entry.py index f91c7701a2eb..0e2adeebe3f2 100644 --- a/python/tvm/script/parser/ir/entry.py +++ b/python/tvm/script/parser/ir/entry.py @@ -17,9 +17,12 @@ """The entry point of TVM parser for ir module.""" import inspect -from typing import Optional, Type +from typing import Callable, Optional, Type -from tvm.ir import IRModule +from tvm.ir import IRModule, GlobalVar +from tvm.relax.expr import ExternFunc +from tvm.relax.base_py_module import BasePyModule +from tvm import cpu, ir from .._core import parse, utils @@ -47,7 +50,86 @@ def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRM def decorator_wrapper(mod): if not inspect.isclass(mod): raise TypeError(f"Expect a class, but got: {mod}") + + # Check BasePyModule inheritance + base_py_module_inherited = any(base.__name__ == "BasePyModule" for base in mod.__bases__) + m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed) + + if base_py_module_inherited: + # Collect pyfunc methods + pyfunc_methods = [ + name + for name, attr in mod.__dict__.items() + if hasattr(attr, "dispatch_token") and attr.dispatch_token == "pyfunc" + ] + + mod._pyfunc_methods = pyfunc_methods + + # Create ExternFunc nodes + + for method_name in pyfunc_methods: + try: + existing_gvars = [ + global_var + for global_var in m.get_global_vars() + if global_var.name_hint == method_name + ] + + extern_func = ExternFunc(method_name) + extern_func = extern_func.with_attr("is_pyfunc", True) + extern_func = extern_func.with_attr("function_type", "python") + extern_func = extern_func.with_attr("python_function_name", method_name) + extern_func = extern_func.with_attr( + "python_source", f"# Source for {method_name}" + ) + extern_func = extern_func.with_attr("python_packed_func", None) + + if existing_gvars: + m[existing_gvars[0]] = extern_func + else: + m[GlobalVar(method_name)] = extern_func + + except Exception: # pylint: disable=broad-exception-caught + continue + + class ModuleFactory: + """Factory class for creating BasePyModule instances with Python functions.""" + + def __init__(self, module, pyfunc_methods, original_class): + self.ir_module = module + self.pyfunc_methods = pyfunc_methods + self.original_class = original_class + + def __call__(self, device=None, target=None): + + if device is None: + device = cpu(0) + + instance_ir_mod = ir.IRModule() + for global_var, func in self.ir_module.functions_items(): + instance_ir_mod[global_var] = func + + instance = BasePyModule(instance_ir_mod, device, target) + + for method_name in self.pyfunc_methods: + if hasattr(self.original_class, method_name): + method = getattr(self.original_class, method_name) + instance.add_python_function(method_name, method) + + return instance + + def __getattr__(self, name): + if hasattr(self.ir_module, name): + return getattr(self.ir_module, name) + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + factory = ModuleFactory(m, pyfunc_methods, mod) + setattr(factory, "__name__", mod.__name__) + return factory + setattr(m, "__name__", mod.__name__) return m @@ -61,4 +143,10 @@ def decorator_wrapper(mod): return decorator_wrapper -setattr(ir_module, "dispatch_token", "ir") +def pyfunc(func: Callable): + # Set the dispatch_token on the decorated function + setattr(func, "dispatch_token", "pyfunc") + return func + + +setattr(pyfunc, "dispatch_token", "pyfunc") diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 4ea57130f1e2..80d2db87ab42 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -17,6 +17,9 @@ # pylint: disable=unused-argument """The base parser for ir module""" +from tvm.ir import GlobalVar +from tvm.relax import ExternFunc + from ...ir_builder import ir as I from .._core import Parser, dispatch, doc @@ -49,7 +52,18 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: fake_module = ModuleWithGlobalVars() self.var_table.add(node.name, fake_module) - # Step 1. Visit non-function stmts, including but not limited to + # Step 1: Check if this class inherits from BasePyModule + is_base_py_module = _check_base_py_module_inheritance(node) + if is_base_py_module: + # Store this information in the IRModule for later use + I.module_attrs({"base_py_module": True}) + # Set the parser context to allow Python functions + self.set_class_context(node.name, True) + else: + # Set the parser context to disallow Python functions + self.set_class_context(node.name, False) + + # Step 2. Visit non-function stmts, including but not limited to # 1. `I.module_attrs` # 2. `I.module_global_infos` with self.with_dispatch_token("ir"): @@ -57,13 +71,13 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: if not isinstance(stmt, doc.FunctionDef): self.visit(stmt) - # Step 2. Visit function stmts to declare the global vars + # Step 3. Visit function stmts to declare the global vars for stmt in node.body: if isinstance(stmt, doc.FunctionDef): global_var = self.visit_tvm_declare_function(stmt) fake_module.__setattr__(stmt.name, global_var) - # Step 3. Visit and parse the functions + # Step 4. Visit and parse the functions with self.with_dispatch_token("ir"): for stmt in node.body: if isinstance(stmt, doc.FunctionDef): @@ -125,3 +139,71 @@ def pre_visit_local_function(self: Parser, node: doc.Expr) -> None: @dispatch.register(token="default", type_name="post_visit_local_function") def post_visit_local_function(self: Parser, node: doc.Expr) -> None: pass + + +@dispatch.register(token="pyfunc", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: + """Declare a Python function as an ExternFunc in the IRModule.""" + # Check if Python functions are allowed in this context + # We need to check if we're in a class that inherits from BasePyModule + current_class = self._get_current_class_context() + if current_class and not self._is_base_py_module_context(): + self.report_error( + node, + "@I.pyfunc are only allowed in classes that inherit from BasePyModule. " + f"Class '{current_class}' does not inherit from BasePyModule.", + ) + + # Create ExternFunc with proper attributes for Python functions + func = ExternFunc(node.name) + func = func.with_attr("is_pyfunc", True) + func = func.with_attr("function_type", "python") + func = func.with_attr("python_function_name", node.name) + + # Add placeholder attributes that will be filled in later + func = func.with_attr("python_source", f"# Source will be filled for {node.name}") + func = func.with_attr("python_packed_func", None) # Will be filled in entry.py + + # Store the function name for later retrieval + return I.decl_function(node.name, func) + + +@dispatch.register(token="pyfunc", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + """Visit Python function definition - no need to parse the body.""" + # Python function body is not parsed in TVMScript + + +def _check_base_py_module_inheritance(node: doc.ClassDef) -> bool: + """Check if a class inherits from BasePyModule. + + Parameters + ---------- + node : doc.ClassDef + The class definition node to check. + + Returns + ------- + bool + True if the class inherits from BasePyModule, False otherwise. + """ + if not node.bases: + return False + + # Check each base class + for base in node.bases: + if hasattr(base, "id"): + if base.id == "BasePyModule": + return True + elif hasattr(base, "attr"): + if base.attr == "BasePyModule": + return True + elif hasattr(base, "value") and hasattr(base.value, "id"): + if ( + base.value.id in ["BasePyModule", "tvm", "relax"] + and hasattr(base, "attr") + and base.attr == "BasePyModule" + ): + return True + + return False diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 04a5f985643e..ec140e57ba60 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -35,7 +35,7 @@ TupleStructInfo, ) from tvm.relax.expr import Var -from tvm.runtime import ObjectGeneric +from tvm.runtime import ObjectConvertible from tvm.tir import PrimExpr from .._core import doc, parse, utils @@ -147,7 +147,7 @@ def wrapper(*args, **kwargs): ############################# Struct Info ############################## -class StructInfoProxy(ObjectGeneric): +class StructInfoProxy(ObjectConvertible): def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> StructInfo: raise NotImplementedError() diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index c7d5dc756b32..bcac49733d00 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -21,7 +21,7 @@ from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc -from ...ir_builder.tir import buffer, ptr +from ...ir_builder.tir import block_name_suffix_context, buffer, ptr from .._core import parse, scan_macro, utils from ..core.parser import Parser, ScriptMacro @@ -90,11 +90,25 @@ def decorator_wrapper(func): class TIRMacro(ScriptMacro): - """Specialization of the ScriptMacro class for TIR.""" + """Specialization of the ScriptMacro class for TIR. + + Attributes + ---------- + call_count : int + Counter for the number of times this macro has been invoked. + Used to generate unique block name suffixes. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.call_count = 0 def parse_macro(self, parser: Parser) -> None: macro_def = self.get_macro_def() - parser.visit_body(macro_def.body) + suffix = f"_{self.call_count}" if self.call_count > 0 else "" + self.call_count += 1 + with block_name_suffix_context(suffix): + parser.visit_body(macro_def.body) def macro(*args, hygienic: bool = True) -> Callable: diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index 22f996a4561c..b22b0a7335db 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -61,6 +61,7 @@ def _auto_broadcast(a, b, op): if ( DataType(b.dtype).type_code == DataTypeCode.INT or DataType(b.dtype).type_code == DataTypeCode.UINT + or DataType(b.dtype).type_code == DataTypeCode.BOOL ): a = IntImm(_get_type_str(b.dtype), a) elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: @@ -80,6 +81,7 @@ def _auto_broadcast(a, b, op): if ( DataType(a.dtype).type_code == DataTypeCode.INT or DataType(a.dtype).type_code == DataTypeCode.UINT + or DataType(a.dtype).type_code == DataTypeCode.BOOL ): b = IntImm(_get_type_str(a.dtype), b) elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index f6141404fa40..92244d5a0472 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -18,11 +18,11 @@ import contextlib from functools import partial -from typing import Any +from typing import Any, Dict, Optional import tvm from tvm.ir import GlobalVar, PrimType -from tvm.tir import Buffer, IterVar, PrimExpr, Var +from tvm.tir import Buffer, BufferLoad, IterVar, PrimExpr, Var from ...ir_builder import ir as I from ...ir_builder import tir as T @@ -138,6 +138,9 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - res = value.__enter__() IRBuilder.name(var_name, res) return res + elif isinstance(value, Buffer) and value.scope() == "local.var": + IRBuilder.name(var_name, value) + return BufferLoad(value, indices=[0]) elif isinstance(value, (Buffer, IterVar)) or ( isinstance(value, Var) and not self.var_table.exist(value) ): @@ -168,6 +171,28 @@ def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: b return default +def range_sugar( + start: PrimExpr, + stop: PrimExpr = None, + step: Optional[PrimExpr] = None, + *, + annotations: Dict[str, Any] = None, +) -> T.frame.ForFrame: + """The sugar for python range builtin.""" + + # Since `tir.For` do not support reversed iteration semantic, + # the step must be checked to be positive integer when use range sugar + if step is not None: + try: + step = int(step) + if step <= 0: + raise ValueError(f"Only support positive step in range(), get {step}") + except TypeError: # pylint: disable=broad-except + raise ValueError(f"Only support literal step in range(), get {step}") + + return T.serial(start, stop, annotations=annotations, step=step) + + @dispatch.register(token="tir", type_name="For") def visit_for(self: Parser, node: doc.For) -> None: """The for visiting method for tir. @@ -255,8 +280,21 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: else: indices = self.eval_expr(lhs.slice) T.buffer_store(self.eval_expr(lhs.value), rhs, indices) - else: - self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + return + + # Handle local.var buffer store + if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): + lhs_value = self.eval_expr(lhs) + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): + T.buffer_store(lhs_value.buffer, rhs, indices=[0]) + return + + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) @dispatch.register(token="tir", type_name="AugAssign") @@ -353,7 +391,8 @@ def visit_with(self: Parser, node: doc.With) -> None: frame = self.eval_expr(item.context_expr) if not isinstance(frame, Frame): self.report_error( - item.context_expr, "Invalid context expression in the with-statement." + item.context_expr, + "Invalid context expression in the with-statement.", ) rhs = stack.enter_context(frame) if item.optional_vars is not None: @@ -378,7 +417,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: privacy = find_decorator_annotation(node, "private", default=False) self.function_annotations = None with self.var_table.with_frame(): - self.var_table.add("range", T.serial) + + self.var_table.add("range", range_sugar) with T.prim_func(is_private=privacy): T.func_name(node.name) if node.returns is not None: @@ -498,7 +538,8 @@ def visit_if(self: Parser, node: doc.If) -> None: self.visit_body(node.orelse) else: self.report_error( - node.test, f"If condition must be a boolean expression, but got {predicate}" + node.test, + f"If condition must be a boolean expression, but got {predicate}", ) @@ -539,6 +580,36 @@ def visit_return(self: Parser, node: doc.Return) -> None: T.evaluate(tvm.tir.ret(value)) +@dispatch.register(token="tir", type_name="Continue") +def visit_continue(self: Parser, node: doc.Continue) -> None: # pylint:disable=unused-argument + """The continue visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Continue + The doc AST continue node. + """ + T.evaluate(tvm.tir.continue_loop()) + + +@dispatch.register(token="tir", type_name="Break") +def visit_break(self: Parser, node: doc.Break) -> None: # pylint:disable=unused-argument + """The continue visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Break + The doc AST break node. + """ + T.evaluate(tvm.tir.break_loop()) + + @dispatch.register(token="tir", type_name="tvm_declare_function") def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: """The function declaration step for tir diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py index 9cbf6cfdca22..967d0d824ba2 100644 --- a/python/tvm/script/printer/_ffi_api.py +++ b/python/tvm/script/printer/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.printer""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("script.printer", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.printer", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 382128ef33d7..62d8c563dd3f 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -19,8 +19,8 @@ from enum import IntEnum, unique from typing import Dict, List, Optional, Sequence, Tuple, Union -from tvm.ffi import register_object -from tvm.ffi.access_path import AccessPath +from tvm_ffi import register_object +from tvm_ffi.access_path import AccessPath from tvm.runtime import Object from tvm.tir import FloatImm, IntImm diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py index 5f1f9800848b..0cfc436b6a6d 100644 --- a/python/tvm/script/printer/doc_printer.py +++ b/python/tvm/script/printer/doc_printer.py @@ -18,7 +18,7 @@ from typing import List, Optional -from tvm.ffi.access_path import AccessPath +from tvm_ffi.access_path import AccessPath from tvm.runtime.script_printer import PrinterConfig from . import _ffi_api diff --git a/python/tvm/support.py b/python/tvm/support.py index 7e0ad5875f83..d0b1540c0417 100644 --- a/python/tvm/support.py +++ b/python/tvm/support.py @@ -22,11 +22,11 @@ import sys import tvm -import tvm.ffi +import tvm_ffi from .runtime.module import Module from . import get_global_func -tvm.ffi._init_api("support", __name__) +tvm_ffi.init_ffi_api("support", __name__) def libinfo(): diff --git a/python/tvm/target/_ffi_api.py b/python/tvm/target/_ffi_api.py index 489b59b4c6ae..8b9f6c73bd4e 100644 --- a/python/tvm/target/_ffi_api.py +++ b/python/tvm/target/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.target""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("target", __name__) +tvm_ffi.init_ffi_api("target", __name__) diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index c9be5531c732..e597c8d147be 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -17,6 +17,9 @@ """Bring Your Own Datatypes custom datatype framework TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist""" +from tvm_ffi import get_global_func +from tvm_ffi import register_global_func as _register_global_func + import tvm from tvm.runtime import convert, DataType from tvm.tir.expr import ( @@ -26,8 +29,6 @@ BinaryOpExpr as _BinaryOpExpr, ) from tvm.tir.op import call_pure_extern -from tvm.ffi import get_global_func -from tvm.ffi import register_func as _register_func from tvm.tir import call_intrin @@ -215,7 +216,7 @@ class name (e.g. Add, LE, Cast, Call). ) else: lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." + src_type_name - tvm.ffi.register_func(lower_func_name, lower_func) + tvm_ffi.register_global_func(lower_func_name, lower_func) def register_min_func(func, type_name): @@ -244,7 +245,7 @@ def register_min_func(func, type_name): type_name : str The name of the custom datatype, e.g. posites2 (but not custom[posites2]32). """ - _register_func("tvm.datatype.min." + type_name, func) + _register_global_func("tvm.datatype.min." + type_name, func) def create_min_lower_func(extern_func_map, type_name): diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index ec1875eb90a1..5c61de62e4e1 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -17,9 +17,8 @@ """Detect target.""" from typing import Union -from ..ffi import get_global_func -from ..runtime import Device -from ..runtime.ndarray import device +from tvm_ffi import get_global_func +from ..runtime import Device, device from . import Target @@ -124,7 +123,7 @@ def detect_target_from_device(dev: Union[str, Device]) -> Target: """ if isinstance(dev, str): dev = device(dev) - device_type = Device.DEVICE_TYPE_TO_NAME[dev.device_type] + device_type = Device._DEVICE_TYPE_TO_NAME[dev.dlpack_device_type()] if device_type not in SUPPORT_DEVICE: raise ValueError( f"Auto detection for device `{device_type}` is not supported. " diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 6c83ef6e5bb2..eb6e25f0450c 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -20,8 +20,8 @@ import warnings from typing import Union -import tvm.ffi -from tvm.ffi import register_func as _register_func +import tvm_ffi +from tvm_ffi import register_global_func as _register_global_func from tvm.runtime import Device from tvm.runtime import Object, convert from tvm.runtime.container import String @@ -30,7 +30,7 @@ from . import _ffi_api -@tvm.ffi.register_object("target.TargetKind") +@tvm_ffi.register_object("target.TargetKind") class TargetKind(Object): """Kind of a compilation target""" @@ -53,7 +53,7 @@ def __getattr__(self, name: str): return _ffi_api.TargetGetFeature(self.target, name) -@tvm.ffi.register_object("target.Target") +@tvm_ffi.register_object("target.Target") class Target(Object): """Target device information, use through TVM API. @@ -637,6 +637,14 @@ def riscv_cpu(model="sifive-u54", options=None): "-mabi=lp64d", # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74 ], + "licheepi3a": [ + "-num-cores=8", + "-mtriple=riscv64-unknown-linux-gnu", + "-mcpu=spacemit-x60", + "-mfloat-abi=hard", + "-mabi=lp64d", + # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gcv -mabi=lp64d -mcpu=spacemit-x60 + ], } pre_defined_opt = trans_table.get(model, ["-model=%s" % model]) @@ -853,7 +861,7 @@ def create(target): return Target(target) -@_register_func("target._load_config_dict") +@_register_global_func("target._load_config_dict") def _load_config_dict(config_dict_str): try: config = json.loads(config_dict_str) diff --git a/python/tvm/target/virtual_device.py b/python/tvm/target/virtual_device.py index b062feb27aeb..e509c5670750 100644 --- a/python/tvm/target/virtual_device.py +++ b/python/tvm/target/virtual_device.py @@ -16,14 +16,13 @@ # under the License. """Python bindings for creating VirtualDevices.""" -import tvm -from tvm.runtime import Object +import tvm_ffi from . import _ffi_api -@tvm.ffi.register_object("target.VirtualDevice") -class VirtualDevice(Object): +@tvm_ffi.register_object("target.VirtualDevice") +class VirtualDevice(tvm_ffi.core.Object): """A compile time representation for where data is to be stored at runtime, and how to compile code to compute it.""" @@ -35,6 +34,5 @@ def __init__(self, device=None, target=None, memory_scope="") -> None: _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope, device, target, memory_scope ) - @property - def device_type(self) -> int: + def dlpack_device_type(self) -> int: return self.device_type_int diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index 177021f1433f..e00dbb437440 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. """Common x86 related utilities""" -from ..ffi import register_func +from tvm_ffi import register_global_func from .codegen import target_has_features -@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") +@register_global_func("tvm.topi.x86.utils.get_simd_32bit_lanes") def get_simd_32bit_lanes(): """X86 SIMD optimal vector length lookup. Parameters diff --git a/python/tvm/te/_ffi_api.py b/python/tvm/te/_ffi_api.py index 98e466e9e88c..172fff01d7ff 100644 --- a/python/tvm/te/_ffi_api.py +++ b/python/tvm/te/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.te""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("te", __name__) +tvm_ffi.init_ffi_api("te", __name__) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 4a5d2425e669..91d3e2b81cc9 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -21,7 +21,6 @@ from numbers import Integral as _Integral from typing import List, Optional, Union -import tvm.ffi import tvm.arith._ffi_api import tvm.tir import tvm.tir._ffi_api @@ -453,7 +452,7 @@ def const(value, dtype="int32", span=None): Parameters ---------- - value : Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] + value : Union[bool, int, float, numpy.ndarray, tvm.runtime.Tensor] The constant value. dtype : str diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 73b995a45e61..4ef1b67969c8 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -16,15 +16,15 @@ # under the License. """Tensor class for computation declaration.""" # pylint: disable=invalid-name -import tvm.ffi +import tvm_ffi -from tvm.runtime import Object, ObjectGeneric +from tvm.runtime import Object, ObjectConvertible from tvm.tir import expr as _expr, DataProducer from . import _ffi_api -class TensorSlice(ObjectGeneric, _expr.ExprOp): +class TensorSlice(ObjectConvertible, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" def __init__(self, tensor, indices): @@ -48,7 +48,7 @@ def dtype(self): return self.tensor.dtype -@tvm.ffi.register_object("te.Tensor") +@tvm_ffi.register_object("te.Tensor") class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" @@ -92,7 +92,7 @@ def name(self): return f"{op.name}.v{self.value_index}" -@tvm.ffi.register_object("te.Operation") +@tvm_ffi.register_object("te.Operation") class Operation(Object): """Represent an operation that generates a tensor""" @@ -122,41 +122,26 @@ def input_tensors(self): return _ffi_api.OpInputTensors(self) -@tvm.ffi.register_object("te.PlaceholderOp") +@tvm_ffi.register_object("te.PlaceholderOp") class PlaceholderOp(Operation): """Placeholder operation.""" -@tvm.ffi.register_object("te.BaseComputeOp") +@tvm_ffi.register_object("te.BaseComputeOp") class BaseComputeOp(Operation): """Compute operation.""" - @property - def axis(self): - """Represent the IterVar axis, defined when it is a ComputeOp""" - return self.__getattr__("axis") - - @property - def reduce_axis(self): - """Represent axis of reductions, only defined when it is a ComputeOp""" - return self.__getattr__("reduce_axis") - -@tvm.ffi.register_object("te.ComputeOp") +@tvm_ffi.register_object("te.ComputeOp") class ComputeOp(BaseComputeOp): """Scalar operation.""" -@tvm.ffi.register_object("te.ScanOp") +@tvm_ffi.register_object("te.ScanOp") class ScanOp(Operation): """Scan operation.""" - @property - def scan_axis(self): - """Represent the scan axis, only defined when it is a ScanOp""" - return self.__getattr__("scan_axis") - -@tvm.ffi.register_object("te.ExternOp") +@tvm_ffi.register_object("te.ExternOp") class ExternOp(Operation): """External operation.""" diff --git a/python/tvm/testing/_ffi_api.py b/python/tvm/testing/_ffi_api.py index e3c30d1299a1..b7a0b59fd0e4 100644 --- a/python/tvm/testing/_ffi_api.py +++ b/python/tvm/testing/_ffi_api.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.testing""" -import tvm.ffi +import tvm_ffi +# must import testing before init_ffi_api +import tvm_ffi.testing -tvm.ffi._init_api("testing", __name__) + +tvm_ffi.init_ffi_api("testing", __name__) diff --git a/python/tvm/testing/attrs.py b/python/tvm/testing/attrs.py index ea6f1b1af65c..4e946ce6d4b9 100644 --- a/python/tvm/testing/attrs.py +++ b/python/tvm/testing/attrs.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=invalid-name, import-outside-toplevel, unused-variable """Testing utilities for attrs""" +from tvm_ffi import register_object from ..ir import Attrs -from ..ffi import register_object @register_object("attrs.TestAttrs") diff --git a/python/tvm/testing/popen_pool.py b/python/tvm/testing/popen_pool.py index 0fc3ce219030..8ff260a62f9c 100644 --- a/python/tvm/testing/popen_pool.py +++ b/python/tvm/testing/popen_pool.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, missing-function-docstring """Common functions for popen_pool test cases""" -import tvm +import tvm_ffi from . import _ffi_api TEST_GLOBAL_STATE_1 = 0 @@ -36,19 +36,19 @@ def after_initializer(): return TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3 -@tvm.ffi.register_func("testing.identity_py") +@tvm_ffi.register_global_func("testing.identity_py") def identity_py(arg): return arg def register_ffi(): - @tvm.ffi.register_func("testing.nested_identity_py") + @tvm_ffi.register_global_func("testing.nested_identity_py") def _identity_py(arg): # pylint: disable=unused-variable return arg def call_py_ffi(arg): - _identity_py = tvm.ffi.get_global_func("testing.nested_identity_py") + _identity_py = tvm_ffi.get_global_func("testing.nested_identity_py") return _identity_py(arg) diff --git a/python/tvm/testing/runner.py b/python/tvm/testing/runner.py index a4615f7a465f..be50cc8707c5 100644 --- a/python/tvm/testing/runner.py +++ b/python/tvm/testing/runner.py @@ -24,7 +24,7 @@ import numpy as np from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig - from tvm.runtime import Device, Module, NDArray + from tvm.runtime import Device, Module, Tensor # pylint: disable=import-outside-toplevel,protected-access @@ -32,11 +32,11 @@ def _args_to_device(args, device): import numpy as np - from tvm.runtime.ndarray import NDArray, empty + from tvm.runtime import Tensor, empty uploaded_args = [] for arg in args: - if isinstance(arg, (np.ndarray, NDArray)): + if isinstance(arg, (np.ndarray, Tensor)): uploaded_args.append(empty(arg.shape, dtype=arg.dtype, device=device).copyfrom(arg)) elif isinstance(arg, (int, float)): uploaded_args.append(arg) @@ -46,11 +46,11 @@ def _args_to_device(args, device): def _args_to_numpy(args): - from tvm.runtime.ndarray import NDArray + from tvm.runtime import Tensor downloaded_args = [] for arg in args: - if isinstance(arg, NDArray): + if isinstance(arg, Tensor): downloaded_args.append(arg.numpy()) else: downloaded_args.append(arg) @@ -80,7 +80,7 @@ def export_with(func): def local_run( # pylint: disable=too-many-arguments,too-many-locals mod: "Module", device_type: str, - args: List[Union["np.ndarray", "NDArray", int, float]], + args: List[Union["np.ndarray", "Tensor", int, float]], evaluator_config: Optional["EvaluatorConfig"] = None, export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]] = "tar", output_format: Optional[str] = None, @@ -93,7 +93,7 @@ def local_run( # pylint: disable=too-many-arguments,too-many-locals The TVM module to run. device_type : str The device type to run the module on. - args : List[Union[np.ndarray, NDArray, int, float]] + args : List[Union[np.ndarray, Tensor, int, float]] The arguments to be fed to the module. evaluator_config : Optional[EvaluatorConfig] The evaluator configuration to use. @@ -109,7 +109,7 @@ def local_run( # pylint: disable=too-many-arguments,too-many-locals Returns ------- - args : List[Union[np.ndarray, NDArray, int, float]] + args : List[Union[np.ndarray, Tensor, int, float]] The results of running the module. profile_result : tvm.runtime.BenchmarkResult The profiling result of running the module. @@ -152,7 +152,7 @@ def local_run( # pylint: disable=too-many-arguments,too-many-locals def rpc_run( # pylint: disable=too-many-arguments,too-many-locals mod: "Module", device_type: str, - args: List[Union["np.ndarray", "NDArray", int, float]], + args: List[Union["np.ndarray", "Tensor", int, float]], evaluator_config: Optional["EvaluatorConfig"] = None, rpc_config: Optional["RPCConfig"] = None, export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]] = "tar", @@ -166,7 +166,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals The TVM module to run. device_type : str The device type to run the module on. - args : List[Union[np.ndarray, NDArray, int, float]] + args : List[Union[np.ndarray, Tensor, int, float]] The arguments to be fed to the module. evaluator_config : Optional[EvaluatorConfig] The evaluator configuration to use. @@ -189,7 +189,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals Returns ------- - args : List[Union[np.ndarray, NDArray, int, float]] + args : List[Union[np.ndarray, Tensor, int, float]] The results of running the module. profile_result : tvm.runtime.BenchmarkResult The profiling result of running the module. diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 6b047de4460a..828ffe7750f4 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -87,7 +87,6 @@ def test_something(): import tvm.arith import tvm.tir import tvm.te -import tvm.ffi from tvm.target import codegen from tvm.contrib import nvcc, cudnn, rocm @@ -105,7 +104,7 @@ def test_something(): ) -def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): +def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7, verbose=True): """Version of np.testing.assert_allclose with `atol` and `rtol` fields set in reasonable defaults. @@ -116,7 +115,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): actual = np.asanyarray(actual) desired = np.asanyarray(desired) np.testing.assert_allclose(actual.shape, desired.shape) - np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True) + np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=verbose) def check_numerical_grads( @@ -325,7 +324,7 @@ def _compute_body(*us): return tvm.tir.stmt_functor.substitute(expr, vmap) A = tvm.te.compute([r.extent.value for v, r in vranges.items()], _compute_body) - args = [tvm.nd.empty(A.shape, A.dtype)] + args = [tvm.runtime.empty(A.shape, A.dtype)] mod = tvm.compile(tvm.IRModule.from_expr(tvm.te.create_prim_func([A]))) mod(*args) return args[0].numpy() diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 120d652dd817..0a598e5e9bb9 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -50,7 +50,13 @@ from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array from .op import tvm_tuple, handle_add_byte_offset, tvm_struct_get, tvm_struct_set from .op import address_of, lookup_param, assume, undef -from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error +from .op import continue_loop, break_loop +from .op import ( + tvm_thread_allreduce, + type_annotation, + tvm_access_ptr, + tvm_throw_last_error, +) from .op import ( tvm_load_matrix_sync, tvm_store_matrix_sync, @@ -86,7 +92,18 @@ from .op import tan, tanh, atan, atan2, atanh from .op import bitwise_and, bitwise_not, bitwise_or, bitwise_xor from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else +from .op import ( + trunc, + abs, + round, + nextafter, + nearbyint, + power, + pow, + popcount, + fmod, + if_then_else, +) from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv, logaddexp from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tir/_ffi_api.py index 8c438557c8c1..4140cda741dd 100644 --- a/python/tvm/tir/_ffi_api.py +++ b/python/tvm/tir/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("tir", __name__) +tvm_ffi.init_ffi_api("tir", __name__) diff --git a/python/tvm/tir/analysis/_ffi_api.py b/python/tvm/tir/analysis/_ffi_api.py index 40a7b4caf340..9e5d094c1a82 100644 --- a/python/tvm/tir/analysis/_ffi_api.py +++ b/python/tvm/tir/analysis/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.analysis""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("tir.analysis", __name__) +tvm_ffi.init_ffi_api("tir.analysis", __name__) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 915b7f765c10..8a84d3ee51fa 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -301,6 +301,10 @@ def find_anchor_block(mod: IRModule) -> Block: return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member +def has_if_then_else(stmt: Stmt) -> bool: + return tvm.ffi.get_global_func("tir.schedule.HasIfThenElse")(stmt) + + def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: """Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size diff --git a/python/tvm/tir/block_dependence_info.py b/python/tvm/tir/block_dependence_info.py index 67a644967e4b..7bd6b418fc72 100644 --- a/python/tvm/tir/block_dependence_info.py +++ b/python/tvm/tir/block_dependence_info.py @@ -18,7 +18,7 @@ to store the block level dependences""" from typing import Union, Optional -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.ir.module import IRModule from tvm.runtime import Object from tvm.tir import Block, PrimFunc diff --git a/python/tvm/tir/block_scope.py b/python/tvm/tir/block_scope.py index b24cca0707a0..d63771fae93e 100644 --- a/python/tvm/tir/block_scope.py +++ b/python/tvm/tir/block_scope.py @@ -18,7 +18,7 @@ from enum import IntEnum from typing import List, Optional, Union -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.runtime import Object from tvm.tir import Block, For diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 1f40520e55be..f333c14986f2 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -17,14 +17,15 @@ """Abstraction for array data structures.""" from numbers import Integral -import tvm.ffi +import tvm_ffi +import tvm from tvm.ir import PointerType, PrimExpr, PrimType, Range from tvm.runtime import Object, Scriptable, convert from . import _ffi_api -@tvm.ffi.register_object("tir.Buffer") +@tvm_ffi.register_object("tir.Buffer") class Buffer(Object, Scriptable): """Symbolic data buffer in TVM. @@ -194,6 +195,8 @@ def __getitem__(self, indices): indices = [indices] has_slice = any(isinstance(i, slice) for i in indices) has_step = any(isinstance(i, slice) and i.step is not None for i in indices) + if has_step: + raise RuntimeError("Buffer slicing with step is not supported.") analyzer = Analyzer() if has_slice and not has_step: region = [] @@ -349,6 +352,6 @@ def decl_buffer( ) -@tvm.ffi.register_object("tir.DataProducer") +@tvm_ffi.register_object("tir.DataProducer") class DataProducer(Object): pass diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index 98e549cc9c32..5df2663fc20b 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -22,7 +22,6 @@ import tvm from tvm import ir from tvm.ir.module import IRModule -from tvm.runtime import ndarray from tvm.target import Target from tvm.tir import PrimFunc @@ -206,7 +205,9 @@ def build( if target is not None: if target.host is not None: target_host = target.host - elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: + elif ( + tvm.device(target.kind.name, 0).dlpack_device_type() == tvm.cpu(0).dlpack_device_type() + ): target_host = target target_host = Target.canon_target(target_host) target_to_bind = target_to_bind.with_host(target_host) @@ -238,4 +239,4 @@ def build( return tir_to_runtime(host_mod, device_mod_dict, target_host) -tvm.register_func("tir.build", build) +tvm.register_global_func("tir.build", build) diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index 39874640ff40..f9c0e0cdc7ce 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -17,13 +17,13 @@ """Data layout.""" from typing import Union -import tvm.ffi +import tvm_ffi from tvm.runtime import Object from . import _ffi_api -@tvm.ffi.register_object("tir.Layout") +@tvm_ffi.register_object("tir.Layout") class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and @@ -81,7 +81,7 @@ def factor_of(self, axis): return _ffi_api.LayoutFactorOf(self, axis) # type: ignore -@tvm.ffi.register_object("tir.BijectiveLayout") +@tvm_ffi.register_object("tir.BijectiveLayout") class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 2e07cef9a3d3..ecfd90acc13b 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,14 +27,14 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union -import tvm.ffi +import tvm_ffi import tvm.ir._ffi_api from tvm import ir from tvm.ir import Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import Object, ObjectGeneric, Scriptable, DataType, DataTypeCode, const +from tvm.runtime import Object, ObjectConvertible, Scriptable, DataType, DataTypeCode, const from . import _ffi_api from . import generic as _generic @@ -227,7 +227,7 @@ def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: return _generic.cast(self, dtype, span) -class EqualOp(ObjectGeneric, ExprOp): +class EqualOp(ObjectConvertible, ExprOp): """Deferred equal operator. This is used to support sugar that a == b can either @@ -264,7 +264,7 @@ def asobject(self) -> PrimExpr: return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore -class NotEqualOp(ObjectGeneric, ExprOp): +class NotEqualOp(ObjectConvertible, ExprOp): """Deferred NE operator. This is used to support sugar that a != b can either @@ -301,7 +301,7 @@ def asobject(self) -> PrimExpr: return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore -class IntImmEnum(ObjectGeneric): +class IntImmEnum(ObjectConvertible): """Lazily evaluate an IntImm in case the constructor is not available in runtime. @@ -349,7 +349,7 @@ class LogicalExpr(PrimExprWithOp): pass -@tvm.ffi.register_object("tir.Var") +@tvm_ffi.register_object("tir.Var") class Var(PrimExprWithOp): """Symbolic variable. @@ -372,7 +372,7 @@ def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) # type: ignore -@tvm.ffi.register_object("tir.SizeVar") +@tvm_ffi.register_object("tir.SizeVar") class SizeVar(Var): """Symbolic variable to represent a tensor index size which is greater or equal to zero. @@ -394,7 +394,7 @@ def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore -@tvm.ffi.register_object("tir.IterVar") +@tvm_ffi.register_object("tir.IterVar") class IterVar(ExprOp, Object, Scriptable): """Represent iteration variable. @@ -467,7 +467,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.CommReducer") +@tvm_ffi.register_object("tir.CommReducer") class CommReducer(Object, Scriptable): """Commutative reduce operator @@ -507,7 +507,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.Reduce") +@tvm_ffi.register_object("tir.Reduce") class Reduce(PrimExprWithOp): """Reduce node. @@ -558,7 +558,7 @@ def __init__( ) -@tvm.ffi.register_object("ir.FloatImm") +@tvm_ffi.register_object("ir.FloatImm") class FloatImm(ConstExpr): """Float constant. @@ -585,7 +585,7 @@ def __float__(self) -> float: return self.value -@tvm.ffi.register_object("ir.IntImm") +@tvm_ffi.register_object("ir.IntImm") class IntImm(ConstExpr): """Int constant. @@ -627,7 +627,7 @@ def __bool__(self) -> bool: return self.__nonzero__() -@tvm.ffi.register_object("tir.StringImm") # type: ignore +@tvm_ffi.register_object("tir.StringImm") # type: ignore class StringImm(ConstExpr): """String constant. @@ -659,7 +659,7 @@ def __hash__(self) -> int: return PrimExpr.__hash__(self) -@tvm.ffi.register_object("tir.Cast") +@tvm_ffi.register_object("tir.Cast") class Cast(PrimExprWithOp): """Cast expression. @@ -681,7 +681,7 @@ def __init__(self, dtype, value, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore -@tvm.ffi.register_object("tir.Add") +@tvm_ffi.register_object("tir.Add") class Add(BinaryOpExpr): """Add node. @@ -701,7 +701,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Sub") +@tvm_ffi.register_object("tir.Sub") class Sub(BinaryOpExpr): """Sub node. @@ -721,7 +721,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Mul") +@tvm_ffi.register_object("tir.Mul") class Mul(BinaryOpExpr): """Mul node. @@ -741,7 +741,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Div") +@tvm_ffi.register_object("tir.Div") class Div(BinaryOpExpr): """Div node. @@ -761,7 +761,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Mod") +@tvm_ffi.register_object("tir.Mod") class Mod(BinaryOpExpr): """Mod node. @@ -781,7 +781,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.FloorDiv") +@tvm_ffi.register_object("tir.FloorDiv") class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -801,7 +801,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.FloorMod") +@tvm_ffi.register_object("tir.FloorMod") class FloorMod(BinaryOpExpr): """FloorMod node. @@ -821,7 +821,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Min") +@tvm_ffi.register_object("tir.Min") class Min(BinaryOpExpr): """Min node. @@ -841,7 +841,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Max") +@tvm_ffi.register_object("tir.Max") class Max(BinaryOpExpr): """Max node. @@ -861,7 +861,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.EQ") +@tvm_ffi.register_object("tir.EQ") class EQ(CmpExpr): """EQ node. @@ -881,7 +881,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.NE") +@tvm_ffi.register_object("tir.NE") class NE(CmpExpr): """NE node. @@ -901,7 +901,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.LT") +@tvm_ffi.register_object("tir.LT") class LT(CmpExpr): """LT node. @@ -921,7 +921,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.LE") +@tvm_ffi.register_object("tir.LE") class LE(CmpExpr): """LE node. @@ -941,7 +941,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.GT") +@tvm_ffi.register_object("tir.GT") class GT(CmpExpr): """GT node. @@ -961,7 +961,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.GE") +@tvm_ffi.register_object("tir.GE") class GE(CmpExpr): """GE node. @@ -981,7 +981,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.And") +@tvm_ffi.register_object("tir.And") class And(LogicalExpr): """And node. @@ -1001,7 +1001,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Or") +@tvm_ffi.register_object("tir.Or") class Or(LogicalExpr): """Or node. @@ -1024,7 +1024,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type: ignore -@tvm.ffi.register_object("tir.Not") +@tvm_ffi.register_object("tir.Not") class Not(LogicalExpr): """Not node. @@ -1043,7 +1043,7 @@ def __init__(self, a: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type: ignore -@tvm.ffi.register_object("tir.Select") +@tvm_ffi.register_object("tir.Select") class Select(PrimExprWithOp): """Select node. @@ -1087,7 +1087,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.BufferLoad") +@tvm_ffi.register_object("tir.BufferLoad") class BufferLoad(PrimExprWithOp): """Buffer load node. @@ -1122,7 +1122,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.ProducerLoad") +@tvm_ffi.register_object("tir.ProducerLoad") class ProducerLoad(PrimExprWithOp): """Producer load node. @@ -1149,7 +1149,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.Ramp") +@tvm_ffi.register_object("tir.Ramp") class Ramp(PrimExprWithOp): """Ramp node. @@ -1180,7 +1180,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.Broadcast") +@tvm_ffi.register_object("tir.Broadcast") class Broadcast(PrimExprWithOp): """Broadcast node. @@ -1203,7 +1203,7 @@ def __init__(self, value: PrimExpr, lanes: PrimExpr, span: Optional[Span] = None self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) # type: ignore -@tvm.ffi.register_object("tir.Shuffle") +@tvm_ffi.register_object("tir.Shuffle") class Shuffle(PrimExprWithOp): """Shuffle node. @@ -1241,7 +1241,7 @@ class CallEffectKind: Opaque = UpdateState -@tvm.ffi.register_object("tir.Call") +@tvm_ffi.register_object("tir.Call") class Call(PrimExprWithOp): """Call node. @@ -1257,6 +1257,9 @@ class Call(PrimExprWithOp): args : list of Expr The input arguments to the call + annotations : Optional[Dict[str, Object]] + Additional annotations about the call. + span : Optional[Span] The location of this expression in the source code. """ @@ -1265,7 +1268,12 @@ class Call(PrimExprWithOp): args: List[PrimExpr] def __init__( - self, dtype: str, op: Union[Op, str], args: List[PrimExpr], span: Optional[Span] = None + self, + dtype: str, + op: Union[Op, str], + args: List[PrimExpr], + annotations: Optional[Dict] = None, + span: Optional[Span] = None, ) -> None: if isinstance(op, str): if not op.startswith("tir."): @@ -1278,10 +1286,10 @@ def __init__( % op ) op = Op.get(op) - self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, annotations, span) # type: ignore -@tvm.ffi.register_object("tir.Let") +@tvm_ffi.register_object("tir.Let") class Let(PrimExprWithOp): """Let node. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index b85fb3952249..5b365e124cfc 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -21,19 +21,20 @@ import inspect from typing import Callable, List, Mapping, Optional, Tuple, Union +import tvm_ffi + import tvm -import tvm.ffi import tvm.runtime from tvm.ir import BaseFunc, Range from tvm.runtime import Object, Scriptable -from ..runtime.ndarray import NDArray +from ..runtime._tensor import Tensor from . import _ffi_api from .buffer import Buffer from .expr import PrimExpr, Var -@tvm.ffi.register_object("tir.PrimFunc") +@tvm_ffi.register_object("tir.PrimFunc") class PrimFunc(BaseFunc, Scriptable): """A function declaration expression. @@ -174,7 +175,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: return _ffi_api.Specialize(self, param_map) # type: ignore -@tvm.ffi.register_object("tir.TensorIntrin") +@tvm_ffi.register_object("tir.TensorIntrin") class TensorIntrin(Object): """A tensor intrinsic. @@ -230,7 +231,7 @@ def get(name: str, allow_missing: bool = False) -> Optional["TensorIntrin"]: return _ffi_api.TensorIntrinGet(name, allow_missing) # pylint: type: ignore -@tvm.ffi.register_object("tir.IndexMap") +@tvm_ffi.register_object("tir.IndexMap") class IndexMap(Object): """A mapping from multi-dimensional indices to another set of multi-dimensional indices @@ -489,20 +490,20 @@ def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]: """ return _ffi_api.IndexMapMapShape(self, shape) - def map_ndarray(self, arr_src: NDArray) -> NDArray: - """Apply thie index map to transform the layout of the input NDArray + def map_tensor(self, arr_src: Tensor) -> Tensor: + """Apply thie index map to transform the layout of the input Tensor Parameters ---------- - arr_src : runtime.NDArray - The NDArray to be transformed + arr_src : runtime.Tensor + The Tensor to be transformed Returns ------- - arr_dst : runtime.NDArray - The transformed NDArray + arr_dst : runtime.Tensor + The transformed Tensor """ - return _ffi_api.IndexMapMapNDArray(self, arr_src) + return _ffi_api.IndexMapMapTensor(self, arr_src) def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": """Return the inverse of the map diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index 06985f6645ec..d5bc20b76f9f 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -18,9 +18,8 @@ """The expression and statement functor of TIR.""" from typing import Callable -import tvm +import tvm_ffi from tvm.ir import PrimExpr -from tvm.runtime import Object from tvm.runtime.support import derived_object from . import _ffi_api @@ -144,8 +143,8 @@ def visit_add_(self, op: Add) -> PrimExpr: """ -@tvm.ffi.register_object("tir.PyStmtExprVisitor") -class _PyStmtExprVisitor(Object): +@tvm_ffi.register_object("tir.PyStmtExprVisitor") +class _PyStmtExprVisitor(tvm_ffi.core.Object): """ An internal wrapper to interface between C++ and Python StmtExprVisitor. This is the TVM object that wraps PyStmtExprVisitor. @@ -363,7 +362,6 @@ def visit_attr_stmt_(self, op: AttrStmt) -> None: op : AttrStmt The AttrStmt to be visited. """ - print("visit_attr_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_if_then_else_(self, op: IfThenElse) -> None: @@ -376,7 +374,6 @@ def visit_if_then_else_(self, op: IfThenElse) -> None: op : IfThenElse The IfThenElse to be visited. """ - print("visit_if_then_else_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_let_stmt_(self, op: LetStmt) -> None: @@ -389,7 +386,6 @@ def visit_let_stmt_(self, op: LetStmt) -> None: op : LetStmt The LetStmt to be visited. """ - print("visit_let_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_for_(self, op: For) -> None: @@ -402,7 +398,6 @@ def visit_for_(self, op: For) -> None: op : For The For to be visited. """ - print("visit_for_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_while_(self, op: While) -> None: @@ -415,7 +410,6 @@ def visit_while_(self, op: While) -> None: op : While The While to be visited. """ - print("visit_while_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_allocate_(self, op: Allocate) -> None: @@ -428,7 +422,6 @@ def visit_allocate_(self, op: Allocate) -> None: op : Allocate The Allocate to be visited. """ - print("visit_allocate_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_allocate_const_(self, op: AllocateConst) -> None: @@ -441,7 +434,6 @@ def visit_allocate_const_(self, op: AllocateConst) -> None: op : AllocateConst The AllocateConst to be visited. """ - print("visit_allocate_const_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_decl_buffer_(self, op: DeclBuffer) -> None: @@ -454,7 +446,6 @@ def visit_decl_buffer_(self, op: DeclBuffer) -> None: op : DeclBuffer The DeclBuffer to be visited. """ - print("visit_decl_buffer_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_buffer_store_(self, op: BufferStore) -> None: @@ -467,7 +458,6 @@ def visit_buffer_store_(self, op: BufferStore) -> None: op : BufferStore The BufferStore to be visited. """ - print("visit_buffer_store_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_buffer_realize_(self, op: BufferRealize) -> None: @@ -480,7 +470,6 @@ def visit_buffer_realize_(self, op: BufferRealize) -> None: op : BufferRealize The BufferRealize to be visited. """ - print("visit_buffer_realize_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_assert_stmt_(self, op: AssertStmt) -> None: @@ -493,7 +482,6 @@ def visit_assert_stmt_(self, op: AssertStmt) -> None: op : AssertStmt The AssertStmt to be visited. """ - print("visit_assert_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_seq_stmt_(self, op: SeqStmt) -> None: @@ -506,7 +494,6 @@ def visit_seq_stmt_(self, op: SeqStmt) -> None: op : SeqStmt The SeqStmt to be visited. """ - print("visit_seq_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_evaluate_(self, op: Evaluate) -> None: @@ -519,7 +506,6 @@ def visit_evaluate_(self, op: Evaluate) -> None: op : Evaluate The Evaluate to be visited. """ - print("visit_evaluate_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_block_(self, op: Block) -> None: @@ -532,7 +518,6 @@ def visit_block_(self, op: Block) -> None: op : Block The Block to be visited. """ - print("visit_block_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_block_realize_(self, op: BlockRealize) -> None: @@ -545,7 +530,6 @@ def visit_block_realize_(self, op: BlockRealize) -> None: op : BlockRealize The BlockRealize to be visited. """ - print("visit_block_realize_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_var_(self, op: Var) -> None: @@ -978,8 +962,8 @@ def visit_string_imm_(self, op: StringImm) -> None: _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore -@tvm.ffi.register_object("tir.PyStmtExprMutator") -class _PyStmtExprMutator(Object): +@tvm_ffi.register_object("tir.PyStmtExprMutator") +class _PyStmtExprMutator(tvm_ffi.core.Object): """ A TVM object to support customization of StmtExprMutator on the python side. This is the decorated result returned from stmt_expr_mutator decorator. diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 7a9708848ab4..1e9cb078308a 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -16,7 +16,7 @@ # under the License. """Developer API of IR node builder make function.""" import tvm -from tvm.runtime import ObjectGeneric, const +from tvm.runtime import ObjectConvertible, const from tvm.ir import container as _container from . import stmt as _stmt @@ -39,7 +39,7 @@ def __exit__(self, ptype, value, trace): self._exit_cb() -class BufferVar(ObjectGeneric): +class BufferVar(ObjectConvertible): """Buffer variable with content type, makes load store easily. Do not create it directly, create use IRBuilder. @@ -202,7 +202,7 @@ def scope_attr(self, node, attr_key, value): value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) - def for_range(self, begin, end, name="i", dtype=None, kind="serial"): + def for_range(self, begin, end, name="i", dtype=None, kind="serial", step=None): """Create a for iteration scope. Parameters @@ -223,6 +223,10 @@ def for_range(self, begin, end, name="i", dtype=None, kind="serial"): kind : str, optional The special tag on the for loop. + step : PrimExpr + The loop step. Default to none which + represent one. + Returns ------- loop_scope : With.Scope of Var @@ -275,7 +279,7 @@ def _exit_cb(): kind_id = _stmt.ForKind.UNROLLED else: raise ValueError("Unknown kind") - self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq())) + self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), step=step)) return WithScope(loop_var, _exit_cb) @@ -448,7 +452,7 @@ def allocate(self, dtype, shape, name="buf", axis_separators=None, scope=""): ) buffer_var = buffer.data - self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) + self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="bool"), x)) return BufferVar(self, buffer, dtype) def pointer(self, content_type, name="ptr", scope=""): diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 54c70ede7a9b..2e96d98489a8 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -18,7 +18,8 @@ """Operators used in TIR expression.""" from typing import Any, Optional, Union -import tvm.ffi +import tvm_ffi +import tvm from tvm import tir from tvm.ir import Array, Op, PrimExpr from tvm.ir.base import Span @@ -41,7 +42,7 @@ def _pack_buffer(buf, span=None): const(0, dtype=buf.dtype), buf.elem_offset, ] - return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, span) + return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, span=span) def call_packed_lowered(*args, span=None): @@ -50,7 +51,7 @@ def call_packed_lowered(*args, span=None): The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will recieve an TVMArrayHandle whose content is valid during the callback period. - If the PackedFunc is a python callback, then the corresponding argument is NDArray. + If the PackedFunc is a python callback, then the corresponding argument is Tensor. Parameters ---------- @@ -70,7 +71,7 @@ def call_packed_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span=span) def call_cpacked_lowered(*args, span=None): @@ -96,7 +97,7 @@ def call_cpacked_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span=span) def call_packed(*args, span=None): @@ -107,7 +108,7 @@ def call_packed(*args, span=None): When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. - If the PackedFunc is a python callback, then the corresponding argument is NDArray. + If the PackedFunc is a python callback, then the corresponding argument is Tensor. Parameters ---------- @@ -127,7 +128,7 @@ def call_packed(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span=span) def call_cpacked(*args, span=None): @@ -154,10 +155,10 @@ def call_cpacked(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span=span) -def call_intrin(dtype, func_name, *args, span=None): +def call_intrin(dtype, func_name, *args, annotations=None, span=None): """Build expression by calling an intrinsic function. Intrinsics can be overloaded with multiple data types via @@ -174,6 +175,9 @@ def call_intrin(dtype, func_name, *args, span=None): args : list Positional arguments. + annotations : Optional[Dict[str, Object]] + Additional annotations about the call. + span : Optional[Span] The location of this operator in the source code. @@ -182,7 +186,11 @@ def call_intrin(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, func_name, args, span) + + # Convert to TVM Map + if annotations is not None: + annotations = {k: tir.const(v) if isinstance(v, (int, bool)) else v for k, v in annotations.items()} + return Call(dtype, func_name, args, annotations=annotations, span=span) def call_pure_extern(dtype, func_name, *args, span=None): @@ -207,7 +215,7 @@ def call_pure_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span) + return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span=span) def call_extern(dtype, func_name, *args, span=None): @@ -355,7 +363,7 @@ def tvm_stack_make_shape(*args): def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset): - """Allocate a NDArray(DLTensor) on stack, return the handle + """Allocate a Tensor(DLTensor) on stack, return the handle Parameters ---------- @@ -570,11 +578,10 @@ def address_of(obj: Union[Buffer, BufferLoad], span: Optional[Span] = None) -> P The call expression. """ if isinstance(obj, Buffer): - n_dim = len(obj.shape) buffer_load = BufferLoad(obj, [0] * n_dim) return call_intrin("handle", "tir.address_of", buffer_load, span=span) - elif isinstance(obj, BufferLoad): + elif isinstance(obj, (BufferLoad, Var)): return call_intrin("handle", "tir.address_of", obj, span=span) else: raise ValueError(f"Invalid object type: {type(obj)}") @@ -1883,7 +1890,7 @@ def ret(val, span=None): def thread_return(span=None): - """Return from a GPU thread. + """Return from a GPU thread Parameters ---------- @@ -1899,6 +1906,40 @@ def thread_return(span=None): return _ffi_api.thread_return(span) +def continue_loop(span=None): + """Create a tir intrinsic call to represent continue expression + + Parameters + ---------- + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + ret : PrimExpr + The continue expression + """ + + return _ffi_api.continue_loop(span) + + +def break_loop(span=None): + """Create a tir intrinsic call to represent break expression + + Parameters + ---------- + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + ret : PrimExpr + The break expression + """ + + return _ffi_api.break_loop(span) + + def any(*args, span=None): """Create a new experssion of the union of all conditions in the arguments @@ -1952,7 +1993,7 @@ def all(*args, span=None): return val -@tvm.ffi.register_func("tvm.default_trace_action") +@tvm_ffi.register_global_func("tvm.default_trace_action") def _tvm_default_trace_action(*args): print(list(args)) @@ -2082,6 +2123,8 @@ def exp(x): The result. """ x = tir.convert(x) + if "int" in x.dtype: + x = tir.Cast("float32", x) return call_intrin(x.dtype, "tir.exp", x) @@ -3634,7 +3677,7 @@ def get_active_lane_mask(dtype, base, limit): return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) -def get_vscale_expr(dtype: Union[str, tvm.ffi.dtype], min_size: int = 128) -> PrimExpr: +def get_vscale_expr(dtype: Union[str, tvm_ffi.dtype], min_size: int = 128) -> PrimExpr: """ Create a datatype dependent scalable expression. @@ -3646,7 +3689,7 @@ def get_vscale_expr(dtype: Union[str, tvm.ffi.dtype], min_size: int = 128) -> Pr The minimum size of the scalable vector in bits. """ if isinstance(dtype, str): - dtype = tvm.ffi.dtype(dtype) + dtype = tvm_ffi.dtype(dtype) return min_size // dtype.bits * vscale() diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index ae78b0573822..96ed9dfdbc96 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -31,6 +31,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I pass_ctx = tvm.transform.PassContext.current() config = pass_ctx.config passes = [ + tir.transform.CanonicalizeLoop(), tir.transform.LowerCrossThreadReduction(), tir.transform.LowerInitBlock(), tir.transform.PlanAndUpdateBufferAllocationLocation(), @@ -43,6 +44,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tir.transform.LowerMatchBuffer(), tir.transform.Simplify(), tir.transform.InjectPermutedLayout(), + tir.transform.AnnotateIrregularLoop(), tir.transform.InjectSoftwarePipeline(), tir.transform.TransformMmaBufferLayout(), tir.transform.LowerOpaqueBlock(), diff --git a/python/tvm/tir/schedule/_ffi_api.py b/python/tvm/tir/schedule/_ffi_api.py index b854145beb6a..5087112b892a 100644 --- a/python/tvm/tir/schedule/_ffi_api.py +++ b/python/tvm/tir/schedule/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.schedule""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("tir.schedule", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("tir.schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 491a689c9309..66eab497eb5a 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -17,7 +17,7 @@ """Analysis used in TensorIR scheduling""" from typing import List, Optional -import tvm.ffi +import tvm_ffi from tvm.runtime import Object from ..buffer import Buffer @@ -62,7 +62,7 @@ def suggest_index_map( ) -@tvm.ffi.register_object("tir.schedule.TensorizeInfo") +@tvm_ffi.register_object("tir.schedule.TensorizeInfo") class TensorizeInfo(Object): """Necessary information used for tensorization.""" @@ -90,7 +90,7 @@ def get_tensorize_loop_mapping( return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func, allow_padding) # type: ignore -@tvm.ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") +@tvm_ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") class AutoTensorizeMappingInfo(Object): """Necessary information used to perform transformations for tensorization.""" diff --git a/python/tvm/tir/schedule/instruction.py b/python/tvm/tir/schedule/instruction.py index 5a8563e652b6..918292a7bbaa 100644 --- a/python/tvm/tir/schedule/instruction.py +++ b/python/tvm/tir/schedule/instruction.py @@ -17,7 +17,7 @@ """Schedule instructions each corresponds to a schedule primitive""" from typing import TYPE_CHECKING, Any, List, Union -from tvm.ffi import register_object as _register_object +from tvm_ffi import register_object as _register_object from tvm.runtime import Object from . import _ffi_api diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 5325ecdc16c4..95effb643fd7 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -18,7 +18,7 @@ import inspect from typing import Callable, Dict, List, Literal, Optional, Tuple, Union -from tvm.ffi import register_object as _register_object +from tvm_ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import GlobalVar, IRModule, PrimExpr from tvm.runtime import Object @@ -1910,7 +1910,8 @@ def resize_cache_index( @type_checked def reindex( - self, block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer] + self, block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer], + skip_simplify: bool = False, ) -> BlockRV: """Create a block that read/write a buffer region into a read/write cache with reindexing. The layout of the cache will be the same as by the iterators of the block that reads/writes @@ -1942,6 +1943,9 @@ def reindex( If `buffer` is a Buffer object, it must exist within the reads/writes of the block. + skip_simplify: bool + Whether to skip the simplification of the indices. + Returns ------- reindex_block : BlockRV @@ -1997,7 +2001,7 @@ def after_reindex( assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member - self, block, buffer_index, buffer_index_type_enum + self, block, buffer_index, buffer_index_type_enum, skip_simplify ) ########## Schedule: Data movement ########## @@ -2345,6 +2349,33 @@ def after_inline(a: T.handle, c: T.handle) -> None: # pylint: disable-next=no-member _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore + @type_checked + def fuse_reduction_epilogue( + self, + reduction_block: Union[BlockRV, str], + epilogue_block: Union[BlockRV, str], + ) -> None: + """Fuse an epilogue block into a reduction block. + + It requires: + 1) The reduction block is a complete reduction block + 2) The epilogue block only reads from the reduction block's output + 3) The epilogue performs a simple addition: output = reduction_result + bias + + Parameters + ---------- + reduction_block : Union[BlockRV, str] + The reduction block (e.g., matmul) + epilogue_block : Union[BlockRV, str] + The epilogue block to be fused (e.g., bias add) + """ + reduction_block = self._normalize_block_arg(reduction_block) + epilogue_block = self._normalize_block_arg(epilogue_block) + # pylint: disable-next=no-member + _ffi_api.ScheduleFuseReductionEpilogue( + self, reduction_block, epilogue_block + ) # type: ignore + ########## Schedule: Reduction ########## @type_checked @@ -2384,7 +2415,7 @@ def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> Block .. code-block:: python @T.prim_func - def before_decompose(a: ty.handle, c: ty.handle) -> None: + def before_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) @@ -2409,7 +2440,7 @@ def before_decompose(a: ty.handle, c: ty.handle) -> None: .. code-block:: python @T.prim_func - def after_decompose(a: ty.handle, c: ty.handle) -> None: + def after_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index f082a9e92ea7..36436fe95783 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -20,7 +20,7 @@ from enum import IntEnum from typing import Dict, Optional, Union -from tvm.ffi import register_object +from tvm_ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object from tvm.tir import Block, BlockRealize, For, PrimFunc diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index da3508a42ee0..edc537f3a296 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -18,7 +18,7 @@ import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional -from tvm.ffi import register_object as _register_object +from tvm_ffi import register_object as _register_object from tvm.runtime import Object from ...ir import Array, Map, save_json diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index ffb6fd6a7068..448ace3ade63 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -29,9 +29,9 @@ from enum import IntEnum from typing import List, Mapping, Optional, Union -import tvm.ffi +import tvm_ffi from tvm.ir import PrimExpr, Range, Span -from tvm.runtime import Object, Scriptable, const, NDArray +from tvm.runtime import Object, Scriptable, const, Tensor from . import _ffi_api from .buffer import Buffer @@ -42,7 +42,7 @@ class Stmt(Object, Scriptable): """Base class of all the statements.""" -@tvm.ffi.register_object("tir.LetStmt") +@tvm_ffi.register_object("tir.LetStmt") class LetStmt(Stmt): """LetStmt node. @@ -72,7 +72,7 @@ def __init__(self, var: Var, value: PrimExpr, body: Stmt, span: Optional[Span] = ) -@tvm.ffi.register_object("tir.AssertStmt") +@tvm_ffi.register_object("tir.AssertStmt") class AssertStmt(Stmt): """AssertStmt node. @@ -120,7 +120,7 @@ class ForKind(IntEnum): THREAD_BINDING = 4 # pylint: disable=invalid-name -@tvm.ffi.register_object("tir.For") +@tvm_ffi.register_object("tir.For") class For(Stmt): """For node. @@ -145,6 +145,10 @@ class For(Stmt): The thread this loop binds to. Only valid if kind is ThreadBinding + step : PrimExpr + The loop step. Default to none which + represent one. + annotations: Optional[Mapping[str, Object]] Additional annotation hints. @@ -159,6 +163,7 @@ class For(Stmt): body: Stmt thread_binding: Optional[IterVar] annotations: Mapping[str, Object] + step: Optional[PrimExpr] span: Optional[Span] def __init__( @@ -170,6 +175,7 @@ def __init__( body: Stmt, thread_binding: Optional[IterVar] = None, annotations: Optional[Mapping[str, Object]] = None, + step: Optional[PrimExpr] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( @@ -181,11 +187,12 @@ def __init__( body, thread_binding, annotations, + step, span, ) -@tvm.ffi.register_object("tir.While") +@tvm_ffi.register_object("tir.While") class While(Stmt): """While node. @@ -209,7 +216,7 @@ def __init__(self, condition: PrimExpr, body: Stmt, span: Optional[Span] = None) self.__init_handle_by_constructor__(_ffi_api.While, condition, body, span) # type: ignore -@tvm.ffi.register_object("tir.BufferStore") +@tvm_ffi.register_object("tir.BufferStore") class BufferStore(Stmt): """Buffer store node. @@ -252,7 +259,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.BufferRealize") +@tvm_ffi.register_object("tir.BufferRealize") class BufferRealize(Stmt): """Buffer realize node. @@ -293,7 +300,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.Allocate") +@tvm_ffi.register_object("tir.Allocate") class Allocate(Stmt): """Allocate node. @@ -353,7 +360,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.AllocateConst") +@tvm_ffi.register_object("tir.AllocateConst") class AllocateConst(Stmt): """Allocate constant node. @@ -368,8 +375,8 @@ class AllocateConst(Stmt): extents : list of Expr The extents of the allocate - data_or_idx : Union[NDArray, int] - If an NDArray, this is the const data associated with the + data_or_idx : Union[Tensor, int] + If an Tensor, this is the const data associated with the constant. If an integer, this is the index into the "constants" attribute of the `IRModule` that contains the `AllocateConst`. @@ -387,7 +394,7 @@ class AllocateConst(Stmt): buffer_var: Var dtype: str extents: List[PrimExpr] - data: Optional[NDArray] + data: Optional[Tensor] irmod_storage_idx: Optional[int] body: Stmt annotations: Mapping[str, Object] @@ -398,7 +405,7 @@ def __init__( buffer_var: Var, dtype: str, extents: List[PrimExpr], - data_or_idx: Union[NDArray, int], + data_or_idx: Union[Tensor, int], body: Stmt, annotations: Optional[Mapping[str, Object]] = None, span: Optional[Span] = None, @@ -415,7 +422,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.DeclBuffer") +@tvm_ffi.register_object("tir.DeclBuffer") class DeclBuffer(Stmt): """DeclBuffer node. @@ -439,7 +446,7 @@ def __init__(self, buffer: Buffer, body: Stmt, span: Optional[Span] = None) -> N self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body, span) -@tvm.ffi.register_object("tir.AttrStmt") +@tvm_ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. @@ -475,7 +482,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.SeqStmt") +@tvm_ffi.register_object("tir.SeqStmt") class SeqStmt(Stmt): """Sequence of statements. @@ -501,7 +508,7 @@ def __len__(self): return len(self.seq) -@tvm.ffi.register_object("tir.IfThenElse") +@tvm_ffi.register_object("tir.IfThenElse") class IfThenElse(Stmt): """IfThenElse node. @@ -536,7 +543,7 @@ def __init__( ) -@tvm.ffi.register_object("tir.Evaluate") +@tvm_ffi.register_object("tir.Evaluate") class Evaluate(Stmt): """Evaluate node. @@ -556,7 +563,7 @@ def __init__(self, value: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore -@tvm.ffi.register_object("tir.BufferRegion") +@tvm_ffi.register_object("tir.BufferRegion") class BufferRegion(Object, Scriptable): """BufferRegion node. @@ -576,7 +583,7 @@ def __init__(self, buffer: Buffer, region: List[Range]) -> None: self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) # type: ignore -@tvm.ffi.register_object("tir.MatchBufferRegion") +@tvm_ffi.register_object("tir.MatchBufferRegion") class MatchBufferRegion(Object, Scriptable): """MatchBufferRegion node. @@ -598,7 +605,7 @@ def __init__(self, buffer: Buffer, source: BufferRegion) -> None: ) -@tvm.ffi.register_object("tir.Block") +@tvm_ffi.register_object("tir.Block") class Block(Stmt): """Block node. @@ -680,7 +687,7 @@ def __init__( ) # type: ignore -@tvm.ffi.register_object("tir.BlockRealize") +@tvm_ffi.register_object("tir.BlockRealize") class BlockRealize(Stmt): """BlockRealize node. diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 564655455245..0a6cf5310c9c 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -20,4 +20,4 @@ from . import cuda if enabled("llvm"): - from . import arm_cpu, x86, rocm, hexagon + from . import arm_cpu, x86, rocm, hexagon, riscv_cpu diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 6f964c94370d..7b0c71583b1a 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -18,7 +18,7 @@ """Intrinsics for tensorization on NVIDIA GPU.""" from typing import Dict, Literal, Optional, Tuple -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.runtime import convert from tvm.script import tir as T from tvm.tir import Cast, IntImm, TensorIntrin @@ -46,7 +46,7 @@ def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col -@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") +@register_global_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): i, j = ind[0], ind[1] thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j) @@ -1465,7 +1465,7 @@ def get_index_C(elem_offset, stride): stride_b = stride // 8 bi = i // 8 bj = j // 8 - return (bi // 2) * 2 * stride_b + bi % 2 + bj * 2 + return ((bi // 2) * 2 * stride_b + bi % 2 + bj * 2) * 2 def get_mma_init_intrin( @@ -1746,7 +1746,7 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: ) -@register_func("tir.index_map_m16n8k8.matrixC") +@register_global_func("tir.index_map_m16n8k8.matrixC") def index_map_m16n8k8_matrixC(ind): i, j = ind[0], ind[1] return convert([(i // 8) // 2, j // 8, (i // 8) % 2, (j % 8) % 2]) diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py b/python/tvm/tir/tensor_intrin/riscv_cpu.py new file mode 100644 index 000000000000..e0782ada4cc1 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,line-too-long +"""Intrinsics for RISCV tensorization""" + +import logging +import tvm_ffi + +from tvm.runtime import DataType +from tvm.script import tir as T +from tvm.target.codegen import llvm_get_vector_width, target_has_features, Target +from .. import TensorIntrin + +logger = logging.getLogger(__name__) + + +def get_max_elems(vlen: int, lmul: int, sew: int) -> int: + """Returns number of elements of a given data type (SEW) + that fits multiple (LMUL) of the vector registers (VLEN). + + Args: + vlen (int): VLEN vector length in bits + lmul (int): LMUL vector lenght multiplier + sew (int): SEW standard (single) element width + + Returns: + int: Number of elements + """ + return (vlen // sew) * lmul + + +def rvv_vec_dot_product_kernels( + n_elems: int, + n_lanes: int, + data_dtype: str, + weight_dtype: str, + out_dtype: str, + lmul: int, +): + """Dot product of vector and matrix rows using RISC-V vector instructions. + + These kernels takes two arrays A[ELEMS] and B[ELEMS][MACS] and computes + dot product of A[ELEMS] with each row of B[LANES], accumulating results + with C[LANES]. + + The pseudo code is as follows: + .. code-block:: c + void vec_dot_prod(A[ELEMS], B[LANES][ELEMS], C[LANES]){ + for (j = 0; j < LANES; j++) { + for (k = 0; k < ELEMS; k++) { + C[j] += A[k] * B[j][k] + } + } + } + """ + + @T.prim_func + def rvv_vec_dot_prod_desc( + A: T.Buffer((n_elems,), data_dtype, offset_factor=1), + B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1), + C: T.Buffer((n_lanes,), out_dtype, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems]) + T.writes(C[0:n_lanes]) + for j in T.serial(0, n_lanes): + for k in T.serial(0, n_elems): + with T.block("update"): + vj, vk = T.axis.remap("SR", [j, k]) + C[vj] = C[vj] + T.cast(A[vk], out_dtype) * T.cast(B[vj, vk], out_dtype) + + # LLVM only supports ELEN=32 or ELEN=64 + # https://llvm.org/docs//RISCV/RISCVVectorExtension.html + d_dtype_lanes = (64 // DataType(data_dtype).bits) * lmul + w_dtype_lanes = (64 // DataType(weight_dtype).bits) * lmul + # reduction lanes narrows + o_dtype_lanes = (64 // DataType(out_dtype).bits) * lmul // n_lanes + # data type widening case + o_dtype_lanes = max(o_dtype_lanes, 2) + + mask_args = () if data_dtype[0] in ("i", "u") else (T.uint64(7),) + + wide_dtype = out_dtype + if DataType(out_dtype).bits > DataType(data_dtype).bits: + wide_dtype = "".join(c for c in data_dtype if not c.isdigit()) + wide_dtype += str(DataType(data_dtype).bits * 2) + + # fmt: off + @T.prim_func + def rvv_vec_dot_prod_impl( + A: T.Buffer((n_elems,), data_dtype, offset_factor=1), + B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1), + C: T.Buffer((n_lanes,), out_dtype, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems]) + T.writes(C[0:n_lanes]) + + vec_A = T.call_llvm_intrin( + f"{data_dtype}xvscalex{d_dtype_lanes}", + "llvm.riscv.vle", + T.broadcast(T.Cast(data_dtype, 0), T.vscale() * d_dtype_lanes), + T.tvm_access_ptr(T.type_annotation(data_dtype), A.data, 0, n_elems, 1), + T.int64(n_elems)) + + for i in range(n_lanes): + with T.block("update"): + T.reads(B[i, 0:n_elems]) + T.writes(C[i]) + + vec_B_row = T.call_llvm_intrin( + f"{weight_dtype}xvscalex{w_dtype_lanes}", + "llvm.riscv.vle", + T.broadcast(T.Cast(data_dtype, 0), T.vscale() * w_dtype_lanes), + T.tvm_access_ptr(T.type_annotation(weight_dtype), B.data, i * n_elems, n_elems, 1), + T.int64(n_elems)) + + product = T.call_llvm_intrin( + f"{wide_dtype}xvscalex{w_dtype_lanes}", + "llvm.riscv.vfmul" if out_dtype[0] == "f" else \ + "llvm.riscv.vwmulsu" if (data_dtype[0] != weight_dtype[0]) else \ + "llvm.riscv.vwmul", + T.broadcast(T.Cast(wide_dtype, 0), T.vscale() * w_dtype_lanes), + vec_B_row, + vec_A, + *mask_args, + T.uint64(n_elems)) + + ini_acc = T.call_llvm_intrin( + f"{out_dtype}xvscalex{o_dtype_lanes}", + "llvm.riscv.vle", + T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes), + T.tvm_access_ptr(T.type_annotation(out_dtype), C.data, i, 1, 1), + T.int64(1)) + + red_sum = T.call_llvm_intrin( + f"{out_dtype}xvscalex{o_dtype_lanes}", + "llvm.riscv.vfredusum" if out_dtype[0] == "f" else \ + "llvm.riscv.vwredsum", + T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes), + product, + ini_acc, + *mask_args, + T.uint64(n_elems)) + + C[i] = T.call_llvm_intrin( + out_dtype, + "llvm.riscv.vfmv.f.s" if out_dtype[0] == "f" else \ + "llvm.riscv.vmv.x.s", + red_sum) + # fmt: on + return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl + + +@tvm_ffi.register_global_func("tir.tensor_intrin.register_rvv_isa_intrinsics") +def register_rvv_isa_intrinsics(target: Target, inventory_only=False) -> dict(): + """Register RISCV V (vector) intrinsics + [x] Implementation follows version 1.0 vector specifications: + https://github.com/riscvarchive/riscv-v-spec/releases/tag/v1.0 + + Args: + target (Target): TVM target + inventory_only (bool): No registration inventory only + + Returns: + dict(): A catalog with registered kernel names and properties + """ + if not target_has_features("v", target): + raise RuntimeError("Current target does not support `v` extension.") + + vlen = llvm_get_vector_width(target) + # get maximum reduction lanes (without grouping) + n_lanes = get_max_elems(vlen, lmul=1, sew=32) + + kernels_inventory = {} + + data_dtype = ["uint8", "int8", "float16", "float32"] + weight_dtype = ["int8", "int8", "float16", "float32"] + output_dtype = ["int32", "int32", "float16", "float32"] + + for d_dtype, w_dtype, o_dtype in zip(data_dtype, weight_dtype, output_dtype): + # max elements to grouped registers + max_elems = get_max_elems(vlen, lmul=8, sew=DataType(d_dtype).bits) + # data widening halves available vector registers + if DataType(o_dtype).bits > DataType(d_dtype).bits: + max_elems //= 2 + # compute optimal LMUL for full load + lmul = max_elems // (vlen // DataType(d_dtype).bits) + + n_elems = max_elems + while n_elems >= 4: + + dt = DataType(d_dtype) + wt = DataType(w_dtype) + ot = DataType(o_dtype) + kernel_name = "rvv_dot" + kernel_name += f"_{n_elems}{dt[0]}{dt.bits}" + kernel_name += f"_{n_lanes}x{n_elems}{wt[0]}{wt.bits}" + kernel_name += f"_{n_lanes}{ot[0]}{ot.bits}" + kernels_inventory[kernel_name] = n_elems + + if not inventory_only: + logger.debug(f"Registering kernel {kernel_name}") + desc, impl = rvv_vec_dot_product_kernels( + n_elems, n_lanes, d_dtype, w_dtype, o_dtype, lmul + ) + TensorIntrin.register(kernel_name, desc, impl, override=True) + + n_elems //= 2 + + return kernels_inventory + + +def register_riscv_intrinsics(target: Target): + """Register RISCV intrinsics + + Args: + target (Target): TVM target + """ + + # RISCV `v` 1.0 extension templates + _ = register_rvv_isa_intrinsics(target) + logger.debug("Finished registering riscv intrinsics.") diff --git a/python/tvm/tir/transform/_ffi_api.py b/python/tvm/tir/transform/_ffi_api.py index 8a6607c11af0..67896ec05dda 100644 --- a/python/tvm/tir/transform/_ffi_api.py +++ b/python/tvm/tir/transform/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.transform""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("tir.transform", __name__) +tvm_ffi.init_ffi_api("tir.transform", __name__) diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index b679d4ab16ce..a85eabd970e1 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -19,13 +19,13 @@ import functools from typing import Callable, List, Optional, Union -import tvm.ffi +import tvm_ffi from tvm.ir.transform import Pass, PassInfo from . import _ffi_api -@tvm.ffi.register_object("tir.PrimFuncPass") +@tvm_ffi.register_object("tir.PrimFuncPass") class PrimFuncPass(Pass): """A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function pass class should be created through py:func:`tvm.tir.transform.function_pass`. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 93a182ca3bc2..88cf4720d3a6 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -244,7 +244,7 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore -def FP8ComputeLegalize(promote_dtype_str: str = "float32"): +def FP8ComputeLegalize(promote_dtype: str = "float32"): """Legalize fp8 compute Ops. Parameters @@ -257,7 +257,7 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore + return _ffi_api.FP8ComputeLegalize(promote_dtype) # type: ignore def BF16StorageLegalize(): @@ -373,7 +373,7 @@ def MakePackedAPI(): For static shapes, the `BufferNode::shape`, `BufferNode::strides`, and `BufferNode::elem_offset` member variables are used to generate runtime checks on the corresponding member variables in - the user-provided `DLTensor*` or `tvm.nd.array` argument. (e.g. A + the user-provided `DLTensor*` or `tvm.runtime.tensor` argument. (e.g. A PrimFunc that accepts a buffer of shape `[16,32]` validates that the `DLTensor::shape` array is `[16,32]`.) @@ -430,6 +430,19 @@ def AnnotateDeviceRegions(): return _ffi_api.AnnotateDeviceRegions() # type: ignore +def AnnotateIrregularLoop(): + """Annotate irregular loop mark. Loop transformations like + peeling, partition, unroll, etc is not allowed on irregular + loop with internal loop continuation and breaks. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateIrregularLoop() # type: ignore + + def SplitHostDevice(): """Split the function into a host function and device functions. @@ -1052,26 +1065,26 @@ def InjectPTXAsyncCopy(): return _ffi_api.InjectPTXAsyncCopy() # type: ignore -def RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=False): +def RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite=False): """Remove weight layout rewrite block before benchmarking during tuning stage. Parameters ---------- - skip_ndarray_rewrite : bool - If True, exact rewrite of NDArray, according to the given index map, will be skipped. - Only the shape of the NDArray is transformed correctly, and the content of the destination + skip_tensor_rewrite : bool + If True, exact rewrite of Tensor, according to the given index map, will be skipped. + Only the shape of the Tensor is transformed correctly, and the content of the destination array will be filled with random values. - When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, - before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's - MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary. + When this pass is called many times during MetaSchedule tuning, the raw data of Tensor, + before and after rewrite, does not matter. Since Tensor layout rewrite, using IndexMap's + MapTensor, is currently slow, skipping the exact rewrite is sometimes necessary. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite) # type: ignore + return _ffi_api.RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite) # type: ignore def ManifestSharedMemoryLocalStage(): @@ -1158,3 +1171,14 @@ def LowerVtcmAlloc(): The result pass """ return _ffi_api.LowerVtcmAlloc() # type: ignore + + +def CanonicalizeLoop(): + """Canonicalize the loop to start from zero and use trivial step + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.CanonicalizeLoop() # type: ignore diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 9503aea0cd2f..c73e8bf54cf5 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -50,6 +50,7 @@ from . import nn from . import utils from . import image +from . import vision from . import gpu # error reporting diff --git a/python/tvm/topi/cpp/cuda.py b/python/tvm/topi/cpp/cuda.py index 22f97293d38d..21cf554add3b 100644 --- a/python/tvm/topi/cpp/cuda.py +++ b/python/tvm/topi/cpp/cuda.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for CUDA TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi.cuda", "tvm.topi.cpp.cuda") +tvm_ffi.init_ffi_api("topi.cuda", "tvm.topi.cpp.cuda") diff --git a/python/tvm/topi/cpp/generic.py b/python/tvm/topi/cpp/generic.py index 3230d5428bb2..77dfcab58a0f 100644 --- a/python/tvm/topi/cpp/generic.py +++ b/python/tvm/topi/cpp/generic.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for generic TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi.generic", "tvm.topi.cpp.generic") +tvm_ffi.init_ffi_api("topi.generic", "tvm.topi.cpp.generic") diff --git a/python/tvm/topi/cpp/impl.py b/python/tvm/topi/cpp/impl.py index e5473a7e6602..c1783067951a 100644 --- a/python/tvm/topi/cpp/impl.py +++ b/python/tvm/topi/cpp/impl.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Load Lib for C++ TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi", "tvm.topi.cpp") +tvm_ffi.init_ffi_api("topi", "tvm.topi.cpp") diff --git a/python/tvm/topi/cpp/nn.py b/python/tvm/topi/cpp/nn.py index 2ea1fc371404..32c24dc1ed98 100644 --- a/python/tvm/topi/cpp/nn.py +++ b/python/tvm/topi/cpp/nn.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for NN TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi.nn", "tvm.topi.cpp.nn") +tvm_ffi.init_ffi_api("topi.nn", "tvm.topi.cpp.nn") diff --git a/python/tvm/topi/cpp/rocm.py b/python/tvm/topi/cpp/rocm.py index 771fc3c3f0f3..3eb83fe689c3 100644 --- a/python/tvm/topi/cpp/rocm.py +++ b/python/tvm/topi/cpp/rocm.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for Rocm TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi.rocm", "tvm.topi.cpp.rocm") +tvm_ffi.init_ffi_api("topi.rocm", "tvm.topi.cpp.rocm") diff --git a/python/tvm/topi/cpp/utils.py b/python/tvm/topi/cpp/utils.py index b78a6baa0f01..ecf341fabd5f 100644 --- a/python/tvm/topi/cpp/utils.py +++ b/python/tvm/topi/cpp/utils.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for TOPI utility functions""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi.utils", "tvm.topi.cpp.utils") +tvm_ffi.init_ffi_api("topi.utils", "tvm.topi.cpp.utils") diff --git a/python/tvm/topi/cpp/vision/__init__.py b/python/tvm/topi/cpp/vision/__init__.py index 5fdf1ac4e3a8..467ce70fbd33 100644 --- a/python/tvm/topi/cpp/vision/__init__.py +++ b/python/tvm/topi/cpp/vision/__init__.py @@ -16,8 +16,9 @@ # under the License. """FFI for vision TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi from . import yolo +from ...vision import nms -tvm.ffi._init_api("topi.vision", "tvm.topi.cpp.vision") +tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision") diff --git a/python/tvm/topi/cpp/vision/yolo.py b/python/tvm/topi/cpp/vision/yolo.py index 5d8bdd99d24c..f5aa6d2d0670 100644 --- a/python/tvm/topi/cpp/vision/yolo.py +++ b/python/tvm/topi/cpp/vision/yolo.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for Yolo TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi.vision.yolo", "tvm.topi.cpp.vision.yolo") +tvm_ffi.init_ffi_api("topi.vision.yolo", "tvm.topi.cpp.vision.yolo") diff --git a/python/tvm/topi/cpp/x86.py b/python/tvm/topi/cpp/x86.py index 18de30c668a3..93cb6d96f6b8 100644 --- a/python/tvm/topi/cpp/x86.py +++ b/python/tvm/topi/cpp/x86.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for x86 TOPI ops and schedules""" -import tvm.ffi +import tvm_ffi -tvm.ffi._init_api("topi.x86", "tvm.topi.cpp.x86") +tvm_ffi.init_ffi_api("topi.x86", "tvm.topi.cpp.x86") diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py index eb48da0a022a..807b23a956e9 100644 --- a/python/tvm/topi/gpu/sort.py +++ b/python/tvm/topi/gpu/sort.py @@ -219,11 +219,22 @@ def compare(a, b): upper_lim = ceil_log2(size) def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): - first = ib.allocate("int64", (1,), name="first", scope="local") - mid = ib.allocate("int64", (1,), name="mid", scope="local") - last = ib.allocate("int64", (1,), name="last", scope="local") - first[0] = tvm.te.max(0, diag - bCount) - last[0] = tvm.te.min(diag, aCount) + target = tvm.target.Target.current() + is_webgpu = "webgpu" in str(target) + target_dtype = "int32" if is_webgpu else "int64" + + first = ib.allocate(target_dtype, (1,), name="first", scope="local") + mid = ib.allocate(target_dtype, (1,), name="mid", scope="local") + last = ib.allocate(target_dtype, (1,), name="last", scope="local") + max_val = tvm.te.max(0, diag - bCount) + min_val = tvm.te.min(diag, aCount) + if is_webgpu: + first[0] = cast(max_val, target_dtype) + last[0] = cast(min_val, target_dtype) + else: + first[0] = max_val + last[0] = min_val + with ib.while_loop(first[0] < last[0]): mid = (first[0] + last[0]) >> 1 a = source[base_idx + (aStart + mid)] @@ -250,10 +261,20 @@ def serial_merge( first, last, ): - i = ib.allocate("int64", (1,), name="i", scope="local") - j = ib.allocate("int64", (1,), name="j", scope="local") - i[0] = aStart + first - j[0] = bStart + diag - last + target = tvm.target.Target.current() + is_webgpu = "webgpu" in str(target) + target_dtype = "int32" if is_webgpu else "int64" + i = ib.allocate(target_dtype, (1,), name="i", scope="local") + j = ib.allocate(target_dtype, (1,), name="j", scope="local") + i_val = aStart + first + j_val = bStart + diag - last + if is_webgpu: + i[0] = cast(i_val, target_dtype) + j[0] = cast(j_val, target_dtype) + else: + i[0] = i_val + j[0] = j_val + with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count: i_idx = base_idx + i[0] j_idx = base_idx + j[0] @@ -287,7 +308,9 @@ def assign_j(): with ib.else_scope(): assign_j() - with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64") as l2_width: + target = tvm.target.Target.current() + target_dtype = "int32" if "webgpu" in str(target) else "int64" + with ib.for_range(0, cast(upper_lim - lower_lim, target_dtype), dtype=target_dtype) as l2_width: width = 2 << (l2_width + lower_lim) # Define and launch the cuda kernel with ib.new_scope(): @@ -359,8 +382,10 @@ def merge(source, dest, source_idx, dest_idx): def mergesort(source, dest, source_idx, dest_idx, size, width, even): # calculate the start, mid, and end points of this section start = width * bz - middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64") - end = cast(tvm.te.min(start + width, size), "int64") + target = tvm.target.Target.current() + target_dtype = "int32" if "webgpu" in str(target) else "int64" + middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), target_dtype) + end = cast(tvm.te.min(start + width, size), target_dtype) with ib.if_scope(start < size): with ib.if_scope(nbx == 1): ## merge the start->middle and middle->end arrays diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index f51c6718ab99..52406d402cdd 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -1,6 +1,6 @@ # Licensed to the Apache Software Foundation (ASF) under one -# or more contrir_builderutor license agreements. See the NOTICE file -# distrir_builderuted with this work for additional information +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance @@ -9,7 +9,7 @@ # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, -# software distrir_builderuted under the License is distrir_builderuted on an +# software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations @@ -29,7 +29,8 @@ def index_put(data, indices, values, accumulate=False): The source array to be modified. indices : Tuple[tvm.te.Tensor] - Tuple of 1D index tensors (one for each dimension) specifying positions. + Tuple of index tensors (can be multi-dimensional) specifying positions. + Index tensors are broadcast together following NumPy broadcasting rules. values : tvm.te.Tensor The values to place at the specified indices. @@ -60,11 +61,28 @@ def index_put(data, indices, values, accumulate=False): for dim in shape: full_range *= dim - # Check all indices have same length - index_len = len(indices[0]) - for idx in indices[1:]: - if not utils.equal_const_int(len(idx), index_len): - raise ValueError("All index tensors must have same length") + index_shapes = [idx.shape for idx in indices] + broadcast_ndim = max(len(s) for s in index_shapes) + broadcast_shape = [] + + for i in range(broadcast_ndim): + max_dim = 1 + for idx_shape in index_shapes: + # Right-align shapes + dim_idx = len(idx_shape) - broadcast_ndim + i + if dim_idx >= 0: + dim_size = idx_shape[dim_idx] + if not utils.equal_const_int(dim_size, 1): + if utils.equal_const_int(max_dim, 1): + max_dim = dim_size + elif not utils.equal_const_int(dim_size, max_dim): + raise ValueError(f"Cannot broadcast index shapes: {index_shapes}") + broadcast_shape.append(max_dim) + + # Compute total number of elements after broadcasting + index_len = 1 + for dim in broadcast_shape: + index_len *= dim def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): ir_builder = tir.ir_builder.create() @@ -78,12 +96,38 @@ def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): out[i] = data[i] with ir_builder.for_range(0, index_len, "k", kind="parallel") as k: - # Calculate multi-dimensional index + # Decompose k into multi-dimensional broadcast index + k_temp = k + broadcast_indices = [] + for i in range(broadcast_ndim - 1, -1, -1): + broadcast_indices.insert(0, k_temp % broadcast_shape[i]) + k_temp = k_temp // broadcast_shape[i] + flat_index = 0 stride = 1 for dim in range(len(shape) - 1, -1, -1): - # Get index and shift to positive if needed - idx_val = indices[dim][k] + # Get the index for this dimension using broadcasting + idx_shape = index_shapes[dim] + idx_ndim = len(idx_shape) + + # Compute the linear index into this index tensor + idx_offset = 0 + idx_stride = 1 + for i in range(broadcast_ndim - 1, -1, -1): + # Right-align the index shape with broadcast shape + dim_idx = idx_ndim - broadcast_ndim + i + if dim_idx >= 0: + dim_size = idx_shape[dim_idx] + # Use broadcasting: if size is 1, use index 0 + # otherwise use broadcast_indices[i] + if utils.equal_const_int(dim_size, 1): + idx_in_dim = 0 + else: + idx_in_dim = broadcast_indices[i] + idx_offset += idx_in_dim * idx_stride + idx_stride *= dim_size + + idx_val = indices[dim][idx_offset] shifted_idx = idx_val + (idx_val < 0) * shape[dim] flat_index += shifted_idx * stride stride *= shape[dim] diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index fb306f9e599b..61b39aad9114 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -450,7 +450,6 @@ def round(x): return te.compute(x.shape, lambda *i: te.round(x(*i))) -@tvm.te.tag_scope(tag=tag.ELEMWISE) def log(x): """Take logarithm of input x. @@ -464,10 +463,11 @@ def log(x): y : tvm.te.Tensor The result. """ - return te.compute(x.shape, lambda *i: te.log(x(*i))) + if x.dtype.startswith("int"): + x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) + return te.compute(x.shape, lambda *i: te.log(x(*i)), tag=tag.ELEMWISE) -@tvm.te.tag_scope(tag=tag.ELEMWISE) def log2(x): """Take logarithm to the base 2 of input x. @@ -481,7 +481,9 @@ def log2(x): y : tvm.te.Tensor The result. """ - return te.compute(x.shape, lambda *i: te.log2(x(*i))) + if x.dtype.startswith("int"): + x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) + return te.compute(x.shape, lambda *i: te.log2(x(*i)), tag=tag.ELEMWISE) def log10(x): diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 531c0a6c6663..ce14df8beddf 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -394,6 +394,135 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ) +def conv2d_NCHWc_OIHWo( + data: te.Tensor, kernel, stride, padding, dilation, layout, out_layout, out_dtype="float32" +): + """Conv2D operator for nChw[x]c layout. + + Parameters + ---------- + data : tvm.te.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.te.Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, + num_filter_block] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype : str + output data type + + Returns + ------- + output : tvm.te.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) + dilation_h, dilation_w = ( + dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + ) + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + kernel_shape = get_const_tuple(kernel.shape) + if len(kernel_shape) == 6: # OIHW4i4o + oc_chunk, ic_chunk_group, kernel_height, kernel_width, kernel_ic_bn, oc_bn = kernel_shape + groups = in_channel // (ic_chunk_group * kernel_ic_bn) + else: # OIHW4o + oc_chunk, ic, kernel_height, kernel_width, oc_bn = kernel_shape + groups = in_channel // ic + + num_filter = oc_chunk * oc_bn + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + + # output shape + out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + pad_before = (0, 0, pad_top, pad_left, 0) + pad_after = (0, 0, pad_down, pad_right, 0) + + # DOPAD + DOPAD = HPAD != 0 or WPAD != 0 + if DOPAD: + data_pad = pad(data, pad_before, pad_after, name="conv2d_data_pad") + else: + data_pad = data + + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + def compute_conv2d(*args): + n, occ, oh, ow, ocb = args + ic = te.reduce_axis((0, in_channel // groups), name="ic") + if groups == 1: + data_pad_ = data_pad[ + n, + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + else: + data_pad_ = data_pad[ + n, + (occ // (oc_chunk // groups)) * (ic_chunk // groups) + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + if len(kernel_shape) == 5: + kernel_ = kernel[occ, ic, kh, kw, ocb] + else: + kernel_ = kernel[occ, idxdiv(ic, oc_bn), kh, kw, idxmod(ic, oc_bn), ocb] + + if out_dtype is not None: + data_pad_ = data_pad_.astype(out_dtype) + kernel_ = kernel_.astype(out_dtype) + + return te.sum( + data_pad_ * kernel_, + axis=[ic, kh, kw], + ) + + return te.compute( + oshape, + lambda *indices: compute_conv2d(*indices), # pylint: disable=W0108 + name="conv2d_NCHWc_OIHWo", + tag="conv2d_NCHWc_OIHWo", + ) + + def conv2d_NCHWc_int8( data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4 ): diff --git a/python/tvm/topi/sort.py b/python/tvm/topi/sort.py index f75e5db4b9b1..1ee2964ae9b5 100644 --- a/python/tvm/topi/sort.py +++ b/python/tvm/topi/sort.py @@ -105,8 +105,8 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): s = topi.generic.schedule_argsort(out) f = tvm.compile(s, [data, out], "llvm") dev = tvm.cpu() - tvm_data = tvm.nd.array(np_data, dev) - tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), dev) + tvm_data = tvm.runtime.tensor(np_data, dev) + tvm_out = tvm.runtime.tensor(np.zeros(dshape, dtype=data.dtype), dev) f(tvm_data, tvm_out) """ data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py index 449c599deaf3..9206e876a15a 100644 --- a/python/tvm/topi/tensor.py +++ b/python/tvm/topi/tensor.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition """Elementwise operators""" +import math as _math + from typing import Optional from tvm import te @@ -57,6 +59,13 @@ def full(shape, dtype, fill_value): y : tvm.te.Tensor The result. """ + + if isinstance(fill_value, (int, float)) and ( + _math.isinf(fill_value) or _math.isnan(fill_value) + ): + if not ("float" in dtype or "bfloat16" in dtype): + raise ValueError("Infinite and NaN require a floating-point dtype.") + return cpp.full(shape, dtype, fill_value) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 98cec99a09b7..db09aed05a3c 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -736,7 +736,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): return cpp.sequence_mask(data, valid_length, mask_value, axis) -def ndarray_size(array, dtype="int32"): +def tensor_size(array, dtype="int32"): """Get the number of elements of input array Parameters @@ -752,7 +752,7 @@ def ndarray_size(array, dtype="int32"): result : tvm.te.Tensor The resulting tensor. """ - return cpp.ndarray_size(array, dtype) + return cpp.tensor_size(array, dtype) def where(condition, x, y): diff --git a/python/tvm/topi/vision/__init__.py b/python/tvm/topi/vision/__init__.py new file mode 100644 index 000000000000..f12758bb9c0a --- /dev/null +++ b/python/tvm/topi/vision/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Vision operators.""" +from .nms import * diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py new file mode 100644 index 000000000000..f4aae45ef9c5 --- /dev/null +++ b/python/tvm/topi/vision/nms.py @@ -0,0 +1,500 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-error, invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements, too-many-function-args +"""Non-maximum suppression operator""" +import tvm +from tvm import te + +from tvm.tir import if_then_else + +from ..sort import argsort +from ..math import cast +from ..transform import reshape, gather +from .. import reduction +from ..scan import cumsum +from .nms_util import ( + binary_search, + collect_selected_indices, + collect_selected_indices_and_scores, + run_all_class_nms, +) + + +def get_valid_counts( + data, score_threshold=0, id_index=0, score_index=1 +): # pylint: disable=unused-argument + """Get valid count of bounding boxes given a score threshold. + Also moves valid boxes to the top of input data. + Parameters + ---------- + data : tvm.te.Tensor + Input data. 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. + score_threshold : optional, float + Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + score_index: optional, int + Index of the scores/confidence of boxes. + Returns + ------- + valid_count : tvm.te.Tensor + 1-D tensor for valid number of boxes. + out_tensor : tvm.te.Tensor + Rearranged data tensor. + out_indices: tvm.te.Tensor or numpy NDArray + Related index in input data. + """ + if isinstance(score_threshold, (float, int)): + score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype) + # id_index_const = tvm.tir.const(id_index, "int32") # Unused + # score_index_const = tvm.tir.const(score_index, "int32") # Unused + return ( + te.compute((data.shape[0],), lambda i: data.shape[1], name="valid_count"), + data, + te.compute((data.shape[0], data.shape[1]), lambda i, j: j, name="out_indices"), + ) + + +def _nms_loop( + ib, + batch_size, + top_k, + iou_threshold, + max_output_size, + valid_count, + on_new_valid_box_func, + on_new_invalidated_box_func, + needs_bbox_check_func, + calc_overlap_func, + out_scores, + num_valid_boxes, + score_threshold=None, +): + def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local): + on_new_valid_box_func(ib, 0, num_valid_boxes_local[0], i, j) + num_valid_boxes_local[0] += 1 + + num_boxes_to_check = nkeep - (j + 1) + + with ib.for_range(0, num_boxes_to_check, name="_k", kind="parallel") as _k: + k = j + 1 + _k + + with ib.if_scope( + tvm.tir.all( + k < nkeep, + out_scores[i, k] > 0, # is the box k still valid? + needs_bbox_check_func(i, j, k), + ) + ): + iou = calc_overlap_func(i, j, k) + + with ib.if_scope(iou >= iou_threshold): + out_scores[i, k] = -1.0 + on_new_invalidated_box_func(i, k) + + with ib.for_range(0, batch_size, name="i") as i: + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + # Use max_output_size directly without if_then_else + # max_output_size = if_then_else(max_output_size > te.const(0), max_output_size, nkeep) + + with ib.if_scope(tvm.tir.all(iou_threshold > te.const(0), valid_count[i] > te.const(0))): + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + num_valid_boxes_local[0] = 0 + + # Use for_range to iterate through all boxes, but limit selection count + with ib.for_range(0, nkeep, name="j") as j: + with ib.if_scope( + tvm.tir.all( + out_scores[i, j] > -1.0, # box is still valid + num_valid_boxes_local[0] < max_output_size, # haven't reached max limit + ) + ): + if score_threshold is not None: + with ib.if_scope(out_scores[i, j] > score_threshold[()]): + nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local) + else: + nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local) + + num_valid_boxes[i] = num_valid_boxes_local[0] + + with ib.else_scope(): + num_valid_boxes[i] = 0 + + return ib.get() + + +def _get_valid_box_count(scores, score_threshold): + batch_classes, num_boxes = scores.shape + + def searchsorted_ir(scores, score_thresh, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + if hasattr(score_threshold, "shape"): + if len(score_threshold.shape) == 0: + score_thresh_scalar = score_thresh[()] + elif len(score_threshold.shape) == 1 and score_threshold.shape[0] > 0: + score_thresh_scalar = score_thresh[0] + else: + score_thresh_scalar = tvm.tir.FloatImm("float32", 0.0) + else: + score_thresh_scalar = score_threshold + binary_search(ib, i, num_boxes, scores, score_thresh_scalar, valid_count) + + return ib.get() + + scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) + searchsorted_buf = tvm.tir.decl_buffer( + (batch_classes,), "int32", "searchsorted", data_alignment=8 + ) + + if hasattr(score_threshold, "shape"): + score_thresh_buf = tvm.tir.decl_buffer( + score_threshold.shape, score_threshold.dtype, "score_thresh_buf", data_alignment=8 + ) + return te.extern( + [(batch_classes,)], + [scores, score_threshold], + lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf, score_thresh_buf], + out_buffers=[searchsorted_buf], + name="searchsorted", + tag="searchsorted", + ) + else: + + def searchsorted_ir_scalar(scores, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + if isinstance(score_threshold, te.Tensor): + if len(score_threshold.shape) == 0: + score_thresh_tir = score_threshold() + elif len(score_threshold.shape) == 1 and score_threshold.shape[0] == 1: + score_thresh_tir = score_threshold[0] + else: + score_thresh_tir = tvm.tir.FloatImm("float32", 0.0) + else: + score_thresh_tir = tvm.tir.FloatImm("float32", float(score_threshold)) + binary_search(ib, i, num_boxes, scores, score_thresh_tir, valid_count) + + return ib.get() + + return te.extern( + [(batch_classes,)], + [scores], + lambda ins, outs: searchsorted_ir_scalar(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf], + out_buffers=[searchsorted_buf], + name="searchsorted", + tag="searchsorted", + ) + + +def _collect_selected_indices_ir( + num_class, selected_indices, num_detections, row_offsets, out, max_output_boxes_per_class=None +): + batch_classes, _ = selected_indices.shape + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + out = ib.buffer_ptr(out) + + # Initialize output buffer to zero + # Calculate the actual output shape based on max_output_boxes_per_class + if isinstance(max_output_boxes_per_class, int): + max_output_rows = batch_classes * max_output_boxes_per_class + else: + # Fallback to a reasonable default if max_output_boxes_per_class is not an integer + max_output_rows = batch_classes * 10 + with ib.for_range(0, max_output_rows, name="init_i") as init_i: + with ib.for_range(0, 3, name="init_j") as init_j: # 3 columns + out[init_i, init_j] = cast(0, "int64") + + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + if isinstance(max_output_boxes_per_class, int): + limit = tvm.tir.min( + num_detections[i], tvm.tir.IntImm("int32", max_output_boxes_per_class) + ) + elif isinstance(max_output_boxes_per_class, te.Tensor): + if len(max_output_boxes_per_class.shape) == 0: + max_boxes_val = max_output_boxes_per_class[()] + else: + max_boxes_val = max_output_boxes_per_class[0] + limit = tvm.tir.min(num_detections[i], max_boxes_val) + else: + limit = num_detections[i] + + with ib.for_range(0, limit, name="j") as j: + out[row_offsets[i] + j, 0] = batch_id + out[row_offsets[i] + j, 1] = class_id + out[row_offsets[i] + j, 2] = cast(selected_indices[i, j], "int64") + + return ib.get() + + +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + zero = cast(0, "int64") + + with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + with ib.for_range(0, num_boxes, name="j") as j: + with ib.if_scope(j < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + j + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64") + collected_scores[batch_id, offset] = selected_scores[i, j] + with ib.else_scope(): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + j + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + +def all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="onnx", + output_shape=None, +): + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + scores: tvm.te.Tensor + 3-D tensor with shape (batch_size, num_classes, num_boxes) + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + score_threshold : float or tvm.te.Tensor, optional + Score threshold to filter out low score boxes early + output_format : str, optional + "onnx" or "tensorflow", see below. + Returns + ------- + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. + + .. note:: + **Important**: The output tensor has a fixed size based on `max_output_boxes_per_class`, + but only the first `num_total_detection` rows contain valid data. The remaining rows + may contain garbage values. When comparing with ONNX Runtime or other implementations + that output dynamic shapes, you should only compare the first + `num_total_detection` rows. + Example: + ```python + selected_indices, valid_count = nms_output + actual_count = int(valid_count.numpy()[0]) + valid_indices = selected_indices.numpy()[:actual_count, :] + ``` + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. + """ + batch, num_class, num_boxes = scores.shape + scores = reshape(scores, (batch * num_class, num_boxes)) + + sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") + sorted_scores = gather(scores, 1, sorted_indices) + + if not isinstance(score_threshold, te.Tensor): + score_threshold_tensor = te.compute((), lambda: score_threshold, name="score_threshold") + else: + score_threshold_tensor = score_threshold + + valid_count = _get_valid_box_count(sorted_scores, score_threshold_tensor) + + selected_indices, selected_scores, num_detections = run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_boxes_per_class, + iou_threshold, + _nms_loop, + return_scores=(output_format == "tensorflow"), + score_threshold=score_threshold_tensor, # Passed score_threshold as tensor + ) + + if output_format == "onnx": + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + + def _sum_clamped_total(): + if isinstance(max_output_boxes_per_class, int): + k_expr = tvm.tir.IntImm("int32", int(max_output_boxes_per_class)) + clamped = te.compute( + num_detections.shape, + lambda i: tvm.tir.min(num_detections[i], k_expr), + name="clamped_num", + ) + return reduction.sum(cast(clamped, "int64"), axis=0) + if isinstance(max_output_boxes_per_class, tvm.tir.IntImm): + k_expr = tvm.tir.Cast("int32", max_output_boxes_per_class) + clamped = te.compute( + num_detections.shape, + lambda i: tvm.tir.min(num_detections[i], k_expr), + name="clamped_num", + ) + return reduction.sum(cast(clamped, "int64"), axis=0) + if isinstance(max_output_boxes_per_class, te.Tensor): + if len(max_output_boxes_per_class.shape) == 0: + kb = te.compute( + num_detections.shape, + lambda i: cast(max_output_boxes_per_class, "int32"), + name="k_broadcast", + ) + elif ( + len(max_output_boxes_per_class.shape) == 1 + and max_output_boxes_per_class.shape[0] == 1 + ): + kb = te.compute( + num_detections.shape, + lambda i: cast(max_output_boxes_per_class[0], "int32"), + name="k_broadcast", + ) + else: + return reduction.sum(cast(num_detections, "int64"), axis=0) + + clamped = te.compute( + num_detections.shape, + lambda i: tvm.tir.min(num_detections[i], kb[i]), + name="clamped_num", + ) + return reduction.sum(cast(clamped, "int64"), axis=0) + return reduction.sum(cast(num_detections, "int64"), axis=0) + + num_total_scalar = _sum_clamped_total() + num_total_detections = reshape(num_total_scalar, (1,)) + + if output_shape is not None: + selected_indices = collect_selected_indices( + num_class, + selected_indices, + num_detections, + row_offsets, + _collect_selected_indices_ir, + max_output_boxes_per_class=max_output_boxes_per_class, + output_shape=output_shape, + ) + else: + # Use num_total_detections to enable dynamic trimming + # Pass image size for intelligent default estimation + input_image_size = None + if hasattr(scores, "shape") and len(scores.shape) >= 3: + # Extract image size from scores shape: (batch, num_classes, num_boxes) + # We can estimate image size from num_boxes (more boxes = larger image) + input_image_size = (scores.shape[2],) # Use num_boxes as proxy for image size + + # TODO: Improve image size estimation by: + # 1. Accepting actual image dimensions as parameters + # 2. Using model metadata to infer typical image sizes + # 3. Learning from historical detection patterns + # 4. Providing user-configurable estimation strategies + + selected_indices = collect_selected_indices( + num_class, + selected_indices, + num_detections, + row_offsets, + _collect_selected_indices_ir, + max_output_boxes_per_class=max_output_boxes_per_class, + num_total_detections=num_total_detections, + input_image_size=input_image_size, + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) + row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1) + num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1) + + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, + ) + + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py new file mode 100644 index 000000000000..1633c923e17f --- /dev/null +++ b/python/tvm/topi/vision/nms_util.py @@ -0,0 +1,473 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Common utilities used in Non-maximum suppression operators""" +import tvm +from tvm import te + + +def _get_boundaries(output, box_idx): + l = tvm.te.min( + output[box_idx], + output[box_idx + 2], + ) + t = tvm.te.min( + output[box_idx + 1], + output[box_idx + 3], + ) + r = tvm.te.max( + output[box_idx], + output[box_idx + 2], + ) + b = tvm.te.max( + output[box_idx + 1], + output[box_idx + 3], + ) + return l, t, r, b + + +def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes.""" + a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx) + + # Overlapping width and height + w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) + h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) + + # Overlapping area + area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return tvm.tir.Select(u <= 0.0, 0.0, area / u) + + +def binary_search(ib, y, num_boxes, scores, score_threshold, out): + """Binary search for score_threshold on scores sorted in descending order""" + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = num_boxes.astype("int32") + + with ib.while_loop(lo[0] < hi[0]): + mid = (hi[0] + lo[0]) >> 1 + with ib.if_scope(scores[y, mid] > score_threshold): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + out[y] = lo[0] + + +def _estimate_max_detections(batch_class, input_image_size=None): + """Estimate maximum detections based on input image size and number of classes. + + This provides a more intelligent default for production environments. + """ + if input_image_size is not None: + # Estimate based on image size: larger images typically have more objects + if len(input_image_size) >= 2: + height, width = input_image_size[-2], input_image_size[-1] + total_pixels = height * width + + # Base estimation per class based on image size + if total_pixels < 300000: # Small images (< 300k pixels) + base_detections_per_class = min(50, max(10, total_pixels // 2000)) + elif total_pixels < 1000000: # Medium images (< 1M pixels) + base_detections_per_class = min(100, max(25, total_pixels // 3000)) + else: # Large images (>= 1M pixels) + base_detections_per_class = min(200, max(50, total_pixels // 4000)) + + # Scale down for many classes (more realistic for multi-class scenarios) + if batch_class > 20: + # For many classes, reduce per-class detections to avoid explosion + detections_per_class = min(base_detections_per_class, 50) + else: + detections_per_class = base_detections_per_class + else: + detections_per_class = 50 # fallback + else: + # Fallback to class-based estimation + if batch_class == 1: + detections_per_class = 100 # Single class detection + elif batch_class <= 10: + detections_per_class = 50 # Small multi-class + else: + detections_per_class = 25 # Large multi-class (COCO-like) + + return batch_class * detections_per_class + + +def collect_selected_indices( + num_class, + selected_indices, + num_detections, + row_offsets, + ir, + max_output_boxes_per_class=None, + output_shape=None, + num_total_detections=None, + input_image_size=None, +): + """Collect selected indices from the core NMS loop into one linear output + Parameters + ---------- + num_class : int + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + num_detections tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes selected by the core NMS loop, per batch and class + row_offsets tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), this should be the exclusive scan + of num_detections + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + Returns + ------- + out : tvm.te.Tensor + The output is indices of size (batch_size * num_class* num_boxes , 3). + Rows of indices are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. + """ + batch_class, num_boxes = selected_indices.shape + + if output_shape is not None: + return te.extern( + [output_shape], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + # TODO: Implement dynamic trimming based on num_total_detections + if num_total_detections is not None: + if isinstance(max_output_boxes_per_class, int): + out_rows = batch_class * max_output_boxes_per_class + else: + # Smart fallback based on input image size and typical production scenarios + out_rows = _estimate_max_detections(batch_class, input_image_size) + + return te.extern( + [(out_rows, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + if isinstance(max_output_boxes_per_class, int): + out_rows = batch_class * max_output_boxes_per_class + return te.extern( + [(out_rows, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + if isinstance(max_output_boxes_per_class, te.Tensor): + try: + if len(max_output_boxes_per_class.shape) == 0: + max_boxes_val = int(max_output_boxes_per_class.data.numpy()) + elif ( + len(max_output_boxes_per_class.shape) == 1 + and max_output_boxes_per_class.shape[0] == 1 + ): + max_boxes_val = int(max_output_boxes_per_class.data.numpy()[0]) + else: + max_boxes_val = num_boxes + except (ValueError, IndexError, AttributeError): + max_boxes_val = num_boxes + + out_rows = batch_class * max_boxes_val + return te.extern( + [(out_rows, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + return te.extern( + [(batch_class * num_boxes, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir( + num_class, ins[0], ins[1], ins[2], outs[0], max_output_boxes_per_class + ), + dtype=["int64"], + name="collect_indices", + tag="collect_indices", + ) + + +def collect_selected_indices_and_scores( + selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir +): + """Collect selected indices and scores from the core NMS loop into one linear output + Parameters + ---------- + num_class : int + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the scores + of selected boxes by the core NMS loop. + num_detections tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), representing + the number of boxes selected by the core NMS loop, per batch and class + row_offsets tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), this should be the exclusive scan + of num_detections along axis 1 + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors. The first is indices of size + (batch_size, num_class* num_boxes, 2), and the second is scores of size + (batch_size, num_class* num_boxes). + """ + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + return te.extern( + [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], + [selected_indices, selected_scores, num_detections, row_offsets, num_total_detections], + lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1]), + dtype=["int64", "float32"], + name="collect_indices_and_scores", + tag="collect_indices_and_scores", + ) + + +def _all_class_nms_ir( + boxes, + sorted_scores, + sorted_indices, + valid_count, + batch_class, + num_class, + num_anchors, + iou_threshold, + max_output_size_per_class, + box_indices, + selected_scores, + num_valid_boxes, + nms_loop, + score_threshold=None, +): + ib = tvm.tir.ir_builder.create() + boxes = ib.buffer_ptr(boxes) + sorted_scores = ib.buffer_ptr(sorted_scores) + sorted_indices = ib.buffer_ptr(sorted_indices) + valid_count = ib.buffer_ptr(valid_count) + box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + + if selected_scores is not None: + selected_scores = ib.buffer_ptr(selected_scores) + + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + elif isinstance(iou_threshold, te.Tensor): + if len(iou_threshold.shape) == 0: + iou_threshold = iou_threshold() + elif len(iou_threshold.shape) == 1 and iou_threshold.shape[0] == 1: + iou_threshold = iou_threshold[0] + else: + iou_threshold = tvm.tir.FloatImm("float32", 0.5) + + if isinstance(max_output_size_per_class, int): + max_output_size_per_class = tvm.tir.const(max_output_size_per_class) + elif isinstance(max_output_size_per_class, te.Tensor): + if len(max_output_size_per_class.shape) == 0: + max_output_size_per_class = max_output_size_per_class() + elif len(max_output_size_per_class.shape) == 1 and max_output_size_per_class.shape[0] == 1: + # Use tensor indexing to get the first element + max_output_size_per_class = max_output_size_per_class[0] + else: + max_output_size_per_class = tvm.tir.const(1000) + + def calc_overlap(i, j, k): + offset_j = sorted_indices[i, j] * 4 + offset_k = sorted_indices[i, k] * 4 + batch_id = i // num_class + base_bbox_idx = batch_id * num_anchors * 4 + return calculate_overlap( + boxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, + ) + + def on_new_valid_box(ib, tid, num_current_valid_box, i, j): + with ib.if_scope(tid + 0 == 0): + box_indices[i, num_current_valid_box] = sorted_indices[i, j] + + if selected_scores is not None: + selected_scores[i, num_current_valid_box] = sorted_scores[i, j] + + def on_new_invalidated_box(*_): + pass + + def needs_bbox_check(*_): + return tvm.tir.const(True) + + return nms_loop( + ib, + batch_class, + tvm.tir.IntImm("int32", -1), # top_k + iou_threshold, + max_output_size_per_class, + valid_count, + on_new_valid_box, + on_new_invalidated_box, + needs_bbox_check, + calc_overlap, + sorted_scores, + num_valid_boxes, + score_threshold, + ) + + +def run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_size_per_class, + iou_threshold, + nms_loop, + return_scores=False, + score_threshold=None, +): + """The core all class NMS routine + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + sorted_scores: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + One of the outputs from argsort + sorted_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + The other output from argsort + valid_count: tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes whose score is above score_threshold, per batch and class + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + nms_loop : function + A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + return_scores : bool, optional + Whether or not to return selected scores, needed by the tensorflow output format. + Returns + ------- + out : a list of tvm.te.Tensor + The output is three tensors, the first and second are indices and scores of size + (batch_size * num_class, num_boxes), and the third is a tensor + num_selected_boxes of shape (batch_size * num_class,) representing the total number of + selected boxes per batch and class. If return_scores is False, the second output is + None. + """ + batch, num_boxes, _ = boxes.shape + batch_class = sorted_scores.shape[0] + num_class = batch_class // batch + + if return_scores is False: + all_class_num0_buf = tvm.tir.decl_buffer( + (batch_class, num_boxes), "int32", "all_class_nms0", data_alignment=8 + ) + all_class_num1_buf = tvm.tir.decl_buffer( + (batch_class,), "int32", "all_class_nms1", data_alignment=8 + ) + extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count] + if score_threshold is not None: + extern_inputs.append(score_threshold) + + selected_indices, num_detections = te.extern( + [(batch_class, num_boxes), (batch_class,)], + extern_inputs, + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + None, # scores + outs[1], # num_selected_boxes + nms_loop, + ins[4] if score_threshold is not None else None, # score_threshold + ), + out_buffers=[all_class_num0_buf, all_class_num1_buf], + dtype=["int32", "int32"], + name="all_class_nms", + tag="all_class_nms", + ) + return selected_indices, None, num_detections + + extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count] + if score_threshold is not None: + extern_inputs.append(score_threshold) + + return te.extern( + [(batch_class, num_boxes), (batch_class, num_boxes), (batch_class,)], + extern_inputs, + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + outs[1], # selected scores + outs[2], # num_selected_boxes + nms_loop, + ins[4] if score_threshold is not None else None, # score_threshold + ), + dtype=["int32", "float32", "int32"], + name="all_class_nms", + tag="all_class_nms", + ) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 9c4220ce29b6..8a32225a9022 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include "./scalable_expression.h" #include "const_fold.h" @@ -38,7 +39,21 @@ Analyzer::Analyzer() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) {} + int_set(this), + z3_prover(this) {} + +std::unique_ptr Analyzer::Clone() const { + auto cloned = std::make_unique(); + // Copy per-sub-analyzer states + cloned->const_int_bound.CopyFrom(this->const_int_bound); + cloned->modular_set.CopyFrom(this->modular_set); + cloned->rewrite_simplify.CopyFrom(this->rewrite_simplify); + cloned->canonical_simplify.CopyFrom(this->canonical_simplify); + cloned->int_set.CopyFrom(this->int_set); + cloned->transitive_comparisons.CopyFrom(this->transitive_comparisons); + cloned->z3_prover.CopyFrom(this->z3_prover); + return cloned; +} void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { PrimExpr new_expr = expr; @@ -51,6 +66,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { this->canonical_simplify.Update(var, new_expr, allow_override); this->int_set.Update(var, this->int_set(new_expr), allow_override); this->transitive_comparisons.Bind(var, expr, allow_override); + this->z3_prover.Bind(var, expr, allow_override); } void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { @@ -61,6 +77,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { this->const_int_bound.Bind(var, range, allow_override); this->int_set.Bind(var, range, allow_override); this->transitive_comparisons.Bind(var, range, allow_override); + this->z3_prover.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify @@ -103,7 +120,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { // We may consider enhance the sub analyzer to directly take // MarkPositiveVar so their bounds do not overlap if (const auto* var_ptr = symbol.as()) { - Var var = GetRef(var_ptr); + Var var = ffi::GetRef(var_ptr); // skip non-index type, keep it to be compatible // with any_dim that do not represent any value if (!IsIndexType(var.dtype())) return; @@ -116,7 +133,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { } } -void Analyzer::Bind(const Map& variables, bool allow_override) { +void Analyzer::Bind(const ffi::Map& variables, bool allow_override) { for (const auto& iter : variables) { this->Bind(iter.first, iter.second, allow_override); } @@ -127,9 +144,10 @@ void ConstraintContext::EnterWithScope() { // entering the scope. recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); - recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_, is_assume_)); recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_, is_assume_)); } void ConstraintContext::ExitWithScope() { @@ -195,14 +213,110 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { } PrimExpr simplified = Simplify(expr); const int64_t* as_int = tir::as_const_int(simplified); - if (as_int && *as_int) return true; + if (as_int && *as_int) { return true; } + + // Structured boolean reasoning for Or/And (and their bitwise counterparts on bool) + // Evaluate children with the same proof strength. + if (const auto* not_node = simplified.as()) { + PrimExpr a = not_node->a; + // Try direct complements on common comparators + if (const auto* p = a.as()) { + return CanProve(tir::GE(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::GT(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::LE(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::LT(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::NE(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::EQ(p->a, p->b), strength); + } + // De Morgan on canonical boolean nodes + if (const auto* or_node = a.as()) { + PrimExpr lhs = tir::Not(or_node->a); + PrimExpr rhs = tir::Not(or_node->b); + return CanProve(tir::And(lhs, rhs), strength); + } + if (const auto* and_node = a.as()) { + PrimExpr lhs = tir::Not(and_node->a); + PrimExpr rhs = tir::Not(and_node->b); + return CanProve(tir::Or(lhs, rhs), strength); + } + // De Morgan on bitwise boolean calls + if (const auto* c = a.as()) { + using namespace tir; + if (c->op.same_as(builtin::bitwise_or()) && c->args.size() == 2 && a.dtype().is_bool()) { + PrimExpr lhs = tir::Not(c->args[0]); + PrimExpr rhs = tir::Not(c->args[1]); + return CanProve(tir::And(lhs, rhs), strength); + } + if (c->op.same_as(builtin::bitwise_and()) && c->args.size() == 2 && a.dtype().is_bool()) { + PrimExpr lhs = tir::Not(c->args[0]); + PrimExpr rhs = tir::Not(c->args[1]); + return CanProve(tir::Or(lhs, rhs), strength); + } + } + if (const auto* inner_not = a.as()) { + // Double negation + return CanProve(inner_not->a, strength); + } + // Fallback: if `a` simplifies to constant false, then Not(a) is true + PrimExpr a_simpl = Simplify(a); + const int64_t* a_const = tir::as_const_int(a_simpl); + if (a_const && *a_const == 0) { return true; } + // Otherwise, cannot conclude true + } + if (const auto* or_node = simplified.as()) { + if (CanProve(or_node->a, strength)) { + return true; + } + if (CanProve(or_node->b, strength)) { + return true; + } + } + if (const auto* and_node = simplified.as()) { + bool lhs = CanProve(and_node->a, strength); + bool rhs = CanProve(and_node->b, strength); + if (lhs && rhs) { + return true; + } + } + if (const auto* call = simplified.as()) { + using namespace tir; + if (call->op.same_as(builtin::bitwise_or()) && call->args.size() == 2 && + simplified.dtype().is_bool()) { + if (CanProve(call->args[0], strength) || CanProve(call->args[1], strength)) { + return true; + } + } + if (call->op.same_as(builtin::bitwise_and()) && call->args.size() == 2 && + simplified.dtype().is_bool()) { + bool lhs = CanProve(call->args[0], strength); + bool rhs = CanProve(call->args[1], strength); + if (lhs && rhs) { + return true; + } + } + if (call->op.same_as(builtin::bitwise_not()) && call->args.size() == 1 && + simplified.dtype().is_bool()) { + // Treat as logical not and reuse Not handling by constructing tir::Not + return CanProve(tir::Not(call->args[0]), strength); + } + } if (strength >= ProofStrength::kSymbolicBound) { // NOTE: we intentionally only pattern match common bound predicate i < bound // and put this implementation at the top-level. // This is to avoid repeatitive calling of this function // that causes speed issues. // This strategy can only be called from top-level and not from sub-analyzers. - Optional pos_diff; + ffi::Optional pos_diff; int lower_bound = 0; if (const auto* ptr_lt = expr.as()) { pos_diff = ptr_lt->b - ptr_lt->a; @@ -221,11 +335,16 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { lower_bound = 0; } if (pos_diff) { - IntSet iset = this->int_set(this->Simplify(pos_diff.value())); + PrimExpr simplified_diff = this->Simplify(pos_diff.value()); + IntSet iset = this->int_set(simplified_diff); if (iset.HasLowerBound()) { ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min())); if (relaxed_lower_bound->min_value >= lower_bound) return true; } + if (iset.HasUpperBound()) { + ConstIntBound relaxed_upper_bound = this->const_int_bound(this->Simplify(iset.max())); + if (relaxed_upper_bound->max_value < lower_bound) return false; + } } } @@ -238,14 +357,41 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { if (ContainsVscaleCall(simplified)) { if (TargetHasVLA(curr_target)) { auto kVScaleValues = GetVScaleValues(curr_target); - return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues); + if(CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues)) { + return true; + } } - LOG(WARNING) - << "The expression contains scalable values. An attempt to prove by substituting " - "with known values of vscale was not performed. This proof currently only supports " - "VLA targets, but the target was " - << curr_target; + // LOG(WARNING) + // << "The expression contains scalable values. An attempt to prove by substituting " + // "with known values of vscale was not performed. This proof currently only supports " + // "VLA targets, but the target was " + // << curr_target; + } + if(z3_prover.CanProve(simplified)) { + // auto msg = z3_prover.GetSMTLIB2(simplified); + // std::stringstream ss; + // ss << msg; + // std::stringstream out; + // std::string tmp; + // while(std::getline(ss, tmp)) { + // out << " " << tmp << "\n"; + // } + // LOG(INFO) << "Proved by Z3: " << simplified << "\n" << out.str(); + return true; } + // if(strength >= ProofStrength::kSymbolicBound && z3_prover.CanProve(simplified)) { + // // The following debug logging is very useful when diagnosing issues with the Z3 prover. + // auto msg = z3_prover.GetSMTLIB2(simplified); + // std::stringstream ss; + // ss << msg; + // std::stringstream out; + // std::string tmp; + // while(std::getline(ss, tmp)) { + // out << " " << tmp << "\n"; + // } + // LOG(INFO) << "Proved by Z3: " << simplified << "\n" << out.str(); + // return true; + // } return false; } @@ -270,102 +416,152 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } -TVM_FFI_STATIC_INIT_BLOCK({ +std::function Analyzer::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + // Entering the scope. + std::vector> recovery_functions; + recovery_functions.push_back(this->const_int_bound.EnterConstraint(constraint)); + recovery_functions.push_back(this->modular_set.EnterConstraint(constraint)); + recovery_functions.push_back(this->rewrite_simplify.EnterConstraint(constraint, is_assume)); + recovery_functions.push_back(this->int_set.EnterConstraint(constraint)); + recovery_functions.push_back(this->transitive_comparisons.EnterConstraint(constraint)); + recovery_functions.push_back(this->z3_prover.EnterConstraint(constraint)); + + return [recovery_functions]() mutable { + // Exiting the scope. + while (recovery_functions.size()) { + auto& func = recovery_functions.back(); + if (func) { + func(); + } + recovery_functions.pop_back(); + } + }; +} + +namespace { +using FnFactory = tvm::ffi::TypedFunction; +static FnFactory BuildAnalyzerFactory(std::shared_ptr self) { + using tvm::ffi::Function; + return FnFactory([self](std::string name) -> Function { + if (name == "const_int_bound") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->const_int_bound(args[0].cast()); + }); + } else if (name == "modular_set") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->modular_set(args[0].cast()); + }); + } else if (name == "clone") { + return Function([self](tvm::ffi::PackedArgs, tvm::ffi::Any* ret) { + auto cloned_unique = self->Clone(); + auto cloned = std::shared_ptr(cloned_unique.release()); + *ret = BuildAnalyzerFactory(cloned); + }); + } else if (name == "const_int_bound_update") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + self->const_int_bound.Update(args[0].cast(), args[1].cast(), + args[2].cast()); + }); + } else if (name == "const_int_bound_is_bound") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->const_int_bound.IsBound(args[0].cast()); + }); + } else if (name == "Simplify") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + if (args.size() == 1) { + *ret = self->Simplify(args[0].cast()); + } else if (args.size() == 2) { + *ret = self->Simplify(args[0].cast(), args[1].cast()); + } else { + LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; + } + }); + } else if (name == "rewrite_simplify") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->rewrite_simplify(args[0].cast()); + }); + } else if (name == "get_rewrite_simplify_stats") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + self->rewrite_simplify.ResetStatsCounters(); + }); + } else if (name == "canonical_simplify") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->canonical_simplify(args[0].cast()); + }); + } else if (name == "int_set") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->int_set(args[0].cast(), args[1].cast>()); + }); + } else if (name == "bind") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + if (auto opt_range = args[1].try_cast()) { + self->Bind(args[0].cast(), opt_range.value()); + } else { + self->Bind(args[0].cast(), args[1].cast()); + } + }); + } else if (name == "can_prove") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + int strength = args[1].cast(); + *ret = self->CanProve(args[0].cast(), static_cast(strength)); + }); + } else if (name == "enter_constraint_context") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + auto ctx = std::shared_ptr>( + new With(self.get(), args[0].cast())); + auto fexit = [ctx](tvm::ffi::PackedArgs, tvm::ffi::Any*) mutable { ctx.reset(); }; + *ret = tvm::ffi::Function::FromPacked(fexit); + }); + } else if (name == "can_prove_equal") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); + }); + } else if (name == "get_enabled_extensions") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); + }); + } else if (name == "set_enabled_extensions") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + int64_t flags = args[0].cast(); + self->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); + } else if (name == "get_smtlib2") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + auto expr = args[0].cast>(); + *ret = self->z3_prover.GetSMTLIB2(expr); + }); + } else if (name == "get_z3_stats") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->z3_prover.GetStats(); + }); + } else if (name == "set_z3_timeout_ms") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + unsigned timeout_ms = args[0].cast(); + self->z3_prover.SetTimeoutMs(timeout_ms); + }); + } else if (name == "set_z3_rlimit") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + unsigned max_step = args[0].cast(); + self->z3_prover.SetRLimit(max_step); + }); + } + return Function(); + }); +} +} // namespace + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs args, ffi::Any* ret) { - using ffi::Function; - using ffi::TypedFunction; + refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs, ffi::Any* ret) { auto self = std::make_shared(); - auto f = [self](std::string name) -> ffi::Function { - if (name == "const_int_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound(args[0].cast()); - }); - } else if (name == "modular_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->modular_set(args[0].cast()); - }); - } else if (name == "const_int_bound_update") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->const_int_bound.Update(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - } else if (name == "const_int_bound_is_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound.IsBound(args[0].cast()); - }); - } else if (name == "Simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (args.size() == 1) { - *ret = self->Simplify(args[0].cast()); - } else if (args.size() == 2) { - *ret = self->Simplify(args[0].cast(), args[1].cast()); - } else { - LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; - } - }); - } else if (name == "rewrite_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify(args[0].cast()); - }); - } else if (name == "get_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify.GetStatsCounters(); - }); - } else if (name == "reset_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->rewrite_simplify.ResetStatsCounters(); - }); - } else if (name == "canonical_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->canonical_simplify(args[0].cast()); - }); - } else if (name == "int_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->int_set(args[0].cast(), args[1].cast>()); - }); - } else if (name == "bind") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (auto opt_range = args[1].try_cast()) { - self->Bind(args[0].cast(), opt_range.value()); - } else { - self->Bind(args[0].cast(), args[1].cast()); - } - }); - } else if (name == "can_prove") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int strength = args[1].cast(); - *ret = self->CanProve(args[0].cast(), static_cast(strength)); - }); - } else if (name == "enter_constraint_context") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr>( - new With(self.get(), args[0].cast())); - auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; - *ret = ffi::Function::FromPacked(fexit); - }); - } else if (name == "can_prove_equal") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); - }); - } else if (name == "get_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); - }); - } else if (name == "set_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int64_t flags = args[0].cast(); - self->rewrite_simplify.SetEnabledExtensions( - static_cast(flags)); - }); - } - return ffi::Function(); - }; - *ret = ffi::TypedFunction(f); + *ret = BuildAnalyzerFactory(self); }); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index f7720095eb2d..eb9edca36341 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -390,8 +390,8 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. -IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, - const Map& relax_map) { +IntSet DeduceBound(PrimExpr v, PrimExpr e, const ffi::Map& hint_map, + const ffi::Map& relax_map) { std::unordered_map hmap; for (auto kv : hint_map) { hmap[kv.first.get()] = kv.second; @@ -403,13 +403,14 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, return DeduceBound(v, e, hmap, rmap); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "arith.DeduceBound", - [](PrimExpr v, PrimExpr cond, const Map hint_map, - const Map relax_map) { return DeduceBound(v, cond, hint_map, relax_map); }); -}); + refl::GlobalDef().def("arith.DeduceBound", + [](PrimExpr v, PrimExpr cond, const ffi::Map hint_map, + const ffi::Map relax_map) { + return DeduceBound(v, cond, hint_map, relax_map); + }); +} } // namespace arith } // namespace tvm diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 7a02a3bedba8..66f8af178a17 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -52,9 +52,8 @@ class CanonicalExprNode : public PrimExprNode { */ virtual PrimExpr Normalize() const = 0; - static constexpr const char* _type_key = "arith.CanonicalExpr"; static constexpr const uint32_t _type_child_slots = 2; - TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("arith.CanonicalExpr", CanonicalExprNode, PrimExprNode); }; inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { @@ -204,13 +203,12 @@ class SplitExprNode : public CanonicalExprNode { /*! \brief positive infty */ static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; - static constexpr const char* _type_key = "arith.SplitExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.SplitExpr", SplitExprNode, CanonicalExprNode); }; class SplitExpr : public PrimExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SplitExpr, PrimExpr, SplitExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode); }; @@ -390,9 +388,7 @@ class SumExprNode : public CanonicalExprNode { } this->dtype = dtype; } - - static constexpr const char* _type_key = "arith.SumExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.SumExpr", SumExprNode, CanonicalExprNode); private: /*! @@ -524,7 +520,7 @@ class SumExprNode : public CanonicalExprNode { class SumExpr : public PrimExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SumExpr, PrimExpr, SumExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode); }; @@ -680,7 +676,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { expr = op->Normalize(); } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = expr.dtype(); n->index = std::move(expr); n->div_mode = kTruncDiv; @@ -717,7 +713,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (auto op = expr.as()) { return op.value(); } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = expr.dtype(); if (const auto* op = expr.as()) { n->base = op->value; @@ -816,7 +812,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { const MulNode* mul = ret.as(); if (mul && mul->a.same_as(op->a) && mul->b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ret; } @@ -825,8 +821,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible) { - auto divisible = make_object(); - auto non_divisible = make_object(); + auto divisible = ffi::make_object(); + auto non_divisible = ffi::make_object(); divisible->dtype = psum->dtype; non_divisible->dtype = psum->dtype; @@ -894,7 +890,7 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // we just skip to save the time if (prhs->as()) return false; // collect lhs products and try to eliminate by matching them to prod in rhs - Array> lhs_prods; + ffi::Array> lhs_prods; PrimExpr new_rhs = make_const(prhs->dtype(), 1); PrimExpr new_common_scale = make_const(prhs->dtype(), 1); int64_t lhs_cscale = 1, rhs_cscale = 1; @@ -939,7 +935,7 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // construct prod via canonical form PrimExpr new_lhs = make_const(plhs->dtype(), 1); - for (Optional val : lhs_prods) { + for (ffi::Optional val : lhs_prods) { if (val.defined()) new_lhs = new_lhs * val.value(); } *plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale); @@ -1006,7 +1002,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { return truncdiv(a, b); } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Div(a, b); } @@ -1066,7 +1062,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { return floordiv(a, b); } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorDiv(a, b); } @@ -1194,7 +1190,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Mod(a, b); } @@ -1259,7 +1255,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorMod(a, b); } @@ -1268,7 +1264,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // Simplify reduce expression. PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) { // First simplify the results - Array simplified_result; + ffi::Array simplified_result; for (const auto& res : op->combiner->result) { PrimExpr new_res = this->VisitExpr(res); simplified_result.push_back(new_res); @@ -1311,12 +1307,12 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) } int new_value_index = op->value_index; - Array new_result; - Array new_identity; - Array new_lhs; - Array new_rhs; - Array new_source; - Array new_init; + ffi::Array new_result; + ffi::Array new_identity; + ffi::Array new_lhs; + ffi::Array new_rhs; + ffi::Array new_source; + ffi::Array new_init; // new stuff is old stuff which is used for (size_t i = 0; i < used.size(); ++i) { @@ -1450,3 +1446,16 @@ CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } } // namespace arith } // namespace tvm + +// After class implementations have been defined above +namespace tvm { +namespace arith { + +// Deep copy internal state from another analyzer +void CanonicalSimplifier::CopyFrom(const CanonicalSimplifier& other) { + // Impl derives from RewriteSimplifier::Impl, reuse its copying logic + this->impl_->CopyFromImpl(*other.impl_); +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 2c905dd563ef..5118204db69c 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -48,7 +48,7 @@ namespace arith { * \return std::nullopt if constant fold fails, otherwise return folded result. */ template -inline Optional TryConstFold(PrimExpr a, PrimExpr b); +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b); /*! * \brief Try to run unary compute with constant folding. @@ -60,7 +60,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b); * \return std::nullopt if constant fold fails, otherwise return folded result. */ template -inline Optional TryConstFold(PrimExpr a); +inline ffi::Optional TryConstFold(PrimExpr a); /*! * \brief Check whether type is used to represent index. @@ -128,7 +128,7 @@ inline double GetFoldResultDoubleRepr(float x) { // specialization of constant folders. template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -152,7 +152,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && (pb && pb->dtype.is_uint() && pb->value > 0U))) @@ -178,7 +178,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -214,7 +214,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -250,7 +250,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -270,7 +270,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -305,7 +305,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -325,7 +325,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); @@ -336,7 +336,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); @@ -347,61 +347,61 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value); }); return std::nullopt; } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value); }); return std::nullopt; } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value); }); return std::nullopt; } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value); }); return std::nullopt; } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value); }); return std::nullopt; } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value); }); return std::nullopt; } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return b; @@ -412,7 +412,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return a; @@ -423,10 +423,10 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a) { +inline ffi::Optional TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { - return IntImm(DataType::UInt(1), !(pa->value)); + return IntImm(DataType::Bool(), !(pa->value)); } return std::nullopt; } diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index c2dd8f120a99..96ba778dd894 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -33,16 +33,17 @@ #include "int_operator.h" #include "pattern_match.h" #include "scalable_expression.h" +#include "tvm/tir/op_attr_types.h" namespace tvm { namespace arith { using namespace tir; -TVM_FFI_STATIC_INIT_BLOCK({ ConstIntBoundNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ConstIntBoundNode::RegisterReflection(); } ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { - auto node = make_object(); + auto node = ffi::make_object(); node->min_value = min_value; node->max_value = max_value; data_ = std::move(node); @@ -52,10 +53,10 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.ConstIntBound", MakeConstIntBound); -}); +} inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { @@ -102,6 +103,12 @@ struct ConstIntBoundAnalyzer::Entry { class ConstIntBoundAnalyzer::Impl : public ExprFunctor { public: + explicit Impl(Analyzer* parent) : parent_(parent) {} + void CopyFrom(const Impl& other) { + this->var_map_ = other.var_map_; + this->additional_info_ = other.additional_info_; + this->bound_ = nullptr; + } /*! \brief additional bound info about expr in bound */ struct BoundInfo { /*! \brief The expr */ @@ -224,8 +231,11 @@ class ConstIntBoundAnalyzer::Impl * \param divisor The input divsor entry * \return The processed entry */ - Entry AssumeNoZeroDivisor(Entry divisor) { - ICHECK(!divisor.is_const(0)) << "Find divide by zero"; + std::optional AssumeNoZeroDivisor(Entry divisor) { + // If divisor is constant zero, return nullopt to signal fallback + if (divisor.is_const(0)) { + return std::nullopt; + } // NOTE: here we make the assumption that // divide by zero won't happen in a valid program // this is important for us to get a lot of symbolic shape bound right @@ -268,16 +278,49 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); - Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->dtype, InfAwareDiv); + auto b = AssumeNoZeroDivisor(VisitExpr(op->b)); + if (!b.has_value()) { + return Everything(op->dtype); + } + return HandleDivision(a, b.value(), op->dtype, InfAwareDiv); } Entry VisitExpr_(const ModNode* op) final { Entry a = VisitExpr(op->a); - Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); + auto b_opt = AssumeNoZeroDivisor(VisitExpr(op->b)); + if (!b_opt.has_value()) { + return Everything(op->dtype); + } + Entry b = b_opt.value(); if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + + // Try to get tighter bounds using modular set information + if (parent_ && b.min_value == b.max_value) { + ModularSet mod_a = parent_->modular_set(op->a); + int64_t modulus = b.min_value; + int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus); + + // If gcd_coeff_mod > 1, we can get tighter bounds + // The result will be of the form gcd_coeff_mod * k + (base % modulus) + // where k ranges to cover [0, modulus - gcd_coeff_mod] + // + // Example: expr = (bx * 2048 + tx * 16) % 7168 + // where bx in [0, 3584), tx in [0, 128) + // GCD(16, 7168) = 16 + // Result can only be {0, 16, 32, ..., 7152} + // Without this optimization: bound = [0, 7167] + // With this optimization: bound = [0, 7152] + if (gcd_coeff_mod > 1) { + int64_t base_mod = mod_a->base % modulus; + if (base_mod < 0) base_mod += modulus; + int64_t tight_max = modulus - gcd_coeff_mod + base_mod; + if (tight_max >= modulus) tight_max -= modulus; + return MakeBound(base_mod, tight_max); + } + } + if (a.min_value >= 0) { // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; @@ -288,7 +331,6 @@ class ConstIntBoundAnalyzer::Impl std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); } } else { - ICHECK(!b.is_const(0)) << "mod by zero"; // mod by negative value is rare, // and we just use the simpliest rule. return Everything(op->dtype); @@ -297,8 +339,11 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const FloorDivNode* op) final { Entry a = VisitExpr(op->a); - Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->dtype, InfAwareFloorDiv); + auto b = AssumeNoZeroDivisor(VisitExpr(op->b)); + if (!b.has_value()) { + return Everything(op->dtype); + } + return HandleDivision(a, b.value(), op->dtype, InfAwareFloorDiv); } Entry VisitExpr_(const FloorModNode* op) final { @@ -320,10 +365,40 @@ class ConstIntBoundAnalyzer::Impl * That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1) */ Entry a = VisitExpr(op->a); - Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); + auto b_opt = AssumeNoZeroDivisor(VisitExpr(op->b)); + if (!b_opt.has_value()) { + return Everything(op->dtype); + } + Entry b = b_opt.value(); if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + // Try to get tighter bounds using modular set information + if (parent_ && b.min_value == b.max_value) { + ModularSet mod_a = parent_->modular_set(op->a); + int64_t modulus = b.min_value; + int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus); + + // If gcd_coeff_mod > 1, we can get tighter bounds + // The result will be of the form gcd_coeff_mod * k + (base % modulus) + // where k ranges to cover [0, modulus - gcd_coeff_mod] + // + // Example: expr = (bx * 2048 + tx * 16) % 7168 + // where bx in [0, 3584), tx in [0, 128) + // ModularSet(expr) = 16*k (coeff=16, base=0) + // GCD(16, 7168) = 16 + // Result can only be {0, 16, 32, ..., 7152} + // Without this optimization: bound = [0, 7167] + // With this optimization: bound = [0, 7152] + if (gcd_coeff_mod > 1) { + int64_t base_mod = mod_a->base % modulus; + if (base_mod < 0) base_mod += modulus; + int64_t tight_max = modulus - gcd_coeff_mod + base_mod; + if (tight_max >= modulus) tight_max -= modulus; + return MakeBound(base_mod, tight_max); + } + } + if (a.min_value >= 0) { // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; @@ -333,7 +408,6 @@ class ConstIntBoundAnalyzer::Impl return MakeBound(0, b_max_cap); } } else { - ICHECK(!b.is_const(0)) << "floormod by zero"; int64_t b_min_cap = InfAwareAdd(b.min_value, 1); int64_t b_max_cap = InfAwareAdd(b.max_value, -1); return Intersect(MakeBound(std::min(static_cast(0), b_min_cap), @@ -377,6 +451,10 @@ class ConstIntBoundAnalyzer::Impl return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); + } else if (op->op.same_as(tir::builtin::bitwise_or())) { + return VisitBitwiseOr(op); + } else if (op->op.same_as(tir::builtin::bitwise_xor())) { + return VisitBitwiseXor(op); } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasVLA(curr_target)) { auto kVScaleValues = GetVScaleValues(curr_target); unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); @@ -387,7 +465,7 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); + Var v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return it->second; @@ -397,7 +475,7 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitExpr_(const SizeVarNode* op) final { - SizeVar v = GetRef(op); + SizeVar v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return it->second; @@ -443,21 +521,83 @@ class ConstIntBoundAnalyzer::Impl } } + Entry VisitBitwiseOr(const CallNode* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + // For non-negative operands, OR result is also non-negative and + // bounded by (1<= 0 && b.min_value >= 0) { + auto bit_width = [](int64_t v) { + if (v <= 0) return 0; + int bw = 0; + while (v) { + ++bw; + v >>= 1; + } + return bw; + }; + int bw_a = bit_width(a.max_value); + int bw_b = bit_width(b.max_value); + int k = std::max(bw_a, bw_b); + if (k >= 63) { + return Everything(op->dtype); + } + int64_t ub = (static_cast(1) << k) - 1; + return MakeBound(0, ub); + } + return Everything(op->dtype); + } + + Entry VisitBitwiseXor(const CallNode* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + // For non-negative operands (common for index math), + // the result is within [0, (1 << k) - 1], where k is the maximum + // number of bits required to represent either operand's upper bound. + // This is a conservative but safe bound and is sufficient for layout + // index computations. + if (a.min_value >= 0 && b.min_value >= 0) { + // Compute bit width of the larger upper bound; cap at 63 to avoid UB. + auto bit_width = [](int64_t v) { + if (v <= 0) return 0; + int bw = 0; + while (v) { + ++bw; + v >>= 1; + } + return bw; + }; + int bw_a = bit_width(a.max_value); + int bw_b = bit_width(b.max_value); + int k = std::max(bw_a, bw_b); + if (k >= 63) { + // Too wide; fall back to dtype limits. + return Everything(op->dtype); + } + int64_t ub = (static_cast(1) << k) - 1; + return MakeBound(0, ub); + } + // If signs are unknown, avoid incorrect assumptions. + return Everything(op->dtype); + } + std::function EnterConstraint(const PrimExpr& constraint) { std::vector info = DetectBoundInfo(constraint); if (info.size() == 0) return nullptr; size_t old_size = additional_info_.size(); additional_info_.insert(additional_info_.end(), info.begin(), info.end()); - size_t new_size = old_size + info.size(); - auto frecover = [old_size, new_size, this]() { - ICHECK_EQ(additional_info_.size(), new_size); - additional_info_.resize(old_size); + auto frecover = [old_size, this]() { + if (additional_info_.size() > old_size) { + additional_info_.resize(old_size); + } }; return frecover; } private: friend class ConstIntBoundAnalyzer; + // parent analyzer + Analyzer* parent_; // internal variable map std::unordered_map var_map_; // additional bound info @@ -525,6 +665,7 @@ class ConstIntBoundAnalyzer::Impl // If the range of b does not have 0, use BinaryOpBoundary. return BinaryOpBoundary(a, b, op); } + /*! * \brief Compute x + y, aware of inf. * \param x The left operand. @@ -678,9 +819,12 @@ class ConstIntBoundAnalyzer::Impl * \return Bound that represent everything dtype can represent. */ static Entry Everything(DataType dtype) { - if (!dtype.is_int() && !dtype.is_uint()) { + if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) { return MakeBound(kNegInf, kPosInf); } + if (dtype.is_bool()) { + return MakeBound(0, 1); + } Entry ret; int64_t vbits = dtype.bits() - static_cast(dtype.is_int()); if (dtype.is_uint()) { @@ -719,6 +863,9 @@ class ConstIntBoundAnalyzer::Impl }; for (const auto& subexpr : ExtractConstraints(cond)) { + if(SideEffect(subexpr) > tir::CallEffectKind::kPure) { + continue; + } // NOTE: The canonical form always uses <= or <, but a // user-supplied constraint from the python API might not be // canonicalized. @@ -732,6 +879,31 @@ class ConstIntBoundAnalyzer::Impl add_info(x.Eval(), kNegInf, c.Eval()->value - 1); } else if ((x == c).Match(subexpr) || (c == x).Match(subexpr)) { add_info(x.Eval(), c.Eval()->value, c.Eval()->value); + } else if ((!x).Match(subexpr)) { + // Handle not operation: not(expr) + PrimExpr inner = x.Eval(); + PVar inner_x; + PVar inner_c; + + // Handle negated comparisons + if ((inner_c <= inner_x).Match(inner) || (inner_x >= inner_c).Match(inner)) { + // not(x >= c) -> x < c -> x <= c-1 + add_info(inner_x.Eval(), kNegInf, inner_c.Eval()->value - 1); + } else if ((inner_c < inner_x).Match(inner) || (inner_x > inner_c).Match(inner)) { + // not(x > c) -> x <= c + add_info(inner_x.Eval(), kNegInf, inner_c.Eval()->value); + } else if ((inner_x <= inner_c).Match(inner) || (inner_x >= inner_c).Match(inner)) { + // not(x <= c) -> x > c -> x >= c+1 + add_info(inner_x.Eval(), inner_c.Eval()->value + 1, kPosInf); + } else if ((inner_x < inner_c).Match(inner) || (inner_c > inner_x).Match(inner)) { + // not(x < c) -> x >= c + add_info(inner_x.Eval(), inner_c.Eval()->value, kPosInf); + } else if ((inner_x == inner_c).Match(inner) || (inner_c == inner_x).Match(inner)) { + // not(x == c) -> x != c + // This is more complex - we can't represent != with a single interval + // For now, we'll just skip this case + } + // Note: We don't recursively call DetectBoundInfo here to avoid infinite recursion } } @@ -744,7 +916,7 @@ class ConstIntBoundAnalyzer::Impl * This expression is used as the implementation of * topi.math.ceil_log2, and can appear in iteration bounds. */ - static Optional FindCeilLog2Arg(const CastNode* op) { + static ffi::Optional FindCeilLog2Arg(const CastNode* op) { if (op->dtype.is_int()) { if (auto as_call = op->value.as()) { if (as_call->op.same_as(Op::Get("tir.ceil"))) { @@ -805,9 +977,14 @@ std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con return impl_->EnterConstraint(constraint); } -ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } +// Deep copy internal state from another analyzer +void ConstIntBoundAnalyzer::CopyFrom(const ConstIntBoundAnalyzer& other) { + this->impl_->CopyFrom(*other.impl_); +} + } // namespace arith } // namespace tvm diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index 3c7d4e0e4bea..70768128e535 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -33,7 +33,7 @@ namespace arith { using namespace tir; -Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { +ffi::Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { // Check the threshold in the range of size_t CHECK_GE(thresh, std::numeric_limits::min()); CHECK_LE(thresh, std::numeric_limits::max()); @@ -63,16 +63,16 @@ Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { } // Return the common sub expr that occur more than thresh times - Map results; + ffi::Map results; for (auto& it : semantic_comp_done_by_expr) { if (it.second >= repeat_thr) results.Set(it.first, it.second); } return results; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.DetectCommonSubExpr", DetectCommonSubExpr); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index e6746efd3717..4a0b5f9cf0c3 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -142,14 +142,14 @@ class LinearEqDetector : public ExprFunctor DetectLinearEquation(const PrimExpr& e, const Array& vars) { +ffi::Array DetectLinearEquation(const PrimExpr& e, const ffi::Array& vars) { PrimExpr base = e; - Array coeff; + ffi::Array coeff; for (Var v : vars) { LinearEqEntry ret; if (!LinearEqDetector(v).Detect(base, &ret)) { - return Array(); + return ffi::Array(); } coeff.push_back(ret.coeff); base = std::move(ret.base); @@ -162,7 +162,7 @@ Array DetectLinearEquation(const PrimExpr& e, const Array& vars) vset.insert(vars[i - 1].get()); // The previous coeff contains the variable if (UsesVar(coeff[i - 2], vset_contains)) { - return Array(); + return ffi::Array(); } } coeff.push_back(base); @@ -218,8 +218,8 @@ bool DetectClipBound(const PrimExpr& cond, ret.coeff = analyzer.Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; - Optional min_value; - Optional max_value; + ffi::Optional min_value; + ffi::Optional max_value; if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift min_value = -ret.base; @@ -265,7 +265,7 @@ void SplitCommExpr(const PrimExpr& e, std::vector* ret) { // Detect the lower and upper bound from the expression. // e must be connected by and. -Array DetectClipBound(const PrimExpr& e, const Array& vars) { +ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars) { std::vector splits; Analyzer analyzer; SplitCommExpr(analyzer.Simplify(e), &splits); @@ -274,9 +274,9 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { rmap[v.get()] = IntervalEntry(); } for (PrimExpr cond : splits) { - if (!DetectClipBound(cond, &rmap)) return Array(); + if (!DetectClipBound(cond, &rmap)) return ffi::Array(); } - Array ret; + ffi::Array ret; for (Var v : vars) { IntervalEntry e = rmap[v.get()]; if (e.min_value.defined()) { @@ -291,12 +291,12 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("arith.DetectLinearEquation", DetectLinearEquation) .def("arith.DetectClipBound", - [](const PrimExpr& e, const Array& vars) { return DetectClipBound(e, vars); }); -}); + [](const PrimExpr& e, const ffi::Array& vars) { return DetectClipBound(e, vars); }); +} } // namespace arith } // namespace tvm diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 96a269d7294f..3fc6d34b7071 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -115,7 +115,7 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { } private: - void Touch(BufferTouches* bounds, const Array& args) { + void Touch(BufferTouches* bounds, const ffi::Array& args) { if (args.size() > bounds->size()) { bounds->resize(args.size()); } @@ -136,25 +136,25 @@ Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores); } -Map> DomainTouchedAccessMap(const PrimFunc& func) { +ffi::Map> DomainTouchedAccessMap(const PrimFunc& func) { auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions(); - Map> ret; + ffi::Map> ret; auto& buffer_map = func->buffer_map; for (auto& var : func->params) { auto& buffer = buffer_map[var]; auto& access = buffer_access_map[buffer.get()]; - Array> loads, stores, combined; + ffi::Array> loads, stores, combined; for (std::vector& touch : std::get(access).set) { - loads.push_back(Array(touch)); + loads.push_back(ffi::Array(touch)); } for (std::vector& touch : std::get(access).set) { - stores.push_back(Array(touch)); + stores.push_back(ffi::Array(touch)); } for (std::vector& touch : std::get(access).set) { - combined.push_back(Array(touch)); + combined.push_back(ffi::Array(touch)); } - Array fields; + ffi::Array fields; fields.push_back(loads); fields.push_back(stores); fields.push_back(combined); @@ -163,12 +163,12 @@ Map> DomainTouchedAccessMap(const PrimFunc& func) { return ret; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("arith.DomainTouched", DomainTouched) .def("arith.DomainTouchedAccessMap", DomainTouchedAccessMap); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index b074e6400aaf..e116ba9e3b7a 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -39,15 +39,16 @@ namespace tvm { namespace arith { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { IntGroupBoundsNode::RegisterReflection(); IntConstraintsNode::RegisterReflection(); IntConstraintsTransformNode::RegisterReflection(); -}); +} -Array AsConditions(const Array& variables, const Map& bounds, - const Array& relations) { - Array res; +ffi::Array AsConditions(const ffi::Array& variables, + const ffi::Map& bounds, + const ffi::Array& relations) { + ffi::Array res; // use variables to keep the order of iteration // so as to get rid of any non-determinism. ICHECK_EQ(variables.size(), bounds.size()); @@ -71,11 +72,11 @@ Array AsConditions(const Array& variables, const Map lower, Array equal, - Array upper) { +IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, + ffi::Array equal, ffi::Array upper) { ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) << "Coefficient in IntGroupBounds must be integers"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->coef = std::move(coef); node->lower = std::move(lower); node->equal = std::move(equal); @@ -86,9 +87,9 @@ IntGroupBounds::IntGroupBounds(PrimExpr coef, Array lower, Arraymin.dtype(), 1); - Array equal; - Array lower; - Array upper; + ffi::Array equal; + ffi::Array lower; + ffi::Array upper; if (tir::is_one(r->extent)) { equal.push_back(r->min); } else { @@ -100,9 +101,9 @@ IntGroupBounds IntGroupBounds::FromRange(const Range& r) { IntGroupBounds IntGroupBounds::operator+(const Range& r) { Analyzer analyzer; - Array equal; - Array lower; - Array upper; + ffi::Array equal; + ffi::Array lower; + ffi::Array upper; const PrimExpr& coef = operator->()->coef; if (tir::is_one(r->extent)) { equal.push_back(analyzer.Simplify(r->min * coef)); @@ -116,7 +117,7 @@ IntGroupBounds IntGroupBounds::operator+(const Range& r) { return IntGroupBounds(coef, lower, equal, upper); } -IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const { +IntGroupBounds IntGroupBounds::Substitute(const ffi::Map& subst) const { auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; return IntGroupBounds(tir::Substitute(operator->()->coef, subst), tir::UpdateArray(operator->()->lower, apply_fun), @@ -124,7 +125,7 @@ IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const tir::UpdateArray(operator->()->upper, apply_fun)); } -Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { +Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) const { Analyzer analyzer; analyzer.Bind(vranges_addl); @@ -133,7 +134,7 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { var_intsets[kv.first.get()] = IntSet::FromRange(kv.second); } - const Array& equal = operator->()->equal; + const ffi::Array& equal = operator->()->equal; const PrimExpr& coef = operator->()->coef; std::vector lowers(equal.begin(), equal.end()); @@ -161,7 +162,7 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { for (const PrimExpr& low : lowers) { for (const PrimExpr& upp : uppers) { // Since diff may depend on some other variables, we compute its overapproximation - Optional diff_over; + ffi::Optional diff_over; PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); IntSet diff_set1 = EvalSet(diff_1, var_intsets); if (diff_set1.HasUpperBound()) { @@ -200,13 +201,12 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("arith.IntGroupBounds", - [](PrimExpr coef, Array lower, Array equal, Array upper) { - return IntGroupBounds(coef, lower, equal, upper); - }) + [](PrimExpr coef, ffi::Array lower, ffi::Array equal, + ffi::Array upper) { return IntGroupBounds(coef, lower, equal, upper); }) .def("arith.IntGroupBounds_from_range", IntGroupBounds::FromRange) .def_packed("arith.IntGroupBounds_FindBestRange", [](ffi::PackedArgs args, ffi::Any* ret) { ICHECK(args.size() == 1 || args.size() == 2); @@ -214,10 +214,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (args.size() == 1) { *ret = bounds.FindBestRange(); } else if (args.size() == 2) { - *ret = bounds.FindBestRange(args[1].cast>()); + *ret = bounds.FindBestRange(args[1].cast>()); } }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -226,14 +226,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", equal=" << op->equal << ", upper=" << op->upper << ")"; }); -IntConstraints::IntConstraints(Array variables, Map ranges, - Array relations) { - ObjectPtr node = make_object(); +IntConstraints::IntConstraints(ffi::Array variables, ffi::Map ranges, + ffi::Array relations) { + ObjectPtr node = ffi::make_object(); if (!variables.defined()) { - variables = Array(); + variables = ffi::Array(); } if (!ranges.defined()) { - ranges = Map(); + ranges = ffi::Map(); } ICHECK(relations.defined()); for (const auto& var : variables) { @@ -246,13 +246,14 @@ IntConstraints::IntConstraints(Array variables, Map ranges, data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("arith.IntConstraints", [](Array variables, Map ranges, - Array relations) { - return IntConstraints(variables, ranges, relations); - }); -}); + refl::GlobalDef().def( + "arith.IntConstraints", + [](ffi::Array variables, ffi::Map ranges, ffi::Array relations) { + return IntConstraints(variables, ranges, relations); + }); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -262,9 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst, - Map src_to_dst, - Map dst_to_src) { - ObjectPtr node = make_object(); + ffi::Map src_to_dst, + ffi::Map dst_to_src) { + ObjectPtr node = ffi::make_object(); node->src = std::move(src); node->dst = std::move(dst); node->src_to_dst = std::move(src_to_dst); @@ -275,8 +276,8 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai IntConstraintsTransform IntConstraintsTransform::operator+( const IntConstraintsTransform& other) const { ICHECK(other->src.same_as(operator->()->dst)); - Map dst_to_src; - Map src_to_dst; + ffi::Map dst_to_src; + ffi::Map src_to_dst; Analyzer ana_first; ana_first.Bind(operator->()->src->ranges); @@ -292,14 +293,14 @@ IntConstraintsTransform IntConstraintsTransform::operator+( return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.IntConstraintsTransform", - [](IntConstraints src, IntConstraints dst, Map src_to_dst, - Map dst_to_src) { + [](IntConstraints src, IntConstraints dst, + ffi::Map src_to_dst, ffi::Map dst_to_src) { return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 6bd0400673be..2e3c3cbdbe28 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -27,12 +27,15 @@ #include #include #include +#include #include #include +#include #include #include "constraint_extract.h" +#include "int_operator.h" #include "interval_set.h" #include "pattern_match.h" @@ -44,13 +47,13 @@ using tir::is_zero; using tir::make_const; using tir::make_zero; -TVM_FFI_STATIC_INIT_BLOCK({ IntervalSetNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { IntervalSetNode::RegisterReflection(); } PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { - auto node = make_object(); + auto node = ffi::make_object(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); data_ = std::move(node); @@ -60,10 +63,10 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.IntervalSet", MakeIntervalSet); -}); +} IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); @@ -109,10 +112,15 @@ TVM_DECLARE_LOGICAL_OP(Not); /*! * \brief Combine two interval set under arithmetic operations. + * \param analyzer The analyzer for simplification and proving + * \param a The first interval set + * \param b The second interval set + * \param op The operation node, used to extract dtype and other properties * \note this can possibly relax the set. */ -template -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { +template +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { + DataType dtype = op->dtype; if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr expr; if (auto res = TryConstFold(a->min_value, b->min_value)) { @@ -134,7 +142,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, Dat template <> inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::AddNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -149,7 +157,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS template <> inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::SubNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -164,7 +172,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MulNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -198,7 +206,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::DivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -232,7 +240,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::ModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -261,7 +269,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::FloorDivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -295,7 +303,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::FloorModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -321,6 +329,29 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int return IntervalSet(tmin, tmax); } } + // Enhanced: Use ModularSet analysis for better bounds + if (auto* div_imm = divisor.as()) { + int64_t div_val = div_imm->value; + + // Analyze the modular properties of the dividend + ModularSet dividend_mod = analyzer->modular_set(op->a); + + if (dividend_mod.defined() && dividend_mod->coeff > 0) { + // Calculate GCD of dividend coefficient and divisor + int64_t gcd = ZeroAwareGCD(dividend_mod->coeff, div_val); + + if (gcd > 1 && div_val % gcd == 0) { + // The dividend is a multiple of gcd, and divisor is also a multiple of gcd + // So the result is also a multiple of gcd, with max value = (div_val/gcd - 1) * gcd + int64_t max_quotient = (div_val / gcd) - 1; + int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd); + + if (max_mod_result >= 0 && max_mod_result < div_val) { + return IntervalSet(make_zero(op->dtype), make_const(op->dtype, max_mod_result)); + } + } + } + } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; @@ -333,7 +364,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int template <> inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MaxNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -344,7 +375,7 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MinNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -368,7 +399,7 @@ using namespace tir; // We might use better set analysis in the future to replace the intervalset. class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, + IntervalSetEvaluator(Analyzer* analyzer, const ffi::Map& dom_map, const std::vector>* dom_constraints = nullptr, bool eval_vec = false) : analyzer_(analyzer), @@ -390,13 +421,18 @@ class IntervalSetEvaluator : public ExprFunctor { } IntervalSet VisitExpr_(const IntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(ffi::GetRef(op)); } IntervalSet VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); + + // Detect cyclic dependency: if we're already visiting this var, return conservative estimate + if (visiting_vars_.count(op)) { + return IntervalSet::SinglePoint(var); + } - Array values; + ffi::Array values; if (dom_constraints_) { for (const auto& constraint : *dom_constraints_) { if (var.same_as(constraint.first)) { @@ -426,9 +462,13 @@ class IntervalSetEvaluator : public ExprFunctor { if (res->min_value.same_as(var) && res->max_value.same_as(var)) { return res; } + // Mark this var as being visited to detect cycles + visiting_vars_.insert(op); // recursively evaluate mapped result // in case the domain contains variables to be relaxed. - return Eval(res); + IntervalSet result = Eval(res); + visiting_vars_.erase(op); + return result; } IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } @@ -475,23 +515,29 @@ class IntervalSetEvaluator : public ExprFunctor { if (op->lanes->IsInstance()) { int lanes = static_cast(Downcast(op->lanes)->value); if (vstride > 0) { - return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * (lanes - 1))), - op->dtype); + PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + auto add_op = tir::Add(op->base, stride_expr); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(make_zero(t), stride_expr), add_node); } else { - return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * (lanes - 1)), make_zero(t)), - op->dtype); + PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + auto add_op = tir::Add(op->base, stride_expr); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(stride_expr, make_zero(t)), add_node); } } else { /* Scalable vector */ if (vstride > 0) { - return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), op->dtype); + auto add_op = tir::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), add_node); } else { - return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), op->dtype); + auto add_op = tir::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), add_node); } } } - DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); + DLOG(WARNING) << "cannot evaluate set on expression " << ffi::GetRef(op); return IntervalSet::Everything(); } @@ -530,17 +576,17 @@ class IntervalSetEvaluator : public ExprFunctor { // Otherwise return `IntervalSet::everything()` since we have no knowledge on the buffer data. for (const PrimExpr& index : op->indices) { if (UsesVar(index, [dom_map = &this->dom_map_](const VarNode* var) { - return dom_map->find(GetRef(var)) != dom_map->end(); + return dom_map->find(ffi::GetRef(var)) != dom_map->end(); })) { return IntervalSet::Everything(); } } - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(ffi::GetRef(op)); } IntervalSet VisitExpr_(const CallNode* op) final { if (op->op.same_as(tir::builtin::vscale())) - return IntervalSet(GetRef(op), GetRef(op)); + return IntervalSet(ffi::GetRef(op), ffi::GetRef(op)); return IntervalSet::Everything(); } @@ -561,25 +607,32 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(ffi::GetRef(op)); } - return Combine(analyzer_, a, b, op->dtype); + return Combine(analyzer_, a, b, op); } // recursive depth int recur_depth_{0}; // analyzer Analyzer* analyzer_; - const Map& dom_map_; + const ffi::Map& dom_map_; const std::vector>* dom_constraints_; bool eval_vec_{false}; + // track variables being visited to detect cyclic dependencies + std::unordered_set visiting_vars_; }; class IntSetAnalyzer::Impl { public: explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} - IntSet Eval(const PrimExpr& expr, const Map& dom_map) const { + void CopyFrom(const Impl& other) { + this->dom_map_ = other.dom_map_; + this->dom_constraints_ = other.dom_constraints_; + } + + IntSet Eval(const PrimExpr& expr, const ffi::Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); } @@ -605,11 +658,11 @@ class IntSetAnalyzer::Impl { // Map of variables to global variable bounds (e.g. loop iterator // ranges) - Map dom_map_; + ffi::Map dom_map_; // List of implicit scope-dependent bounds (e.g. inside the body of // an if-statement). Maintained as a list of constraints, rather - // than as a `Map`, to avoid computing an Intersection + // than as a `ffi::Map`, to avoid computing an Intersection // until required. std::vector> dom_constraints_; }; @@ -618,7 +671,7 @@ IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } -IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& dom_map) { +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const ffi::Map& dom_map) { return impl_->Eval(expr, dom_map); } @@ -709,6 +762,11 @@ std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& cons return frecover; } +// Deep copy internal state from another analyzer +void IntSetAnalyzer::CopyFrom(const IntSetAnalyzer& other) { + this->impl_->CopyFrom(*other.impl_); +} + // Quickly adapt to IntSet interface // TODO(tqchen): revisit IntSet interface as well. Range IntSet::CoverRange(Range max_range) const { @@ -861,7 +919,7 @@ bool IntSet::MatchRange(const Range& b) const { ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); } -IntSet Union(const Array& sets) { +IntSet Union(const ffi::Array& sets) { if (sets.size() == 0) return IntSet::Nothing(); if (sets.size() == 1) return sets[0]; Analyzer ana; @@ -872,16 +930,16 @@ IntSet Union(const Array& sets) { return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } -Array UnionRegion(const Array>& nd_int_sets) { +ffi::Array UnionRegion(const ffi::Array>& nd_int_sets) { if (nd_int_sets.empty()) { return {}; } int n = nd_int_sets.size(); int ndim = nd_int_sets[0].size(); - Array result; + ffi::Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { - Array candidates; + ffi::Array candidates; candidates.reserve(n); for (int j = 0; j < n; ++j) { candidates.push_back(nd_int_sets[j][i]); @@ -891,7 +949,7 @@ Array UnionRegion(const Array>& nd_int_sets) { return result; } -IntSet UnionLowerBound(const Array& sets) { +IntSet UnionLowerBound(const ffi::Array& sets) { if (sets.size() == 0) return IntSet::Nothing(); if (sets.size() == 1) return sets[0]; Analyzer analyzer; @@ -925,16 +983,16 @@ IntSet UnionLowerBound(const Array& sets) { return IntSet::Interval(min_inclusive, max_inclusive); } -Array UnionRegionLowerBound(const Array>& nd_int_sets) { +ffi::Array UnionRegionLowerBound(const ffi::Array>& nd_int_sets) { if (nd_int_sets.empty()) { return {}; } int n = nd_int_sets.size(); int ndim = nd_int_sets[0].size(); - Array result; + ffi::Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { - Array candidates; + ffi::Array candidates; candidates.reserve(n); for (int j = 0; j < n; ++j) { candidates.push_back(nd_int_sets[j][i]); @@ -944,7 +1002,7 @@ Array UnionRegionLowerBound(const Array>& nd_int_sets) { return result; } -IntSet Intersect(const Array& sets) { +IntSet Intersect(const ffi::Array& sets) { if (sets.size() == 0) return IntSet::Nothing(); if (sets.size() == 1) return sets[0]; Analyzer ana; @@ -955,23 +1013,23 @@ IntSet Intersect(const Array& sets) { return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } -Map ConvertDomMap(const Map& dom_map) { - Map dmap; +ffi::Map ConvertDomMap(const ffi::Map& dom_map) { + ffi::Map dmap; for (auto kv : dom_map) { dmap.Set(kv.first->var, kv.second); } return dmap; } -Map ConvertDomMap(const std::unordered_map& dom_map) { - Map dmap; +ffi::Map ConvertDomMap(const std::unordered_map& dom_map) { + ffi::Map dmap; for (auto kv : dom_map) { - dmap.Set(GetRef(kv.first), kv.second); + dmap.Set(ffi::GetRef(kv.first), kv.second); } return dmap; } -IntSet EvalSet(PrimExpr e, const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map) { Analyzer ana; return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e); } @@ -983,12 +1041,12 @@ IntSet IntSet::Vector(PrimExpr x) { } else { // vector case. Analyzer ana; - Map dmap; + ffi::Map dmap; return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); } } -IntSet EvalSet(PrimExpr e, const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } @@ -996,7 +1054,7 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, const Map& dom_map) { +IntSet EvalSet(Range r, const ffi::Map& dom_map) { Analyzer ana; if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) { return EvalSet(r->min, dom_map); @@ -1012,10 +1070,10 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_ma return EvalSet(r, ConvertDomMap(dom_map)); } -Array EvalSet(const Array& region, const Map& dom_map) { +ffi::Array EvalSet(const ffi::Array& region, const ffi::Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); - Array result; + ffi::Array result; result.reserve(region.size()); for (const Range& r : region) { PrimExpr sum = r->min + (r->extent - 1); @@ -1036,7 +1094,7 @@ IntSet EvalSet(IntSet s, const std::unordered_map& dom_m class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map) + explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const ffi::Map& dom_map) : IntervalSetEvaluator(analyzer, dom_map) {} IntervalSet VisitExpr(const PrimExpr& n) final { @@ -1057,12 +1115,12 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, return m.expr_map; } -IntSet EvalSet(Range r, const Map& dom_map) { +IntSet EvalSet(Range r, const ffi::Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -Map AsIntSet(const Map& var_dom) { - Map result; +ffi::Map AsIntSet(const ffi::Map& var_dom) { + ffi::Map result; for (auto kv : var_dom) { const Var& var = kv.first; const Range& range = kv.second; @@ -1072,8 +1130,8 @@ Map AsIntSet(const Map& var_dom) { } /*! \brief Helper function to convert IterSumExpr to the actual touched range. */ -static Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, - Analyzer* analyzer) { +static ffi::Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, + Analyzer* analyzer) { if (analyzer->CanProve(extent == 0)) { return IntSet::Nothing(); } @@ -1105,13 +1163,14 @@ static Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& } } -Optional> EstimateRegionStrictBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, Analyzer* analyzer) { +ffi::Optional> EstimateRegionStrictBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, + Analyzer* analyzer) { int ndim = region.size(); - Array iter_sum_exprs{nullptr}; + ffi::Array iter_sum_exprs{nullptr}; { - Array affine_indices; + ffi::Array affine_indices; affine_indices.reserve(ndim); for (const Range& range : region) { if (!is_const_number(range->extent)) { @@ -1129,12 +1188,12 @@ Optional> EstimateRegionStrictBound(const Array& region, return std::nullopt; } ICHECK_EQ(iter_sum_exprs.size(), ndim); - Array result; + ffi::Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { const IterSumExpr& sum_expr = iter_sum_exprs[i]; const Range& range = region[i]; - Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer); + ffi::Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer); if (int_set.defined()) { result.push_back(int_set.value()); } else { @@ -1144,22 +1203,23 @@ Optional> EstimateRegionStrictBound(const Array& region, return result; } -Optional> EstimateRegionLowerBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer) { +ffi::Optional> EstimateRegionLowerBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer) { return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); } -Array EstimateRegionUpperBound(const Array& region, const Map& var_dom, - const PrimExpr& predicate, Analyzer* analyzer) { - if (Optional> result = EstimateRegionStrictBound( +ffi::Array EstimateRegionUpperBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, Analyzer* analyzer) { + if (ffi::Optional> result = EstimateRegionStrictBound( /*region=*/region, /*var_dom=*/var_dom, /*predicate=*/predicate, /*analyzer=*/analyzer)) { return result.value(); } - Array result; + ffi::Array result; result.reserve(region.size()); // try estimate each dimension independently for (const Range& range : region) { @@ -1178,7 +1238,7 @@ Array EstimateRegionUpperBound(const Array& region, const Map int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { + if (ffi::Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { result.push_back(int_set.value()); continue; } @@ -1196,7 +1256,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "[" << op->min_value << ", " << op->max_value << ']'; }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("arith.intset_single_point", IntSet::SinglePoint) @@ -1207,27 +1267,27 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("arith.IntSetIsNothing", &IntSet::IsNothing) .def_method("arith.IntSetIsEverything", &IntSet::IsEverything) .def("arith.EstimateRegionLowerBound", - [](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { + [](ffi::Array region, ffi::Map var_dom, + PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); }) .def("arith.EstimateRegionStrictBound", - [](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { + [](ffi::Array region, ffi::Map var_dom, + PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); }) .def("arith.EstimateRegionUpperBound", - [](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { + [](ffi::Array region, ffi::Map var_dom, + PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); }) .def("arith.PosInf", []() { return SymbolicLimits::pos_inf_; }) .def("arith.NegInf", []() { return SymbolicLimits::neg_inf_; }) .def("arith.UnionLowerBound", UnionLowerBound); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 4fadf985db9b..b8597db7aa90 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -75,9 +75,7 @@ class IntervalSetNode : public IntSetNode { } /*! \return whether interval represent everything */ bool IsEverything() const { return is_neg_inf(min_value) && is_pos_inf(max_value); } - - static constexpr const char* _type_key = "arith.IntervalSet"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.IntervalSet", IntervalSetNode, IntSetNode); }; /*! @@ -113,7 +111,7 @@ class IntervalSet : public IntSet { static IntervalSet Empty() { return IntervalSet(pos_inf(), neg_inf()); } TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode); - TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntervalSet, IntSet, IntervalSetNode); }; /*! diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index d26ac3667620..8dca76b5aed8 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -25,6 +25,7 @@ #include #include #include +#include "tvm/arith/analyzer.h" namespace tvm { namespace arith { @@ -40,14 +41,14 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) { } } -Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const Array& indices, - bool non_trivial_only) { +ffi::Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext( + const ffi::Array& indices, bool non_trivial_only) { PrimExpr pred = const_true(); for (PrimExpr val : iter_predicates_) { pred = pred && val; } int n = indices.size(); - Array simplified = arith::IterMapSimplify( + ffi::Array simplified = arith::IterMapSimplify( indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, this->analyzer_); if (non_trivial_only) { for (int i = 0; i < n; ++i) { @@ -64,6 +65,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { Range dom = Range::FromMinExtent(op->min, op->extent); analyzer_->Bind(op->loop_var, dom); iter_vars_.Set(op->loop_var, dom); + With ctx(analyzer_, op->extent > 0); return StmtExprMutator::VisitStmt_(op); } @@ -84,7 +86,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { // as sub-class may or maynot choose to replace it. Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); @@ -105,7 +107,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { } Stmt then_case; - Optional else_case; + ffi::Optional else_case; { With ctx(analyzer_, real_condition); WithRecordIterPredicate(real_condition, [&] { then_case = this->VisitStmt(op->then_case); }); @@ -121,7 +123,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->condition = std::move(condition); @@ -140,9 +142,15 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { iter_vars_.Set(iv->var, dom); Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; - } else { + } + else if(op->attr_key == tir::attr::tilelang_assume) { + auto condition = Downcast(op->node); + With constraint(analyzer_, condition, true); return StmtExprMutator::VisitStmt_(op); } + else { + return StmtExprMutator::VisitStmt_(op); + } } Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { @@ -152,7 +160,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->condition = std::move(condition); @@ -185,9 +193,9 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { } if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { - return GetRef(op); + return ffi::GetRef(op); } else { - return Call(op->dtype, op->op, {cond, true_value, false_value}); + return Call(op->dtype, op->op, {cond, true_value, false_value}, op->annotations); } } return StmtExprMutator::VisitExpr_(op); @@ -202,7 +210,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) { // as sub-class may or maynot choose to replace it. PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -228,7 +236,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) { // normal path if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Select(cond, true_value, false_value); } diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index fb01fd19cee7..28f8e600d38e 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -74,7 +74,8 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { * \brief Use internal bound information to perform inter map simplification of indices. * \note Only do this during layout remapping */ - Array IterMapSimplifyWithContext(const Array& indices, bool non_trivial_only); + ffi::Array IterMapSimplifyWithContext(const ffi::Array& indices, + bool non_trivial_only); /*! \brief internal analyzer field. */ Analyzer* analyzer_; @@ -83,9 +84,9 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { // expensive and we only encourage doing them during // necessary cases like layout remapping /*! \brief Recorded loop iterators */ - Map iter_vars_; + ffi::Map iter_vars_; /*! \brief iterator predicates */ - Array iter_predicates_; + ffi::Array iter_predicates_; /*! * \brief Run callback while trying to record iter predicate * \param conditon Condition to be checked. @@ -94,7 +95,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { template void WithRecordIterPredicate(PrimExpr condition, FLambda callback) { auto f_use_itervar = [this](const tir::VarNode* v) { - return iter_vars_.count(GetRef(v)); + return iter_vars_.count(ffi::GetRef(v)); }; // simple heuristics for detecting predicate if (tir::UsesVar(condition, f_use_itervar)) { diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index dba4567f88ec..c5960faa7e25 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -69,8 +69,16 @@ void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + StmtExprVisitor::VisitStmt_(op); + } + else if(op->attr_key == tir::attr::tilelang_assume) { + auto condition = Downcast(op->node); + With constraint(&analyzer_, condition, true); + StmtExprVisitor::VisitStmt_(op); + } + else { + StmtExprVisitor::VisitStmt_(op); } - StmtExprVisitor::VisitStmt_(op); } void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 42b99abd4063..3de431fb9574 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -41,25 +41,25 @@ namespace arith { using namespace tir; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { IterMarkNode::RegisterReflection(); IterSplitExprNode::RegisterReflection(); IterSumExprNode::RegisterReflection(); IterMapResultNode::RegisterReflection(); -}); +} IterMark::IterMark(PrimExpr source, PrimExpr extent) { - auto n = make_object(); + auto n = ffi::make_object(); n->source = std::move(source); n->extent = std::move(extent); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.IterMark", [](PrimExpr source, PrimExpr extent) { return IterMark(source, extent); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -68,7 +68,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); IterSplitExpr::IterSplitExpr(IterMark source) { - auto n = make_object(); + auto n = ffi::make_object(); auto one = make_const(source->source->dtype, 1); n->dtype = source->source->dtype; n->source = std::move(source); @@ -79,7 +79,7 @@ IterSplitExpr::IterSplitExpr(IterMark source) { } IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { - auto n = make_object(); + auto n = ffi::make_object(); auto one = make_const(source->source->dtype, 1); n->dtype = source->source->dtype; n->source = std::move(source); @@ -91,7 +91,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { - auto n = make_object(); + auto n = ffi::make_object(); n->dtype = source->source->dtype; n->source = std::move(source); n->lower_factor = std::move(lower_factor); @@ -100,13 +100,13 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr ex data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.IterSplitExpr", [](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { return IterSplitExpr(source, lower_factor, extent, scale); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -115,20 +115,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", extent=" << op->extent << ", scale=" << op->scale << ")"; }); -IterSumExpr::IterSumExpr(Array args, PrimExpr base) { - auto n = make_object(); +IterSumExpr::IterSumExpr(ffi::Array args, PrimExpr base) { + auto n = ffi::make_object(); n->dtype = base->dtype; n->args = std::move(args); n->base = std::move(base); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("arith.IterSumExpr", [](Array args, PrimExpr base) { + refl::GlobalDef().def("arith.IterSumExpr", [](ffi::Array args, PrimExpr base) { return IterSumExpr(args, base); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -152,7 +152,7 @@ class IterMarkSplitCollector { * \brief Collect all mark2splits recursively from indices. * \param indices The iterator of interest. */ - void Collect(const Array& indices) { + void Collect(const ffi::Array& indices) { for (IterSumExpr sum_expr : indices) { for (IterSplitExpr split : sum_expr->args) { this->CollectInternal(split->source); @@ -186,9 +186,9 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, + explicit IterMapRewriter(Analyzer* analyzer, const ffi::Map& input_iters, IterMapLevel check_level, bool simplify_trivial_iterators, - Array* errors) + ffi::Array* errors) : analyzer_(analyzer), check_level_(check_level), errors_(*errors), @@ -227,8 +227,8 @@ class IterMapRewriter : public ExprMutator { } IterSumExpr RewriteIterConstraint(const PrimExpr& expr, - const Optional& predicate_induced_min, - const Optional& predicate_induced_max) { + const ffi::Optional& predicate_induced_min, + const ffi::Optional& predicate_induced_max) { return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, predicate_induced_max); } @@ -263,7 +263,7 @@ class IterMapRewriter : public ExprMutator { * - bindings = [x / 3] will not pass because x / 3 can not be one split of x * \return whether the bindings are valid */ - bool CheckMapping(const Array& bindings, IterMapLevel check_level) { + bool CheckMapping(const ffi::Array& bindings, IterMapLevel check_level) { IterMarkSplitCollector collector; // We can check that for each iter mark: // All the splits that refers to the iter_mark covers its extent. @@ -447,7 +447,7 @@ class IterMapRewriter : public ExprMutator { // Iter map check level IterMapLevel check_level_; // Error messages for each unresolved expression. - Array& errors_; + ffi::Array& errors_; // The var map std::unordered_map var_map_; // input iter marks @@ -568,9 +568,9 @@ class IterMapRewriter : public ExprMutator { * \param check_level Iteration mapping's check level. * \return The normalized splits. */ - Array TryNormalizeSplits(const IterMark& mark, - const std::vector& splits, - IterMapLevel check_level) { + ffi::Array TryNormalizeSplits(const IterMark& mark, + const std::vector& splits, + IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); @@ -586,7 +586,7 @@ class IterMapRewriter : public ExprMutator { if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective if (check_level == IterMapLevel::Bijective) { - return Array(); + return ffi::Array(); } // look for the next split skipping this lower factor // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2] @@ -595,7 +595,7 @@ class IterMapRewriter : public ExprMutator { j = SearchSkipLowerFactor(splits, used, expected_lower_factor); // split not found if (j == splits.size()) { - return Array(); + return ffi::Array(); } } @@ -647,24 +647,24 @@ class IterMapRewriter : public ExprMutator { if (match_full_iter) { if (splits.size() != 1) { ErrorLogger(this) << "Dependent iterations on padding iter space"; - return Array(); + return ffi::Array(); } else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) && !analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { ErrorLogger(this) << "Split on padding iteration is not surjective " << "if the split extent equals to the full iter space extent"; - return Array(); + return ffi::Array(); } } else if (match_iter_divisor) { if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { ErrorLogger(this) << "The extent before padding is less than lower factor"; - return Array(); + return ffi::Array(); } } else { ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; return {}; } } - return Array(iters.rbegin(), iters.rend()); + return ffi::Array(iters.rbegin(), iters.rend()); } /*! @@ -674,8 +674,9 @@ class IterMapRewriter : public ExprMutator { * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. * \return The Normalized expression. */ - IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional predicate_induced_min, - Optional predicate_induced_max) { + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, + ffi::Optional predicate_induced_min, + ffi::Optional predicate_induced_max) { // normalize to zero base PrimExpr base = expr->base; if (!is_zero(base)) { @@ -685,7 +686,7 @@ class IterMapRewriter : public ExprMutator { if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max.value() - base; } - Optional opt = TryFuseIters(expr, check_level_, false); + ffi::Optional opt = TryFuseIters(expr, check_level_, false); ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 if (opt.defined() && is_one(opt.value()->args[0]->scale)) { @@ -739,7 +740,7 @@ class IterMapRewriter : public ExprMutator { // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); IterSumExprNode* normalized_expr = expr.CopyOnWrite(); - normalized_expr->args = Array({split}); + normalized_expr->args = ffi::Array({split}); normalized_expr->base = base; return expr; } @@ -755,7 +756,7 @@ class IterMapRewriter : public ExprMutator { IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter if (expr->args.size() < 1) return expr; - Optional opt = TryFuseIters(expr, check_level_, true); + ffi::Optional opt = TryFuseIters(expr, check_level_, true); if (opt.defined()) { return opt.value(); } else { @@ -820,7 +821,7 @@ class IterMapRewriter : public ExprMutator { return lhs.symbol_prod_count > rhs.symbol_prod_count; }); - Array args; + ffi::Array args; for (const Item& item : items) { args.push_back(item.split); } @@ -857,7 +858,7 @@ class IterMapRewriter : public ExprMutator { * \return Whether we can find one. */ int FindBaseIter(const IterSumExpr& expr, const std::vector& skip_flag, - Optional match_source, int rbegin = -1) { + ffi::Optional match_source, int rbegin = -1) { if (rbegin == -1) { rbegin = static_cast(expr->args.size()) - 1; } @@ -927,7 +928,7 @@ class IterMapRewriter : public ExprMutator { * \return -1 if not no match found, otherwise return the index. */ int FindIterWithExactScale(const IterSumExpr& expr, const std::vector& skip_flag, - const PrimExpr& expected_scale, Optional match_source, + const PrimExpr& expected_scale, ffi::Optional match_source, int rbegin = -1, int first_possible_unit_extent_pos = 0) { if (rbegin == -1) { rbegin = static_cast(expr->args.size()) - 1; @@ -993,7 +994,7 @@ class IterMapRewriter : public ExprMutator { * \param check_level The check level if iter mapping. * \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryCombineSplitFromSameSource(IterSumExpr expr) { + ffi::Optional TryCombineSplitFromSameSource(IterSumExpr expr) { if (expr->args.size() <= 1) return std::nullopt; std::unordered_map hit_count; // most iter map are small n < 5 @@ -1078,7 +1079,7 @@ class IterMapRewriter : public ExprMutator { IterSumExpr simplified_sum = expr; // flip the order so we preserve the original order simplified_sum.CopyOnWrite()->args = - Array(reverse_flattened_iters.rbegin(), reverse_flattened_iters.rend()); + ffi::Array(reverse_flattened_iters.rbegin(), reverse_flattened_iters.rend()); return simplified_sum; } @@ -1095,8 +1096,8 @@ class IterMapRewriter : public ExprMutator { * (this may cause us to return parameters that are not canonically wrapped as * IterSum(IterMark)) \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level, - bool allow_early_skip) { + ffi::Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level, + bool allow_early_skip) { if (auto opt = TryCombineSplitFromSameSource(expr)) { expr = opt.value(); if (expr->args.size() <= 1 && allow_early_skip) { @@ -1146,7 +1147,7 @@ class IterMapRewriter : public ExprMutator { // predicate: j*2 + k < 9 // We need to match the predicate in expr and adjust the expected scale, // otherwise we expect the scale of i to be 2*5=10 - Optional constraint_to_match; + ffi::Optional constraint_to_match; for (const IterSumExpr& iter : constrained_iters_flattened_) { if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) { // find a predicate started from match position @@ -1208,10 +1209,10 @@ class IterMapRewriter : public ExprMutator { // both forms have splits from outermost to innermost IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = - Array(flattened_iters.rbegin(), flattened_iters.rend()); + ffi::Array(flattened_iters.rbegin(), flattened_iters.rend()); flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); structured_form.CopyOnWrite()->args = - Array(grouped_iters.rbegin(), grouped_iters.rend()); + ffi::Array(grouped_iters.rbegin(), grouped_iters.rend()); structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { @@ -1285,14 +1286,14 @@ struct IterConstraint { // The expr of the iter PrimExpr iter; // The expr of the lower_bound, maybe undefined - Optional lower_bound; + ffi::Optional lower_bound; // The expr of the upper_bound, maybe undefined - Optional upper_bound; + ffi::Optional upper_bound; // The size of the iter, which is the number of nodes size_t expr_size = 0; - IterConstraint(PrimExpr iter, Optional lower_bound, Optional upper_bound, - size_t size) + IterConstraint(PrimExpr iter, ffi::Optional lower_bound, + ffi::Optional upper_bound, size_t size) : iter(std::move(iter)), lower_bound(std::move(lower_bound)), upper_bound(std::move(upper_bound)), @@ -1306,7 +1307,7 @@ struct IterConstraint { * \param result The result of predicate split. * \return A list of IterConstraint, empty if the split failed. */ -bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, +bool MatchBoundConstraints(PrimExpr pred, ffi::Map* input_iters, std::vector* result) { arith::PVar lhs, rhs, rest; for (;;) { @@ -1348,7 +1349,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, // determine iter and bound, if we can not distinguish them simply, // try divide (lhs - rhs) into itervar aware and itervar free parts auto f_use_itervar = [&input_iters](const VarNode* v) { - return input_iters->count(GetRef(v)); + return input_iters->count(ffi::GetRef(v)); }; bool bound_at_left; if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) { @@ -1381,7 +1382,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, lhs_expr = analyzer.Simplify(lhs_expr); rhs_expr = analyzer.Simplify(rhs_expr); } - Optional lower_bound = std::nullopt, upper_bound = std::nullopt; + ffi::Optional lower_bound = std::nullopt, upper_bound = std::nullopt; PrimExpr iter; if (is_greater) { if (bound_at_left) { @@ -1427,19 +1428,20 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, return true; } -bool IterRangeSanityCheck(const Map& iter_ranges) { +bool IterRangeSanityCheck(const ffi::Map& iter_ranges) { std::unordered_set iters; for (const auto& it : iter_ranges) iters.insert(it.first); - auto f = [&](const VarNode* var) { return iters.count(GetRef(var)); }; + auto f = [&](const VarNode* var) { return iters.count(ffi::GetRef(var)); }; for (const auto& it : iter_ranges) { if (UsesVar(it.second->min, f) || UsesVar(it.second->extent, f)) return false; } return true; } -IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, IterMapLevel check_level, - arith::Analyzer* analyzer, bool simplify_trivial_iterators) { +IterMapResult DetectIterMap(const ffi::Array& indices, + const ffi::Map& input_iters, const PrimExpr& predicate, + IterMapLevel check_level, arith::Analyzer* analyzer, + bool simplify_trivial_iterators) { IterMapResult result; // Overall detection algorithm is divided into two steps: @@ -1449,7 +1451,7 @@ IterMapResult DetectIterMap(const Array& indices, const Maperrors.push_back("Invalid iterators. Iterators may not be expressions of each other."); return result; } - Map constrained_input_iters = input_iters; + ffi::Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { @@ -1484,7 +1486,7 @@ IterMapResult DetectIterMap(const Array& indices, const Map rewrite_indices; + ffi::Array rewrite_indices; rewrite_indices.reserve(indices.size()); bool allow_padding = check_level != IterMapLevel::Bijective; if (allow_padding) { @@ -1522,19 +1524,19 @@ IterMapResult DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, + [](const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, simplify_trivial_iterators); }); -}); +} -IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iters, +IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input_iters, arith::Analyzer* analyzer) { IterMapResult result; ICHECK(IterRangeSanityCheck(input_iters)) @@ -1550,17 +1552,17 @@ IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iter return rewriter.RewriteToNormalizedIterSum(index); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.NormalizeToIterSum", - [](PrimExpr index, const Map& input_iters) { + [](PrimExpr index, const ffi::Map& input_iters) { arith::Analyzer ana; return NormalizeToIterSum(index, input_iters, &ana); }); -}); +} PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) return it->second; return var; @@ -1578,7 +1580,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Add(a, b); } @@ -1613,7 +1615,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Sub(a, b); } @@ -1648,7 +1650,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Mul(a, b); } @@ -1657,8 +1659,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, " - << "occurs in " << GetRef(op); - return GetRef(op); + << "occurs in " << ffi::GetRef(op); + return ffi::GetRef(op); } if (!a->IsInstance()) { @@ -1961,7 +1963,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorDiv(a, b); } @@ -1969,19 +1971,19 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance()) { // cannot divide an iterator, mark as unresolved. - ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " << GetRef(op) - << " may not be an iterator"; - return GetRef(op); + ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " + << ffi::GetRef(op) << " may not be an iterator"; + return ffi::GetRef(op); } IterSumExpr preprocessed = PreprocessDividend(Downcast(a), op->a); if (!preprocessed.defined()) { - return GetRef(op); + return ffi::GetRef(op); } ICHECK_EQ(preprocessed->args.size(), 1U); PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); if (!remainder.defined()) { - return GetRef(op); + return ffi::GetRef(op); } return remainder; } @@ -2045,7 +2047,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorMod(a, b); } @@ -2054,19 +2056,19 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance()) { // cannot mod an iterator, mark as unresolved. ErrorLogger(this) << "Cannot represent as an IterMap: the right-hand side of FloorMod in " - << GetRef(op) << " may not be an iterator"; - return GetRef(op); + << ffi::GetRef(op) << " may not be an iterator"; + return ffi::GetRef(op); } IterSumExpr preprocessed = PreprocessDividend(Downcast(a), op->a); if (!preprocessed.defined()) { - return GetRef(op); + return ffi::GetRef(op); } ICHECK_EQ(preprocessed->args.size(), 1U); PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); if (!remainder.defined()) { - return GetRef(op); + return ffi::GetRef(op); } return remainder; } @@ -2152,18 +2154,19 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { return normalizer.Convert(expr); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.NormalizeIterMapToExpr", NormalizeIterMapToExpr); -}); +} -Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, IterMapLevel check_level, - arith::Analyzer* ana, bool simplify_trivial_iterators) { +ffi::Array IterMapSimplify(const ffi::Array& indices, + const ffi::Map& input_iters, + const PrimExpr& input_pred, IterMapLevel check_level, + arith::Analyzer* ana, bool simplify_trivial_iterators) { if (!IterRangeSanityCheck(input_iters)) return indices; auto res = DetectIterMap(indices, input_iters, input_pred, check_level, ana, /*simplify_trivial_iterators=*/simplify_trivial_iterators); - Array rewrite = res->indices; + ffi::Array rewrite = res->indices; if (rewrite.empty() && !is_one(input_pred) && check_level != IterMapLevel::Bijective) { // The input predicate may cause detect iter map to fail @@ -2177,24 +2180,24 @@ Array IterMapSimplify(const Array& indices, const Map simplified; + ffi::Array simplified; simplified.reserve(rewrite.size()); IterMapToExprNormalizer converter(ana); for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr)); return simplified; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "arith.IterMapSimplify", - [](const Array& indices, const Map& input_iters, + [](const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, simplify_trivial_iterators); }); -}); +} /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) @@ -2384,7 +2387,7 @@ class SubspaceDivider { extent *= arg->extent; res.push_back(arg); } - return IterMark(IterSumExpr(Array(res.rbegin(), res.rend()), base), extent); + return IterMark(IterSumExpr(ffi::Array(res.rbegin(), res.rend()), base), extent); } DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) { @@ -2394,7 +2397,7 @@ class SubspaceDivider { // encounter one of them. If we encounter another later, we directly return the record. return it->second; } - const Array& splits = collector_.mark2splits_.at(expr->source); + const ffi::Array& splits = collector_.mark2splits_.at(expr->source); if (auto iter_ptr = expr->source->source.as()) { // source is input_iter bool inner = sub_iters_.count(iter_ptr.value()); @@ -2487,15 +2490,16 @@ class SubspaceDivider { PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)}; }; -Array> SubspaceDivide(const Array& bindings, - const Map& input_iters, - const Array& sub_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer, - bool simplify_trivial_iterators) { - if (!IterRangeSanityCheck(input_iters)) return Array>(); +ffi::Array> SubspaceDivide(const ffi::Array& bindings, + const ffi::Map& input_iters, + const ffi::Array& sub_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators) { + if (!IterRangeSanityCheck(input_iters)) return ffi::Array>(); auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer, simplify_trivial_iterators); - const Array& maps = res->indices; + const ffi::Array& maps = res->indices; if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -2507,7 +2511,7 @@ Array> SubspaceDivide(const Array& bindings, collector.Collect(maps); SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set); - std::vector> results; + std::vector> results; for (const IterSumExpr& expr : maps) { SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0); if (subspace_divider.unresolved_count()) return {}; @@ -2520,30 +2524,31 @@ Array> SubspaceDivide(const Array& bindings, return results; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "arith.SubspaceDivide", [](const Array& bindings, const Map& root_iters, - const Array& sub_iters, const PrimExpr& predicate, - int check_level, bool simplify_trivial_iterators) { + "arith.SubspaceDivide", + [](const ffi::Array& bindings, const ffi::Map& root_iters, + const ffi::Array& sub_iters, const PrimExpr& predicate, int check_level, + bool simplify_trivial_iterators) { arith::Analyzer ana; return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), &ana, simplify_trivial_iterators); }); -}); +} class InverseAffineIterMapTransformer { public: explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {} - Map operator()(const Array& iter_map, - const Array& outputs) { + ffi::Map operator()(const ffi::Array& iter_map, + const ffi::Array& outputs) { ICHECK(iter_map.size() == outputs.size()); std::vector post_dfs_order = ReverseTopologyOrder(iter_map); // initialize back propagation accumulator for (const IterMapExprNode* node : post_dfs_order) { - backprop_.Set(GetRef(node), Integer(0)); + backprop_.Set(ffi::GetRef(node), Integer(0)); } for (size_t i = 0; i < iter_map.size(); i++) { backprop_.Set(iter_map[i], outputs[i]); @@ -2552,10 +2557,10 @@ class InverseAffineIterMapTransformer { // run back propagation for (const IterMapExprNode* node : post_dfs_order) { if (node->IsInstance()) { - Visit_(Downcast(GetRef(node))); + Visit_(Downcast(ffi::GetRef(node))); } else { ICHECK(node->IsInstance()); - Visit_(Downcast(GetRef(node))); + Visit_(Downcast(ffi::GetRef(node))); } } return std::move(inverse_); @@ -2591,7 +2596,8 @@ class InverseAffineIterMapTransformer { } } - std::vector ReverseTopologyOrder(const Array& iter_map) { + std::vector ReverseTopologyOrder( + const ffi::Array& iter_map) { std::vector post_dfs_order; std::unordered_map visited; @@ -2652,20 +2658,20 @@ class InverseAffineIterMapTransformer { } Analyzer* analyzer_; - Map backprop_; // the accumulator of backpropgation - Map inverse_; // the result of inverse transformation + ffi::Map backprop_; // the accumulator of backpropgation + ffi::Map inverse_; // the result of inverse transformation }; -Map InverseAffineIterMap(const Array& iter_map, - const Array outputs) { +ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, + const ffi::Array outputs) { Analyzer analyzer; return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.InverseAffineIterMap", InverseAffineIterMap); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index fc082907a6d2..47d8acb14dc7 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -39,10 +39,10 @@ namespace arith { using namespace tir; -TVM_FFI_STATIC_INIT_BLOCK({ ModularSetNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ModularSetNode::RegisterReflection(); } ModularSet::ModularSet(int64_t coeff, int64_t base) { - auto node = make_object(); + auto node = ffi::make_object(); node->coeff = coeff; node->base = base; // finish construction. @@ -58,10 +58,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.ModularSet", MakeModularSet); -}); +} // internal entry for const int bound struct ModularSetAnalyzer::Entry { @@ -104,6 +104,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctor(op); + Var v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return it->second; @@ -399,5 +401,8 @@ ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } +// Deep copy internal state from another analyzer +void ModularSetAnalyzer::CopyFrom(const ModularSetAnalyzer& other) { this->impl_->CopyFrom(*other.impl_); } + } // namespace arith } // namespace tvm diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index d339b728db2c..d73364cf45ca 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -50,14 +50,14 @@ using namespace tir; // with free parameters, and the range of those parameters. class ExpressionNarrower : public tir::ExprMutator { public: - static PrimExpr Apply(PrimExpr expr, Map free_parameters) { + static PrimExpr Apply(PrimExpr expr, ffi::Map free_parameters) { ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; ExpressionNarrower mutator(free_parameters); return mutator(expr); } private: - explicit ExpressionNarrower(Map free_parameters) + explicit ExpressionNarrower(ffi::Map free_parameters) : free_parameters_(free_parameters) {} using Parent = tir::ExprMutator; @@ -111,22 +111,22 @@ class ExpressionNarrower : public tir::ExprMutator { PrimExpr VisitExpr_(const GTNode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), OppositeContext(current), current); + return VisitInequality(ffi::GetRef(op), OppositeContext(current), current); } PrimExpr VisitExpr_(const GENode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), OppositeContext(current), current); + return VisitInequality(ffi::GetRef(op), OppositeContext(current), current); } PrimExpr VisitExpr_(const LTNode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), current, OppositeContext(current)); + return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); } PrimExpr VisitExpr_(const LENode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), current, OppositeContext(current)); + return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); } PrimExpr VisitExpr_(const EQNode* op) override { @@ -143,7 +143,7 @@ class ExpressionNarrower : public tir::ExprMutator { PrimExpr VisitExpr_(const SubNode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), current, OppositeContext(current)); + return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); } PrimExpr VisitExpr_(const NotNode* op) override { @@ -154,11 +154,11 @@ class ExpressionNarrower : public tir::ExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) override { contains_unknown_expr_ = true; - return GetRef(op); + return ffi::GetRef(op); } PrimExpr VisitExpr_(const VarNode* op) override { - auto it = free_parameters_.find(GetRef(op)); + auto it = free_parameters_.find(ffi::GetRef(op)); if (it == free_parameters_.end()) { return Parent::VisitExpr_(op); } @@ -206,18 +206,18 @@ class ExpressionNarrower : public tir::ExprMutator { }; std::vector context_stack_; - Map free_parameters_; + ffi::Map free_parameters_; bool contains_unknown_expr_{false}; }; -PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters) { +PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters) { return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.NarrowPredicateExpression", NarrowPredicateExpression); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/narrow_predicate_expression.h b/src/arith/narrow_predicate_expression.h index 1e452e3ad493..42a7c2cf038f 100644 --- a/src/arith/narrow_predicate_expression.h +++ b/src/arith/narrow_predicate_expression.h @@ -50,7 +50,7 @@ namespace arith { * \returns An expression that, if true, implies that the original * expression is also true. */ -PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters); +PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters); } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 98cf61990d90..7c498d7a9c90 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -214,7 +214,7 @@ class PVar : public Pattern> { typename = typename std::enable_if::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { - return Match_(GetRef(ptr)); + return Match_(ffi::GetRef(ptr)); } else { return false; } @@ -257,7 +257,7 @@ class PVarWithCheck : public arith::Pattern> { typename = typename std::enable_if::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { - return Match_(GetRef(ptr)); + return Match_(ffi::GetRef(ptr)); } else { return false; } @@ -727,7 +727,7 @@ struct PCallExprMatchFunctor { }; struct PCallExprEvalArgsFunctor { - Array args_; + ffi::Array args_; template void operator()(size_t i, const T& pattern) { @@ -778,7 +778,7 @@ class PCallExpr : public Pattern> { // arithemetic intrinsics #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ - static PrimExpr Eval(Array args) { \ + static PrimExpr Eval(ffi::Array args) { \ return tir::Call(args[0].dtype(), GetOp(), args); \ } \ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ @@ -797,7 +797,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ - static PrimExpr Eval(Array args) { \ + static PrimExpr Eval(ffi::Array args) { \ return tir::Call(args[0].dtype(), GetOp(), args); \ } \ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ @@ -811,7 +811,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { - static PrimExpr Eval(Array args) { return tir::Call(args[1].dtype(), GetOp(), args); } + static PrimExpr Eval(ffi::Array args) { + return tir::Call(args[1].dtype(), GetOp(), args); + } static const Op& GetOp() { return tir::builtin::if_then_else(); } }; diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 5674cf4f65bf..f69761259683 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -43,10 +43,9 @@ namespace tvm { namespace arith { -#ifdef TVM_MLIR_VERSION -#if TVM_MLIR_VERSION >= 150 +#if defined(TVM_MLIR_VERSION) && TVM_MLIR_VERSION >= 150 -TVM_FFI_STATIC_INIT_BLOCK({ PresburgerSetNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PresburgerSetNode::RegisterReflection(); } using namespace tir; static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) { @@ -92,10 +91,10 @@ static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) { } PresburgerSet::PresburgerSet(const PrimExpr& constraint) { - Array vars; + ffi::Array vars; PostOrderVisit(constraint, [&vars](const ObjectRef& obj) { if (const VarNode* new_var = obj.as()) { - auto var = GetRef(new_var); + auto var = ffi::GetRef(new_var); if (!std::any_of(vars.begin(), vars.end(), [&var](const Var& v) { return v.same_as(var); })) { vars.push_back(var); } @@ -105,19 +104,19 @@ PresburgerSet::PresburgerSet(const PrimExpr& constraint) { Analyzer analyzer; PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); auto space = PresburgerSpace::getRelationSpace(vars.size(), 0, 0, 0); - auto node = make_object(std::move(space), vars); + auto node = ffi::make_object(std::move(space), vars); node->SetVars(vars); Update(simplified_constraint, node.get()); data_ = std::move(node); } PresburgerSet::PresburgerSet(const std::vector& disjuncts, - const Array& vars) { - auto node = make_object(disjuncts, disjuncts[0].getSpace(), vars); + const ffi::Array& vars) { + auto node = ffi::make_object(disjuncts, disjuncts[0].getSpace(), vars); data_ = std::move(node); } -void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const Array& vars) { +void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const ffi::Array& vars) { Analyzer analyzer; PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); Update(simplified_constraint, this); @@ -186,7 +185,7 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { return constraint; } -PresburgerSet Union(const Array& sets) { +PresburgerSet Union(const ffi::Array& sets) { CHECK_GT(sets.size(), 0); if (sets.size() == 1) return sets[0]; auto relations = sets[0]->disjuncts; @@ -198,7 +197,7 @@ PresburgerSet Union(const Array& sets) { return PresburgerSet(std::move(relations), sets[0]->GetVars()); } -PresburgerSet Intersect(const Array& sets) { +PresburgerSet Intersect(const ffi::Array& sets) { CHECK_GT(sets.size(), 0); if (sets.size() == 1) return sets[0]; auto relations = sets[0]->disjuncts; @@ -217,7 +216,7 @@ PresburgerSet Intersect(const Array& sets) { } IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { - Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); + ffi::Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); #if TVM_MLIR_VERSION >= 190 SmallVector coeffs; #elif TVM_MLIR_VERSION >= 160 @@ -270,15 +269,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}"; }); -#endif // TVM_MLIR_VERSION >= 150 -#endif // TVM_MLIR_VERSION +#else // defined(TVM_MLIR_VERSION) && TVM_MLIR_VERSION >= 150 PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + PresburgerSetNode::RegisterReflection(); refl::GlobalDef().def("arith.PresburgerSet", MakePresburgerSet); -}); +} + +#endif // defined(TVM_MLIR_VERSION) && TVM_MLIR_VERSION >= 150 } // namespace arith } // namespace tvm diff --git a/src/arith/presburger_set.h b/src/arith/presburger_set.h index 3a7114048f92..2404f36428f6 100644 --- a/src/arith/presburger_set.h +++ b/src/arith/presburger_set.h @@ -60,10 +60,10 @@ using namespace presburger; class PresburgerSetNode : public IntSetNode { public: PresburgerSetNode() : space(PresburgerSpace::getRelationSpace()) {} - explicit PresburgerSetNode(const PresburgerSpace& space, const Array& vars) + explicit PresburgerSetNode(const PresburgerSpace& space, const ffi::Array& vars) : disjuncts({}), space(space), vars(vars) {} explicit PresburgerSetNode(const std::vector& disjuncts, - const PresburgerSpace& space, const Array& vars) + const PresburgerSpace& space, const ffi::Array& vars) : disjuncts(disjuncts), space(space), vars(vars) {} /*! \brief Represent the union of multiple IntegerRelation */ @@ -91,7 +91,7 @@ class PresburgerSetNode : public IntSetNode { * \param constraint The added constraint to the PresburgerSet. * \param vars The specified domain vars in constraint expression. */ - void UpdateConstraint(const PrimExpr& constraint, const Array& vars); + void UpdateConstraint(const PrimExpr& constraint, const ffi::Array& vars); /*! * \brief Generate expression that represents the constraint @@ -103,25 +103,23 @@ class PresburgerSetNode : public IntSetNode { * \brief Set domain vars * \param new_vars Vars that will be taken as the domain vars */ - void SetVars(const Array& new_vars) { vars = new_vars; } + void SetVars(const ffi::Array& new_vars) { vars = new_vars; } /*! * \brief Get the current domain vars * \return The current doamin vars */ - Array GetVars() const { return vars; } + ffi::Array GetVars() const { return vars; } /*! \return whether integer set is empty */ bool IsEmpty() const { return std::all_of(disjuncts.begin(), disjuncts.end(), std::mem_fn(&IntegerRelation::isIntegerEmpty)); } - - static constexpr const char* _type_key = "arith.PresburgerSet"; - TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.PresburgerSet", PresburgerSetNode, IntSetNode); private: - Array vars; + ffi::Array vars; }; /*! @@ -136,7 +134,7 @@ class PresburgerSet : public IntSet { * \param vars The variables that the constraint describes about. * \return The created PresburgerSet. */ - TVM_DLL PresburgerSet(const std::vector& disjuncts, const Array& vars); + TVM_DLL PresburgerSet(const std::vector& disjuncts, const ffi::Array& vars); /*! * \brief Make a new instance of PresburgerSet, collect all vars as space vars. @@ -146,7 +144,7 @@ class PresburgerSet : public IntSet { TVM_DLL PresburgerSet(const PrimExpr& constraint); TVM_DEFINE_OBJECT_REF_COW_METHOD(PresburgerSetNode); - TVM_DEFINE_OBJECT_REF_METHODS(PresburgerSet, IntSet, PresburgerSetNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PresburgerSet, IntSet, PresburgerSetNode); }; #endif // TVM_MLIR_VERSION >= 150 #else // TVM_MLIR_VERSION @@ -158,9 +156,7 @@ class PresburgerSetNode : public IntSetNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "arith.PresburgerSet"; - TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.PresburgerSet", PresburgerSetNode, IntSetNode); }; class PresburgerSet : public IntSet { @@ -178,14 +174,14 @@ class PresburgerSet : public IntSet { * \param sets The sets to be combined * \return the set after union */ -PresburgerSet Union(const Array& sets); +PresburgerSet Union(const ffi::Array& sets); /*! * \brief Create an intersected set of all sets * \param sets The sets to be intersected * \return The intersect set */ -PresburgerSet Intersect(const Array& sets); +PresburgerSet Intersect(const ffi::Array& sets); /*! * \brief Evaluate the range of given expression based on the constraint diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 66720a579233..0b23edd422ad 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -44,7 +45,7 @@ namespace arith { using namespace tir; -TVM_FFI_STATIC_INIT_BLOCK({ RewriteSimplifierStatsNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RewriteSimplifierStatsNode::RegisterReflection(); } // Note: When using matches_one_of or PMatchesOneOf alongside these // macros, be careful which patterns are used in the ResExpr. While @@ -498,13 +499,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { return ret; } -std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { +std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint, bool is_assume) { size_t old_literal_size = literal_constraints_.size(); // we will compare the already simplified result with the constraint, // so simplify the constraint as well PrimExpr new_constraint = operator()(constraint); for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) { - if (SideEffect(subconstraint) <= CallEffectKind::kPure) { + if (is_assume || SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; if (subconstraint.dtype().is_bool()) { @@ -774,13 +775,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { // Pattern var for lanes in broadcast and ramp PVar lanes; - // x / 2.0 = x * 0.5 - if (const FloatImmNode* ptr = op->b.as()) { - ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() || - datatype::Registry::Global()->GetTypeRegistered(op->dtype.code())); - return op->a * make_const(op->b.dtype(), 1.0 / ptr->value); - } - // Vector rules if (op->dtype.is_scalable_or_fixed_length_vector()) { // NOTE: use div as the pattern also works for float. @@ -821,6 +815,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { return make_const(op->dtype, truncdiv(c1val, c2val)); } + // x % c1 // c2 => 0 if 0 < c1 < c2 && x >= 0 + TVM_TRY_REWRITE_IF(truncdiv(truncmod(x, c1), c2), ZeroWithTypeLike(x), + c1.Eval()->value > 0 && c2.Eval()->value > c1.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); + // while it is always true for trunc div // restrict to common case(positive div) TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2), @@ -1166,7 +1165,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -1221,8 +1220,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { c2.Eval()->value % c1.Eval()->value == 0 && CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y * c2 + z, c3), floormod(x * floordiv(c1, c2) + y, floordiv(c3, c2)) * c2 + z, + c2.Eval()->value > 0 && c3.Eval()->value > 0 && + c3.Eval()->value % c2.Eval()->value == 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveEqual(floordiv(z.Eval(), c2.Eval()), 0)); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), - c2.Eval()->value > 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x TVM_TRY_REWRITE_IF( @@ -1652,7 +1657,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { return ret; } -Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const { +ffi::Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint( + const PrimExpr& expr) const { PrimExpr negation = Not(expr); ExprDeepEqual expr_equal; @@ -1946,7 +1952,110 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1); TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y); - auto merge_constants = [&]() -> Optional { + // If (base + offset) is compared against a multiple of k, and `offset` + // is known to be in [0, k), then the comparison is equivalent to just + // comparing `base` against the multiple of k. + // + // Example: + // tx * 4 + i < N ==> tx * 4 < N + // when 0 <= i < 4 and N % 4 == 0. + auto is_multiple_of = [&](PrimExpr expr, int64_t expr_gcd, int64_t factor) -> bool { + if (factor <= 1) return false; + if (expr_gcd % factor == 0) return true; + + PrimExpr factor_expr = make_const(expr.dtype(), factor); + PrimExpr cond = floormod(expr, factor_expr) == make_zero(expr.dtype()); + if (auto match = TryMatchLiteralConstraint(cond)) { + if (const int64_t* as_int = as_const_int(match.value())) { + return *as_int != 0; + } + } + return analyzer_->CanProve(cond); + }; + + auto eliminate_bounded_offset = [&](PrimExpr base, PrimExpr offset, + PrimExpr rhs) -> ffi::Optional { + ConstIntBound offset_bound = analyzer_->const_int_bound(offset); + if (!offset_bound.defined()) return std::nullopt; + if (offset_bound->min_value < 0) return std::nullopt; + + auto base_mod = analyzer_->modular_set(base); + auto rhs_mod = analyzer_->modular_set(rhs); + + int64_t base_gcd = ZeroAwareGCD(base_mod->base, base_mod->coeff); + int64_t rhs_gcd = ZeroAwareGCD(rhs_mod->base, rhs_mod->coeff); + + // Prefer the largest factor known from modular analysis of both sides. + // If rhs modular information isn't available (e.g. constraints nested in + // `and` aren't propagated to ModularSetAnalyzer), fall back to the + // factor known from the base expression and use literal-constraint + // matching to prove rhs alignment. + int64_t common_factor = ZeroAwareGCD(base_gcd, rhs_gcd); + int64_t factor = common_factor > 1 ? common_factor : base_gcd; + if (factor <= 1) return std::nullopt; + + if (offset_bound->max_value >= factor) return std::nullopt; + if (!is_multiple_of(rhs, rhs_gcd, factor)) return std::nullopt; + + return RecursiveRewrite(base < rhs); + }; + + if (const auto* add = ret->a.as()) { + if (auto simplified = + eliminate_bounded_offset(add->a, add->b, ret->b)) { + return simplified.value(); + } + if (auto simplified = + eliminate_bounded_offset(add->b, add->a, ret->b)) { + return simplified.value(); + } + } + + // If `lhs` and `base` are multiples of k, then the comparison + // lhs < base + offset + // can sometimes be simplified depending on the bounds of `offset`. + // + // Example: + // z < x * 4 + y ==> z <= x * 4 + // when 1 <= y < 4 and z % 4 == 0. + auto eliminate_bounded_offset_rhs = + [&](PrimExpr lhs, PrimExpr base, PrimExpr offset) -> ffi::Optional { + ConstIntBound offset_bound = analyzer_->const_int_bound(offset); + if (!offset_bound.defined()) return std::nullopt; + if (offset_bound->min_value < 0) return std::nullopt; + + auto base_mod = analyzer_->modular_set(base); + auto lhs_mod = analyzer_->modular_set(lhs); + + int64_t base_gcd = ZeroAwareGCD(base_mod->base, base_mod->coeff); + int64_t lhs_gcd = ZeroAwareGCD(lhs_mod->base, lhs_mod->coeff); + + int64_t common_factor = ZeroAwareGCD(base_gcd, lhs_gcd); + int64_t factor = common_factor > 1 ? common_factor : base_gcd; + if (factor <= 1) return std::nullopt; + + if (offset_bound->max_value >= factor) return std::nullopt; + if (!is_multiple_of(lhs, lhs_gcd, factor)) return std::nullopt; + + if (offset_bound->min_value > 0) { + return RecursiveRewrite(lhs <= base); + } + if (offset_bound->min_value == 0 && offset_bound->max_value == 0) { + return RecursiveRewrite(lhs < base); + } + return std::nullopt; + }; + + if (const auto* add = ret->b.as()) { + if (auto simplified = eliminate_bounded_offset_rhs(ret->a, add->a, add->b)) { + return simplified.value(); + } + if (auto simplified = eliminate_bounded_offset_rhs(ret->a, add->b, add->a)) { + return simplified.value(); + } + } + + auto merge_constants = [&]() -> ffi::Optional { auto [lhs, lhs_offset] = ExtractConstantOffset(ret->a); auto [rhs, rhs_offset] = ExtractConstantOffset(ret->b); if (lhs_offset == 0 && rhs_offset == 0) { @@ -1970,6 +2079,16 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { return RecursiveRewrite(merge_constants.value()); } + auto contains_floordiv = [](const PrimExpr& expr) -> bool { + bool found = false; + PostOrderVisit(expr, [&found](const ObjectRef& obj) { + if (obj.as()) { + found = true; + } + }); + return found; + }; + auto common_factor = [&]() -> int64_t { auto modular_a = analyzer_->modular_set(ret->a); auto modular_b = analyzer_->modular_set(ret->b); @@ -1978,7 +2097,15 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { return ZeroAwareGCD(gcd_lhs, gcd_rhs); }(); if (common_factor > 1) { - return RecursiveRewrite(floordiv(ret->a, common_factor) < floordiv(ret->b, common_factor)); + PrimExpr lhs = VisitExpr(floordiv(ret->a, common_factor)); + PrimExpr rhs = VisitExpr(floordiv(ret->b, common_factor)); + + // Don't introduce floordiv in the comparison if it cannot be + // eliminated after simplification. Keeping `x * k < N` can be + // preferable to rewriting to `x < N // k` even when `N % k == 0`. + if (!contains_floordiv(lhs) && !contains_floordiv(rhs)) { + return RecursiveRewrite(lhs < rhs); + } } } return ret; @@ -2051,7 +2178,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { // Otherwise, follow ExprMutator's convention of returning the // original object. if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return And(a, b); } @@ -2160,7 +2287,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { - PrimExpr orig = GetRef(op); + PrimExpr orig = ffi::GetRef(op); PrimExpr ret = [&]() -> PrimExpr { // If this extension isn't enabled, just delegate out. @@ -2200,7 +2327,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { // Otherwise, follow ExprMutator's convention of returning the // original object. if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Or(a, b); } @@ -2350,7 +2477,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (op->dtype == DataType::Bool()) { if (auto match = TryMatchLiteralConstraint(var)) { return match.value(); @@ -2361,7 +2488,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { if (it != var_map_.end()) { return it->second; } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { @@ -2388,7 +2515,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { } PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -2410,8 +2537,8 @@ void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_ impl_->Update(var, info, allow_override); } -std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { - return impl_->EnterConstraint(constraint); +std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); } void RewriteSimplifier::SetEnabledExtensions(Extension flags) { @@ -2433,6 +2560,22 @@ RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) RewriteSimplifier::~RewriteSimplifier() { delete impl_; } +// Impl state copy +void RewriteSimplifier::Impl::CopyFromImpl(const RewriteSimplifier::Impl& other) { + this->var_map_ = other.var_map_; + this->literal_constraints_ = other.literal_constraints_; + this->enabled_extensions_ = other.enabled_extensions_; + this->maximum_rewrite_steps_ = other.maximum_rewrite_steps_; + this->stats_ = other.stats_; + this->recur_depth_ = 0; + this->recursively_visiting_boolean_ = false; +} + +// Deep copy internal state from another analyzer +void RewriteSimplifier::CopyFrom(const RewriteSimplifier& other) { + this->impl_->CopyFromImpl(*other.impl_); +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* ptr = node.as(); diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b4bd799a2933..d27d750e0615 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -64,17 +64,17 @@ struct RewriteSimplifierStatsNode : Object { .def_ro("max_recursive_depth", &RewriteSimplifierStatsNode::max_recursive_depth) .def_ro("num_recursive_rewrites", &RewriteSimplifierStatsNode::num_recursive_rewrites); } - - static constexpr const char* _type_key = "arith.RewriteSimplifierStats"; - TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.RewriteSimplifierStats", RewriteSimplifierStatsNode, + Object); }; struct RewriteSimplifierStats : ObjectRef { explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) { - data_ = make_object(data); + data_ = ffi::make_object(data); } - TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RewriteSimplifierStats, ObjectRef, + RewriteSimplifierStatsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode); }; @@ -116,7 +116,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CastNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; - std::function EnterConstraint(const PrimExpr& constraint); + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false); + + // Copy internal state from another Impl instance (used by Analyzer cloning) + void CopyFromImpl(const Impl& other); /*! \brief Enable an optional extension or extensions * @@ -193,7 +196,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { * matches a constraint, return the boolean it should be replaced * with. Otherwise, return false. */ - Optional TryMatchLiteralConstraint(const PrimExpr& expr) const; + ffi::Optional TryMatchLiteralConstraint(const PrimExpr& expr) const; /*! \brief Rewrite rules for Less Than comparisons * diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 1937b9c34e03..5c968966e2f0 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -86,7 +86,7 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } -bool TargetHasVLA(Optional target) { +bool TargetHasVLA(ffi::Optional target) { if (!target.defined()) { target = Target::Current(); } @@ -102,7 +102,7 @@ bool TargetHasVLA(Optional target) { return has_vla; } -const std::vector GetVScaleValues(Optional target) { +const std::vector GetVScaleValues(ffi::Optional target) { unsigned int vector_width = 0; std::vector kVScaleValues; if (!target.defined()) { diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index 2470d5dcd827..88c140288734 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -81,14 +81,14 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr * \param target The target to check. * \return Whether VLA is supported */ -bool TargetHasVLA(Optional target = std::nullopt); +bool TargetHasVLA(ffi::Optional target = std::nullopt); /*! * \brief Get a list of known vscale values to try for an VLA target. * \param target The target to check. * \return A list of vscale values as std::vector */ -const std::vector GetVScaleValues(Optional target = std::nullopt); +const std::vector GetVScaleValues(ffi::Optional target = std::nullopt); } // namespace arith } // namespace tvm diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 5d1f102a5b7e..8143892d9abd 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -209,10 +209,11 @@ void SmithNormalFormDiag(std::vector>* S, std::vector InferRange(const Map& vars_to_infer, const Array& ori_vars, - const Map& ori_ranges) { +ffi::Map InferRange(const ffi::Map& vars_to_infer, + const ffi::Array& ori_vars, + const ffi::Map& ori_ranges) { // The resulting ranges - Map new_ranges; + ffi::Map new_ranges; std::unordered_set ori_vset; for (const Var& v : ori_vars) { @@ -260,7 +261,7 @@ void DebugPrint(const std::vector>& S, } std::cout << "\n"; } - std::cout << "V_inv x:\n" << Array(V_inv_x); + std::cout << "V_inv x:\n" << ffi::Array(V_inv_x); std::cout << "\n" << std::endl; } @@ -298,8 +299,8 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol for (const PrimExpr& equation : system_to_solve->relations) { if (const tir::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] - Array coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b), - system_to_solve->variables); + ffi::Array coeffs = arith::DetectLinearEquation( + analyzer_problem.Simplify(eq->a - eq->b), system_to_solve->variables); if (!coeffs.empty()) { std::vector row; for (size_t j = 0; j < coeffs.size() - 1; ++j) { @@ -337,10 +338,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // Uy is U \times y SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy); - Array new_vars; - Array new_relations; - Map new_to_old_map; - Map old_to_new_map; + ffi::Array new_vars; + ffi::Array new_relations; + ffi::Map new_to_old_map; + ffi::Map old_to_new_map; // Simplify right hand sides for (PrimExpr r : Uy) { @@ -372,7 +373,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol } } - Array solution_for_V_inv_x; + ffi::Array solution_for_V_inv_x; // Now create new variables or directly solve the equations // suppose the rank of A is r, aka r = # of non-zeros in S // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b @@ -421,7 +422,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol } // The resulting ranges - Map new_ranges = + ffi::Map new_ranges = InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); @@ -455,16 +456,16 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol return transform; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "arith.SolveLinearEquations", [](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveLinearEquations(args[0].cast()); } else if (args.size() == 3) { - auto opt_vars = args[0].cast>>(); - auto opt_map = args[1].cast>>(); - auto opt_relations = args[2].cast>>(); + auto opt_vars = args[0].cast>>(); + auto opt_map = args[1].cast>>(); + auto opt_relations = args[2].cast>>(); IntConstraints problem(opt_vars.value_or({}), opt_map.value_or({}), opt_relations.value_or({})); *ret = SolveLinearEquations(problem); @@ -472,7 +473,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); } }); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index bf50a0ea52ec..a46f9e520176 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -133,7 +133,7 @@ void ClassifyByPolarity(const Var& var, const std::vector& current_ine // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { if (const LENode* le = ineq.as()) { - Array coef = arith::DetectLinearEquation(le->a, {var}); + ffi::Array coef = arith::DetectLinearEquation(le->a, {var}); if (!coef.empty() && is_const_int(coef[0])) { int64_t coef0 = *as_const_int(coef[0]); if (coef0 == 0) { @@ -147,7 +147,7 @@ void ClassifyByPolarity(const Var& var, const std::vector& current_ine continue; } } else if (const EQNode* eq = ineq.as()) { - Array coef = arith::DetectLinearEquation(eq->a, {var}); + ffi::Array coef = arith::DetectLinearEquation(eq->a, {var}); if (!coef.empty() && is_const_int(coef[0])) { int64_t coef0 = *as_const_int(coef[0]); if (coef0 == 0) { @@ -218,7 +218,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t &analyzer); } - Map res_bounds; + ffi::Map res_bounds; for (const Var& v : system_to_solve->variables) { ICHECK(!res_bounds.count(v)) << "Variable " << v @@ -329,16 +329,16 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Write it to the result. IntGroupBounds bnds(make_const(v.dtype(), coef_lcm), - Array(lower_bounds.begin(), lower_bounds.end()), - Array(equal_list.begin(), equal_list.end()), - Array(upper_bounds.begin(), upper_bounds.end())); + ffi::Array(lower_bounds.begin(), lower_bounds.end()), + ffi::Array(equal_list.begin(), equal_list.end()), + ffi::Array(upper_bounds.begin(), upper_bounds.end())); res_bounds.Set(v, bnds); std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); } // Everything that is left goes to res.relations - Array other_conditions; + ffi::Array other_conditions; for (const PrimExpr& e : current_ineq_set_to_solve) { PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite); if (is_const_int(e_simp, 0)) { @@ -366,17 +366,17 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges // It will be useful when solving Jacobian axes jac_xxx) - Map res_ranges; + ffi::Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; - Array solved_other_relations = solved_system.second; + ffi::Map solved_bounds = solved_system.first; + ffi::Array solved_other_relations = solved_system.second; - Array res_relations; + ffi::Array res_relations; // this keeps being updated during determining the range of each variable. - Map vranges; + ffi::Map vranges; for (std::pair vr : inequalities->ranges) { vranges.Set(vr.first, vr.second); } @@ -441,21 +441,21 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) - Map res_ranges; + ffi::Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; - Array solved_other_relations = solved_system.second; + ffi::Map solved_bounds = solved_system.first; + ffi::Array solved_other_relations = solved_system.second; arith::Analyzer analyzer; - Map res_src_to_dst; - Map res_dst_to_src; - Array res_variables; - Array res_relations; + ffi::Map res_src_to_dst; + ffi::Map res_dst_to_src; + ffi::Array res_variables; + ffi::Array res_relations; // this keeps being updated during determining the range of each variable. - Map vranges; + ffi::Map vranges; for (std::pair vr : inequalities->ranges) { vranges.Set(vr.first, vr.second); } @@ -528,7 +528,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } // Reverse the axis so that it matches the order of the original variables - res_variables = Array(res_variables.rbegin(), res_variables.rend()); + res_variables = ffi::Array(res_variables.rbegin(), res_variables.rend()); IntConstraints new_inequalities(res_variables, res_ranges, res_relations); IntConstraintsTransform transform(inequalities, new_inequalities, res_src_to_dst, res_dst_to_src); @@ -536,7 +536,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ return transform; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed( @@ -548,8 +548,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ problem = args[0].cast(); ret_ineq = SolveLinearInequalities(problem); } else if (args.size() == 3) { - problem = IntConstraints(args[0].cast>(), args[1].cast>(), - args[2].cast>()); + problem = IntConstraints(args[0].cast>(), + args[1].cast>(), + args[2].cast>()); ret_ineq = SolveLinearInequalities(problem); } else { LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " @@ -562,9 +563,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (args.size() == 1) { *ret = SolveInequalitiesToRange(args[0].cast()); } else if (args.size() == 3) { - auto opt_map = args[1].cast>>(); - IntConstraints problem(args[0].cast>(), opt_map.value_or({}), - args[2].cast>()); + auto opt_map = args[1].cast>>(); + IntConstraints problem(args[0].cast>(), opt_map.value_or({}), + args[2].cast>()); *ret = SolveInequalitiesToRange(problem); } else { LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " @@ -575,16 +576,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (args.size() == 1) { *ret = SolveInequalitiesDeskewRange(args[0].cast()); } else if (args.size() == 3) { - auto opt_map = args[1].cast>>(); - IntConstraints problem(args[0].cast>(), opt_map.value_or({}), - args[2].cast>()); + auto opt_map = args[1].cast>>(); + IntConstraints problem(args[0].cast>(), opt_map.value_or({}), + args[2].cast>()); *ret = SolveInequalitiesDeskewRange(problem); } else { LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " << args.size(); } }); -}); +} } // namespace arith } // namespace tvm diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 52010ec322c8..ec0173ca996e 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -82,6 +82,9 @@ class TransitiveComparisonAnalyzer::Impl { */ std::function EnterConstraint(const PrimExpr& expr); + // Copy internal state from another Impl (for Analyzer cloning) + void CopyFrom(const Impl& other); + private: /* \brief Internal representation of a PrimExpr * @@ -276,7 +279,7 @@ class TransitiveComparisonAnalyzer::Impl { * Tracked separatedly to handle the `allow_override` option used by * all sub-analyzers when binding variables. */ - Map prev_bindings_; + ffi::Map prev_bindings_; /*! \brief Known comparisons based on definitionally-true statements * @@ -600,6 +603,11 @@ std::function TransitiveComparisonAnalyzer::Impl::EnterConstraint(const return frecover; } +// Deep copy internal state from another analyzer +void TransitiveComparisonAnalyzer::CopyFrom(const TransitiveComparisonAnalyzer& other) { + this->impl_->CopyFrom(*other.impl_); +} + CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, const PrimExpr& rhs_expr, bool propagate_inequalities) const { @@ -872,5 +880,13 @@ CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons( return result; } +// Implementation of the CopyFrom helper +void TransitiveComparisonAnalyzer::Impl::CopyFrom(const Impl& other) { + prev_bindings_ = other.prev_bindings_; + knowns_ = other.knowns_; + scoped_knowns_ = other.scoped_knowns_; + expr_to_key = other.expr_to_key; +} + } // namespace arith } // namespace tvm diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc index 6a3e8c3d434c..c074eb5c935a 100644 --- a/src/arith/unwrap_vector_expr.cc +++ b/src/arith/unwrap_vector_expr.cc @@ -47,7 +47,7 @@ class Scalarizer : public ExprMutator { PrimExpr VisitExpr_(const BroadcastNode* op) final { return op->value; } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = let_var_remap_.find(op); if (it != let_var_remap_.end()) { diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index f582f6416d93..dc2d5d1ef9a1 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -53,41 +53,41 @@ class BaseOpCode { * \brief The constructor of BaseOpCode * \param func_name the function name for the node. */ - explicit BaseOpCode(const String& func_name) : func_name_(func_name) {} + explicit BaseOpCode(const ffi::String& func_name) : func_name_(func_name) {} virtual ~BaseOpCode() = default; /*! \brief Config the BaseOpCode*/ void Config(const MSCJoint& node, const std::shared_ptr config, - const Map& prims) { + const ffi::Map& prims) { node_ = node; config_ = config; prims_ = prims; } /*! \brief Get docs for the node*/ - virtual const Array GetDocs() = 0; + virtual const ffi::Array GetDocs() = 0; /*! \brief Get return describe for default node*/ - virtual const String IdxNode() { return IdxNodeBase(node_); } + virtual const ffi::String IdxNode() { return IdxNodeBase(node_); } /*! \brief Get describe for default node input*/ - const String IdxInput(int idx = 0, bool process = true) { + const ffi::String IdxInput(int idx = 0, bool process = true) { return IdxInputBase(node_, idx, process); } /*! \brief Get describe for default node output*/ - const String IdxOutput(int idx = 0) { return IdxOutputBase(node_, idx); } + const ffi::String IdxOutput(int idx = 0) { return IdxOutputBase(node_, idx); } /*! \brief Get describe for default node weight*/ - const String IdxWeight(const String& wtype, bool process = true) { + const ffi::String IdxWeight(const ffi::String& wtype, bool process = true) { return IdxWeightBase(node_, wtype, process); } /*! \brief Get the node attr as doc*/ - const ExprDoc GetAttrDoc(const String& key, const String& type) { + const ExprDoc GetAttrDoc(const ffi::String& key, const ffi::String& type) { if (StringUtils::StartsWith(type, "list")) { - const String& ele_type = + const ffi::String& ele_type = StringUtils::Replace(StringUtils::Replace(type, "list(", ""), ")", ""); if (ele_type == "bool") { return DocUtils::ToList(node_->GetTypeArrayAttr(key)); @@ -115,16 +115,16 @@ class BaseOpCode { } /*! \brief Get comment for default node*/ - const String Comment() { return Comment(node_); } + const ffi::String Comment() { return Comment(node_); } /*! \brief Get func_name for the default node*/ - const String func_name() { return func_name_; } + const ffi::String func_name() { return func_name_; } /*! \brief Get valid func name for the default node*/ - virtual const String callee_name() { return func_name(); } + virtual const ffi::String callee_name() { return func_name(); } /*! \brief Get valid return name for the default node*/ - virtual const String ret_name() { return IdxNode(); } + virtual const ffi::String ret_name() { return IdxNode(); } /*! \brief Get the default node*/ const MSCJoint node() { return node_; } @@ -132,7 +132,7 @@ class BaseOpCode { CODEGEN_MEMBERS; private: - String func_name_; + ffi::String func_name_; MSCJoint node_; }; @@ -170,7 +170,8 @@ class BaseCodeGen { virtual ~BaseCodeGen() = default; /*! \brief Get sources*/ - virtual const Map GetSources(const std::string& print_options = "") = 0; + virtual const ffi::Map GetSources( + const std::string& print_options = "") = 0; CODEGEN_MEMBERS; @@ -210,7 +211,7 @@ class BaseCodeGen { } /*! \brief Get the optype for op codegen*/ - const String GetOpType(const MSCJoint& node) { + const ffi::String GetOpType(const MSCJoint& node) { if (config_->use_plugin && IsPlugin(node->optype)) { return "plugin"; } @@ -218,10 +219,10 @@ class BaseCodeGen { } /*! \brief Get the docs for the op*/ - virtual const Array GetOpCodes(const MSCJoint& node) = 0; + virtual const ffi::Array GetOpCodes(const MSCJoint& node) = 0; /*! \brief Describe the prim*/ - virtual const String DescribePrim(const MSCPrim& prim) { + virtual const ffi::String DescribePrim(const MSCPrim& prim) { if (prim->optype == "Int") { return prim->GetTypeAttr("value"); } @@ -247,14 +248,14 @@ class BaseCodeGen { const MSCGraph graph() const { return graph_; } /*! \brief Get the scopes*/ - const std::stack> scopes() const { return scopes_; } + const std::stack> scopes() const { return scopes_; } /*! \brief The stack of codes*/ CodeStack stack_; private: MSCGraph graph_; - std::stack> scopes_; + std::stack> scopes_; }; } // namespace msc diff --git a/src/contrib/msc/core/codegen/code_stack.cc b/src/contrib/msc/core/codegen/code_stack.cc index 041ffe7091b2..e1b34f7d28b7 100644 --- a/src/contrib/msc/core/codegen/code_stack.cc +++ b/src/contrib/msc/core/codegen/code_stack.cc @@ -27,16 +27,16 @@ namespace tvm { namespace contrib { namespace msc { -const Array BaseStack::GetDocs() const { +const ffi::Array BaseStack::GetDocs() const { ICHECK(blocks_.size() == 1) << "Has incomplete blocks, please check"; return TopBlock(); } void BaseStack::Line(const Doc& doc) { PushDoc(doc); } -void BaseStack::Line(const String& line) { Line(IdDoc(line)); } +void BaseStack::Line(const ffi::String& line) { Line(IdDoc(line)); } -void BaseStack::Comment(const String& comment, bool attach) { +void BaseStack::Comment(const ffi::String& comment, bool attach) { if (attach) { const auto& doc = TopDoc(); ICHECK(doc->IsInstance()) << "Only stmt doc support attach comments"; @@ -47,38 +47,39 @@ void BaseStack::Comment(const String& comment, bool attach) { } } -void BaseStack::Declare(const String& type, const String& variable, size_t len, +void BaseStack::Declare(const ffi::String& type, const ffi::String& variable, size_t len, bool use_constructor) { PushDoc(DocUtils::ToDeclare(type, variable, len, use_constructor)); } void BaseStack::DeclareArgBase(const ExprDoc& value) { const auto& declare = PopCheckedDoc(); - Array init_args = declare->init_args; + ffi::Array init_args = declare->init_args; init_args.push_back(value); PushDoc(DeclareDoc(declare->type, declare->variable, init_args, declare->use_constructor)); } -void BaseStack::FuncDef(const String& func_name, const String& ret_type) { +void BaseStack::FuncDef(const ffi::String& func_name, const ffi::String& ret_type) { if (ret_type.size() > 0) { - PushDoc(FunctionDoc(IdDoc(func_name), Array(), Array(), IdDoc(ret_type), - Array())); + PushDoc(FunctionDoc(IdDoc(func_name), ffi::Array(), ffi::Array(), + IdDoc(ret_type), ffi::Array())); } else { - PushDoc(FunctionDoc(IdDoc(func_name), Array(), Array(), std::nullopt, - Array())); + PushDoc(FunctionDoc(IdDoc(func_name), ffi::Array(), ffi::Array(), + std::nullopt, ffi::Array())); } } -void BaseStack::FuncArg(const String& arg, const String& annotation, const String& value) { +void BaseStack::FuncArg(const ffi::String& arg, const ffi::String& annotation, + const ffi::String& value) { const auto& func = PopCheckedDoc(); - Array args = func->args; + ffi::Array args = func->args; args.push_back(DocUtils::ToAssign(arg, value, annotation)); PushDoc(FunctionDoc(func->name, args, func->decorators, func->return_type, func->body)); } -void BaseStack::FuncDecorator(const String& decorator) { +void BaseStack::FuncDecorator(const ffi::String& decorator) { const auto& func = PopCheckedDoc(); - Array decorators = func->decorators; + ffi::Array decorators = func->decorators; decorators.push_back(IdDoc(decorator)); PushDoc(FunctionDoc(func->name, func->args, decorators, func->return_type, func->body)); } @@ -95,13 +96,13 @@ void BaseStack::FuncEnd() { PushDoc(FunctionDoc(func->name, func->args, func->decorators, func->return_type, body)); } -void BaseStack::ClassDef(const String& class_name) { - PushDoc(ClassDoc(IdDoc(class_name), Array(), Array())); +void BaseStack::ClassDef(const ffi::String& class_name) { + PushDoc(ClassDoc(IdDoc(class_name), ffi::Array(), ffi::Array())); } -void BaseStack::ClassDecorator(const String& decorator) { +void BaseStack::ClassDecorator(const ffi::String& decorator) { const auto& class_doc = PopCheckedDoc(); - Array decorators = class_doc->decorators; + ffi::Array decorators = class_doc->decorators; decorators.push_back(IdDoc(decorator)); PushDoc(ClassDoc(class_doc->name, decorators, class_doc->body)); } @@ -118,8 +119,8 @@ void BaseStack::ClassEnd() { PushDoc(ClassDoc(class_doc->name, class_doc->decorators, body)); } -void BaseStack::StructStart(const String& struct_name) { - PushDoc(StructDoc(IdDoc(struct_name), Array(), Array())); +void BaseStack::StructStart(const ffi::String& struct_name) { + PushDoc(StructDoc(IdDoc(struct_name), ffi::Array(), ffi::Array())); BlockStart(); } @@ -130,13 +131,14 @@ void BaseStack::StructEnd() { PushDoc(StructDoc(struct_doc->name, struct_doc->decorators, body)); } -void BaseStack::ConstructorDef(const String& constructor_name) { - PushDoc(ConstructorDoc(IdDoc(constructor_name), Array(), Array())); +void BaseStack::ConstructorDef(const ffi::String& constructor_name) { + PushDoc(ConstructorDoc(IdDoc(constructor_name), ffi::Array(), ffi::Array())); } -void BaseStack::ConstructorArg(const String& arg, const String& annotation, const String& value) { +void BaseStack::ConstructorArg(const ffi::String& arg, const ffi::String& annotation, + const ffi::String& value) { const auto& func = PopCheckedDoc(); - Array args = func->args; + ffi::Array args = func->args; args.push_back(DocUtils::ToAssign(arg, value, annotation)); PushDoc(ConstructorDoc(func->name, args, func->body)); } @@ -153,20 +155,22 @@ void BaseStack::ConstructorEnd() { PushDoc(ConstructorDoc(func->name, func->args, body)); } -void BaseStack::LambdaDef(const String& lambda_name) { - PushDoc(LambdaDoc(IdDoc(lambda_name), Array(), Array(), Array())); +void BaseStack::LambdaDef(const ffi::String& lambda_name) { + PushDoc(LambdaDoc(IdDoc(lambda_name), ffi::Array(), ffi::Array(), + ffi::Array())); } -void BaseStack::LambdaArg(const String& arg, const String& annotation, const String& value) { +void BaseStack::LambdaArg(const ffi::String& arg, const ffi::String& annotation, + const ffi::String& value) { const auto& lambda = PopCheckedDoc(); - Array args = lambda->args; + ffi::Array args = lambda->args; args.push_back(DocUtils::ToAssign(arg, value, annotation)); PushDoc(LambdaDoc(lambda->name, args, lambda->refs, lambda->body)); } -void BaseStack::LambdaRef(const String& ref) { +void BaseStack::LambdaRef(const ffi::String& ref) { const auto& lambda = PopCheckedDoc(); - Array refs = lambda->refs; + ffi::Array refs = lambda->refs; refs.push_back(IdDoc(ref)); PushDoc(LambdaDoc(lambda->name, lambda->args, refs, lambda->body)); } @@ -176,7 +180,7 @@ void BaseStack::LambdaStart() { BlockStart(); } -void BaseStack::LambdaEnd(const String& ret_val) { +void BaseStack::LambdaEnd(const ffi::String& ret_val) { if (ret_val.size() > 0) { PushDoc(ReturnDoc(IdDoc(ret_val))); } @@ -191,13 +195,15 @@ void BaseStack::LambdaEnd(const ExprDoc& ret_val) { LambdaEnd(""); } -void BaseStack::FuncCall(const String& callee, Optional assign_to, - Optional caller) { +void BaseStack::FuncCall(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller) { if (!caller.defined()) { - PushDoc(CallDoc(IdDoc(callee), Array(), Array(), Array())); + PushDoc(CallDoc(IdDoc(callee), ffi::Array(), ffi::Array(), + ffi::Array())); } else { const auto& new_access = AttrAccessDoc(caller.value(), callee); - PushDoc(CallDoc(new_access, Array(), Array(), Array())); + PushDoc(CallDoc(new_access, ffi::Array(), ffi::Array(), + ffi::Array())); } if (assign_to.defined()) { const auto& last_call = PopCheckedDoc(); @@ -211,14 +217,15 @@ void BaseStack::FuncCall(const String& callee, Optional assign_to, } } -void BaseStack::FuncCall(const String& callee, const String& assign_to, const String& caller) { - Optional assign_doc; +void BaseStack::FuncCall(const ffi::String& callee, const ffi::String& assign_to, + const ffi::String& caller) { + ffi::Optional assign_doc; if (assign_to.size() == 0) { assign_doc = std::nullopt; } else { assign_doc = IdDoc(assign_to); } - Optional caller_doc; + ffi::Optional caller_doc; if (caller.size() == 0) { caller_doc = std::nullopt; } else { @@ -227,26 +234,27 @@ void BaseStack::FuncCall(const String& callee, const String& assign_to, const St FuncCall(callee, assign_doc, caller_doc); } -void BaseStack::MethodCall(const String& callee, bool new_line) { +void BaseStack::MethodCall(const ffi::String& callee, bool new_line) { const auto& host = PopDoc(); if (host->IsInstance()) { const auto& v_callee = callee + (new_line ? DocSymbol::NextLine() : ""); FuncCall(v_callee, std::nullopt, Downcast(host)); } else if (const auto* a_node = host.as()) { ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; - FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, Array(), true), + FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, ffi::Array(), true), a_node->rhs); } else { LOG(FATAL) << "Unexpected host type for inplace " << host->GetTypeKey(); } } -void BaseStack::InplaceStart(const String& callee, Optional assign_to, - Optional caller) { +void BaseStack::InplaceStart(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller) { FuncCall(callee, assign_to, caller); } -void BaseStack::InplaceStart(const String& callee, const String& assign_to, const String& caller) { +void BaseStack::InplaceStart(const ffi::String& callee, const ffi::String& assign_to, + const ffi::String& caller) { FuncCall(callee, assign_to, caller); } @@ -266,7 +274,7 @@ void BaseStack::InplaceEnd() { } } -void BaseStack::PopNest(const String& key) { +void BaseStack::PopNest(const ffi::String& key) { const auto& last = PopDoc(); if (last->IsInstance()) { CallArgBase(Downcast(last), key); @@ -275,11 +283,11 @@ void BaseStack::PopNest(const String& key) { } } -void BaseStack::CallArgBase(const ExprDoc& value, const String& key) { +void BaseStack::CallArgBase(const ExprDoc& value, const ffi::String& key) { const auto& last = PopDoc(); - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; // get args and kwargs if (const auto* call = last.as()) { args = call->args; @@ -313,16 +321,16 @@ void BaseStack::CallArgBase(const ExprDoc& value, const String& key) { } } -void BaseStack::ConditionIf(const String& predicate) { - Array else_branch{ExprStmtDoc(IdDoc("pass"))}; - PushDoc(IfDoc(IdDoc(predicate), Array(), else_branch)); +void BaseStack::ConditionIf(const ffi::String& predicate) { + ffi::Array else_branch{ExprStmtDoc(IdDoc("pass"))}; + PushDoc(IfDoc(IdDoc(predicate), ffi::Array(), else_branch)); BlockStart(); } void BaseStack::ConditionElse() { const auto& block = PopBlock(); const auto& if_doc = PopCheckedDoc(); - PushDoc(IfDoc(if_doc->predicate, DocUtils::ToStmts(block), Array())); + PushDoc(IfDoc(if_doc->predicate, DocUtils::ToStmts(block), ffi::Array())); BlockStart(); } @@ -331,7 +339,7 @@ void BaseStack::ConditionEnd() { const auto& if_doc = PopCheckedDoc(); const auto& branch = DocUtils::ToStmts(block); if (if_doc->then_branch.size() == 0) { - PushDoc(IfDoc(if_doc->predicate, branch, Array())); + PushDoc(IfDoc(if_doc->predicate, branch, ffi::Array())); } else { PushDoc(IfDoc(if_doc->predicate, if_doc->then_branch, branch)); } @@ -344,8 +352,8 @@ void BaseStack::ForEnd() { PushDoc(ForDoc(for_doc->lhs, for_doc->rhs, body)); } -void BaseStack::WhileStart(const String& predicate) { - PushDoc(WhileDoc(IdDoc(predicate), Array())); +void BaseStack::WhileStart(const ffi::String& predicate) { + PushDoc(WhileDoc(IdDoc(predicate), ffi::Array())); BlockStart(); } @@ -356,20 +364,20 @@ void BaseStack::WhileEnd() { PushDoc(WhileDoc(while_doc->predicate, body)); } -void BaseStack::SwitchStart(const String& predicate) { - Array predicates; +void BaseStack::SwitchStart(const ffi::String& predicate) { + ffi::Array predicates; predicates.push_back(IdDoc(predicate)); - PushDoc(SwitchDoc(predicates, Array>(), Array())); + PushDoc(SwitchDoc(predicates, ffi::Array>(), ffi::Array())); BlockStart(); } -void BaseStack::SwitchCase(const String& predicate) { +void BaseStack::SwitchCase(const ffi::String& predicate) { const auto& block = PopBlock(); const auto& switch_doc = PopCheckedDoc(); auto branchs = switch_doc->branchs; branchs.push_back(DocUtils::ToStmts(block)); if (predicate.size() == 0) { - Array default_branch{ExprStmtDoc(IdDoc("pass"))}; + ffi::Array default_branch{ExprStmtDoc(IdDoc("pass"))}; PushDoc(SwitchDoc(switch_doc->predicates, branchs, default_branch)); } else { auto predicates = switch_doc->predicates; @@ -392,7 +400,7 @@ void BaseStack::SwitchEnd() { } void BaseStack::BlockStart() { - Array block; + ffi::Array block; blocks_.push(block); } @@ -407,11 +415,11 @@ void BaseStack::BlockEnd(bool block_docs) { } } -void BaseStack::ScopeStart(const String& scope_def, const String& scope_ref) { +void BaseStack::ScopeStart(const ffi::String& scope_def, const ffi::String& scope_ref) { if (scope_ref.size() > 0) { - PushDoc(ScopeDoc(IdDoc(scope_ref), IdDoc(scope_def), Array())); + PushDoc(ScopeDoc(IdDoc(scope_ref), IdDoc(scope_def), ffi::Array())); } else { - PushDoc(ScopeDoc(std::nullopt, IdDoc(scope_def), Array())); + PushDoc(ScopeDoc(std::nullopt, IdDoc(scope_def), ffi::Array())); } BlockStart(); } @@ -424,12 +432,12 @@ void BaseStack::ScopeEnd() { bool BaseStack::HasBlock() const { return blocks_.size() > 0; } -const Array BaseStack::TopBlock() const { +const ffi::Array BaseStack::TopBlock() const { ICHECK(HasBlock()) << "No block found"; return blocks_.top(); } -const Array BaseStack::PopBlock() { +const ffi::Array BaseStack::PopBlock() { const auto& block = TopBlock(); blocks_.pop(); return block; diff --git a/src/contrib/msc/core/codegen/code_stack.h b/src/contrib/msc/core/codegen/code_stack.h index ff4e6b58247a..d588c3cf4f31 100644 --- a/src/contrib/msc/core/codegen/code_stack.h +++ b/src/contrib/msc/core/codegen/code_stack.h @@ -59,24 +59,24 @@ class BaseStack { } /*! \brief Get the docs*/ - const Array GetDocs() const; + const ffi::Array GetDocs() const; protected: /*! \brief Push Id Doc*/ void Line(const Doc& doc); - void Line(const String& line = ""); + void Line(const ffi::String& line = ""); /*! \brief Push Comment Doc*/ - void Comment(const String& comment, bool attach = false); + void Comment(const ffi::String& comment, bool attach = false); /*! \brief Push Assign Doc*/ template - inline void Assign(const LT& lhs, const RT& rhs, const String& annotation = "") { + inline void Assign(const LT& lhs, const RT& rhs, const ffi::String& annotation = "") { PushDoc(DocUtils::ToAssign(lhs, rhs, annotation)); } /*! \brief Push declare Doc*/ - void Declare(const String& type, const String& variable, size_t len = 0, + void Declare(const ffi::String& type, const ffi::String& variable, size_t len = 0, bool use_constructor = true); /*! \brief Cache declare argument*/ @@ -89,10 +89,10 @@ class BaseStack { } /*! \brief Cache class Doc*/ - void ClassDef(const String& class_name); + void ClassDef(const ffi::String& class_name); /*! \brief Cache class decorator*/ - void ClassDecorator(const String& decorator); + void ClassDecorator(const ffi::String& decorator); /*! \brief Start class body block*/ void ClassStart(); @@ -101,19 +101,20 @@ class BaseStack { void ClassEnd(); /*! \brief Start struct body block*/ - void StructStart(const String& struct_name); + void StructStart(const ffi::String& struct_name); /*! \brief End struct body block*/ void StructEnd(); /*! \brief Cache function Doc*/ - void FuncDef(const String& func_name, const String& ret_type = ""); + void FuncDef(const ffi::String& func_name, const ffi::String& ret_type = ""); /*! \brief Cache function argument*/ - void FuncArg(const String& arg, const String& annotation = "", const String& value = ""); + void FuncArg(const ffi::String& arg, const ffi::String& annotation = "", + const ffi::String& value = ""); /*! \brief Cache function decorator*/ - void FuncDecorator(const String& decorator); + void FuncDecorator(const ffi::String& decorator); /*! \brief Start function body block*/ void FuncStart(); @@ -128,10 +129,11 @@ class BaseStack { } /*! \brief Cache constructor Doc*/ - void ConstructorDef(const String& constructor_name); + void ConstructorDef(const ffi::String& constructor_name); /*! \brief Cache constructor argument*/ - void ConstructorArg(const String& arg, const String& annotation = "", const String& value = ""); + void ConstructorArg(const ffi::String& arg, const ffi::String& annotation = "", + const ffi::String& value = ""); /*! \brief Start constructor body block*/ void ConstructorStart(); @@ -140,52 +142,55 @@ class BaseStack { void ConstructorEnd(); /*! \brief Cache lambda Doc*/ - void LambdaDef(const String& lambda_name); + void LambdaDef(const ffi::String& lambda_name); /*! \brief Cache lambda argument*/ - void LambdaArg(const String& arg, const String& annotation = "", const String& value = ""); + void LambdaArg(const ffi::String& arg, const ffi::String& annotation = "", + const ffi::String& value = ""); /*! \brief Cache lambda reference*/ - void LambdaRef(const String& ref); + void LambdaRef(const ffi::String& ref); /*! \brief Start lambda body block*/ void LambdaStart(); /*! \brief End lambda body block*/ - void LambdaEnd(const String& ret_val = ""); + void LambdaEnd(const ffi::String& ret_val = ""); void LambdaEnd(const ExprDoc& ret_val); /*! \brief Push call and maybe assign Doc*/ - void FuncCall(const String& callee, Optional assign_to, - Optional caller = std::nullopt); - void FuncCall(const String& callee, const String& assign_to = "", const String& caller = ""); + void FuncCall(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller = std::nullopt); + void FuncCall(const ffi::String& callee, const ffi::String& assign_to = "", + const ffi::String& caller = ""); /*! \brief Push method call Doc*/ - void MethodCall(const String& callee, bool new_line = false); + void MethodCall(const ffi::String& callee, bool new_line = false); /*! \brief Push inplace call and maybe assign Doc*/ - void InplaceStart(const String& callee, Optional assign_to, - Optional caller = std::nullopt); - void InplaceStart(const String& callee, const String& assign_to = "", const String& caller = ""); + void InplaceStart(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller = std::nullopt); + void InplaceStart(const ffi::String& callee, const ffi::String& assign_to = "", + const ffi::String& caller = ""); /*! \brief End inplace call*/ void InplaceEnd(); /*! \brief Push nested expr to last Doc*/ - void PopNest(const String& key = ""); + void PopNest(const ffi::String& key = ""); /*! \brief Cache call typed argument*/ - void CallArgBase(const ExprDoc& value, const String& key = ""); + void CallArgBase(const ExprDoc& value, const ffi::String& key = ""); /*! \brief Cache call normal argument*/ template - inline void CallArg(T value, const String& key = "") { + inline void CallArg(T value, const ffi::String& key = "") { const auto& doc_value = DocUtils::ToDoc(value); if (doc_value.defined()) { CallArgBase(doc_value, key); } } - inline void CallArg(const Array& values) { + inline void CallArg(const ffi::Array& values) { for (const auto& v : values) { if (v.defined()) { CallArgBase(v); @@ -194,7 +199,7 @@ class BaseStack { } /*! \brief Push if to cache and start if block*/ - void ConditionIf(const String& predicate); + void ConditionIf(const ffi::String& predicate); /*! \brief Push then branch to cached and start block*/ void ConditionElse(); @@ -205,15 +210,15 @@ class BaseStack { /*! \brief Push for to cache and start for block*/ template void ForStart(const LT& lhs, const RT& rhs) { - PushDoc(ForDoc(DocUtils::ToDoc(lhs), DocUtils::ToDoc(rhs), Array())); + PushDoc(ForDoc(DocUtils::ToDoc(lhs), DocUtils::ToDoc(rhs), ffi::Array())); BlockStart(); } /*! \brief Push for range to cache and start for block*/ template - void ForStart(const String& lhs, const ST& start, const ET& end) { - Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; - PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), Array())); + void ForStart(const ffi::String& lhs, const ST& start, const ET& end) { + ffi::Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; + PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), ffi::Array())); BlockStart(); } @@ -221,16 +226,16 @@ class BaseStack { void ForEnd(); /*! \brief Push while to cache and start while block*/ - void WhileStart(const String& predicate); + void WhileStart(const ffi::String& predicate); /*! \brief End a while block*/ void WhileEnd(); /*! \brief Push switch to cache and start switch block*/ - void SwitchStart(const String& predicate); + void SwitchStart(const ffi::String& predicate); /*! \brief Add new case to switch*/ - void SwitchCase(const String& predicate = ""); + void SwitchCase(const ffi::String& predicate = ""); /*! \brief Push switch to cached*/ void SwitchEnd(); @@ -242,7 +247,7 @@ class BaseStack { void BlockEnd(bool block_docs = true); /*! \brief Start a new scope*/ - void ScopeStart(const String& scope_def = "", const String& scope_ref = ""); + void ScopeStart(const ffi::String& scope_def = "", const ffi::String& scope_ref = ""); /*! \brief End a scope*/ void ScopeEnd(); @@ -252,10 +257,10 @@ class BaseStack { bool HasBlock() const; /*! \brief Get the last the block*/ - const Array TopBlock() const; + const ffi::Array TopBlock() const; /*! \brief Pop last the block*/ - const Array PopBlock(); + const ffi::Array PopBlock(); /*! \brief Check if doc left*/ bool HasDoc(); @@ -274,237 +279,239 @@ class BaseStack { void PushDoc(const Doc& doc); /*! \brief The blocks, each has docs array*/ - std::stack> blocks_; + std::stack> blocks_; }; -#define COMMON_WRAPPERS(Stack) \ - Stack& line(const Doc& doc) { \ - Line(doc); \ - return *this; \ - } \ - Stack& line(const String& line = "") { \ - Line(line); \ - return *this; \ - } \ - Stack& comment(const String& comment, bool attach = false) { \ - Comment(comment, attach); \ - return *this; \ - } \ - template \ - Stack& assign(const LT& lhs, const RT& rhs, const String& annotation = "") { \ - Assign(lhs, rhs, annotation); \ - return *this; \ - } \ - Stack& declare(const String& type, const String& variable, size_t len = 0, \ - bool use_constructor = true) { \ - Declare(type, variable, len, use_constructor); \ - return *this; \ - } \ - template \ - Stack& declare_arg(const T& value) { \ - DeclareArg(value); \ - return *this; \ - } \ - Stack& class_def(const String& class_name) { \ - ClassDef(class_name); \ - return *this; \ - } \ - Stack& class_decorator(const String& decorator) { \ - ClassDecorator(decorator); \ - return *this; \ - } \ - Stack& class_start() { \ - ClassStart(); \ - return *this; \ - } \ - Stack& class_end() { \ - ClassEnd(); \ - return *this; \ - } \ - Stack& struct_start(const String& struct_name) { \ - StructStart(struct_name); \ - return *this; \ - } \ - Stack& struct_end() { \ - StructEnd(); \ - return *this; \ - } \ - Stack& func_def(const String& func_name, const String& ret_type = "") { \ - FuncDef(func_name, ret_type); \ - return *this; \ - } \ - Stack& func_arg(const String& arg, const String& annotation = "", const String& value = "") { \ - FuncArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& func_decorator(const String& decorator) { \ - FuncDecorator(decorator); \ - return *this; \ - } \ - Stack& func_start() { \ - FuncStart(); \ - return *this; \ - } \ - Stack& func_end() { \ - FuncEnd(); \ - return *this; \ - } \ - template \ - Stack& func_end(const T& ret_val) { \ - FuncEnd(ret_val); \ - return *this; \ - } \ - Stack& func_call(const String& callee, Optional assign_to, \ - Optional caller = std::nullopt) { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& func_call(const String& callee, const String& assign_to = "", \ - const String& caller = "") { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& method_call(const String& callee, bool new_line = false) { \ - MethodCall(callee, new_line); \ - return *this; \ - } \ - Stack& inplace_start(const String& callee, Optional assign_to, \ - Optional caller = std::nullopt) { \ - InplaceStart(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& inplace_start(const String& callee, const String& assign_to = "", \ - const String& caller = "") { \ - InplaceStart(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& inplace_end() { \ - InplaceEnd(); \ - return *this; \ - } \ - Stack& constructor_def(const String& func_name) { \ - ConstructorDef(func_name); \ - return *this; \ - } \ - Stack& constructor_arg(const String& arg, const String& annotation = "", \ - const String& value = "") { \ - ConstructorArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& constructor_start() { \ - ConstructorStart(); \ - return *this; \ - } \ - Stack& constructor_end() { \ - ConstructorEnd(); \ - return *this; \ - } \ - Stack& lambda_def(const String& lambda_name) { \ - LambdaDef(lambda_name); \ - return *this; \ - } \ - Stack& lambda_arg(const String& arg, const String& annotation = "", const String& value = "") { \ - LambdaArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& lambda_ref(const String& ref) { \ - LambdaRef(ref); \ - return *this; \ - } \ - Stack& lambda_start() { \ - LambdaStart(); \ - return *this; \ - } \ - Stack& lambda_end(const String& ret_val = "") { \ - LambdaEnd(ret_val); \ - return *this; \ - } \ - Stack& lambda_end(const ExprDoc& ret_val) { \ - LambdaEnd(ret_val); \ - return *this; \ - } \ - Stack& pop_nest(const String& key = "") { \ - PopNest(key); \ - return *this; \ - } \ - template \ - Stack& call_arg(T value, const String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const ExprDoc& value, const String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const Array& values) { \ - CallArg(values); \ - return *this; \ - } \ - Stack& cond_if(const String& predicate) { \ - ConditionIf(predicate); \ - return *this; \ - } \ - Stack& cond_else() { \ - ConditionElse(); \ - return *this; \ - } \ - Stack& cond_end() { \ - ConditionEnd(); \ - return *this; \ - } \ - template \ - Stack& for_start(const LT& lhs, const RT& rhs) { \ - ForStart(lhs, rhs); \ - return *this; \ - } \ - template \ - Stack& for_start(const String& lhs, const ST& start, const ET& end) { \ - ForStart(lhs, start, end); \ - return *this; \ - } \ - Stack& for_start(const String& lhs, const String& start, const String& end) { \ - ForStart(lhs, start, end); \ - return *this; \ - } \ - Stack& for_end() { \ - ForEnd(); \ - return *this; \ - } \ - Stack& while_start(const String& predicate) { \ - WhileStart(predicate); \ - return *this; \ - } \ - Stack& while_end() { \ - WhileEnd(); \ - return *this; \ - } \ - Stack& switch_start(const String& predicate) { \ - SwitchStart(predicate); \ - return *this; \ - } \ - Stack& switch_case(const String& predicate = "") { \ - SwitchCase(predicate); \ - return *this; \ - } \ - Stack& switch_end() { \ - SwitchEnd(); \ - return *this; \ - } \ - Stack& block_start() { \ - BlockStart(); \ - return *this; \ - } \ - Stack& block_end(bool block_docs = true) { \ - BlockEnd(block_docs); \ - return *this; \ - } \ - Stack& scope_start(const String& scope_def = "", const String& scope_ref = "") { \ - ScopeStart(scope_def, scope_ref); \ - return *this; \ - } \ - Stack& scope_end() { \ - ScopeEnd(); \ - return *this; \ +#define COMMON_WRAPPERS(Stack) \ + Stack& line(const Doc& doc) { \ + Line(doc); \ + return *this; \ + } \ + Stack& line(const ffi::String& line = "") { \ + Line(line); \ + return *this; \ + } \ + Stack& comment(const ffi::String& comment, bool attach = false) { \ + Comment(comment, attach); \ + return *this; \ + } \ + template \ + Stack& assign(const LT& lhs, const RT& rhs, const ffi::String& annotation = "") { \ + Assign(lhs, rhs, annotation); \ + return *this; \ + } \ + Stack& declare(const ffi::String& type, const ffi::String& variable, size_t len = 0, \ + bool use_constructor = true) { \ + Declare(type, variable, len, use_constructor); \ + return *this; \ + } \ + template \ + Stack& declare_arg(const T& value) { \ + DeclareArg(value); \ + return *this; \ + } \ + Stack& class_def(const ffi::String& class_name) { \ + ClassDef(class_name); \ + return *this; \ + } \ + Stack& class_decorator(const ffi::String& decorator) { \ + ClassDecorator(decorator); \ + return *this; \ + } \ + Stack& class_start() { \ + ClassStart(); \ + return *this; \ + } \ + Stack& class_end() { \ + ClassEnd(); \ + return *this; \ + } \ + Stack& struct_start(const ffi::String& struct_name) { \ + StructStart(struct_name); \ + return *this; \ + } \ + Stack& struct_end() { \ + StructEnd(); \ + return *this; \ + } \ + Stack& func_def(const ffi::String& func_name, const ffi::String& ret_type = "") { \ + FuncDef(func_name, ret_type); \ + return *this; \ + } \ + Stack& func_arg(const ffi::String& arg, const ffi::String& annotation = "", \ + const ffi::String& value = "") { \ + FuncArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& func_decorator(const ffi::String& decorator) { \ + FuncDecorator(decorator); \ + return *this; \ + } \ + Stack& func_start() { \ + FuncStart(); \ + return *this; \ + } \ + Stack& func_end() { \ + FuncEnd(); \ + return *this; \ + } \ + template \ + Stack& func_end(const T& ret_val) { \ + FuncEnd(ret_val); \ + return *this; \ + } \ + Stack& func_call(const ffi::String& callee, ffi::Optional assign_to, \ + ffi::Optional caller = std::nullopt) { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& func_call(const ffi::String& callee, const ffi::String& assign_to = "", \ + const ffi::String& caller = "") { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& method_call(const ffi::String& callee, bool new_line = false) { \ + MethodCall(callee, new_line); \ + return *this; \ + } \ + Stack& inplace_start(const ffi::String& callee, ffi::Optional assign_to, \ + ffi::Optional caller = std::nullopt) { \ + InplaceStart(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& inplace_start(const ffi::String& callee, const ffi::String& assign_to = "", \ + const ffi::String& caller = "") { \ + InplaceStart(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& inplace_end() { \ + InplaceEnd(); \ + return *this; \ + } \ + Stack& constructor_def(const ffi::String& func_name) { \ + ConstructorDef(func_name); \ + return *this; \ + } \ + Stack& constructor_arg(const ffi::String& arg, const ffi::String& annotation = "", \ + const ffi::String& value = "") { \ + ConstructorArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& constructor_start() { \ + ConstructorStart(); \ + return *this; \ + } \ + Stack& constructor_end() { \ + ConstructorEnd(); \ + return *this; \ + } \ + Stack& lambda_def(const ffi::String& lambda_name) { \ + LambdaDef(lambda_name); \ + return *this; \ + } \ + Stack& lambda_arg(const ffi::String& arg, const ffi::String& annotation = "", \ + const ffi::String& value = "") { \ + LambdaArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& lambda_ref(const ffi::String& ref) { \ + LambdaRef(ref); \ + return *this; \ + } \ + Stack& lambda_start() { \ + LambdaStart(); \ + return *this; \ + } \ + Stack& lambda_end(const ffi::String& ret_val = "") { \ + LambdaEnd(ret_val); \ + return *this; \ + } \ + Stack& lambda_end(const ExprDoc& ret_val) { \ + LambdaEnd(ret_val); \ + return *this; \ + } \ + Stack& pop_nest(const ffi::String& key = "") { \ + PopNest(key); \ + return *this; \ + } \ + template \ + Stack& call_arg(T value, const ffi::String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const ExprDoc& value, const ffi::String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const ffi::Array& values) { \ + CallArg(values); \ + return *this; \ + } \ + Stack& cond_if(const ffi::String& predicate) { \ + ConditionIf(predicate); \ + return *this; \ + } \ + Stack& cond_else() { \ + ConditionElse(); \ + return *this; \ + } \ + Stack& cond_end() { \ + ConditionEnd(); \ + return *this; \ + } \ + template \ + Stack& for_start(const LT& lhs, const RT& rhs) { \ + ForStart(lhs, rhs); \ + return *this; \ + } \ + template \ + Stack& for_start(const ffi::String& lhs, const ST& start, const ET& end) { \ + ForStart(lhs, start, end); \ + return *this; \ + } \ + Stack& for_start(const ffi::String& lhs, const ffi::String& start, const ffi::String& end) { \ + ForStart(lhs, start, end); \ + return *this; \ + } \ + Stack& for_end() { \ + ForEnd(); \ + return *this; \ + } \ + Stack& while_start(const ffi::String& predicate) { \ + WhileStart(predicate); \ + return *this; \ + } \ + Stack& while_end() { \ + WhileEnd(); \ + return *this; \ + } \ + Stack& switch_start(const ffi::String& predicate) { \ + SwitchStart(predicate); \ + return *this; \ + } \ + Stack& switch_case(const ffi::String& predicate = "") { \ + SwitchCase(predicate); \ + return *this; \ + } \ + Stack& switch_end() { \ + SwitchEnd(); \ + return *this; \ + } \ + Stack& block_start() { \ + BlockStart(); \ + return *this; \ + } \ + Stack& block_end(bool block_docs = true) { \ + BlockEnd(block_docs); \ + return *this; \ + } \ + Stack& scope_start(const ffi::String& scope_def = "", const ffi::String& scope_ref = "") { \ + ScopeStart(scope_def, scope_ref); \ + return *this; \ + } \ + Stack& scope_end() { \ + ScopeEnd(); \ + return *this; \ } /*! @@ -542,35 +549,37 @@ class OpCodeStack : public BaseStack { COMMON_WRAPPERS(OpCodeStack) /*! \brief Push op_call Doc*/ - OpCodeStack& op_call(const String& callee = "msc::auto", - const String& assign_to = "msc::auto") { - const String& v_callee = callee == "msc::auto" ? codegen_->callee_name() : callee; - const String& v_assign = assign_to == "msc::auto" ? codegen_->ret_name() : assign_to; + OpCodeStack& op_call(const ffi::String& callee = "msc::auto", + const ffi::String& assign_to = "msc::auto") { + const ffi::String& v_callee = callee == "msc::auto" ? codegen_->callee_name() : callee; + const ffi::String& v_assign = assign_to == "msc::auto" ? codegen_->ret_name() : assign_to; return func_call(v_callee, v_assign); } /*! \brief Push op comment Doc*/ - OpCodeStack& op_comment(const String& comment_str = "msc::auto") { - const String& v_comment = (comment_str == "msc::auto" ? codegen_->Comment() : comment_str); + OpCodeStack& op_comment(const ffi::String& comment_str = "msc::auto") { + const ffi::String& v_comment = (comment_str == "msc::auto" ? codegen_->Comment() : comment_str); return comment(v_comment); } /*! \brief Cache typed attribute as argument*/ template - OpCodeStack& op_arg(const String& attr_key, const String& key = "msc::auto") { + OpCodeStack& op_arg(const ffi::String& attr_key, + const ffi::String& key = "msc::auto") { T attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const String& valid_key = key == "msc::auto" ? attr_key : key; + const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; return call_arg(attr_val, valid_key); } return *this; } /*! \brief Cache str attribute as argument*/ - OpCodeStack& op_str_arg(const String& attr_key, const String& key = "msc::auto") { + OpCodeStack& op_str_arg(const ffi::String& attr_key, + const ffi::String& key = "msc::auto") { std::string attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const String& valid_key = key == "msc::auto" ? attr_key : key; + const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; return call_arg(DocUtils::ToStr(attr_val), valid_key); } return *this; @@ -578,24 +587,25 @@ class OpCodeStack : public BaseStack { /*! \brief Cache list attribute as argument*/ template - OpCodeStack& op_list_arg(const String& attr_key, const String& key = "msc::auto", + OpCodeStack& op_list_arg(const ffi::String& attr_key, + const ffi::String& key = "msc::auto", bool allow_empty = false) { std::vector attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const String& valid_key = key == "msc::auto" ? attr_key : key; + const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; return call_arg(DocUtils::ToList(attr_val, allow_empty), valid_key); } return *this; } /*! \brief Cache input as argument*/ - OpCodeStack& op_input_arg(int idx = 0, const String& key = "") { + OpCodeStack& op_input_arg(int idx = 0, const ffi::String& key = "") { return call_arg(codegen_->IdxInput(idx, true), key); } /*! \brief Cache inputs as argument*/ - OpCodeStack& op_inputs_arg(bool as_list = true, const String& key = "") { - Array inputs; + OpCodeStack& op_inputs_arg(bool as_list = true, const ffi::String& key = "") { + ffi::Array inputs; for (size_t i = 0; i < codegen_->node()->inputs.size(); i++) { inputs.push_back(codegen_->IdxInput(i, true)); } @@ -607,12 +617,12 @@ class OpCodeStack : public BaseStack { } /*! \brief Cache output as argument*/ - OpCodeStack& op_output_arg(int idx = 0, const String& key = "") { + OpCodeStack& op_output_arg(int idx = 0, const ffi::String& key = "") { return call_arg(codegen_->IdxOutput(idx), key); } /*! \brief Cache weight as argument*/ - OpCodeStack& op_weight_arg(const String& wtype, const String& key = "") { + OpCodeStack& op_weight_arg(const ffi::String& wtype, const ffi::String& key = "") { if (codegen_->node()->weights.count(wtype)) { return call_arg(codegen_->IdxWeight(wtype, true), key); } @@ -620,15 +630,15 @@ class OpCodeStack : public BaseStack { } /*! \brief Cache name as argument*/ - OpCodeStack& op_name_arg(const String& key = "msc::auto", - const String& name = "msc::auto") { - const String& valid_key = key == "msc::auto" ? "name" : key; - const String& valid_name = name == "msc::auto" ? codegen_->node()->name : name; + OpCodeStack& op_name_arg(const ffi::String& key = "msc::auto", + const ffi::String& name = "msc::auto") { + const ffi::String& valid_key = key == "msc::auto" ? "name" : key; + const ffi::String& valid_name = name == "msc::auto" ? codegen_->node()->name : name; return call_arg(DocUtils::ToStr(valid_name), valid_key); return *this; } - OpCodeStack& op_dtype_arg(const DataType& dtype, const String& key = "") { + OpCodeStack& op_dtype_arg(const DataType& dtype, const ffi::String& key = "") { return call_arg(codegen_->DType(dtype), key); } diff --git a/src/contrib/msc/core/codegen/codegen_json.cc b/src/contrib/msc/core/codegen/codegen_json.cc index 7bbe576b6bfe..6ccec35b78b4 100644 --- a/src/contrib/msc/core/codegen/codegen_json.cc +++ b/src/contrib/msc/core/codegen/codegen_json.cc @@ -50,11 +50,11 @@ std::vector MSCJSONSerializer::VisitExpr_(const CallNode* ca } global_options_set_ = true; } - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } -void MSCJSONSerializer::AddNodeAttr(JSONGraphObjectPtr node, const String& key, - const String& value) { +void MSCJSONSerializer::AddNodeAttr(JSONGraphObjectPtr node, const ffi::String& key, + const ffi::String& value) { std::vector array_value{std::string(value)}; std::vector dmlc_value; dmlc_value.emplace_back(array_value); diff --git a/src/contrib/msc/core/codegen/codegen_json.h b/src/contrib/msc/core/codegen/codegen_json.h index dfc2d699a968..08a834bdaa27 100644 --- a/src/contrib/msc/core/codegen/codegen_json.h +++ b/src/contrib/msc/core/codegen/codegen_json.h @@ -69,7 +69,7 @@ class MSCJSONSerializer : public JSONSerializer { * \brief Constructor * \param constant_names The names of all constants in the original module. */ - explicit MSCJSONSerializer(const Map& constant_names, + explicit MSCJSONSerializer(const ffi::Map& constant_names, const std::string& options) : JSONSerializer(constant_names) { MSCCompileConfig config; @@ -86,19 +86,19 @@ class MSCJSONSerializer : public JSONSerializer { std::vector VisitExpr_(const CallNode* call_node) final; - const String GetOption(const String& key) { + const ffi::String GetOption(const ffi::String& key) { ICHECK(options_.count(key)) << "Can not find option " << key; return options_[key]; } - const Map GetOptions() { return options_; } + const ffi::Map GetOptions() { return options_; } protected: - void AddNodeAttr(JSONGraphObjectPtr node, const String& key, const String& value); + void AddNodeAttr(JSONGraphObjectPtr node, const ffi::String& key, const ffi::String& value); private: MSCGraph graph_; - Map options_; + ffi::Map options_; bool global_options_set_; }; diff --git a/src/contrib/msc/core/codegen/codegen_utils.cc b/src/contrib/msc/core/codegen/codegen_utils.cc index 741b729bd015..768c9f276e9e 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.cc +++ b/src/contrib/msc/core/codegen/codegen_utils.cc @@ -27,13 +27,13 @@ namespace tvm { namespace contrib { namespace msc { -const String CodeGenUtils::IdxNode(const MSCJoint& node, const String& prefix, - const String& suffix) { +const ffi::String CodeGenUtils::IdxNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::String& suffix) { return prefix + std::to_string(node->index) + suffix; } -const String CodeGenUtils::IdxOutput(const MSCJoint& node, const String& prefix, int idx, - const String& suffix) { +const ffi::String CodeGenUtils::IdxOutput(const MSCJoint& node, const ffi::String& prefix, int idx, + const ffi::String& suffix) { const auto& idx_node = IdxNode(node, prefix, suffix); size_t output_size = node->outputs.size(); if (output_size == 1 && node->optype != "tuple") { @@ -43,20 +43,20 @@ const String CodeGenUtils::IdxOutput(const MSCJoint& node, const String& prefix, return idx_node + "[" + std::to_string(v_index) + "]"; } -const String CodeGenUtils::IdxInput(const MSCJoint& node, const String& prefix, int idx, - const String& suffix) { +const ffi::String CodeGenUtils::IdxInput(const MSCJoint& node, const ffi::String& prefix, int idx, + const ffi::String& suffix) { const auto& pair = node->ProducerAndIdxOf(idx); return IdxOutput(pair.first, prefix, pair.second, suffix); } -const String CodeGenUtils::IdxWeight(const MSCJoint& node, const String& wtype, - const String& suffix) { +const ffi::String CodeGenUtils::IdxWeight(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix) { return wtype + "_" + std::to_string(node->index) + suffix; } -const Array CodeGenUtils::GetPrims(const MSCTensor& tensor, - const Map& prims) { - Array dims; +const ffi::Array CodeGenUtils::GetPrims( + const MSCTensor& tensor, const ffi::Map& prims) { + ffi::Array dims; if (tensor->prims.size() == 0) { for (size_t i = 0; i < tensor->Ndim(); i++) { dims.push_back(StringUtils::ToString(tensor->DimAt(i))); @@ -70,9 +70,9 @@ const Array CodeGenUtils::GetPrims(const MSCTensor& tensor, return dims; } -const String CodeGenUtils::CommentNode(const MSCJoint& node, const String& prefix, - const Map& prims) { - String comment = node->name + "(" + node->optype + "): <"; +const ffi::String CodeGenUtils::CommentNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::Map& prims) { + ffi::String comment = node->name + "(" + node->optype + "): <"; for (size_t i = 0; i < node->inputs.size(); i++) { comment = comment + IdxInput(node, prefix, i) + (i == node->inputs.size() - 1 ? "> -> <" : ","); } diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h index 09b44af894e4..6fbaa96dd698 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ b/src/contrib/msc/core/codegen/codegen_utils.h @@ -86,39 +86,42 @@ using namespace tvm::script::printer; this->DescribePrim(prim->ParentAt(1)) + ")"; \ } -#define CODEGEN_MEMBERS \ - public: \ - virtual const String DType(const DataType& dtype) { return runtime::DLDataTypeToString(dtype); } \ - \ - protected: \ - const std::shared_ptr config() { return config_; } \ - const Map prims() { return prims_; } \ - const String IdxNodeBase(const MSCJoint& node) { \ - return helper_.IdxNodeBase(node, config()->prefix, ""); \ - } \ - const String IdxInputBase(const MSCJoint& node, int idx = 0, bool process = true) { \ - return helper_.IdxInputBase(node, config()->prefix, idx, "", process && config()->use_tools); \ - } \ - const String IdxOutputBase(const MSCJoint& node, int idx = 0, bool mark_exit = false) { \ - return helper_.IdxOutputBase(node, config()->prefix, idx, "", \ - mark_exit && config()->use_tools); \ - } \ - const String IdxWeightBase(const MSCJoint& node, const String& wtype, bool process = true) { \ - return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ - } \ - const Array GetPrims(const MSCTensor& tensor) { \ - return CodeGenUtils::GetPrims(tensor, prims_); \ - } \ - const String Comment(const MSCJoint& node) { \ - return helper_.Comment(node, config()->prefix, prims_); \ - } \ - int CompareVersion(size_t major, size_t minor, size_t patch) { \ - return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ - } \ - \ - private: \ - std::shared_ptr config_; \ - Map prims_; \ +#define CODEGEN_MEMBERS \ + public: \ + virtual const ffi::String DType(const DataType& dtype) { \ + return runtime::DLDataTypeToString(dtype); \ + } \ + \ + protected: \ + const std::shared_ptr config() { return config_; } \ + const ffi::Map prims() { return prims_; } \ + const ffi::String IdxNodeBase(const MSCJoint& node) { \ + return helper_.IdxNodeBase(node, config()->prefix, ""); \ + } \ + const ffi::String IdxInputBase(const MSCJoint& node, int idx = 0, bool process = true) { \ + return helper_.IdxInputBase(node, config()->prefix, idx, "", process && config()->use_tools); \ + } \ + const ffi::String IdxOutputBase(const MSCJoint& node, int idx = 0, bool mark_exit = false) { \ + return helper_.IdxOutputBase(node, config()->prefix, idx, "", \ + mark_exit && config()->use_tools); \ + } \ + const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, \ + bool process = true) { \ + return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ + } \ + const ffi::Array GetPrims(const MSCTensor& tensor) { \ + return CodeGenUtils::GetPrims(tensor, prims_); \ + } \ + const ffi::String Comment(const MSCJoint& node) { \ + return helper_.Comment(node, config()->prefix, prims_); \ + } \ + int CompareVersion(size_t major, size_t minor, size_t patch) { \ + return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ + } \ + \ + private: \ + std::shared_ptr config_; \ + ffi::Map prims_; \ HelperType helper_; /*! @@ -130,42 +133,42 @@ class CodeGenUtils { * \brief Get indexed node string. * \return The String. */ - TVM_DLL static const String IdxNode(const MSCJoint& node, const String& prefix, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::String& suffix = ""); /*! * \brief Get indexed output string. * \return The String. */ - TVM_DLL static const String IdxOutput(const MSCJoint& node, const String& prefix, int idx = 0, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxOutput(const MSCJoint& node, const ffi::String& prefix, + int idx = 0, const ffi::String& suffix = ""); /*! * \brief Get indexed input string. * \return The String. */ - TVM_DLL static const String IdxInput(const MSCJoint& node, const String& prefix, int idx = 0, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxInput(const MSCJoint& node, const ffi::String& prefix, + int idx = 0, const ffi::String& suffix = ""); /*! * \brief Get indexed weight string. * \return The String. */ - TVM_DLL static const String IdxWeight(const MSCJoint& node, const String& wtype, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxWeight(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix = ""); /*! * \brief Infer prims of tensor. * \return The prims. */ - TVM_DLL static const Array GetPrims(const MSCTensor& tensor, - const Map& prims); + TVM_DLL static const ffi::Array GetPrims( + const MSCTensor& tensor, const ffi::Map& prims); /*! * \brief Get comment of a node. * \return The String. */ - TVM_DLL static const String CommentNode(const MSCJoint& node, const String& prefix, - const Map& prims); + TVM_DLL static const ffi::String CommentNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::Map& prims); }; /*! @@ -173,16 +176,17 @@ class CodeGenUtils { */ class BaseCodeGenHelper { public: - const String GetSuffix(const MSCJoint& node, bool process = false) { + const ffi::String GetSuffix(const MSCJoint& node, bool process = false) { return process ? "c" + std::to_string(node->index) : ""; } - virtual const String IdxNodeBase(const MSCJoint& node, const String& prefix = "", - const String& suffix = "") { + virtual const ffi::String IdxNodeBase(const MSCJoint& node, const ffi::String& prefix = "", + const ffi::String& suffix = "") { return CodeGenUtils::IdxNode(node, prefix, suffix); } - virtual const String IdxInputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool process = false) { + virtual const ffi::String IdxInputBase(const MSCJoint& node, const ffi::String& prefix = "", + int idx = 0, const ffi::String& suffix = "", + bool process = false) { const auto& pair = node->ProducerAndIdxOf(idx); size_t output_size = pair.first->outputs.size(); if (process && (output_size > 1 || pair.first->optype == "tuple")) { @@ -190,8 +194,9 @@ class BaseCodeGenHelper { } return CodeGenUtils::IdxInput(node, prefix, idx, suffix + GetSuffix(node, process)); } - virtual const String IdxOutputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool mark_exit = false) { + virtual const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", + int idx = 0, const ffi::String& suffix = "", + bool mark_exit = false) { if (mark_exit) { if (node->outputs.size() > 1 || node->optype == "tuple") { return CodeGenUtils::IdxNode(node, prefix, suffix) + "_" + std::to_string(idx) + "_exit"; @@ -200,12 +205,13 @@ class BaseCodeGenHelper { } return CodeGenUtils::IdxOutput(node, prefix, idx, suffix); } - virtual const String IdxWeightBase(const MSCJoint& node, const String& wtype, - const String& suffix = "", bool process = false) { + virtual const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix = "", bool process = false) { return CodeGenUtils::IdxWeight(node, wtype, suffix + GetSuffix(node, process)); } - virtual const String Comment(const MSCJoint& node, const String& prefix = "", - const Map& prims = Map()) { + virtual const ffi::String Comment( + const MSCJoint& node, const ffi::String& prefix = "", + const ffi::Map& prims = ffi::Map()) { return CodeGenUtils::CommentNode(node, prefix, prims); } }; diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h index 260bd27ca35a..99988d689a95 100644 --- a/src/contrib/msc/core/codegen/cpp_codegen.h +++ b/src/contrib/msc/core/codegen/cpp_codegen.h @@ -69,9 +69,10 @@ class CppCodeGen : public BaseCodeGen { virtual void CodeGenCmake() = 0; /*! \brief Get sources*/ - virtual const Map GetSources(const std::string& print_options = "") { - Map sources; - auto add_source = [&print_options, &sources, this](const String& file) { + virtual const ffi::Map GetSources( + const std::string& print_options = "") { + ffi::Map sources; + auto add_source = [&print_options, &sources, this](const ffi::String& file) { CppPrinter printer(print_options); for (const auto& d : this->stack_.GetDocs()) { printer.Append(d); @@ -96,7 +97,7 @@ class CppCodeGen : public BaseCodeGen { protected: /*! \brief Describe the prim*/ - virtual const String DescribePrim(const MSCPrim& prim) { + virtual const ffi::String DescribePrim(const MSCPrim& prim) { // binary ops DESCRIBE_PRIM_BINARY("Min", "std::min", true) DESCRIBE_PRIM_BINARY("Max", "std::max", true) @@ -152,8 +153,8 @@ class CppCodeGen : public BaseCodeGen { } /*! \brief Get the tensor context for codegen_tensor*/ - virtual const Map GetTensorCtx(const MSCTensor& tensor) { - Map tensor_ctx; + virtual const ffi::Map GetTensorCtx(const MSCTensor& tensor) { + ffi::Map tensor_ctx; MSCJoint producer; if (this->graph()->weight_holders.count(tensor->name)) { producer = this->graph()->FindProducer(tensor); @@ -175,8 +176,8 @@ class CppCodeGen : public BaseCodeGen { } /*! \brief Get the step context for codegen_step*/ - virtual const Map GetStepCtx() { - Map step_ctx; + virtual const ffi::Map GetStepCtx() { + ffi::Map step_ctx; std::string version = ""; for (size_t i = 0; i < this->config()->version.size(); i++) { version += std::to_string(this->config()->version[i]) + diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h index af75f0e4233d..460818089f82 100644 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ b/src/contrib/msc/core/codegen/py_codegen.h @@ -70,8 +70,9 @@ class PyCodeGen : public BaseCodeGen { } /*! \brief Get sources*/ - virtual const Map GetSources(const std::string& print_options = "") { - Map sources; + virtual const ffi::Map GetSources( + const std::string& print_options = "") { + ffi::Map sources; PythonPrinter printer(print_options); CodeGenScript(); for (const auto& d : this->stack_.GetDocs()) { @@ -83,7 +84,7 @@ class PyCodeGen : public BaseCodeGen { protected: /*! \brief Describe the prim*/ - virtual const String DescribePrim(const MSCPrim& prim) { + virtual const ffi::String DescribePrim(const MSCPrim& prim) { // binary ops DESCRIBE_PRIM_BINARY("Min", "min", true) DESCRIBE_PRIM_BINARY("Max", "max", true) @@ -216,7 +217,7 @@ class PyCodeGen : public BaseCodeGen { virtual void CodeGenInference() = 0; /*! \brief Get tensor type of the framework*/ - virtual const String TensorType() const { return "np.ndarray"; } + virtual const ffi::String TensorType() const { return "np.ndarray"; } private: std::set graph_outputs_; diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index dff38aade5aa..6e69e66bca01 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -36,9 +36,10 @@ namespace tvm { namespace contrib { namespace msc { -MSCTensor::MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias, const Array& prims) { - ObjectPtr n = make_object(); +MSCTensor::MSCTensor(const ffi::String& name, const DataType& dtype, const ffi::String& layout, + const ffi::Array& shape, const ffi::String& alias, + const ffi::Array& prims) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->alias = std::move(alias); n->dtype = std::move(dtype); @@ -49,13 +50,13 @@ MSCTensor::MSCTensor(const String& name, const DataType& dtype, const String& la } MSCTensor::MSCTensor(const JsonMSCTensor& j_tensor) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_tensor); data_ = std::move(n); } MSCTensor::MSCTensor(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -107,23 +108,23 @@ const Integer MSCTensorNode::DimAt(int index) const { return shape[v_index]; } -const Integer MSCTensorNode::DimAt(const String& axis) const { +const Integer MSCTensorNode::DimAt(const ffi::String& axis) const { auto index = layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); return DimAt(index); } -const String MSCTensorNode::PrimAt(int index) const { +const ffi::String MSCTensorNode::PrimAt(int index) const { if (prims.size() == 0) { return ""; } return prims[CommonUtils::GetIndex(index, Ndim())]; } -const String MSCTensorNode::PrimAt(const String& axis) const { +const ffi::String MSCTensorNode::PrimAt(const ffi::String& axis) const { return PrimAt(layout.IndexOf(tvm::tir::LayoutAxis::Get(axis))); } -int32_t MSCTensorNode::LayoutOf(const String& axis) const { +int32_t MSCTensorNode::LayoutOf(const ffi::String& axis) const { return layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); } @@ -135,7 +136,7 @@ const Integer MSCTensorNode::GetSize() const { return size; } -const String MSCTensorNode::DTypeName() const { return runtime::DLDataTypeToString(dtype); } +const ffi::String MSCTensorNode::DTypeName() const { return runtime::DLDataTypeToString(dtype); } size_t BaseJointNode::AddChild(const BaseJoint& child) const { for (size_t i = 0; i < children.size(); i++) { @@ -157,9 +158,9 @@ const BaseJoint BaseJointNode::ChildAt(int index) const { return Downcast(children[v_index]); } -bool BaseJointNode::HasAttr(const String& key) const { return attrs.count(key); } +bool BaseJointNode::HasAttr(const ffi::String& key) const { return attrs.count(key); } -bool BaseJointNode::GetAttr(const String& key, std::string* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::string* val) const { if (attrs.count(key) && attrs[key].size() > 0) { *val = attrs[key]; return true; @@ -167,7 +168,7 @@ bool BaseJointNode::GetAttr(const String& key, std::string* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, int* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, int* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -184,7 +185,7 @@ bool BaseJointNode::GetAttr(const String& key, int* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, int64_t* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, int64_t* val) const { std::string val_str; if (GetAttr(key, &val_str)) { try { @@ -197,7 +198,7 @@ bool BaseJointNode::GetAttr(const String& key, int64_t* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, float* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, float* val) const { std::string val_str; if (GetAttr(key, &val_str)) { try { @@ -210,7 +211,7 @@ bool BaseJointNode::GetAttr(const String& key, float* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, bool* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, bool* val) const { int val_int; if (GetAttr(key, &val_int)) { *val = (val_int != 0); @@ -219,7 +220,7 @@ bool BaseJointNode::GetAttr(const String& key, bool* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -238,7 +239,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) co return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -257,7 +258,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { try { @@ -275,7 +276,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const } return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -294,7 +295,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -313,20 +314,22 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { return false; } -MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, const String& optype, - const Map& attrs, const Array& scope, +MSCJoint::MSCJoint(int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& optype, const ffi::Map& attrs, + const ffi::Array& scope, const std::vector>& inputs, - const Array& outputs, const Map& weights) { - ObjectPtr n = make_object(); + const ffi::Array& outputs, + const ffi::Map& weights) { + ObjectPtr n = ffi::make_object(); n->index = index; n->name = std::move(name); n->shared_ref = std::move(shared_ref); n->optype = std::move(optype); n->attrs = std::move(attrs); n->scope = std::move(scope); - Array parents; - Array> array_inputs; - Array added_parents; + ffi::Array parents; + ffi::Array> array_inputs; + ffi::Array added_parents; for (const auto& pair : inputs) { // const auto& parent=Downcast(pair.first); const auto& p_name = pair.first->name; @@ -342,7 +345,7 @@ MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, cons added_parents.push_back(p_name); p_idx = added_parents.size() - 1; } - Array input{Integer(p_idx), Integer(pair.second)}; + ffi::Array input{Integer(p_idx), Integer(pair.second)}; array_inputs.push_back(input); } n->parents = std::move(parents); @@ -352,14 +355,14 @@ MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, cons data_ = std::move(n); } -MSCJoint::MSCJoint(const JsonMSCJoint& j_joint, const Map& nodes) { - ObjectPtr n = make_object(); +MSCJoint::MSCJoint(const JsonMSCJoint& j_joint, const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(j_joint, nodes); data_ = std::move(n); } -MSCJoint::MSCJoint(const std::string& json_str, const Map& nodes) { - ObjectPtr n = make_object(); +MSCJoint::MSCJoint(const std::string& json_str, const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(json_str, nodes); data_ = std::move(n); } @@ -397,7 +400,8 @@ const JsonMSCJoint MSCJointNode::ToJson() const { return j_joint; } -void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map& nodes) { +void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, + const ffi::Map& nodes) { index = j_joint.index; name = j_joint.name; shared_ref = j_joint.shared_ref; @@ -413,7 +417,7 @@ void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map= 0) << "Can not find parent for " << in_name; - Array input{Integer(p_idx), Integer(std::stol(index_str))}; + ffi::Array input{Integer(p_idx), Integer(std::stol(index_str))}; inputs.push_back(input); } for (const auto& o : j_joint.outputs) { @@ -434,7 +438,8 @@ void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map& nodes) { +void MSCJointNode::FromJson(const std::string& json_str, + const ffi::Map& nodes) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JsonMSCJoint j_joint; @@ -449,8 +454,8 @@ const MSCTensor MSCJointNode::InputAt(int index) const { return ParentAt(p_idx->value)->OutputAt(out_idx->value); } -const Array MSCJointNode::GetInputs() const { - Array t_inputs; +const ffi::Array MSCJointNode::GetInputs() const { + ffi::Array t_inputs; for (size_t i = 0; i < inputs.size(); i++) { t_inputs.push_back(InputAt(i)); } @@ -462,15 +467,15 @@ const MSCTensor MSCJointNode::OutputAt(int index) const { return outputs[v_index]; } -const Array MSCJointNode::GetOutputs() const { - Array t_outputs; +const ffi::Array MSCJointNode::GetOutputs() const { + ffi::Array t_outputs; for (size_t i = 0; i < outputs.size(); i++) { t_outputs.push_back(OutputAt(i)); } return t_outputs; } -const MSCTensor MSCJointNode::WeightAt(const String& wtype) const { +const MSCTensor MSCJointNode::WeightAt(const ffi::String& wtype) const { ICHECK(weights.count(wtype)) << "Can not find " << wtype << " from weights"; return weights[wtype]; } @@ -490,7 +495,7 @@ const MSCJoint MSCJointNode::ProducerOf(int index) const { return pair.first; } -const MSCJoint MSCJointNode::ProducerOf(const String& input_name) const { +const MSCJoint MSCJointNode::ProducerOf(const ffi::String& input_name) const { const auto& pair = ProducerAndIdxOf(input_name); return pair.first; } @@ -505,7 +510,7 @@ const std::pair MSCJointNode::ProducerAndIdxOf(int index) cons return std::make_pair(ParentAt(p_idx->value), inputs[v_index][1]->value); } -const std::pair MSCJointNode::ProducerAndIdxOf(const String& name) const { +const std::pair MSCJointNode::ProducerAndIdxOf(const ffi::String& name) const { for (size_t i = 0; i < inputs.size(); i++) { if (InputAt(i)->name == name) { return ProducerAndIdxOf(i); @@ -518,9 +523,10 @@ const std::pair MSCJointNode::ProducerAndIdxOf(const MSCTensor return ProducerAndIdxOf(input->name); } -MSCPrim::MSCPrim(int index, const String& name, const String& optype, - const Array& parents, const Map& attrs) { - ObjectPtr n = make_object(); +MSCPrim::MSCPrim(int index, const ffi::String& name, const ffi::String& optype, + const ffi::Array& parents, + const ffi::Map& attrs) { + ObjectPtr n = ffi::make_object(); n->index = index; n->name = std::move(name); n->optype = std::move(optype); @@ -531,14 +537,14 @@ MSCPrim::MSCPrim(int index, const String& name, const String& optype, data_ = std::move(n); } -MSCPrim::MSCPrim(const JsonMSCPrim& j_prim, const Map& prims) { - ObjectPtr n = make_object(); +MSCPrim::MSCPrim(const JsonMSCPrim& j_prim, const ffi::Map& prims) { + ObjectPtr n = ffi::make_object(); n->FromJson(j_prim, prims); data_ = std::move(n); } -MSCPrim::MSCPrim(const std::string& json_str, const Map& prims) { - ObjectPtr n = make_object(); +MSCPrim::MSCPrim(const std::string& json_str, const ffi::Map& prims) { + ObjectPtr n = ffi::make_object(); n->FromJson(json_str, prims); data_ = std::move(n); } @@ -557,7 +563,8 @@ const JsonMSCPrim MSCPrimNode::ToJson() const { return j_prim; } -void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, const Map& prims) { +void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, + const ffi::Map& prims) { index = j_prim.index; name = j_prim.name; optype = j_prim.optype; @@ -570,7 +577,8 @@ void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, const Map& prims) { +void MSCPrimNode::FromJson(const std::string& json_str, + const ffi::Map& prims) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JsonMSCPrim j_prim; @@ -588,11 +596,12 @@ const MSCPrim MSCPrimNode::ChildAt(int index) const { return Downcast(children[v_index]); } -WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref, - const String& weight_type, const MSCTensor& weight, - const Array parents, const Map& attrs, - const Array& friends) { - ObjectPtr n = make_object(); +WeightJoint::WeightJoint(int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& weight_type, const MSCTensor& weight, + const ffi::Array parents, + const ffi::Map& attrs, + const ffi::Array& friends) { + ObjectPtr n = ffi::make_object(); n->index = index; n->name = std::move(name); n->shared_ref = std::move(shared_ref); @@ -606,14 +615,16 @@ WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref data_ = std::move(n); } -WeightJoint::WeightJoint(const JsonWeightJoint& j_joint, const Map& nodes) { - ObjectPtr n = make_object(); +WeightJoint::WeightJoint(const JsonWeightJoint& j_joint, + const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(j_joint, nodes); data_ = std::move(n); } -WeightJoint::WeightJoint(const std::string& json_str, const Map& nodes) { - ObjectPtr n = make_object(); +WeightJoint::WeightJoint(const std::string& json_str, + const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(json_str, nodes); data_ = std::move(n); } @@ -639,7 +650,7 @@ const JsonWeightJoint WeightJointNode::ToJson() const { } void WeightJointNode::FromJson(const JsonWeightJoint& j_joint, - const Map& nodes) { + const ffi::Map& nodes) { index = j_joint.index; name = j_joint.name; shared_ref = j_joint.shared_ref; @@ -654,7 +665,8 @@ void WeightJointNode::FromJson(const JsonWeightJoint& j_joint, } } -void WeightJointNode::FromJson(const std::string& json_str, const Map& nodes) { +void WeightJointNode::FromJson(const std::string& json_str, + const ffi::Map& nodes) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JsonWeightJoint j_joint; @@ -672,14 +684,14 @@ const WeightJoint WeightJointNode::ChildAt(int index) const { return Downcast(children[v_index]); } -const bool BaseGraphNode::HasNode(const String& name) const { +const bool BaseGraphNode::HasNode(const ffi::String& name) const { return nodes.count(name) ? true : false; } -MSCGraph::MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names, - const Array& prims) { - ObjectPtr n = make_object(); +MSCGraph::MSCGraph(const ffi::String& name, const ffi::Array& nodes, + const ffi::Array& input_names, + const ffi::Array& output_names, const ffi::Array& prims) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); for (const auto& node : nodes) { n->node_names.push_back(node->name); @@ -696,13 +708,13 @@ MSCGraph::MSCGraph(const String& name, const Array& nodes, } MSCGraph::MSCGraph(const JsonMSCGraph& j_graph) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_graph); data_ = std::move(n); } MSCGraph::MSCGraph(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -735,7 +747,7 @@ void MSCGraphNode::FromJson(const JsonMSCGraph& j_graph) { for (const auto& o : j_graph.outputs) { output_names.push_back(o); } - Map loaded_nodes; + ffi::Map loaded_nodes; for (const auto& n : j_graph.nodes) { const auto& node = MSCJoint(n, loaded_nodes); loaded_nodes.Set(node->name, node); @@ -745,7 +757,7 @@ void MSCGraphNode::FromJson(const JsonMSCGraph& j_graph) { node_names.push_back(node->name); nodes.Set(node->name, node); } - Map loaded_prims; + ffi::Map loaded_prims; for (const auto& n : j_graph.prims) { const auto& prim = MSCPrim(n, loaded_prims); loaded_prims.Set(prim->name, prim); @@ -766,13 +778,13 @@ void MSCGraphNode::FromJson(const std::string& json_str) { FromJson(j_graph); } -const String MSCGraphNode::ToPrototxt() const { +const ffi::String MSCGraphNode::ToPrototxt() const { PrototxtPrinter printer; - printer.Append(Map{{"name", name}}); + printer.Append(ffi::Map{{"name", name}}); for (const auto& n : node_names) { const auto& node = FindNode(n); // define layer - std::vector> layer; + std::vector> layer; layer.push_back(std::make_pair("name", node->name)); layer.push_back(std::make_pair("type", StringUtils::Replace(node->optype, ".", "_"))); layer.push_back(std::make_pair("top", node->name)); @@ -780,7 +792,7 @@ const String MSCGraphNode::ToPrototxt() const { layer.push_back(std::make_pair("bottom", Downcast(p)->name)); } // define layer param - Map param; + ffi::Map param; param.Set("idx", Integer(node->index)); for (size_t i = 0; i < node->inputs.size(); i++) { param.Set("input_" + std::to_string(i), node->InputAt(i)); @@ -796,17 +808,17 @@ const String MSCGraphNode::ToPrototxt() const { } layer.push_back(std::make_pair("layer_param", PrototxtPrinter::ToDictDoc(param))); // Append the layer Map - printer.Append(Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); + printer.Append(ffi::Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); } return printer.GetString(); } -const MSCJoint MSCGraphNode::FindNode(const String& name) const { +const MSCJoint MSCGraphNode::FindNode(const ffi::String& name) const { ICHECK(nodes.count(name)) << "Can not find node " << name; return Downcast(nodes[name]); } -const MSCPrim MSCGraphNode::FindPrim(const String& name) const { +const MSCPrim MSCGraphNode::FindPrim(const ffi::String& name) const { ICHECK(prims.count(name)) << "Can not find prim " << name; return prims[name]; } @@ -816,8 +828,8 @@ const MSCTensor MSCGraphNode::InputAt(int index) const { return FindTensor(input_names[v_index]); } -const Array MSCGraphNode::GetInputs() const { - Array t_inputs; +const ffi::Array MSCGraphNode::GetInputs() const { + ffi::Array t_inputs; for (size_t i = 0; i < input_names.size(); i++) { t_inputs.push_back(InputAt(i)); } @@ -829,25 +841,25 @@ const MSCTensor MSCGraphNode::OutputAt(int index) const { return FindTensor(output_names[v_index]); } -const Array MSCGraphNode::GetOutputs() const { - Array t_outputs; +const ffi::Array MSCGraphNode::GetOutputs() const { + ffi::Array t_outputs; for (size_t i = 0; i < output_names.size(); i++) { t_outputs.push_back(OutputAt(i)); } return t_outputs; } -const Array MSCGraphNode::GetEntries() const { - Array entries; +const ffi::Array MSCGraphNode::GetEntries() const { + ffi::Array entries; for (size_t i = 0; i < input_names.size(); i++) { entries.push_back(FindProducer(input_names[i])); } return entries; } -const Array MSCGraphNode::GetExits() const { - Array exits; - std::set setted_exits; +const ffi::Array MSCGraphNode::GetExits() const { + ffi::Array exits; + std::set setted_exits; for (size_t i = 0; i < output_names.size(); i++) { const auto& exit = FindProducer(output_names[i]); if (setted_exits.count(exit->name)) { @@ -859,18 +871,18 @@ const Array MSCGraphNode::GetExits() const { return exits; } -const bool MSCGraphNode::HasTensor(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const bool MSCGraphNode::HasTensor(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { return true; } - String host, index; + ffi::String host, index; std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); return nodes.count(host) > 0 ? true : false; } -const MSCTensor MSCGraphNode::FindTensor(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const MSCTensor MSCGraphNode::FindTensor(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { const auto& node = FindNode(weight_holders[tensor_name][0]); for (const auto& pair : node->weights) { @@ -884,8 +896,8 @@ const MSCTensor MSCGraphNode::FindTensor(const String& name) const { return pair.first->OutputAt(pair.second); } -const MSCJoint MSCGraphNode::FindProducer(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const MSCJoint MSCGraphNode::FindProducer(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { return FindNode(weight_holders[tensor_name][0]); } @@ -897,10 +909,10 @@ const MSCJoint MSCGraphNode::FindProducer(const MSCTensor& tensor) const { return FindProducer(tensor->name); } -const std::pair MSCGraphNode::FindProducerAndIdx(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const std::pair MSCGraphNode::FindProducerAndIdx(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; ICHECK(!weight_holders.count(tensor_name)) << "Weight " << name << " has no producer with index"; - String host, index; + ffi::String host, index; std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); if (index.size() == 0) { const auto& node = FindNode(host); @@ -914,9 +926,9 @@ const std::pair MSCGraphNode::FindProducerAndIdx(const MSCTens return FindProducerAndIdx(tensor->name); } -const Array MSCGraphNode::FindConsumers(const String& name) const { - Array consumers; - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const ffi::Array MSCGraphNode::FindConsumers(const ffi::String& name) const { + ffi::Array consumers; + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { for (const auto& h : weight_holders[tensor_name]) { consumers.push_back(FindNode(h)); @@ -930,13 +942,13 @@ const Array MSCGraphNode::FindConsumers(const String& name) const { return consumers; } -const Array MSCGraphNode::FindConsumers(const MSCTensor& tensor) const { +const ffi::Array MSCGraphNode::FindConsumers(const MSCTensor& tensor) const { return FindConsumers(tensor->name); } const std::vector> MSCGraphNode::FindConsumersAndIndices( - const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; ICHECK(!weight_holders.count(tensor_name)) << "Weight has no index"; std::vector> consumers; for (const auto& c : FindConsumers(name)) { @@ -987,11 +999,11 @@ void MSCGraphNode::AnalysisGraph() { for (const auto& pair : node->weights) { const auto& w_name = pair.second->name; if (weight_holders.count(w_name)) { - Array holders = weight_holders[w_name]; + ffi::Array holders = weight_holders[w_name]; holders.push_back(n); weight_holders.Set(w_name, holders); } else { - weight_holders.Set(w_name, Array({n})); + weight_holders.Set(w_name, ffi::Array({n})); if (pair.second->alias.size() > 0) { tensor_alias.Set(pair.second->alias, pair.second->name); } @@ -1000,28 +1012,30 @@ void MSCGraphNode::AnalysisGraph() { } } -WeightGraph::WeightGraph(const MSCGraph& graph, const Map>& main_wtypes, - const Map& relation_wtypes) { - ObjectPtr n = make_object(); +WeightGraph::WeightGraph(const MSCGraph& graph, + const ffi::Map>& main_wtypes, + const ffi::Map& relation_wtypes) { + ObjectPtr n = ffi::make_object(); n->name = graph->name + "_weights"; n->Build(graph, main_wtypes, relation_wtypes); data_ = std::move(n); } WeightGraph::WeightGraph(const JsonWeightGraph& j_graph) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_graph); data_ = std::move(n); } WeightGraph::WeightGraph(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } -void WeightGraphNode::Build(const MSCGraph& graph, const Map>& main_wtypes, - const Map& relation_wtypes) { +void WeightGraphNode::Build(const MSCGraph& graph, + const ffi::Map>& main_wtypes, + const ffi::Map& relation_wtypes) { auto sort_nodes = [&graph](const BaseJoint& node_a, const BaseJoint& node_b) { return graph->FindProducer(node_a->name)->index < graph->FindProducer(node_b->name)->index; }; @@ -1058,7 +1072,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Map parents_array; + ffi::Array parents_array; if (parents.size() > 1) { std::sort(parents.begin(), parents.end(), sort_nodes); } @@ -1089,7 +1103,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapoptype]) { if (node->weights.count(wtype)) { const auto& weight = node->WeightAt(wtype); - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); attrs.Set("weight_strategy", "main"); const auto& w_node = @@ -1104,7 +1118,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapweights) { if (!nodes.count(pair.second->name)) { - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); attrs.Set("weight_strategy", "follow"); const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, @@ -1116,7 +1130,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapoptype)) { const auto& tensor = node->OutputAt(0); - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); if (node->optype == "reshape") { // TODO(archermmt): check non-passby reshape @@ -1134,7 +1148,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapweights.size() > 0) { for (const auto& pair : node->weights) { if (!nodes.count(pair.second->name)) { - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); attrs.Set("weight_strategy", "follow"); const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, @@ -1151,7 +1165,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Map(nodes[name]); } @@ -1168,7 +1182,7 @@ const JsonWeightGraph WeightGraphNode::ToJson() const { void WeightGraphNode::FromJson(const JsonWeightGraph& j_graph) { name = j_graph.name; - Map loaded_nodes; + ffi::Map loaded_nodes; for (const auto& n : j_graph.nodes) { const auto& node = WeightJoint(n, loaded_nodes); loaded_nodes.Set(node->name, node); @@ -1196,13 +1210,13 @@ void WeightGraphNode::FromJson(const std::string& json_str) { FromJson(j_graph); } -const String WeightGraphNode::ToPrototxt() const { +const ffi::String WeightGraphNode::ToPrototxt() const { PrototxtPrinter printer; - printer.Append(Map{{"name", name}}); + printer.Append(ffi::Map{{"name", name}}); for (const auto& n : node_names) { const auto& node = FindNode(n); // define layer - std::vector> layer; + std::vector> layer; layer.push_back(std::make_pair("name", node->name)); layer.push_back(std::make_pair("type", node->weight_type)); layer.push_back(std::make_pair("top", node->name)); @@ -1210,7 +1224,7 @@ const String WeightGraphNode::ToPrototxt() const { layer.push_back(std::make_pair("bottom", Downcast(p)->name)); } // define layer param - Map param; + ffi::Map param; param.Set("idx", Integer(node->index)); param.Set("weight", node->weight); for (size_t i = 0; i < node->friends.size(); i++) { @@ -1221,14 +1235,15 @@ const String WeightGraphNode::ToPrototxt() const { } layer.push_back(std::make_pair("layer_param", PrototxtPrinter::ToDictDoc(param))); // Append the layer Map - printer.Append(Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); + printer.Append(ffi::Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); } return printer.GetString(); } -MSCGraph PruneWeights(const MSCGraph& graph, const Map& pruned_tensors) { - Array nodes; - std::unordered_map> inputs_map; +MSCGraph PruneWeights(const MSCGraph& graph, + const ffi::Map& pruned_tensors) { + ffi::Array nodes; + std::unordered_map> inputs_map; for (const auto& name : graph->node_names) { const auto& node = graph->FindNode(name); // define inputs @@ -1238,20 +1253,20 @@ MSCGraph PruneWeights(const MSCGraph& graph, const Map& prune inputs.push_back(inputs_map[input->name]); } // define outputs - Array outputs; + ffi::Array outputs; for (const auto& out : node->outputs) { const auto& output = pruned_tensors.count(out->name) ? pruned_tensors[out->name] : out; outputs.push_back(output); } // define weights - Map weights; + ffi::Map weights; for (const auto& pair : node->weights) { const auto& weight = pruned_tensors.count(pair.second->name) ? pruned_tensors[pair.second->name] : pair.second; weights.Set(pair.first, weight); } // define attributes - Map attrs = node->attrs; + ffi::Map attrs = node->attrs; if (node->optype == "reshape" && attrs.count("shape") && pruned_tensors.count(node->OutputAt(0)->name)) { const auto& new_shape = pruned_tensors[node->OutputAt(0)->name]->shape; @@ -1268,7 +1283,7 @@ MSCGraph PruneWeights(const MSCGraph& graph, const Map& prune Downcast(p)->AddChild(new_node); } } - Array prims; + ffi::Array prims; for (const auto& name : graph->prim_names) { prims.push_back(graph->FindPrim(name)); } @@ -1421,7 +1436,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { MSCTensorNode::RegisterReflection(); BaseJointNode::RegisterReflection(); MSCJointNode::RegisterReflection(); @@ -1430,19 +1445,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ BaseGraphNode::RegisterReflection(); MSCGraphNode::RegisterReflection(); WeightGraphNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.MSCTensor", - [](const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias, - const Array& prims) -> MSCTensor { + [](const ffi::String& name, const DataType& dtype, const ffi::String& layout, + const ffi::Array& shape, const ffi::String& alias, + const ffi::Array& prims) -> MSCTensor { return MSCTensor(name, dtype, layout, shape, alias, prims); }) .def("msc.core.MSCTensorToJson", - [](const MSCTensor& tensor) -> String { + [](const MSCTensor& tensor) -> ffi::String { const auto& tensor_json = tensor->ToJson(); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -1450,12 +1465,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ return os.str(); }) .def("msc.core.MSCTensorFromJson", - [](const String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }) + [](const ffi::String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }) .def("msc.core.MSCJoint", - [](Integer index, const String& name, const String& shared_ref, const String& optype, - const Map& attrs, const Array& scope, - const Array& parents, const Array out_indices, - const Array& outputs, const Map& weights) -> MSCJoint { + [](Integer index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& optype, const ffi::Map& attrs, + const ffi::Array& scope, const ffi::Array& parents, + const ffi::Array out_indices, const ffi::Array& outputs, + const ffi::Map& weights) -> MSCJoint { std::vector> inputs; for (size_t i = 0; i < parents.size(); i++) { inputs.push_back(std::make_pair(parents[i], out_indices[i]->value)); @@ -1464,19 +1480,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ weights); }) .def("msc.core.MSCPrim", - [](Integer index, const String& name, const String& optype, - const Map& attrs, const Array& parents) -> MSCPrim { - Array b_parents; + [](Integer index, const ffi::String& name, const ffi::String& optype, + const ffi::Map& attrs, + const ffi::Array& parents) -> MSCPrim { + ffi::Array b_parents; for (const auto& p : parents) { b_parents.push_back(p); } return MSCPrim(index->value, name, optype, b_parents, attrs); }) .def("msc.core.WeightJoint", - [](Integer index, const String& name, const String& shared_ref, - const String& weight_type, const MSCTensor& weight, const Array parents, - const Map& attrs, const Array& friends) -> WeightJoint { - Array b_parents, b_friends; + [](Integer index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& weight_type, const MSCTensor& weight, + const ffi::Array parents, const ffi::Map& attrs, + const ffi::Array& friends) -> WeightJoint { + ffi::Array b_parents, b_friends; for (const auto& p : parents) { b_parents.push_back(p); } @@ -1486,55 +1504,60 @@ TVM_FFI_STATIC_INIT_BLOCK({ return WeightJoint(index->value, name, shared_ref, weight_type, weight, b_parents, attrs, b_friends); }) - .def("msc.core.WeightJointSetAttr", [](const WeightJoint& node, const String& key, - const String& value) { node->attrs.Set(key, value); }) + .def("msc.core.WeightJointSetAttr", + [](const WeightJoint& node, const ffi::String& key, const ffi::String& value) { + node->attrs.Set(key, value); + }) .def("msc.core.MSCGraph", - [](const String& name, const Array& nodes, const Array& input_names, - const Array& output_names, const Array& prims) -> MSCGraph { + [](const ffi::String& name, const ffi::Array& nodes, + const ffi::Array& input_names, + const ffi::Array& output_names, + const ffi::Array& prims) -> MSCGraph { return MSCGraph(name, nodes, input_names, output_names, prims); }) .def("msc.core.WeightGraph", - [](const MSCGraph& graph, const Map>& main_wtypes, - const Map& relation_wtypes) -> WeightGraph { + [](const MSCGraph& graph, + const ffi::Map>& main_wtypes, + const ffi::Map& relation_wtypes) -> WeightGraph { return WeightGraph(graph, main_wtypes, relation_wtypes); }); -}); +} // MSC Graph APIS -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.MSCGraphHasNode", - [](const MSCGraph& graph, const String& name) -> Bool { + [](const MSCGraph& graph, const ffi::String& name) -> Bool { return Bool(graph->HasNode(name)); }) .def("msc.core.MSCGraphFindNode", - [](const MSCGraph& graph, const String& name) -> MSCJoint { + [](const MSCGraph& graph, const ffi::String& name) -> MSCJoint { return graph->FindNode(name); }) .def("msc.core.MSCGraphFindPrim", - [](const MSCGraph& graph, const String& name) -> MSCPrim { + [](const MSCGraph& graph, const ffi::String& name) -> MSCPrim { return graph->FindPrim(name); }) .def("msc.core.MSCGraphHasTensor", - [](const MSCGraph& graph, const String& name) -> Bool { + [](const MSCGraph& graph, const ffi::String& name) -> Bool { return Bool(graph->HasTensor(name)); }) .def("msc.core.MSCGraphFindTensor", - [](const MSCGraph& graph, const String& name) -> MSCTensor { + [](const MSCGraph& graph, const ffi::String& name) -> MSCTensor { return graph->FindTensor(name); }) .def("msc.core.MSCGraphSetTensorAlias", - [](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) { + [](const MSCGraph& graph, const MSCTensor& tensor, const ffi::String& alias) { tensor->alias = alias; graph->tensor_alias.Set(alias, tensor->name); }) .def("msc.core.MSCGraphFindProducer", - [](const MSCGraph& graph, const String& name) -> MSCJoint { + [](const MSCGraph& graph, const ffi::String& name) -> MSCJoint { return graph->FindProducer(name); }) .def("msc.core.MSCGraphFindConsumers", - [](const MSCGraph& graph, const String& name) -> Array { + [](const MSCGraph& graph, const ffi::String& name) -> ffi::Array { return graph->FindConsumers(name); }) .def("msc.core.MSCGraphInputAt", @@ -1542,11 +1565,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("msc.core.MSCGraphOutputAt", [](const MSCGraph& graph, int index) -> MSCTensor { return graph->OutputAt(index); }) .def("msc.core.MSCGraphGetInputs", - [](const MSCGraph& graph) -> Array { return graph->GetInputs(); }) + [](const MSCGraph& graph) -> ffi::Array { return graph->GetInputs(); }) .def("msc.core.MSCGraphGetOutputs", - [](const MSCGraph& graph) -> Array { return graph->GetOutputs(); }) + [](const MSCGraph& graph) -> ffi::Array { return graph->GetOutputs(); }) .def("msc.core.MSCGraphToJson", - [](const MSCGraph& graph) -> String { + [](const MSCGraph& graph) -> ffi::String { const auto& graph_json = graph->ToJson(); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -1554,25 +1577,25 @@ TVM_FFI_STATIC_INIT_BLOCK({ return os.str(); }) .def("msc.core.MSCGraphFromJson", - [](const String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }) + [](const ffi::String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }) .def("msc.core.MSCGraphToPrototxt", - [](const MSCGraph& graph) -> String { return graph->ToPrototxt(); }); -}); + [](const MSCGraph& graph) -> ffi::String { return graph->ToPrototxt(); }); +} // Weight Graph APIS -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.WeightGraphHasNode", - [](const WeightGraph& graph, const String& name) -> Bool { + [](const WeightGraph& graph, const ffi::String& name) -> Bool { return Bool(graph->HasNode(name)); }) .def("msc.core.WeightGraphFindNode", - [](const WeightGraph& graph, const String& name) -> WeightJoint { + [](const WeightGraph& graph, const ffi::String& name) -> WeightJoint { return graph->FindNode(name); }) .def("msc.core.WeightGraphToJson", - [](const WeightGraph& graph) -> String { + [](const WeightGraph& graph) -> ffi::String { const auto& graph_json = graph->ToJson(); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -1580,48 +1603,50 @@ TVM_FFI_STATIC_INIT_BLOCK({ return os.str(); }) .def("msc.core.WeightGraphFromJson", - [](const String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }) + [](const ffi::String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }) .def("msc.core.WeightGraphToPrototxt", - [](const WeightGraph& graph) -> String { return graph->ToPrototxt(); }) + [](const WeightGraph& graph) -> ffi::String { return graph->ToPrototxt(); }) .def("msc.core.MSCJointInputAt", [](const MSCJoint& node, int index) -> MSCTensor { return node->InputAt(index); }) .def("msc.core.MSCJointOutputAt", [](const MSCJoint& node, int index) -> MSCTensor { return node->OutputAt(index); }) .def("msc.core.MSCJointWeightAt", - [](const MSCJoint& node, const String& wtype) -> MSCTensor { + [](const MSCJoint& node, const ffi::String& wtype) -> MSCTensor { return node->WeightAt(wtype); }) .def("msc.core.MSCJointGetInputs", - [](const MSCJoint& node) -> Array { return node->GetInputs(); }) + [](const MSCJoint& node) -> ffi::Array { return node->GetInputs(); }) .def("msc.core.MSCJointGetOutputs", - [](const MSCJoint& node) -> Array { return node->GetOutputs(); }) + [](const MSCJoint& node) -> ffi::Array { return node->GetOutputs(); }) .def("msc.core.MSCJointGetWeights", - [](const MSCJoint& node) -> Map { return node->weights; }) + [](const MSCJoint& node) -> ffi::Map { return node->weights; }) .def("msc.core.MSCJointHasAttr", - [](const MSCJoint& node, const String& key) -> Bool { return Bool(node->HasAttr(key)); }) + [](const MSCJoint& node, const ffi::String& key) -> Bool { + return Bool(node->HasAttr(key)); + }) .def("msc.core.MSCJointGetAttrs", - [](const MSCJoint& node) -> Map { return node->attrs; }) + [](const MSCJoint& node) -> ffi::Map { return node->attrs; }) .def("msc.core.WeightJointHasAttr", - [](const WeightJoint& node, const String& key) -> Bool { + [](const WeightJoint& node, const ffi::String& key) -> Bool { return Bool(node->HasAttr(key)); }) - .def("msc.core.WeightJointGetAttrs", - [](const WeightJoint& node) -> Map { return node->attrs; }) + .def( + "msc.core.WeightJointGetAttrs", + [](const WeightJoint& node) -> ffi::Map { return node->attrs; }) .def("msc.core.MSCTensorDTypeName", - [](const MSCTensor& tensor) -> String { return tensor->DTypeName(); }) + [](const MSCTensor& tensor) -> ffi::String { return tensor->DTypeName(); }) .def("msc.core.MSCTensorDimAt", - [](const MSCTensor& tensor, const String& axis) -> Integer { + [](const MSCTensor& tensor, const ffi::String& axis) -> Integer { return tensor->DimAt(axis); }) .def("msc.core.MSCTensorGetSize", [](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }) .def("msc.core.MSCTensorSetAlias", - [](const MSCTensor& tensor, const String& alias) { tensor->alias = alias; }) + [](const MSCTensor& tensor, const ffi::String& alias) { tensor->alias = alias; }) .def("msc.core.PruneWeights", - [](const MSCGraph& graph, const Map& pruned_tensors) -> MSCGraph { - return PruneWeights(graph, pruned_tensors); - }); -}); + [](const MSCGraph& graph, const ffi::Map& pruned_tensors) + -> MSCGraph { return PruneWeights(graph, pruned_tensors); }); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index a8587a2e5ed8..d795bea7fa1b 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -342,17 +342,17 @@ struct JsonWeightGraph { class MSCTensorNode : public Object { public: /*! \brief The name of tensor. */ - String name; + ffi::String name; /*! \brief The alias of tensor, can be changed. */ - mutable String alias; + mutable ffi::String alias; /*! \brief The data type of tensor. */ DataType dtype; /*! \brief The layout of tensor. */ tvm::tir::Layout layout; /*! \brief The shape of tensor. */ - Array shape; + ffi::Array shape; /*! \brief The prims of tensor. */ - Array prims; + ffi::Array prims; /*! \brief Export tensor to json. */ const JsonMSCTensor ToJson() const; /*! \brief Load tensor from json struct. */ @@ -364,17 +364,17 @@ class MSCTensorNode : public Object { /*! \brief Get dim at given index. */ const Integer DimAt(int index) const; /*! \brief Get dim at given axis. */ - const Integer DimAt(const String& axis) const; + const Integer DimAt(const ffi::String& axis) const; /*! \brief Get prim at given index. */ - const String PrimAt(int index) const; + const ffi::String PrimAt(int index) const; /*! \brief Get prim at given axis. */ - const String PrimAt(const String& axis) const; + const ffi::String PrimAt(const ffi::String& axis) const; /*! \brief Get layout index of given axis. */ - int32_t LayoutOf(const String& axis) const; + int32_t LayoutOf(const ffi::String& axis) const; /*! \brief Get size of the tensor. */ const Integer GetSize() const; /*! \brief Get name of the dtype. */ - const String DTypeName() const; + const ffi::String DTypeName() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -388,8 +388,7 @@ class MSCTensorNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.MSCTensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(MSCTensorNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCTensor", MSCTensorNode, Object); }; /*! @@ -407,9 +406,9 @@ class MSCTensor : public ObjectRef { * \param alias The alias of the tensor. * \param prims The prims of the tensor shape. */ - TVM_DLL MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias = "", - const Array& prims = Array()); + TVM_DLL MSCTensor(const ffi::String& name, const DataType& dtype, const ffi::String& layout, + const ffi::Array& shape, const ffi::String& alias = "", + const ffi::Array& prims = ffi::Array()); /*! * \brief The json constructor. @@ -423,7 +422,7 @@ class MSCTensor : public ObjectRef { */ TVM_DLL MSCTensor(const std::string& json_str); - TVM_DEFINE_OBJECT_REF_METHODS(MSCTensor, ObjectRef, MSCTensorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCTensor, ObjectRef, MSCTensorNode); }; /*! @@ -435,15 +434,15 @@ class BaseJointNode : public Object { /*! \brief The index of node, can be changed. */ mutable int index; /*! \brief The name of node. */ - String name; + ffi::String name; /*! \brief The shared_ref of node, can be changed. */ - String shared_ref; + ffi::String shared_ref; /*! \brief The attributes of node. */ - mutable Map attrs; + mutable ffi::Map attrs; /*! \brief The parents of node. */ - Array parents; + ffi::Array parents; /*! \brief The children of node. */ - mutable Array children; + mutable ffi::Array children; /*! \brief Add child to the node. */ size_t AddChild(const BaseJoint& child) const; /*! \brief Get parent from the node. */ @@ -451,27 +450,27 @@ class BaseJointNode : public Object { /*! \brief Get child from the node. */ const BaseJoint ChildAt(int index) const; /*! \brief Check if has the attribute. */ - bool HasAttr(const String& key) const; + bool HasAttr(const ffi::String& key) const; /*! \brief Get the attribute by type. */ - bool GetAttr(const String& key, std::string* val) const; - bool GetAttr(const String& key, int* val) const; - bool GetAttr(const String& key, int64_t* val) const; - bool GetAttr(const String& key, float* val) const; - bool GetAttr(const String& key, bool* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::string* val) const; + bool GetAttr(const ffi::String& key, int* val) const; + bool GetAttr(const ffi::String& key, int64_t* val) const; + bool GetAttr(const ffi::String& key, float* val) const; + bool GetAttr(const ffi::String& key, bool* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; /*! \brief Check and get the attribute by type. */ template - const T GetTypeAttr(const String& key) const { + const T GetTypeAttr(const ffi::String& key) const { T val; ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; return val; } template - const std::vector GetTypeArrayAttr(const String& key) const { + const std::vector GetTypeArrayAttr(const ffi::String& key) const { std::vector val; ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; return val; @@ -489,9 +488,8 @@ class BaseJointNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.BaseJoint"; static constexpr const uint32_t _type_child_slots = 2; - TVM_DECLARE_BASE_OBJECT_INFO(BaseJointNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("msc.core.BaseJoint", BaseJointNode, Object); }; /*! @@ -500,7 +498,7 @@ class BaseJointNode : public Object { */ class BaseJoint : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(BaseJoint, ObjectRef, BaseJointNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseJoint, ObjectRef, BaseJointNode); }; /*! @@ -510,42 +508,42 @@ class MSCJoint; class MSCJointNode : public BaseJointNode { public: /*! \brief The op type of node. */ - String optype; + ffi::String optype; /*! \brief The scope of node. */ - Array scope; + ffi::Array scope; /*! \brief The inputs of node, can be changed. */ - Array> inputs; + ffi::Array> inputs; /*! \brief The outputs of node. */ - Array outputs; + ffi::Array outputs; /*! \brief The weights of node. */ - Map weights; + ffi::Map weights; /*! \brief Export node to json. */ const JsonMSCJoint ToJson() const; /*! \brief Load node from json struct. */ - void FromJson(const JsonMSCJoint& j_joint, const Map& nodes); + void FromJson(const JsonMSCJoint& j_joint, const ffi::Map& nodes); /*! \brief Load node from json string. */ - void FromJson(const std::string& json_str, const Map& nodes); + void FromJson(const std::string& json_str, const ffi::Map& nodes); /*! \brief Get input from the node. */ const MSCTensor InputAt(int index) const; /*! \brief Get inputs from the node. */ - const Array GetInputs() const; + const ffi::Array GetInputs() const; /*! \brief Get output from the node. */ const MSCTensor OutputAt(int index) const; /*! \brief Get outputs from the node. */ - const Array GetOutputs() const; + const ffi::Array GetOutputs() const; /*! \brief Get weight from the node. */ - const MSCTensor WeightAt(const String& wtype) const; + const MSCTensor WeightAt(const ffi::String& wtype) const; /*! \brief Get parent from the node. */ const MSCJoint ParentAt(int index) const; /*! \brief Get child from the node. */ const MSCJoint ChildAt(int index) const; /*! \brief Get Producer of the input. */ const MSCJoint ProducerOf(int index) const; - const MSCJoint ProducerOf(const String& input_name) const; + const MSCJoint ProducerOf(const ffi::String& input_name) const; const MSCJoint ProducerOf(const MSCTensor& input) const; /*! \brief Get Producer and out index of the input. */ const std::pair ProducerAndIdxOf(int index) const; - const std::pair ProducerAndIdxOf(const String& name) const; + const std::pair ProducerAndIdxOf(const ffi::String& name) const; const std::pair ProducerAndIdxOf(const MSCTensor& input) const; static void RegisterReflection() { @@ -559,8 +557,7 @@ class MSCJointNode : public BaseJointNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.MSCJoint"; - TVM_DECLARE_FINAL_OBJECT_INFO(MSCJointNode, BaseJointNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCJoint", MSCJointNode, BaseJointNode); }; /*! @@ -580,28 +577,30 @@ class MSCJoint : public BaseJoint { * \param outputs The outputs of the node. * \param weights The weights of the node. */ - TVM_DLL MSCJoint(int index, const String& name, const String& shared_ref, const String& optype, - const Map& attrs, const Array& scope, + TVM_DLL MSCJoint(int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& optype, const ffi::Map& attrs, + const ffi::Array& scope, const std::vector>& inputs, - const Array& outputs, const Map& weights); + const ffi::Array& outputs, + const ffi::Map& weights); /*! * \brief The json constructor. * \param j_joint The json describe of the node. */ - TVM_DLL MSCJoint(const JsonMSCJoint& j_joint, const Map& nodes); + TVM_DLL MSCJoint(const JsonMSCJoint& j_joint, const ffi::Map& nodes); /*! * \brief The json constructor. * \param json_str The json describe of the node. */ - TVM_DLL MSCJoint(const std::string& json_str, const Map& nodes); + TVM_DLL MSCJoint(const std::string& json_str, const ffi::Map& nodes); /*! \brief Clone the node. */ TVM_DLL static const MSCJoint Clone(const MSCJoint& node, const std::vector>& inputs); - TVM_DEFINE_OBJECT_REF_METHODS(MSCJoint, BaseJoint, MSCJointNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCJoint, BaseJoint, MSCJointNode); }; /*! @@ -611,13 +610,13 @@ class MSCPrim; class MSCPrimNode : public BaseJointNode { public: /*! \brief The op of prim. */ - String optype; + ffi::String optype; /*! \brief Export prim to json. */ const JsonMSCPrim ToJson() const; /*! \brief Load prim from json struct. */ - void FromJson(const JsonMSCPrim& j_prim, const Map& prims); + void FromJson(const JsonMSCPrim& j_prim, const ffi::Map& prims); /*! \brief Load prim from json string. */ - void FromJson(const std::string& json_str, const Map& prims); + void FromJson(const std::string& json_str, const ffi::Map& prims); /*! \brief Get parent from the prim. */ const MSCPrim ParentAt(int index) const; /*! \brief Get child from the prim. */ @@ -627,9 +626,7 @@ class MSCPrimNode : public BaseJointNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("optype", &MSCPrimNode::optype); } - - static constexpr const char* _type_key = "msc.core.MSCPrim"; - TVM_DECLARE_FINAL_OBJECT_INFO(MSCPrimNode, BaseJointNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCPrim", MSCPrimNode, BaseJointNode); }; /*! @@ -646,23 +643,24 @@ class MSCPrim : public BaseJoint { * \param parents The parents of the prim. * \param attrs The attributes of the prim. */ - TVM_DLL MSCPrim(int index, const String& name, const String& optype, - const Array& parents, - const Map& attrs = Map()); + TVM_DLL MSCPrim( + int index, const ffi::String& name, const ffi::String& optype, + const ffi::Array& parents, + const ffi::Map& attrs = ffi::Map()); /*! * \brief The json constructor. * \param j_prim The json describe of the prim. */ - TVM_DLL MSCPrim(const JsonMSCPrim& j_prim, const Map& prims); + TVM_DLL MSCPrim(const JsonMSCPrim& j_prim, const ffi::Map& prims); /*! * \brief The json constructor. * \param json_str The json describe of the prim. */ - TVM_DLL MSCPrim(const std::string& json_str, const Map& prims); + TVM_DLL MSCPrim(const std::string& json_str, const ffi::Map& prims); - TVM_DEFINE_OBJECT_REF_METHODS(MSCPrim, BaseJoint, MSCPrimNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCPrim, BaseJoint, MSCPrimNode); }; /*! @@ -672,17 +670,17 @@ class WeightJoint; class WeightJointNode : public BaseJointNode { public: /*! \brief The weight reference of weight node. */ - String weight_type; + ffi::String weight_type; /*! \brief The weight of weight node. */ MSCTensor weight; /*! \brief The friends of weight node. */ - mutable Array friends; + mutable ffi::Array friends; /*! \brief Export node to json. */ const JsonWeightJoint ToJson() const; /*! \brief Load node from json struct. */ - void FromJson(const JsonWeightJoint& j_joint, const Map& nodes); + void FromJson(const JsonWeightJoint& j_joint, const ffi::Map& nodes); /*! \brief Load node from json string. */ - void FromJson(const std::string& json_str, const Map& nodes); + void FromJson(const std::string& json_str, const ffi::Map& nodes); /*! \brief Get parent from the node. */ const WeightJoint ParentAt(int index) const; /*! \brief Get child from the node. */ @@ -695,9 +693,7 @@ class WeightJointNode : public BaseJointNode { .def_ro("weight", &WeightJointNode::weight) .def_ro("friends", &WeightJointNode::friends); } - - static constexpr const char* _type_key = "msc.core.WeightJoint"; - TVM_DECLARE_FINAL_OBJECT_INFO(WeightJointNode, BaseJointNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.WeightJoint", WeightJointNode, BaseJointNode); }; /*! @@ -717,25 +713,26 @@ class WeightJoint : public BaseJoint { * \param attrs The attributes of the node. * \param friends The friends of the node. */ - TVM_DLL WeightJoint(int index, const String& name, const String& shared_ref, - const String& weight_type, const MSCTensor& weight, - const Array parents, - const Map& attrs = Map(), - const Array& friends = Array()); + TVM_DLL WeightJoint( + int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& weight_type, const MSCTensor& weight, const ffi::Array parents, + const ffi::Map& attrs = ffi::Map(), + const ffi::Array& friends = ffi::Array()); /*! * \brief The json constructor. * \param j_joint The json describe of the node. */ - TVM_DLL WeightJoint(const JsonWeightJoint& j_joint, const Map& nodes); + TVM_DLL WeightJoint(const JsonWeightJoint& j_joint, + const ffi::Map& nodes); /*! * \brief The json constructor. * \param json_str The json describe of the node. */ - TVM_DLL WeightJoint(const std::string& json_str, const Map& nodes); + TVM_DLL WeightJoint(const std::string& json_str, const ffi::Map& nodes); - TVM_DEFINE_OBJECT_REF_METHODS(WeightJoint, BaseJoint, WeightJointNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(WeightJoint, BaseJoint, WeightJointNode); }; /*! @@ -744,13 +741,13 @@ class WeightJoint : public BaseJoint { class BaseGraphNode : public Object { public: /*! \brief The name of graph. */ - String name; + ffi::String name; /*! \brief The node names in graph, can be changed. */ - Array node_names; + ffi::Array node_names; /*! \brief The nodes in graph, can be changed. */ - Map nodes; + ffi::Map nodes; /*! \brief Check if node in the graph. */ - const bool HasNode(const String& name) const; + const bool HasNode(const ffi::String& name) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -761,10 +758,9 @@ class BaseGraphNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.BaseGraph"; static constexpr const uint32_t _type_child_slots = 2; - TVM_DECLARE_BASE_OBJECT_INFO(BaseGraphNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("msc.core.BaseGraph", BaseGraphNode, Object); }; /*! @@ -773,7 +769,7 @@ class BaseGraphNode : public Object { */ class BaseGraph : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(BaseGraph, ObjectRef, BaseGraphNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseGraph, ObjectRef, BaseGraphNode); }; /*! @@ -783,17 +779,17 @@ class MSCGraph; class MSCGraphNode : public BaseGraphNode { public: /*! \brief The shape node names in graph. */ - Array prim_names; + ffi::Array prim_names; /*! \brief The shape nodes in graph. */ - Map prims; + ffi::Map prims; /*! \brief The input names of graph. */ - Array input_names; + ffi::Array input_names; /*! \brief The output names of graph. */ - Array output_names; + ffi::Array output_names; /*! \brief The tensor alias in graph, get by AnalysisGraph. */ - mutable Map tensor_alias; + mutable ffi::Map tensor_alias; /*! \brief The weights in graph, get by AnalysisGraph. */ - Map> weight_holders; + ffi::Map> weight_holders; /*! \brief Export graph to json. */ const JsonMSCGraph ToJson() const; /*! \brief Load graph from json. */ @@ -801,41 +797,42 @@ class MSCGraphNode : public BaseGraphNode { /*! \brief Load graph from json string. */ void FromJson(const std::string& json_str); /*! \brief Export graph to prototxt. */ - const String ToPrototxt() const; + const ffi::String ToPrototxt() const; /*! \brief Find node in graph. */ - const MSCJoint FindNode(const String& name) const; + const MSCJoint FindNode(const ffi::String& name) const; /*! \brief Find prim in graph. */ - const MSCPrim FindPrim(const String& name) const; + const MSCPrim FindPrim(const ffi::String& name) const; /*! \brief Get input from the graph. */ const MSCTensor InputAt(int index) const; /*! \brief Get inputs from the graph. */ - const Array GetInputs() const; + const ffi::Array GetInputs() const; /*! \brief Get output from the graph. */ const MSCTensor OutputAt(int index) const; /*! \brief Get outputs from the graph. */ - const Array GetOutputs() const; + const ffi::Array GetOutputs() const; /*! \brief Get entries from the graph. */ - const Array GetEntries() const; + const ffi::Array GetEntries() const; /*! \brief Get exits from the graph. */ - const Array GetExits() const; + const ffi::Array GetExits() const; /*! \brief Check if tensor in the graph. */ - const bool HasTensor(const String& name) const; + const bool HasTensor(const ffi::String& name) const; /*! \brief Find tensor from the graph. */ - const MSCTensor FindTensor(const String& name) const; + const MSCTensor FindTensor(const ffi::String& name) const; /*! \brief Find producer of tensor from the graph. */ - const MSCJoint FindProducer(const String& name) const; + const MSCJoint FindProducer(const ffi::String& name) const; /*! \brief Find producer of tensor from the graph. */ const MSCJoint FindProducer(const MSCTensor& tensor) const; /*! \brief Find producer and output index of tensor from the graph. */ - const std::pair FindProducerAndIdx(const String& name) const; + const std::pair FindProducerAndIdx(const ffi::String& name) const; /*! \brief Find producer and output index of tensor from the graph. */ const std::pair FindProducerAndIdx(const MSCTensor& tensor) const; /*! \brief Find consumers of tensor from the graph. */ - const Array FindConsumers(const String& name) const; + const ffi::Array FindConsumers(const ffi::String& name) const; /*! \brief Find consumers of tensor from the graph. */ - const Array FindConsumers(const MSCTensor& tensor) const; + const ffi::Array FindConsumers(const MSCTensor& tensor) const; /*! \brief Find consumers and input indices of tensor from the graph. */ - const std::vector> FindConsumersAndIndices(const String& name) const; + const std::vector> FindConsumersAndIndices( + const ffi::String& name) const; /*! \brief Find consumers and input indices of tensor from the graph. */ const std::vector> FindConsumersAndIndices( const MSCTensor& tensor) const; @@ -851,9 +848,7 @@ class MSCGraphNode : public BaseGraphNode { .def_ro("output_names", &MSCGraphNode::output_names) .def_ro("weight_holders", &MSCGraphNode::weight_holders); } - - static constexpr const char* _type_key = "msc.core.MSCGraph"; - TVM_DECLARE_FINAL_OBJECT_INFO(MSCGraphNode, BaseGraphNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCGraph", MSCGraphNode, BaseGraphNode); }; /*! @@ -870,9 +865,10 @@ class MSCGraph : public BaseGraph { * \param output_names The output names of the graph. * \param prims The prims in the graph. */ - TVM_DLL MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names, - const Array& prims = Array()); + TVM_DLL MSCGraph(const ffi::String& name, const ffi::Array& nodes, + const ffi::Array& input_names, + const ffi::Array& output_names, + const ffi::Array& prims = ffi::Array()); /*! * \brief The json constructor. @@ -886,7 +882,7 @@ class MSCGraph : public BaseGraph { */ TVM_DLL MSCGraph(const std::string& json_str); - TVM_DEFINE_OBJECT_REF_METHODS(MSCGraph, BaseGraph, MSCGraphNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCGraph, BaseGraph, MSCGraphNode); }; /*! @@ -895,10 +891,11 @@ class MSCGraph : public BaseGraph { class WeightGraphNode : public BaseGraphNode { public: /*! \brief build from MSCGraph. */ - void Build(const MSCGraph& graph, const Map>& prunable_types, - const Map& relation_types); + void Build(const MSCGraph& graph, + const ffi::Map>& prunable_types, + const ffi::Map& relation_types); /*! \brief Find node in graph. */ - const WeightJoint FindNode(const String& name) const; + const WeightJoint FindNode(const ffi::String& name) const; /*! \brief Export graph to json. */ const JsonWeightGraph ToJson() const; /*! \brief Load graph from json. */ @@ -906,15 +903,13 @@ class WeightGraphNode : public BaseGraphNode { /*! \brief Load graph from json string. */ void FromJson(const std::string& json_str); /*! \brief Export graph to prototxt. */ - const String ToPrototxt() const; + const ffi::String ToPrototxt() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "msc.core.WeightGraph"; - TVM_DECLARE_FINAL_OBJECT_INFO(WeightGraphNode, BaseGraphNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.WeightGraph", WeightGraphNode, BaseGraphNode); }; /*! @@ -929,8 +924,9 @@ class WeightGraph : public BaseGraph { * \param prunable_types The prunable types. * \param relation_types The relation types. */ - TVM_DLL WeightGraph(const MSCGraph& graph, const Map>& prunable_types, - const Map& relation_types); + TVM_DLL WeightGraph(const MSCGraph& graph, + const ffi::Map>& prunable_types, + const ffi::Map& relation_types); /*! * \brief The json constructor. @@ -944,10 +940,11 @@ class WeightGraph : public BaseGraph { */ TVM_DLL WeightGraph(const std::string& json_str); - TVM_DEFINE_OBJECT_REF_METHODS(WeightGraph, BaseGraph, WeightGraphNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(WeightGraph, BaseGraph, WeightGraphNode); }; -MSCGraph PruneWeights(const MSCGraph& graph, const Map& pruned_tensors); +MSCGraph PruneWeights(const MSCGraph& graph, + const ffi::Map& pruned_tensors); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 7f84978105ea..df7a1520ebfa 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -34,7 +34,7 @@ namespace msc { using namespace tvm::relax; -const std::string GetScalarStr(const runtime::NDArray& data, int float_precision) { +const std::string GetScalarStr(const runtime::Tensor& data, int float_precision) { std::string scalar_str; if (data->dtype.code == kDLFloat) { const float val = ExprUtils::GetScalar(data); @@ -50,13 +50,13 @@ const std::string GetScalarStr(const runtime::NDArray& data, int float_precision void FuncAttrGetter::VisitExpr_(const CallNode* op) { if (op->attrs.defined()) { - Map attrs; + ffi::Map attrs; AttrGetter getter(&attrs); getter(op->attrs); for (const auto& pair : attrs) { if (attrs_.count(pair.first)) { int cnt = 1; - String rep_key = pair.first; + ffi::String rep_key = pair.first; while (attrs_.count(rep_key + "_" + std::to_string(cnt))) { cnt++; } @@ -87,7 +87,7 @@ void FuncValueGetter::VisitExpr_(const CallNode* op) { } void FuncParamsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } void FuncParamsFinder::VisitExpr_(const CallNode* call_node) { @@ -112,7 +112,7 @@ void FuncParamsFinder::VisitExpr_(const CallNode* call_node) { } void LayoutsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } void LayoutsFinder::VisitExpr_(const CallNode* call_node) { @@ -126,7 +126,8 @@ void LayoutsFinder::VisitExpr_(const CallNode* call_node) { func = local_funcs_[call_node->op]; } if (func.defined()) { - const auto& layouts_opt = func->GetAttr>(msc_attr::kInputLayouts); + const auto& layouts_opt = + func->GetAttr>(msc_attr::kInputLayouts); if (layouts_opt.defined()) { for (const auto& pair : layouts_opt.value()) { layouts_.Set(pair.first, pair.second); @@ -137,8 +138,8 @@ void LayoutsFinder::VisitExpr_(const CallNode* call_node) { const MSCGraph GraphBuilder::Build(const Function& func) { // Add input nodes and record inputs; - Array input_names, output_names; - std::set added_inputs; + ffi::Array input_names, output_names; + std::set added_inputs; // Add prims for (const auto& p : func->params) { if (!p->struct_info_.defined()) { @@ -148,11 +149,11 @@ const MSCGraph GraphBuilder::Build(const Function& func) { const auto& shape = ExprUtils::GetShape(p, false); for (size_t i = 0; i < shape.size(); i++) { if (shape[i]->IsInstance()) { - Map attrs; + ffi::Map attrs; attrs.Set("producer", p->name_hint()); attrs.Set("out_idx", "0"); attrs.Set("dim", std::to_string(i)); - MatchOrCreatePrim(shape[i], "shape", Array(), attrs); + MatchOrCreatePrim(shape[i], "shape", ffi::Array(), attrs); } } } else { @@ -169,7 +170,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } if (func_params_.count(p) && func_params_[p]->IsInstance()) { const auto& tuple = Downcast(func_params_[p]); - Array tuple_names; + ffi::Array tuple_names; for (const auto& f : tuple->fields) { if (expr_tensor_map_.count(f)) { LOG_INFO << "Replica tuple input " << f; @@ -200,8 +201,8 @@ const MSCGraph GraphBuilder::Build(const Function& func) { << "Can not find seqexpr body " << func->body->body; output_names = expr_tensor_map_[func->body->body]; // remove const nodes as weights - Array valid_nodes; - std::set ignore_inputs; + ffi::Array valid_nodes; + std::set ignore_inputs; for (const auto& n : nodes_) { if (weights_.count(n->name) || ignore_nodes_.count(n->name)) { for (const auto& o : n->outputs) { @@ -218,7 +219,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } } // remove uselese inputs - Array valid_inputs; + ffi::Array valid_inputs; for (const auto& i : input_names) { if (!ignore_inputs.count(i)) { valid_inputs.push_back(i); @@ -255,12 +256,12 @@ const MSCGraph GraphBuilder::Build(const Function& func) { return graph; } -const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& binding_var, - const String& name) { +const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional& binding_var, + const ffi::String& name) { // Get optype, node_name and layout - String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); - String optype = "unknown"; - String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + ffi::String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); + ffi::String optype = "unknown"; + ffi::String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); if (func_params_.count(expr) && func_params_[expr]->IsInstance()) { node_name = SpanUtils::GetAttr(func_params_[expr]->span, msc_attr::kName); optype = "constant"; @@ -318,11 +319,12 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin const auto& plugin = IsPlugin(optype) ? GetPlugin(optype) : Plugin(); // Extract normal attributes - Map attrs; + ffi::Map attrs; if (plugin.defined()) { const auto& op = Downcast(expr)->op; if (target_funcs_.count(op)) { - const auto& opattrs_opt = target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); + const auto& opattrs_opt = + target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); if (opattrs_opt.defined()) { const auto& opattrs = opattrs_opt.value(); ICHECK_EQ(opattrs.size(), plugin->attrs.size()) @@ -341,7 +343,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } else if (const auto* call_node = expr.as()) { if (const auto* v_node = call_node->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - const auto& name_opt = func->GetAttr(relax::attr::kComposite); + const auto& name_opt = func->GetAttr(relax::attr::kComposite); if (name_opt.has_value()) { attrs = FuncAttrGetter().GetAttrs(func); } @@ -365,10 +367,10 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Extract attributes from arguments - Array input_types; + ffi::Array input_types; if (!plugin.defined() && expr->IsInstance()) { const auto& call = Downcast(expr); - Array values; + ffi::Array values; if (call->op->IsInstance()) { ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op; values = FuncValueGetter().GetValues(target_funcs_[call->op]); @@ -396,8 +398,8 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Build inputs and weights - Array input_names; - Map node_weights; + ffi::Array input_names; + ffi::Map node_weights; if (plugin.defined()) { const auto& call = Downcast(expr); if (call->args.size() == 1) { @@ -419,7 +421,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin continue; } const auto& arg = call_node->args[i]; - Array arg_names; + ffi::Array arg_names; if (expr_tensor_map_.count(arg)) { arg_names = expr_tensor_map_[arg]; } else if (input_types[i] == "input" && arg->IsInstance()) { @@ -431,7 +433,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } } } - String weight_name; + ffi::String weight_name; if (input_types[i] != "input" && arg->IsInstance()) { weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName); } else if (input_types[i] != "input" && func_params_.count(arg) && @@ -448,12 +450,12 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin const auto& ref = producer->OutputAt(pair.second); MSCTensor weight; if (input_types[i] == "bias") { - weight = MSCTensor(weight_name, ref->dtype, "O", Array{ref->GetSize()}); + weight = MSCTensor(weight_name, ref->dtype, "O", ffi::Array{ref->GetSize()}); } else if (input_types[i] == "weight" && (optype == "msc.linear" || optype == "msc.linear_bias")) { if (ref->layout.name() == "IO") { - String valid_layout = ref->layout[1].name() + ref->layout[0].name(); - const auto& valid_shape = Array({ref->shape[1], ref->shape[0]}); + ffi::String valid_layout = ref->layout[1].name() + ref->layout[0].name(); + const auto& valid_shape = ffi::Array({ref->shape[1], ref->shape[0]}); weight = MSCTensor(weight_name, ref->dtype, valid_layout, valid_shape); } else { weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); @@ -512,13 +514,13 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Build output tensor - auto build_output = [this](const StructInfo& sinfo, const String& node_name, - const String& layout) { + auto build_output = [this](const StructInfo& sinfo, const ffi::String& node_name, + const ffi::String& layout) { ICHECK(sinfo->IsInstance()) << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); const auto& t_info = Downcast(sinfo); const auto& shape = ArrayUtils::Cast(ExprUtils::GetShape(t_info)); - Array prims; + ffi::Array prims; bool has_prims = false; if (shape.size() > 0) { for (const auto& s : t_info->GetShape().value()) { @@ -537,15 +539,15 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin }; // Gather outputs - Array outputs; + ffi::Array outputs; const auto& sinfo = GetStructInfo(expr); - Array layouts = StringUtils::Split(layout, ","); + ffi::Array layouts = StringUtils::Split(layout, ","); size_t num_output = 1; if (const auto* tuple_sinfo = sinfo.as()) { num_output = tuple_sinfo->fields.size(); } if (layouts.size() == 0) { - layouts = Array(num_output, ""); + layouts = ffi::Array(num_output, ""); } ICHECK_EQ(layouts.size(), num_output) << "Layouts " << layouts << " msimatch with output size " << num_output; @@ -553,7 +555,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin const auto& t_name = node_name + ":" + std::to_string(0); outputs.push_back(build_output(sinfo, t_name, layouts[0])); } else if (const auto* s_sinfo = sinfo.as()) { - Array shape{s_sinfo->ndim}; + ffi::Array shape{s_sinfo->ndim}; const auto& t_name = node_name + ":" + std::to_string(0); const auto& dtype = DataType(ffi::StringToDLDataType("int32")); outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape)); @@ -568,14 +570,14 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Build node - Array scope; + ffi::Array scope; if (optype != "input" && optype != "constant") { scope = StringUtils::Split(scope_name_, "."); } const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef); const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, attrs, scope, inputs, outputs, node_weights); - Array output_names; + ffi::Array output_names; for (size_t i = 0; i < outputs.size(); i++) { output_names.push_back(outputs[i]->name); tensor_input_map_[outputs[i]->name] = std::make_pair(node, i); @@ -587,11 +589,11 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } void GraphBuilder::VisitBindingBlock(const BindingBlock& block) { - String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); + ffi::String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); if (block_name.size() == 0) { block_name = "block"; } - const String& prefix = StringUtils::Join(block_stack_, "."); + const ffi::String& prefix = StringUtils::Join(block_stack_, "."); if (setted_blocks_.count(prefix + "." + block_name)) { int cnt = 1; while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { @@ -638,15 +640,15 @@ const MSCPrim GraphBuilder::AddPrim(const PrimExpr& prim) { // scalar if (prim->IsInstance()) { - Map attrs; + ffi::Map attrs; attrs.Set("value", StringUtils::ToString(prim)); - return MatchOrCreatePrim(prim, "Int", Array(), attrs); + return MatchOrCreatePrim(prim, "Int", ffi::Array(), attrs); } // call if (const auto* c_node = prim.as()) { - String optype; - Array parents; + ffi::String optype; + ffi::Array parents; if (const auto* op_node = c_node->op.as()) { optype = StringUtils::Replace(op_node->name, "tir.", ""); } else { @@ -660,9 +662,9 @@ const MSCPrim GraphBuilder::AddPrim(const PrimExpr& prim) { return MatchOrCreatePrim(prim); } -const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String& optype, - const Array& parents, - const Map& attrs) { +const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const ffi::String& optype, + const ffi::Array& parents, + const ffi::Map& attrs) { if (prim_map_.count(prim)) { return prim_map_[prim]; } @@ -692,7 +694,7 @@ const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String prim_map_.Set(prim, p); return p; } - String name; + ffi::String name; if (const auto* v_node = prim.as()) { name = v_node->name_hint; } else { @@ -705,26 +707,26 @@ const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String } void GraphBuilder::VisitExpr_(const ConstantNode* op) { - if (!expr_tensor_map_.count(GetRef(op))) { - AddNode(GetRef(op)); + if (!expr_tensor_map_.count(ffi::GetRef(op))) { + AddNode(ffi::GetRef(op)); } } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) { ExprVisitor::VisitBinding_(binding, call_node); - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; try { - AddNode(GetRef(call_node), binding->var, name); + AddNode(ffi::GetRef(call_node), binding->var, name); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to add node from " << binding->var << " : " << binding->value << ", reason: " << err.what(); @@ -734,49 +736,50 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const VarNode* val) { ExprVisitor::VisitBinding_(binding, val); - const auto& output = GetRef(val); + const auto& output = ffi::GetRef(val); ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) { ExprVisitor::VisitBinding_(binding, val); - const auto& output = GetRef(val); + const auto& output = ffi::GetRef(val); ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - const auto& name_opt = val->GetAttr(relax::attr::kComposite); + const auto& name_opt = val->GetAttr(relax::attr::kComposite); ICHECK(name_opt.has_value()) << "Unexpected target func without composite"; ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target)) << "Target should be given for target function"; - target_funcs_.Set(binding->var, GetRef(val)); + target_funcs_.Set(binding->var, ffi::GetRef(val)); } -const std::tuple GraphBuilder::ParseFunc(const Function& func) { - String node_name, optype, layout; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); +const std::tuple GraphBuilder::ParseFunc( + const Function& func) { + ffi::String node_name, optype, layout; + const auto& name_opt = func->GetAttr(msc_attr::kUnique); // get node_name if (name_opt.has_value()) { node_name = name_opt.value(); } // get optype - const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); - const auto& composite_opt = func->GetAttr(relax::attr::kComposite); + const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + const auto& composite_opt = func->GetAttr(relax::attr::kComposite); if (codegen_opt.has_value()) { optype = codegen_opt.value(); } else if (optype_opt.has_value()) { @@ -788,7 +791,7 @@ const std::tuple GraphBuilder::ParseFunc(const Function& } } // get layout - const auto& layout_opt = func->GetAttr(msc_attr::kLayout); + const auto& layout_opt = func->GetAttr(msc_attr::kLayout); if (layout_opt.has_value()) { layout = layout_opt.value(); } @@ -802,14 +805,14 @@ void GraphBuilder::VisitPrimExpr(const PrimExpr& prim) { } } -Array GraphBuilder::GetPluginInputs(const Expr& expr) { +ffi::Array GraphBuilder::GetPluginInputs(const Expr& expr) { ICHECK(expr->IsInstance()) << "plugin expr should be call"; const auto& call = Downcast(expr); ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; return Downcast(call->args[1])->fields; } -Map WeightsExtractor::GetWeights(const Function& func) { +ffi::Map WeightsExtractor::GetWeights(const Function& func) { VisitExpr(func); return weights_; } @@ -817,13 +820,13 @@ Map WeightsExtractor::GetWeights(const Function& func) { void WeightsExtractor::VisitExpr_(const ConstantNode* op) { const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); - const auto& sinfo = GetStructInfo(GetRef(op)); + const auto& sinfo = GetStructInfo(ffi::GetRef(op)); ICHECK(sinfo->IsInstance()) << "Constant StrcutInfo should be TensorStructInfo"; const auto& t_info = Downcast(sinfo); const auto& opt_shape = t_info->GetShape(); const auto& shape = - opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : Array(); + opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : ffi::Array(); const auto& weight = MSCTensor(name, t_info->dtype, layout, shape); weights_.Set(weight, op->data); } @@ -836,24 +839,26 @@ void WeightsExtractor::VisitExpr_(const CallNode* op) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.BuildFromRelax", - [](const IRModule& module, const String& entry_name, const String& options) -> MSCGraph { + [](const IRModule& module, const ffi::String& entry_name, + const ffi::String& options) -> MSCGraph { auto builder = GraphBuilder(module, entry_name, options); const auto& func_name = builder.config().byoc_entry.size() > 0 - ? String(builder.config().byoc_entry) + ? ffi::String(builder.config().byoc_entry) : entry_name; const auto& func = Downcast(module->Lookup(func_name)); return builder.Build(func); }) - .def("msc.core.GetRelaxWeights", - [](const IRModule& module, const String& entry_name) -> Map { - const auto& func = Downcast(module->Lookup(entry_name)); - return WeightsExtractor(module).GetWeights(func); - }); -}); + .def( + "msc.core.GetRelaxWeights", + [](const IRModule& module, const ffi::String& entry_name) -> ffi::Map { + const auto& func = Downcast(module->Lookup(entry_name)); + return WeightsExtractor(module).GetWeights(func); + }); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 401c452d95cb..22a4929fe12f 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -50,7 +50,7 @@ namespace msc { using namespace tvm::relax; using Expr = tvm::RelaxExpr; -using tvm::runtime::NDArray; +using tvm::runtime::Tensor; /*! * \brief Config for building MSCGraph. @@ -110,10 +110,10 @@ struct MSCRBuildConfig { class AttrGetter { public: /*! - * \brief Get the attributes as Map + * \brief Get the attributes as ffi::Map * \param attrs the attributes. */ - explicit AttrGetter(Map* attrs) : attrs_(attrs) {} + explicit AttrGetter(ffi::Map* attrs) : attrs_(attrs) {} void operator()(const Attrs& attrs) { if (const auto* dict_attrs = attrs.as()) { @@ -125,14 +125,14 @@ class AttrGetter { if (attrs_tinfo->metadata != nullptr) { tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs); - this->VisitAny(String(field_info->name), field_value); + this->VisitAny(ffi::String(field_info->name), field_value); }); } } } private: - void VisitAny(String key, Any value) { + void VisitAny(ffi::String key, Any value) { switch (value.type_index()) { case kTVMFFINone: { attrs_->Set(key, ""); @@ -156,7 +156,7 @@ class AttrGetter { } case kTVMFFISmallStr: case kTVMFFIStr: { - attrs_->Set(key, value.cast()); + attrs_->Set(key, value.cast()); break; } default: { @@ -171,13 +171,13 @@ class AttrGetter { } private: - Map* attrs_; + ffi::Map* attrs_; }; class FuncAttrGetter : public ExprVisitor { public: - /*! \brief Get the attributes as Map*/ - Map GetAttrs(const Expr& expr) { + /*! \brief Get the attributes as ffi::Map*/ + ffi::Map GetAttrs(const Expr& expr) { VisitExpr(expr); return attrs_; } @@ -187,13 +187,13 @@ class FuncAttrGetter : public ExprVisitor { void VisitExpr_(const TupleGetItemNode* op) final; private: - Map attrs_; + ffi::Map attrs_; }; class FuncValueGetter : public ExprVisitor { public: - /*! \brief Get the attributes from prim value as Map*/ - Array GetValues(const Expr& expr) { + /*! \brief Get the attributes from prim value as ffi::Map*/ + ffi::Array GetValues(const Expr& expr) { VisitExpr(expr); return values_; } @@ -201,7 +201,7 @@ class FuncValueGetter : public ExprVisitor { void VisitExpr_(const CallNode* op) final; private: - Array values_; + ffi::Array values_; }; class FuncParamsFinder : public ExprVisitor { @@ -215,7 +215,7 @@ class FuncParamsFinder : public ExprVisitor { } /*! \brief Find the func params and bind with arguments*/ - Map FindParams(const Expr& expr) { + ffi::Map FindParams(const Expr& expr) { VisitExpr(expr); return params_; } @@ -226,8 +226,8 @@ class FuncParamsFinder : public ExprVisitor { private: IRModule ref_module_; - Map params_; - Map local_funcs_; + ffi::Map params_; + ffi::Map local_funcs_; }; class LayoutsFinder : public ExprVisitor { @@ -239,7 +239,7 @@ class LayoutsFinder : public ExprVisitor { explicit LayoutsFinder(const IRModule& ref_module) : ExprVisitor() { ref_module_ = ref_module; } /*! \brief Find the layouts form attrs*/ - Map FindLayouts(const Expr& expr) { + ffi::Map FindLayouts(const Expr& expr) { VisitExpr(expr); return layouts_; } @@ -250,8 +250,8 @@ class LayoutsFinder : public ExprVisitor { private: IRModule ref_module_; - Map layouts_; - Map local_funcs_; + ffi::Map layouts_; + ffi::Map local_funcs_; }; class GraphBuilder : public ExprVisitor { @@ -262,7 +262,7 @@ class GraphBuilder : public ExprVisitor { * \param name the name of the graph. * \param options the options of build the graph. */ - explicit GraphBuilder(const IRModule& ref_module, const String& name, + explicit GraphBuilder(const IRModule& ref_module, const ffi::String& name, const std::string& options = "") : ExprVisitor() { ref_module_ = ref_module; @@ -271,7 +271,7 @@ class GraphBuilder : public ExprVisitor { dmlc::JSONReader reader(&is); reader.Read(&config_); } - name_ = config_.graph_name.size() > 0 ? String(config_.graph_name) : name; + name_ = config_.graph_name.size() > 0 ? ffi::String(config_.graph_name) : name; if (config_.byoc_entry.size() > 0) { func_params_ = FuncParamsFinder(ref_module).FindParams(ref_module->Lookup(name)); } @@ -285,15 +285,16 @@ class GraphBuilder : public ExprVisitor { const MSCRBuildConfig config() { return config_; } /*! \brief Create and add MSCJoint from expr*/ - const MSCJoint AddNode(const Expr& expr, const Optional& binding_var = std::nullopt, - const String& name = ""); + const MSCJoint AddNode(const Expr& expr, const ffi::Optional& binding_var = std::nullopt, + const ffi::String& name = ""); /*! \brief Create and add MSCPrim from prim*/ const MSCPrim AddPrim(const PrimExpr& prim); - const MSCPrim MatchOrCreatePrim(const PrimExpr& prim, const String& op = "", - const Array& parents = Array(), - const Map& attrs = Map()); + const MSCPrim MatchOrCreatePrim( + const PrimExpr& prim, const ffi::String& op = "", + const ffi::Array& parents = ffi::Array(), + const ffi::Map& attrs = ffi::Map()); void VisitBindingBlock(const BindingBlock& block) final; @@ -319,30 +320,30 @@ class GraphBuilder : public ExprVisitor { private: /*! \brief Get the node_name, optype, layout for func*/ - const std::tuple ParseFunc(const Function& func); + const std::tuple ParseFunc(const Function& func); /*! \brief Get the plugin inputs*/ - Array GetPluginInputs(const Expr& expr); + ffi::Array GetPluginInputs(const Expr& expr); - String name_; + ffi::String name_; IRModule ref_module_; MSCRBuildConfig config_; - Map layouts_; - Array nodes_; - Map weights_; - Map> expr_tensor_map_; - std::unordered_map> tensor_input_map_; - std::set ignore_nodes_; + ffi::Map layouts_; + ffi::Array nodes_; + ffi::Map weights_; + ffi::Map> expr_tensor_map_; + std::unordered_map> tensor_input_map_; + std::set ignore_nodes_; // scope name - String scope_name_; - std::set setted_blocks_; - Array block_stack_; + ffi::String scope_name_; + std::set setted_blocks_; + ffi::Array block_stack_; // BYOC maps - Map target_funcs_; - Map func_params_; + ffi::Map target_funcs_; + ffi::Map func_params_; // prims - Array prims_; - Map prim_map_; + ffi::Array prims_; + ffi::Map prim_map_; }; class WeightsExtractor : public ExprVisitor { @@ -358,15 +359,15 @@ class WeightsExtractor : public ExprVisitor { } /*! \brief Visit the constant and save weights */ - Map GetWeights(const Function& func); + ffi::Map GetWeights(const Function& func); void VisitExpr_(const ConstantNode* op) final; void VisitExpr_(const CallNode* op) final; private: - Map weights_; - Map local_funcs_; + ffi::Map weights_; + ffi::Map local_funcs_; IRModule ref_module_; }; diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc index 659cb29628e7..1ff3a8dc8dcd 100644 --- a/src/contrib/msc/core/ir/plugin.cc +++ b/src/contrib/msc/core/ir/plugin.cc @@ -35,9 +35,9 @@ namespace tvm { namespace contrib { namespace msc { -PluginAttr::PluginAttr(const String& name, const String& type, const String& default_value, - const String& describe) { - ObjectPtr n = make_object(); +PluginAttr::PluginAttr(const ffi::String& name, const ffi::String& type, + const ffi::String& default_value, const ffi::String& describe) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->type = std::move(type); n->default_value = std::move(default_value); @@ -46,13 +46,13 @@ PluginAttr::PluginAttr(const String& name, const String& type, const String& def } PluginAttr::PluginAttr(const JsonPluginAttr& j_attr) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_attr); data_ = std::move(n); } PluginAttr::PluginAttr(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -81,9 +81,9 @@ void PluginAttrNode::FromJson(const std::string& json_str) { FromJson(j_attr); } -PluginTensor::PluginTensor(const String& name, const String& dtype, const Integer& ndim, - const String& device, const String& describe) { - ObjectPtr n = make_object(); +PluginTensor::PluginTensor(const ffi::String& name, const ffi::String& dtype, const Integer& ndim, + const ffi::String& device, const ffi::String& describe) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->dtype = std::move(dtype); n->ndim = std::move(ndim); @@ -93,13 +93,13 @@ PluginTensor::PluginTensor(const String& name, const String& dtype, const Intege } PluginTensor::PluginTensor(const JsonPluginTensor& j_tensor) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_tensor); data_ = std::move(n); } PluginTensor::PluginTensor(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -130,9 +130,10 @@ void PluginTensorNode::FromJson(const std::string& json_str) { FromJson(j_tensor); } -PluginExtern::PluginExtern(const String& name, const String& header, const String& source, - const String& lib, const String& describe) { - ObjectPtr n = make_object(); +PluginExtern::PluginExtern(const ffi::String& name, const ffi::String& header, + const ffi::String& source, const ffi::String& lib, + const ffi::String& describe) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->header = std::move(header); n->source = std::move(source); @@ -142,13 +143,13 @@ PluginExtern::PluginExtern(const String& name, const String& header, const Strin } PluginExtern::PluginExtern(const JsonPluginExtern& j_extern) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_extern); data_ = std::move(n); } PluginExtern::PluginExtern(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -179,13 +180,13 @@ void PluginExternNode::FromJson(const std::string& json_str) { FromJson(j_extern); } -Plugin::Plugin(const String& name, const String& version, const String& describe, - const Array& attrs, const Array& inputs, - const Array& outputs, const Array& buffers, - const Map& externs, - const Map>& support_dtypes, - const Map& options) { - ObjectPtr n = make_object(); +Plugin::Plugin(const ffi::String& name, const ffi::String& version, const ffi::String& describe, + const ffi::Array& attrs, const ffi::Array& inputs, + const ffi::Array& outputs, const ffi::Array& buffers, + const ffi::Map& externs, + const ffi::Map>& support_dtypes, + const ffi::Map& options) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->version = std::move(version); n->describe = std::move(describe); @@ -200,13 +201,13 @@ Plugin::Plugin(const String& name, const String& version, const String& describe } Plugin::Plugin(const JsonPlugin& j_plugin) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_plugin); data_ = std::move(n); } Plugin::Plugin(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -264,7 +265,7 @@ void PluginNode::FromJson(const JsonPlugin& j_plugin) { externs.Set(pair.first, PluginExtern(pair.second)); } for (const auto& pair : j_plugin.support_dtypes) { - Array dtypes; + ffi::Array dtypes; for (const auto& d : pair.second) { dtypes.push_back(d); } @@ -301,30 +302,32 @@ int PluginNode::FindDeviceRefIdx(const PluginTensor& tensor) const { return -1; } -const Array ListPluginNames() { return PluginRegistry::Global()->ListAllNames(); } +const ffi::Array ListPluginNames() { return PluginRegistry::Global()->ListAllNames(); } -const Plugin GetPlugin(const String& name) { return PluginRegistry::Global()->Get(name); } +const Plugin GetPlugin(const ffi::String& name) { return PluginRegistry::Global()->Get(name); } -bool IsPlugin(const String& name) { return PluginRegistry::Global()->Registered(name); } +bool IsPlugin(const ffi::String& name) { return PluginRegistry::Global()->Registered(name); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PluginAttrNode::RegisterReflection(); PluginTensorNode::RegisterReflection(); PluginExternNode::RegisterReflection(); PluginNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.RegisterPlugin", - [](const String& name, const String& json_str) { + [](const ffi::String& name, const ffi::String& json_str) { PluginRegistry::Global()->Register(name, json_str); }) - .def("msc.core.ListPluginNames", []() -> Array { return ListPluginNames(); }) - .def("msc.core.GetPlugin", [](const String& name) -> Plugin { return GetPlugin(name); }) - .def("msc.core.IsPlugin", [](const String& name) -> Bool { return Bool(IsPlugin(name)); }); -}); + .def("msc.core.ListPluginNames", + []() -> ffi::Array { return ListPluginNames(); }) + .def("msc.core.GetPlugin", [](const ffi::String& name) -> Plugin { return GetPlugin(name); }) + .def("msc.core.IsPlugin", + [](const ffi::String& name) -> Bool { return Bool(IsPlugin(name)); }); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h index f0a5dc9937b8..eaf3167dcf4e 100644 --- a/src/contrib/msc/core/ir/plugin.h +++ b/src/contrib/msc/core/ir/plugin.h @@ -254,13 +254,13 @@ struct JsonPlugin { class PluginAttrNode : public Object { public: /*! \brief The name of attribute. */ - String name; + ffi::String name; /*! \brief The type of attribute. */ - String type; + ffi::String type; /*! \brief The default_value of attribute. */ - String default_value; + ffi::String default_value; /*! \brief The describe of attribute. */ - String describe; + ffi::String describe; /*! \brief Export attribute to json. */ const JsonPluginAttr ToJson() const; @@ -279,8 +279,7 @@ class PluginAttrNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.PluginAttr"; - TVM_DECLARE_FINAL_OBJECT_INFO(PluginAttrNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginAttr", PluginAttrNode, Object); }; /*! @@ -296,8 +295,8 @@ class PluginAttr : public ObjectRef { * \param default_value The default_value of the attribute. * \param describe The describe of the attribute. */ - TVM_DLL PluginAttr(const String& name, const String& type, const String& default_value, - const String& describe); + TVM_DLL PluginAttr(const ffi::String& name, const ffi::String& type, + const ffi::String& default_value, const ffi::String& describe); /*! * \brief The json constructor. @@ -311,7 +310,7 @@ class PluginAttr : public ObjectRef { */ TVM_DLL PluginAttr(const std::string& json_str); - TVM_DEFINE_OBJECT_REF_METHODS(PluginAttr, ObjectRef, PluginAttrNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginAttr, ObjectRef, PluginAttrNode); }; /*! @@ -320,15 +319,15 @@ class PluginAttr : public ObjectRef { class PluginTensorNode : public Object { public: /*! \brief The name of tensor. */ - String name; + ffi::String name; /*! \brief The dtype of tensor. */ - String dtype; + ffi::String dtype; /*! \brief The ndim of tensor. */ Integer ndim; /*! \brief The device of tensor. */ - String device; + ffi::String device; /*! \brief The describe of tensor. */ - String describe; + ffi::String describe; /*! \brief Export tensor to json. */ const JsonPluginTensor ToJson() const; @@ -348,8 +347,7 @@ class PluginTensorNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.PluginTensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(PluginTensorNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginTensor", PluginTensorNode, Object); }; /*! @@ -366,8 +364,8 @@ class PluginTensor : public ObjectRef { * \param device The device of the tensor. * \param describe The describe of the tensor. */ - TVM_DLL PluginTensor(const String& name, const String& dtype, const Integer& ndim, - const String& device, const String& describe); + TVM_DLL PluginTensor(const ffi::String& name, const ffi::String& dtype, const Integer& ndim, + const ffi::String& device, const ffi::String& describe); /*! * \brief The json constructor. @@ -381,7 +379,7 @@ class PluginTensor : public ObjectRef { */ TVM_DLL PluginTensor(const std::string& json_str); - TVM_DEFINE_OBJECT_REF_METHODS(PluginTensor, ObjectRef, PluginTensorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginTensor, ObjectRef, PluginTensorNode); }; /*! @@ -390,15 +388,15 @@ class PluginTensor : public ObjectRef { class PluginExternNode : public Object { public: /*! \brief The name of extern. */ - String name; + ffi::String name; /*! \brief The header of extern. */ - String header; + ffi::String header; /*! \brief The source of extern. */ - String source; + ffi::String source; /*! \brief The lib of extern. */ - String lib; + ffi::String lib; /*! \brief The describe of extern. */ - String describe; + ffi::String describe; /*! \brief Export extern to json. */ const JsonPluginExtern ToJson() const; @@ -418,8 +416,7 @@ class PluginExternNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.PluginExtern"; - TVM_DECLARE_FINAL_OBJECT_INFO(PluginExternNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginExtern", PluginExternNode, Object); }; /*! @@ -436,8 +433,9 @@ class PluginExtern : public ObjectRef { * \param lib The lib of the extern. * \param describe The describe of the extern. */ - TVM_DLL PluginExtern(const String& name, const String& header, const String& source, - const String& lib, const String& describe); + TVM_DLL PluginExtern(const ffi::String& name, const ffi::String& header, + const ffi::String& source, const ffi::String& lib, + const ffi::String& describe); /*! * \brief The json constructor. @@ -451,7 +449,7 @@ class PluginExtern : public ObjectRef { */ TVM_DLL PluginExtern(const std::string& json_str); - TVM_DEFINE_OBJECT_REF_METHODS(PluginExtern, ObjectRef, PluginExternNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginExtern, ObjectRef, PluginExternNode); }; /*! @@ -460,25 +458,25 @@ class PluginExtern : public ObjectRef { class PluginNode : public Object { public: /*! \brief The name of plugin. */ - String name; + ffi::String name; /*! \brief The version of plugin. */ - String version; + ffi::String version; /*! \brief The describe of plugin. */ - String describe; + ffi::String describe; /*! \brief The attributes of plugin. */ - Array attrs; + ffi::Array attrs; /*! \brief The inputs of plugin. */ - Array inputs; + ffi::Array inputs; /*! \brief The outputs of plugin. */ - Array outputs; + ffi::Array outputs; /*! \brief The buffers of plugin. */ - Array buffers; + ffi::Array buffers; /*! \brief The externs of plugin. */ - Map externs; + ffi::Map externs; /*! \brief The support_dtypes of plugin. */ - Map> support_dtypes; + ffi::Map> support_dtypes; /*! \brief The options of plugin. */ - Map options; + ffi::Map options; /*! \brief Export plugin to json. */ const JsonPlugin ToJson() const; @@ -508,8 +506,7 @@ class PluginNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const char* _type_key = "msc.core.Plugin"; - TVM_DECLARE_FINAL_OBJECT_INFO(PluginNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.Plugin", PluginNode, Object); }; /*! @@ -531,12 +528,12 @@ class Plugin : public ObjectRef { * \param support_dtypes The support_dtypes of the plugin. * \param options The options of the plugin. */ - TVM_DLL Plugin(const String& name, const String& version, const String& describe, - const Array& attrs, const Array& inputs, - const Array& outputs, const Array& buffers, - const Map& externs, - const Map>& support_dtypes, - const Map& options); + TVM_DLL Plugin(const ffi::String& name, const ffi::String& version, const ffi::String& describe, + const ffi::Array& attrs, const ffi::Array& inputs, + const ffi::Array& outputs, const ffi::Array& buffers, + const ffi::Map& externs, + const ffi::Map>& support_dtypes, + const ffi::Map& options); /*! * \brief The json constructor. @@ -550,7 +547,7 @@ class Plugin : public ObjectRef { */ TVM_DLL Plugin(const std::string& json_str); - TVM_DEFINE_OBJECT_REF_METHODS(Plugin, ObjectRef, PluginNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Plugin, ObjectRef, PluginNode); }; class PluginRegistry { @@ -561,7 +558,7 @@ class PluginRegistry { * \param json_str The json_str. * \return The corresponding entry. */ - bool Register(const String& name, const String& json_str) { + bool Register(const ffi::String& name, const ffi::String& json_str) { plugin_map_[name] = Plugin(json_str); return true; } @@ -571,7 +568,7 @@ class PluginRegistry { * \param name The name of the item. * \return Whether the plugin is registered. */ - bool Registered(const String& name) const { + bool Registered(const ffi::String& name) const { auto it = plugin_map_.find(name); return it != plugin_map_.end(); } @@ -581,7 +578,7 @@ class PluginRegistry { * \param name The name of the item. * \return The corresponding plugin. */ - const Plugin Get(const String& name) const { + const Plugin Get(const ffi::String& name) const { auto it = plugin_map_.find(name); ICHECK(it != plugin_map_.end()) << "Can not find plugin " << name; return it->second; @@ -591,8 +588,8 @@ class PluginRegistry { * \brief List all the plugin names in the registry. * \return The plugin names. */ - Array ListAllNames() const { - Array names; + ffi::Array ListAllNames() const { + ffi::Array names; for (const auto& kv : plugin_map_) { names.push_back(kv.first); } @@ -609,28 +606,28 @@ class PluginRegistry { private: // map from name to plugins. - std::unordered_map plugin_map_; + std::unordered_map plugin_map_; }; /*! * \brief List all plugin names. * \return the corresponding plugin names. */ -const Array ListPluginNames(); +const ffi::Array ListPluginNames(); /*! * \brief Get the registered plugin. * \param name The name of the Plugin. * \return the corresponding plugin. */ -const Plugin GetPlugin(const String& name); +const Plugin GetPlugin(const ffi::String& name); /*! * \brief Check if an plugin is registered. * \param name The name of the item. * \return Whether the plugin is registered. */ -bool IsPlugin(const String& name); +bool IsPlugin(const ffi::String& name); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/printer/cpp_printer.cc b/src/contrib/msc/core/printer/cpp_printer.cc index 1f0fdb11778a..8c2a512a6d86 100644 --- a/src/contrib/msc/core/printer/cpp_printer.cc +++ b/src/contrib/msc/core/printer/cpp_printer.cc @@ -348,7 +348,7 @@ bool CppPrinter::IsEmptyDoc(const ExprDoc& doc) { return id_doc->name == DocSymbol::Empty(); } -void CppPrinter::PrintIndentedBlock(const Array& docs) { +void CppPrinter::PrintIndentedBlock(const ffi::Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { PrintDoc(d); diff --git a/src/contrib/msc/core/printer/cpp_printer.h b/src/contrib/msc/core/printer/cpp_printer.h index bdd25acdebed..62e205a7c749 100644 --- a/src/contrib/msc/core/printer/cpp_printer.h +++ b/src/contrib/msc/core/printer/cpp_printer.h @@ -147,7 +147,7 @@ class CppPrinter : public MSCBasePrinter { bool IsEmptyDoc(const ExprDoc& doc); /*! \brief Print block with indent*/ - void PrintIndentedBlock(const Array& docs); + void PrintIndentedBlock(const ffi::Array& docs); }; } // namespace msc diff --git a/src/contrib/msc/core/printer/msc_base_printer.h b/src/contrib/msc/core/printer/msc_base_printer.h index af369a530dae..10dafb54c2ac 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.h +++ b/src/contrib/msc/core/printer/msc_base_printer.h @@ -97,7 +97,7 @@ class MSCBasePrinter { * \brief Get the printed string of all Doc appended * \sa Append */ - String GetString() const { return output_.str(); } + ffi::String GetString() const { return output_.str(); } protected: /*! \brief Print doc*/ @@ -199,7 +199,7 @@ class MSCBasePrinter { /*! \brief Print docs to joined doc */ template - void PrintJoinedDocs(const Array& docs, const String& separator = ", ") { + void PrintJoinedDocs(const ffi::Array& docs, const ffi::String& separator = ", ") { for (size_t i = 0; i < docs.size(); i++) { PrintDoc(docs[i], false); output_ << (i == docs.size() - 1 ? "" : separator); diff --git a/src/contrib/msc/core/printer/msc_doc.cc b/src/contrib/msc/core/printer/msc_doc.cc index b69e554ab9c4..e1cae35be132 100644 --- a/src/contrib/msc/core/printer/msc_doc.cc +++ b/src/contrib/msc/core/printer/msc_doc.cc @@ -29,9 +29,9 @@ namespace tvm { namespace contrib { namespace msc { -DeclareDoc::DeclareDoc(Optional type, ExprDoc variable, Array init_args, +DeclareDoc::DeclareDoc(ffi::Optional type, ExprDoc variable, ffi::Array init_args, bool use_constructor) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->type = type; n->variable = variable; n->init_args = init_args; @@ -40,45 +40,46 @@ DeclareDoc::DeclareDoc(Optional type, ExprDoc variable, Array } StrictListDoc::StrictListDoc(ListDoc list, bool allow_empty) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->list = list; n->allow_empty = allow_empty; this->data_ = std::move(n); } -PointerDoc::PointerDoc(String name) { - ObjectPtr n = make_object(); +PointerDoc::PointerDoc(ffi::String name) { + ObjectPtr n = ffi::make_object(); n->name = name; this->data_ = std::move(n); } -StructDoc::StructDoc(IdDoc name, Array decorators, Array body) { - ObjectPtr n = make_object(); +StructDoc::StructDoc(IdDoc name, ffi::Array decorators, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->decorators = decorators; n->body = body; this->data_ = std::move(n); } -ConstructorDoc::ConstructorDoc(IdDoc name, Array args, Array body) { - ObjectPtr n = make_object(); +ConstructorDoc::ConstructorDoc(IdDoc name, ffi::Array args, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->args = args; n->body = body; this->data_ = std::move(n); } -SwitchDoc::SwitchDoc(Array predicates, Array> branchs, - Array default_branch) { - ObjectPtr n = make_object(); +SwitchDoc::SwitchDoc(ffi::Array predicates, ffi::Array> branchs, + ffi::Array default_branch) { + ObjectPtr n = ffi::make_object(); n->predicates = predicates; n->branchs = branchs; n->default_branch = default_branch; this->data_ = std::move(n); } -LambdaDoc::LambdaDoc(IdDoc name, Array args, Array refs, Array body) { - ObjectPtr n = make_object(); +LambdaDoc::LambdaDoc(IdDoc name, ffi::Array args, ffi::Array refs, + ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->args = args; n->refs = refs; @@ -86,7 +87,7 @@ LambdaDoc::LambdaDoc(IdDoc name, Array args, Array refs, Arr this->data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { DeclareDocNode::RegisterReflection(); StrictListDocNode::RegisterReflection(); PointerDocNode::RegisterReflection(); @@ -94,7 +95,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ConstructorDocNode::RegisterReflection(); SwitchDocNode::RegisterReflection(); LambdaDocNode::RegisterReflection(); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h index ea13d74d569f..fe0f6c68338f 100644 --- a/src/contrib/msc/core/printer/msc_doc.h +++ b/src/contrib/msc/core/printer/msc_doc.h @@ -43,11 +43,11 @@ using namespace tvm::script::printer; class DeclareDocNode : public ExprDocNode { public: /*! \brief The type of the variable */ - Optional type; + ffi::Optional type; /*! \brief The variable */ - ExprDoc variable{nullptr}; + ExprDoc variable{ffi::UnsafeInit{}}; /*! \brief The init arguments for the variable. */ - Array init_args; + ffi::Array init_args; /*! \brief Whether to use constructor(otherwise initializer) */ bool use_constructor{true}; @@ -59,9 +59,7 @@ class DeclareDocNode : public ExprDocNode { .def_ro("init_args", &DeclareDocNode::init_args) .def_ro("use_constructor", &DeclareDocNode::use_constructor); } - - static constexpr const char* _type_key = "msc.script.printer.DeclareDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(DeclareDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.DeclareDoc", DeclareDocNode, ExprDocNode); }; /*! @@ -78,9 +76,9 @@ class DeclareDoc : public ExprDoc { * \param init_args The init arguments of the variable. * \param use_constructor Whether to use constructor(otherwise initializer). */ - explicit DeclareDoc(Optional type, ExprDoc variable, Array init_args, + explicit DeclareDoc(ffi::Optional type, ExprDoc variable, ffi::Array init_args, bool use_constructor); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DeclareDoc, ExprDoc, DeclareDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DeclareDoc, ExprDoc, DeclareDocNode); }; /*! @@ -101,9 +99,8 @@ class StrictListDocNode : public ExprDocNode { .def_ro("list", &StrictListDocNode::list) .def_ro("allow_empty", &StrictListDocNode::allow_empty); } - - static constexpr const char* _type_key = "msc.script.printer.StrictListDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(StrictListDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.StrictListDoc", StrictListDocNode, + ExprDocNode); }; /*! @@ -119,7 +116,7 @@ class StrictListDoc : public ExprDoc { * \param allow_empty Whether to allow empty. */ explicit StrictListDoc(ListDoc list, bool allow_empty); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StrictListDoc, ExprDoc, StrictListDocNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StrictListDoc, ExprDoc, StrictListDocNode); }; /*! @@ -130,15 +127,13 @@ class StrictListDoc : public ExprDoc { class PointerDocNode : public ExprDocNode { public: /*! \brief The name of the identifier */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("name", &PointerDocNode::name); } - - static constexpr const char* _type_key = "msc.script.printer.PointerDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(PointerDocNode, ExprDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.PointerDoc", PointerDocNode, ExprDocNode); }; /*! @@ -152,8 +147,8 @@ class PointerDoc : public ExprDoc { * \brief Constructor of PointerDoc. * \param name The name of identifier. */ - explicit PointerDoc(String name); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PointerDoc, ExprDoc, PointerDocNode); + explicit PointerDoc(ffi::String name); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PointerDoc, ExprDoc, PointerDocNode); }; /*! @@ -164,11 +159,11 @@ class PointerDoc : public ExprDoc { class StructDocNode : public StmtDocNode { public: /*! \brief The name of class. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! \brief Decorators of class. */ - Array decorators; + ffi::Array decorators; /*! \brief The body of class. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -177,9 +172,7 @@ class StructDocNode : public StmtDocNode { .def_ro("decorators", &StructDocNode::decorators) .def_ro("body", &StructDocNode::body); } - - static constexpr const char* _type_key = "msc.script.printer.StructDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(StructDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.StructDoc", StructDocNode, StmtDocNode); }; /*! @@ -195,8 +188,8 @@ class StructDoc : public StmtDoc { * \param decorators The decorator of class. * \param body The body of class. */ - explicit StructDoc(IdDoc name, Array decorators, Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StructDoc, StmtDoc, StructDocNode); + explicit StructDoc(IdDoc name, ffi::Array decorators, ffi::Array body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StructDoc, StmtDoc, StructDocNode); }; /*! @@ -207,7 +200,7 @@ class StructDoc : public StmtDoc { class ConstructorDocNode : public StmtDocNode { public: /*! \brief The name of function. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of function. * @@ -215,9 +208,9 @@ class ConstructorDocNode : public StmtDocNode { * `annotation` means argument type, * and `rhs` means default value. */ - Array args; + ffi::Array args; /*! \brief The body of function. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -226,9 +219,8 @@ class ConstructorDocNode : public StmtDocNode { .def_ro("args", &ConstructorDocNode::args) .def_ro("body", &ConstructorDocNode::body); } - - static constexpr const char* _type_key = "msc.script.printer.ConstructorDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.ConstructorDoc", ConstructorDocNode, + StmtDocNode); }; /*! @@ -244,8 +236,8 @@ class ConstructorDoc : public StmtDoc { * \param args The arguments of function. * \param body The body of function. */ - explicit ConstructorDoc(IdDoc name, Array args, Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ConstructorDoc, StmtDoc, ConstructorDocNode); + explicit ConstructorDoc(IdDoc name, ffi::Array args, ffi::Array body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ConstructorDoc, StmtDoc, ConstructorDocNode); }; /*! @@ -256,11 +248,11 @@ class ConstructorDoc : public StmtDoc { class SwitchDocNode : public StmtDocNode { public: /*! \brief The predicates of the switch statement. */ - Array predicates; + ffi::Array predicates; /*! \brief The branchs of the switch statement. */ - Array> branchs; + ffi::Array> branchs; /*! \brief The default_branch of the switch statement. */ - Array default_branch; + ffi::Array default_branch; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -269,9 +261,7 @@ class SwitchDocNode : public StmtDocNode { .def_ro("branchs", &SwitchDocNode::branchs) .def_ro("default_branch", &SwitchDocNode::default_branch); } - - static constexpr const char* _type_key = "msc.script.printer.SwitchDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(SwitchDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.SwitchDoc", SwitchDocNode, StmtDocNode); }; /*! @@ -287,9 +277,9 @@ class SwitchDoc : public StmtDoc { * \param branchs The branchs of the switch statement. * \param default_branch The default_branch of the switch statement. */ - explicit SwitchDoc(Array predicates, Array> branchs, - Array default_branch); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SwitchDoc, StmtDoc, SwitchDocNode); + explicit SwitchDoc(ffi::Array predicates, ffi::Array> branchs, + ffi::Array default_branch); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SwitchDoc, StmtDoc, SwitchDocNode); }; /*! @@ -300,7 +290,7 @@ class SwitchDoc : public StmtDoc { class LambdaDocNode : public StmtDocNode { public: /*! \brief The name of lambda. */ - IdDoc name{nullptr}; + IdDoc name{ffi::UnsafeInit{}}; /*! * \brief The arguments of lambda. * @@ -308,11 +298,11 @@ class LambdaDocNode : public StmtDocNode { * `annotation` means argument type, * and `rhs` means default value. */ - Array args; + ffi::Array args; /*! \brief References of lambda. */ - Array refs; + ffi::Array refs; /*! \brief The body of lambda. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -322,9 +312,7 @@ class LambdaDocNode : public StmtDocNode { .def_ro("refs", &LambdaDocNode::refs) .def_ro("body", &LambdaDocNode::body); } - - static constexpr const char* _type_key = "msc.script.printer.LambdaDoc"; - TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, StmtDocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.LambdaDoc", LambdaDocNode, StmtDocNode); }; /*! @@ -341,8 +329,9 @@ class LambdaDoc : public StmtDoc { * \param refs The references of lambda. * \param body The body of lambda. */ - explicit LambdaDoc(IdDoc name, Array args, Array refs, Array body); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, StmtDoc, LambdaDocNode); + explicit LambdaDoc(IdDoc name, ffi::Array args, ffi::Array refs, + ffi::Array body); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LambdaDoc, StmtDoc, LambdaDocNode); }; } // namespace msc diff --git a/src/contrib/msc/core/printer/print_utils.cc b/src/contrib/msc/core/printer/print_utils.cc index 234ca3aec9c3..50d36df10bdb 100644 --- a/src/contrib/msc/core/printer/print_utils.cc +++ b/src/contrib/msc/core/printer/print_utils.cc @@ -28,9 +28,9 @@ namespace tvm { namespace contrib { namespace msc { -const String DocSymbol::Empty() { return "::EMPTY"; } +const ffi::String DocSymbol::Empty() { return "::EMPTY"; } -const String DocSymbol::NextLine() { return "::NEXT_LINE"; } +const ffi::String DocSymbol::NextLine() { return "::NEXT_LINE"; } const ExprDoc DocUtils::ToDoc(int64_t val) { return LiteralDoc::Int(val, std::nullopt); } @@ -50,19 +50,19 @@ const ExprDoc DocUtils::ToDoc(const FloatImm& val) { return ToDoc(val->value); } const ExprDoc DocUtils::ToDoc(const char* val) { return IdDoc(std::string(val)); } -const ExprDoc DocUtils::ToDoc(const String& val) { return IdDoc(val); } +const ExprDoc DocUtils::ToDoc(const ffi::String& val) { return IdDoc(val); } const ExprDoc DocUtils::ToDoc(bool val) { return LiteralDoc::Boolean(val, std::nullopt); } const ExprDoc DocUtils::ToDoc(const ExprDoc& val) { return val; } -const ExprDoc DocUtils::ToStr(const String& val) { return LiteralDoc::Str(val, std::nullopt); } +const ExprDoc DocUtils::ToStr(const ffi::String& val) { return LiteralDoc::Str(val, std::nullopt); } -const PointerDoc DocUtils::ToPtr(const String& val) { return PointerDoc(val); } +const PointerDoc DocUtils::ToPtr(const ffi::String& val) { return PointerDoc(val); } const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { if (values.size() > 0 || allow_empty) { - Array elements; + ffi::Array elements; for (const auto& v : values) { elements.push_back(ToStr(v)); } @@ -71,7 +71,7 @@ const StrictListDoc DocUtils::ToStrList(const std::vector& values, return StrictListDoc(ListDoc(), false); } -const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { +const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -79,7 +79,7 @@ const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool return ToStrList(v_values, allow_empty); } -const StrictListDoc DocUtils::ToStrList(const Array& values, bool allow_empty) { +const StrictListDoc DocUtils::ToStrList(const ffi::Array& values, bool allow_empty) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -87,8 +87,8 @@ const StrictListDoc DocUtils::ToStrList(const Array& values, bool allow_ return ToStrList(v_values, allow_empty); } -const Array DocUtils::ToStmts(const Array& docs) { - Array stmts; +const ffi::Array DocUtils::ToStmts(const ffi::Array& docs) { + ffi::Array stmts; for (const auto& d : docs) { if (d->IsInstance()) { stmts.push_back(Downcast(d)); @@ -101,7 +101,7 @@ const Array DocUtils::ToStmts(const Array& docs) { return stmts; } -const StmtBlockDoc DocUtils::ToStmtBlock(const Array& docs) { +const StmtBlockDoc DocUtils::ToStmtBlock(const ffi::Array& docs) { return StmtBlockDoc(ToStmts(docs)); } diff --git a/src/contrib/msc/core/printer/print_utils.h b/src/contrib/msc/core/printer/print_utils.h index b3949d54a762..3ccc1cdc22cc 100644 --- a/src/contrib/msc/core/printer/print_utils.h +++ b/src/contrib/msc/core/printer/print_utils.h @@ -44,10 +44,10 @@ using namespace tvm::script::printer; class DocSymbol { public: /*! * \brief The empty symbol*/ - TVM_DLL static const String Empty(); + TVM_DLL static const ffi::String Empty(); /*! * \brief The next line symbol*/ - TVM_DLL static const String NextLine(); + TVM_DLL static const ffi::String NextLine(); }; /*! @@ -68,30 +68,30 @@ class DocUtils { TVM_DLL static const ExprDoc ToDoc(double val); TVM_DLL static const ExprDoc ToDoc(const FloatImm& val); TVM_DLL static const ExprDoc ToDoc(const char* val); - TVM_DLL static const ExprDoc ToDoc(const String& val); + TVM_DLL static const ExprDoc ToDoc(const ffi::String& val); TVM_DLL static const ExprDoc ToDoc(bool val); TVM_DLL static const ExprDoc ToDoc(const ExprDoc& val); - TVM_DLL static const ExprDoc ToStr(const String& val); - TVM_DLL static const PointerDoc ToPtr(const String& val); + TVM_DLL static const ExprDoc ToStr(const ffi::String& val); + TVM_DLL static const PointerDoc ToPtr(const ffi::String& val); /*! * \brief Change object to DeclareDoc. * \return The DeclareDoc. */ template - TVM_DLL static const DeclareDoc ToDeclare(const String& type, const T& variable, size_t len = 0, - bool use_constructor = true) { - Optional type_doc; + TVM_DLL static const DeclareDoc ToDeclare(const ffi::String& type, const T& variable, + size_t len = 0, bool use_constructor = true) { + ffi::Optional type_doc; if (type.size() == 0) { type_doc = std::nullopt; } else { type_doc = IdDoc(type); } if (len == 0) { - return DeclareDoc(type_doc, ToDoc(variable), Array(), use_constructor); + return DeclareDoc(type_doc, ToDoc(variable), ffi::Array(), use_constructor); } - Array doc_indices{DocUtils::ToDoc(len)}; - return DeclareDoc(type_doc, IndexDoc(ToDoc(variable), doc_indices), Array(), + ffi::Array doc_indices{DocUtils::ToDoc(len)}; + return DeclareDoc(type_doc, IndexDoc(ToDoc(variable), doc_indices), ffi::Array(), use_constructor); } @@ -101,22 +101,22 @@ class DocUtils { */ template TVM_DLL static const AssignDoc ToAssign(const LT& lhs, const RT& rhs, - const String& annotation = "") { + const ffi::String& annotation = "") { if (annotation.size() == 0) { return AssignDoc(ToDoc(lhs), ToDoc(rhs), std::nullopt); } return AssignDoc(ToDoc(lhs), ToDoc(rhs), IdDoc(annotation)); } template - TVM_DLL static const AssignDoc ToAssign(const T& lhs, const String& rhs, - const String& annotation = "") { - Optional rhs_doc; + TVM_DLL static const AssignDoc ToAssign(const T& lhs, const ffi::String& rhs, + const ffi::String& annotation = "") { + ffi::Optional rhs_doc; if (rhs.size() > 0) { rhs_doc = IdDoc(rhs); } else { rhs_doc = std::nullopt; } - Optional annotation_doc; + ffi::Optional annotation_doc; if (annotation.size() > 0) { annotation_doc = IdDoc(annotation); } else { @@ -130,7 +130,7 @@ class DocUtils { * \return The AttrAccessDoc. */ template - TVM_DLL static const AttrAccessDoc ToAttrAccess(const T& value, const String& name) { + TVM_DLL static const AttrAccessDoc ToAttrAccess(const T& value, const ffi::String& name) { return AttrAccessDoc(ToDoc(value), name); } @@ -139,15 +139,15 @@ class DocUtils { * \return The List of Docs. */ template - TVM_DLL static const Array ToDocList(const std::vector& values) { - Array elements; + TVM_DLL static const ffi::Array ToDocList(const std::vector& values) { + ffi::Array elements; for (const auto& v : values) { elements.push_back(ToDoc(v)); } return elements; } template - TVM_DLL static const Array ToDocList(const Array& values) { + TVM_DLL static const ffi::Array ToDocList(const ffi::Array& values) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -168,7 +168,7 @@ class DocUtils { return StrictListDoc(ListDoc(), false); } template - TVM_DLL static const StrictListDoc ToList(const Array& values, bool allow_empty = false) { + TVM_DLL static const StrictListDoc ToList(const ffi::Array& values, bool allow_empty = false) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -182,9 +182,9 @@ class DocUtils { */ TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, bool allow_empty = false); - TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, + TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, bool allow_empty = false); - TVM_DLL static const StrictListDoc ToStrList(const Array& values, + TVM_DLL static const StrictListDoc ToStrList(const ffi::Array& values, bool allow_empty = false); /*! @@ -193,21 +193,21 @@ class DocUtils { */ template TVM_DLL static const IndexDoc ToIndex(const VT& value, const IT& index) { - Array doc_indices; + ffi::Array doc_indices; doc_indices.push_back(ToDoc(index)); return IndexDoc(ToDoc(value), doc_indices); } template TVM_DLL static const IndexDoc ToIndices(const VT& value, const std::vector& indices) { - Array doc_indices; + ffi::Array doc_indices; for (const auto& i : indices) { doc_indices.push_back(ToDoc(i)); } return IndexDoc(ToDoc(value), doc_indices); } template - TVM_DLL static const IndexDoc ToIndices(const VT& value, const Array& indices) { - Array doc_indices; + TVM_DLL static const IndexDoc ToIndices(const VT& value, const ffi::Array& indices) { + ffi::Array doc_indices; for (const auto& i : indices) { doc_indices.push_back(ToDoc(i)); } @@ -218,13 +218,13 @@ class DocUtils { * \brief Convert the docs to Stmts. * \return The Stmts. */ - TVM_DLL static const Array ToStmts(const Array& docs); + TVM_DLL static const ffi::Array ToStmts(const ffi::Array& docs); /*! * \brief Convert the docs to StmtBlock. * \return The StmtBlockDoc. */ - TVM_DLL static const StmtBlockDoc ToStmtBlock(const Array& docs); + TVM_DLL static const StmtBlockDoc ToStmtBlock(const ffi::Array& docs); }; } // namespace msc diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index d62e5ac2a8f6..ffaf035385f1 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -43,9 +43,9 @@ LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) { return LiteralDoc::Str(obj_des.str(), std::nullopt); } -DictDoc PrototxtPrinter::ToDictDoc(const Map& dict) { - Array keys; - Array values; +DictDoc PrototxtPrinter::ToDictDoc(const ffi::Map& dict) { + ffi::Array keys; + ffi::Array values; for (const auto& pair : dict) { keys.push_back(IdDoc(pair.first)); if (pair.second.as()) { @@ -57,9 +57,9 @@ DictDoc PrototxtPrinter::ToDictDoc(const Map& dict) { return DictDoc(keys, values); } -DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& dict) { - Array keys; - Array values; +DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& dict) { + ffi::Array keys; + ffi::Array values; for (const auto& pair : dict) { keys.push_back(IdDoc(pair.first)); if (pair.second.as()) { @@ -71,18 +71,18 @@ DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& di return DictDoc(keys, values); } -void PrototxtPrinter::Append(const Map& dict) { +void PrototxtPrinter::Append(const ffi::Map& dict) { DictDoc doc = ToDictDoc(dict); PrintDoc(doc, false); } -void PrototxtPrinter::Append(const std::vector>& dict) { +void PrototxtPrinter::Append(const std::vector>& dict) { DictDoc doc = ToDictDoc(dict); PrintDoc(doc, false); } -void PrototxtPrinter::AppendPair(const String& key, const ffi::Any& value) { - Map dict; +void PrototxtPrinter::AppendPair(const ffi::String& key, const ffi::Any& value) { + ffi::Map dict; dict.Set(key, value); return Append(dict); } diff --git a/src/contrib/msc/core/printer/prototxt_printer.h b/src/contrib/msc/core/printer/prototxt_printer.h index e760a179d8dd..f304dcdd5819 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.h +++ b/src/contrib/msc/core/printer/prototxt_printer.h @@ -53,19 +53,19 @@ class PrototxtPrinter : public MSCBasePrinter { static LiteralDoc ToLiteralDoc(const ffi::Any& obj); /*! \brief Change map to DictDoc*/ - static DictDoc ToDictDoc(const Map& dict); + static DictDoc ToDictDoc(const ffi::Map& dict); /*! \brief Change ordered pairs to DictDoc*/ - static DictDoc ToDictDoc(const std::vector>& dict); + static DictDoc ToDictDoc(const std::vector>& dict); /*! \brief Append a map into the final content*/ - void Append(const Map& dict); + void Append(const ffi::Map& dict); /*! \brief Append ordered pairs into the final content*/ - void Append(const std::vector>& dict); + void Append(const std::vector>& dict); /*! \brief Append a map pair into the final content*/ - void AppendPair(const String& key, const ffi::Any& value); + void AppendPair(const ffi::String& key, const ffi::Any& value); protected: /*! * \brief Print a DictDoc to prototxt format*/ diff --git a/src/contrib/msc/core/printer/python_printer.cc b/src/contrib/msc/core/printer/python_printer.cc index df75887ce1b6..eb087f7f40e6 100644 --- a/src/contrib/msc/core/printer/python_printer.cc +++ b/src/contrib/msc/core/printer/python_printer.cc @@ -248,7 +248,7 @@ void PythonPrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) { } } -void PythonPrinter::PrintIndentedBlock(const Array& docs) { +void PythonPrinter::PrintIndentedBlock(const ffi::Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { PrintDoc(d); @@ -259,7 +259,7 @@ void PythonPrinter::PrintIndentedBlock(const Array& docs) { DecreaseIndent(); } -void PythonPrinter::PrintDecorators(const Array& decorators) { +void PythonPrinter::PrintDecorators(const ffi::Array& decorators) { for (const ExprDoc& decorator : decorators) { output_ << "@"; PrintDoc(decorator, false); diff --git a/src/contrib/msc/core/printer/python_printer.h b/src/contrib/msc/core/printer/python_printer.h index 31f380bc87be..3e09b1fcdabc 100644 --- a/src/contrib/msc/core/printer/python_printer.h +++ b/src/contrib/msc/core/printer/python_printer.h @@ -92,10 +92,10 @@ class PythonPrinter : public MSCBasePrinter { private: /*! \brief Print block with indent*/ - void PrintIndentedBlock(const Array& docs); + void PrintIndentedBlock(const ffi::Array& docs); /*! \brief Print decorators for function and class*/ - void PrintDecorators(const Array& decorators); + void PrintDecorators(const ffi::Array& decorators); }; } // namespace msc diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index df534f4cfae6..992c514ad7ef 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -34,23 +34,23 @@ namespace tvm { namespace relax { using namespace tvm::contrib::msc; -std::tuple, Map> NormalizeNamedBindings( - const Function& func, const Map& untyped_params) { +std::tuple, ffi::Map> NormalizeNamedBindings( + const Function& func, const ffi::Map& untyped_params) { ICHECK(func.defined()); ICHECK(untyped_params.defined()); // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; + std::unordered_map> string_lookup; std::unordered_set var_set; for (const auto& param : func->params) { string_lookup[param->name_hint()].push_back(param); var_set.insert(param.get()); } - Map relax_var_remap; + ffi::Map relax_var_remap; auto normalize_key = [&](ffi::Any obj) -> relax::Var { - if (auto opt_str = obj.as()) { + if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); CHECK(it != string_lookup.end()) @@ -83,7 +83,7 @@ std::tuple, Map> NormalizeNamedBindings( auto normalize_value = [&](Var key, ffi::Any obj) -> relax::Expr { if (auto opt = obj.as()) { return opt.value(); - } else if (auto opt = obj.as()) { + } else if (auto opt = obj.as()) { const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint()); return Constant(opt.value(), StructInfo(), span); } else { @@ -96,7 +96,7 @@ std::tuple, Map> NormalizeNamedBindings( } arith::Analyzer analyzer; - Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); return {relax_var_remap, symbolic_var_map}; } @@ -107,7 +107,8 @@ std::tuple, Map> NormalizeNamedBindings( * \param params params dict * \return Function */ -Function FunctionBindNamedParams(Function func, const Map& untyped_params) { +Function FunctionBindNamedParams(Function func, + const ffi::Map& untyped_params) { auto [bind_dict, symbolic_var_map] = NormalizeNamedBindings(func, untyped_params); Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); @@ -121,43 +122,47 @@ Function FunctionBindNamedParams(Function func, const Map& * \param param The param dict * \return The module after binding params. */ -IRModule BindNamedParam(IRModule m, String func_name, Map bind_params) { +IRModule BindNamedParam(IRModule m, ffi::String func_name, + ffi::Map bind_params) { IRModuleNode* new_module = m.CopyOnWrite(); - Map functions = m->functions; + ffi::Map functions = m->functions; for (const auto& func_pr : functions) { if (const auto* relax_f = func_pr.second.as()) { if (relax_f->GetLinkageType() == LinkageType::kExternal) { // Use global_symbol if it's external linkage - Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = + relax_f->GetAttr(tvm::attr::kGlobalSymbol); if (gsymbol.has_value() && gsymbol.value() == func_name) { - Function f_after_bind = FunctionBindNamedParams(GetRef(relax_f), bind_params); + Function f_after_bind = + FunctionBindNamedParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } else { // Use global var's name_hint if it's internal linkage if (func_pr.first->name_hint == func_name) { - Function f_after_bind = FunctionBindNamedParams(GetRef(relax_f), bind_params); + Function f_after_bind = + FunctionBindNamedParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } } } - return GetRef(new_module); + return ffi::GetRef(new_module); } namespace transform { -Pass BindNamedParams(String func_name, Map params) { +Pass BindNamedParams(ffi::String func_name, ffi::Map params) { auto pass_func = [=](IRModule mod, PassContext pc) { return BindNamedParam(std::move(mod), func_name, params); }; return CreateModulePass(pass_func, 0, "BindNamedParams", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.BindNamedParams", BindNamedParams); -}); +} } // namespace transform diff --git a/src/contrib/msc/core/transform/bind_shape.cc b/src/contrib/msc/core/transform/bind_shape.cc index b7c3491bff1a..c9963ba94e84 100644 --- a/src/contrib/msc/core/transform/bind_shape.cc +++ b/src/contrib/msc/core/transform/bind_shape.cc @@ -37,7 +37,8 @@ namespace relax { */ class ShapeBinder : public ExprMutator { public: - explicit ShapeBinder(IRModule ctx_module, const String& entry_name) : ExprMutator(ctx_module) { + explicit ShapeBinder(IRModule ctx_module, const ffi::String& entry_name) + : ExprMutator(ctx_module) { mod_ = ctx_module; entry_name_ = entry_name; } @@ -51,7 +52,7 @@ class ShapeBinder : public ExprMutator { continue; } if (func->IsInstance()) { - Array new_params; + ffi::Array new_params; for (const auto& p : Downcast(func)->params) { auto struct_info = GetStructInfo(p); if (struct_info->IsInstance()) { @@ -76,7 +77,7 @@ class ShapeBinder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - Array new_args; + ffi::Array new_args; for (const auto& a : call_node->args) { auto struct_info = GetStructInfo(a); if (a->IsInstance() && struct_info->IsInstance()) { @@ -92,7 +93,7 @@ class ShapeBinder : public ExprMutator { } else if (const auto* op_node = call_node->op.as()) { ICHECK(op_node->name == "relax.reshape" || op_node->name == "relax.image.resize2d") << "Expect ShapeExpr consumer as reshape or image.resize2d, get " - << GetRef(call_node); + << ffi::GetRef(call_node); const auto& opt_shape = Downcast(GetStructInfo(call_node->args[1]))->values; ICHECK(opt_shape.defined()) << "Expected shape defined, get " << call_node->args[1]; new_args.push_back(ShapeExpr(opt_shape.value())); @@ -101,7 +102,7 @@ class ShapeBinder : public ExprMutator { ReEmitBinding(binding, builder_->Normalize(new_call)); } else if (const auto* gv_node = call_node->op.as()) { const auto& func_info = Downcast(gv_node->struct_info_); - Array params_info; + ffi::Array params_info; for (const auto& a : new_args) { ICHECK(a->struct_info_.defined()) << "Global func argument without defined struct info " << a; @@ -113,30 +114,30 @@ class ShapeBinder : public ExprMutator { Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); ReEmitBinding(binding, builder_->Normalize(new_call)); } else { - LOG_FATAL << "Unexpected shape consumer " << GetRef(call_node); + LOG_FATAL << "Unexpected shape consumer " << ffi::GetRef(call_node); } } private: IRModule mod_; - String entry_name_; + ffi::String entry_name_; }; -IRModule BindShape(IRModule mod, const String& entry_name) { +IRModule BindShape(IRModule mod, const ffi::String& entry_name) { return ShapeBinder(mod, entry_name).Bind(); } namespace transform { -Pass BindShape(const String& entry_name) { +Pass BindShape(const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::BindShape(m, entry_name); }; return CreateModulePass(pass_func, 0, "BindShape", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.BindShape", BindShape); -}); +} } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index 19b8f08f4780..6f2913ac9599 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -41,7 +41,7 @@ using namespace tvm::contrib::msc; */ class TupleFuser : public ExprMutator { public: - explicit TupleFuser(IRModule ctx_module, const String& target, const String& entry_name) + explicit TupleFuser(IRModule ctx_module, const ffi::String& target, const ffi::String& entry_name) : ExprMutator(ctx_module) { mod_ = ctx_module; target_ = target + "."; @@ -54,7 +54,7 @@ class TupleFuser : public ExprMutator { if (gv->name_hint == entry_name_) { main_var = gv; } else { - const auto& name_opt = func->GetAttr(attr::kComposite); + const auto& name_opt = func->GetAttr(attr::kComposite); if (name_opt.has_value() && StringUtils::StartsWith(name_opt.value(), target_)) { target_funcs_.Set(gv, Downcast(func)); } @@ -70,12 +70,12 @@ class TupleFuser : public ExprMutator { void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { bool has_tuple_arg = false; if (target_funcs_.count(val->op)) { - Array new_args; + ffi::Array new_args; for (size_t i = 0; i < val->args.size(); i++) { const auto& arg = val->args[i]; if (arg->IsInstance()) { - String tuple_name; - const auto& name_opt = target_funcs_[val->op]->GetAttr(msc_attr::kUnique); + ffi::String tuple_name; + const auto& name_opt = target_funcs_[val->op]->GetAttr(msc_attr::kUnique); if (name_opt.has_value()) { if (val->args.size() == 1) { tuple_name = name_opt.value() + "_input"; @@ -114,7 +114,7 @@ class TupleFuser : public ExprMutator { } } if (on_target) { - ReEmitFunc(binding, GetRef(val)); + ReEmitFunc(binding, ffi::GetRef(val)); } else { ExprMutator::VisitBinding_(binding, val); } @@ -122,16 +122,16 @@ class TupleFuser : public ExprMutator { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { if (target_funcs_.count(val->tuple)) { - ReEmitFunc(binding, GetRef(val)); + ReEmitFunc(binding, ffi::GetRef(val)); } else { ExprMutator::VisitBinding_(binding, val); } } private: - Call AddFunc(const Expr& expr, const String tuple_name = "") { + Call AddFunc(const Expr& expr, const ffi::String tuple_name = "") { builder_->BeginDataflowBlock(); - Array inputs; + ffi::Array inputs; if (const auto* v_node = expr.as()) { inputs = v_node->fields; } else if (const auto* g_node = expr.as()) { @@ -139,17 +139,17 @@ class TupleFuser : public ExprMutator { } else { LOG_FATAL << "Unexpceted expr " << expr; } - Array func_inputs; - Array call_inputs; - Array params; - Map added_params; + ffi::Array func_inputs; + ffi::Array call_inputs; + ffi::Array params; + ffi::Map added_params; for (size_t i = 0; i < inputs.size(); i++) { if (inputs[i]->IsInstance()) { func_inputs.push_back(inputs[i]); continue; } if (!added_params.count(inputs[i])) { - const auto& name = String("param_" + std::to_string(i)); + const auto& name = ffi::String("param_" + std::to_string(i)); const auto& var = Var(std::move(name), GetStructInfo(inputs[i])); added_params.Set(inputs[i], var); } @@ -159,7 +159,7 @@ class TupleFuser : public ExprMutator { } Expr out_expr; - String func_name; + ffi::String func_name; Span expr_span = expr->span; if (!expr_span.defined()) { ICHECK(tuple_name.size() > 0) << "Missing tuple for " << expr; @@ -180,7 +180,7 @@ class TupleFuser : public ExprMutator { Expr body = builder_->Normalize(output); body = builder_->Normalize(SeqExpr({new_block}, body)); - Map func_attrs; + ffi::Map func_attrs; func_attrs.Set(attr::kPrimitive, true); func_attrs.Set(attr::kComposite, target_ + func_name); func_attrs.Set(msc_attr::kUnique, SpanUtils::GetAttr(expr_span, msc_attr::kName)); @@ -190,7 +190,7 @@ class TupleFuser : public ExprMutator { /*ret_struct_info=*/std::nullopt, // /*is_pure=*/true, // /*attrs=*/DictAttrs(func_attrs)); - Array free_vars = + ffi::Array free_vars = FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { params.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); @@ -214,28 +214,28 @@ class TupleFuser : public ExprMutator { } IRModule mod_; - String target_; - String entry_name_; - Map target_funcs_; + ffi::String target_; + ffi::String entry_name_; + ffi::Map target_funcs_; }; -IRModule FuseTuple(IRModule mod, const String& target, const String& entry_name) { +IRModule FuseTuple(IRModule mod, const ffi::String& target, const ffi::String& entry_name) { return TupleFuser(mod, target, entry_name).Fuse(); } namespace transform { -Pass FuseTuple(const String& target, const String& entry_name) { +Pass FuseTuple(const ffi::String& target, const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::FuseTuple(m, target, entry_name); }; return CreateModulePass(pass_func, 0, "FuseTuple", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.FuseTuple", FuseTuple); -}); +} } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc index 086c475f6d1f..9c5eb7536564 100644 --- a/src/contrib/msc/core/transform/inline_params.cc +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -40,7 +40,8 @@ using namespace tvm::contrib::msc; */ class ParamsInliner : public ExprMutator { public: - explicit ParamsInliner(IRModule ctx_module, const String& entry_name) : ExprMutator(ctx_module) { + explicit ParamsInliner(IRModule ctx_module, const ffi::String& entry_name) + : ExprMutator(ctx_module) { mod_ = ctx_module; entry_name_ = entry_name; } @@ -54,22 +55,22 @@ class ParamsInliner : public ExprMutator { continue; } if (func->IsInstance()) { - Array new_params; - Array attrs; + ffi::Array new_params; + ffi::Array attrs; for (const auto& p : Downcast(func)->params) { auto struct_info = GetStructInfo(p); if (struct_info->IsInstance()) { continue; } if (struct_info->IsInstance()) { - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); ICHECK(optype_opt.has_value()) << "Can not find attr " << msc_attr::kOptype << " form extern func"; extern_types_.Set(p, optype_opt.value()); continue; } if (const auto* tuple_info = struct_info.as()) { - Array new_fields; + ffi::Array new_fields; for (const auto& i : tuple_info->fields) { if (i->IsInstance()) { new_fields.push_back(i); @@ -88,7 +89,7 @@ class ParamsInliner : public ExprMutator { continue; } const auto& new_func = Downcast(VisitExpr(func)); - Map func_attrs = new_func->attrs->dict; + ffi::Map func_attrs = new_func->attrs->dict; if (attrs.size() > 0) { func_attrs.Set(msc_attr::kOpattrs, attrs); } @@ -105,7 +106,7 @@ class ParamsInliner : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - Array new_args; + ffi::Array new_args; bool has_inline = false; for (const auto& a : call_node->args) { auto struct_info = GetStructInfo(a); @@ -124,8 +125,8 @@ class ParamsInliner : public ExprMutator { has_inline = true; } else if (call_node->op->IsInstance() && a->IsInstance()) { const auto& tuple = Downcast(a); - Array new_fields; - Array new_infos; + ffi::Array new_fields; + ffi::Array new_infos; for (const auto& f : tuple->fields) { if (f->IsInstance()) { @@ -152,7 +153,7 @@ class ParamsInliner : public ExprMutator { ReEmitBinding(binding, builder_->Normalize(new_call)); } else if (const auto* gv_node = call_node->op.as()) { const auto& func_info = Downcast(gv_node->struct_info_); - Array params_info; + ffi::Array params_info; for (const auto& a : new_args) { ICHECK(a->struct_info_.defined()) << "Global func argument without defined struct info " << a; @@ -164,31 +165,31 @@ class ParamsInliner : public ExprMutator { Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); ReEmitBinding(binding, builder_->Normalize(new_call)); } else { - LOG_FATAL << "Unexpected shape consumer " << GetRef(call_node); + LOG_FATAL << "Unexpected shape consumer " << ffi::GetRef(call_node); } } private: IRModule mod_; - String entry_name_; - Map extern_types_; + ffi::String entry_name_; + ffi::Map extern_types_; }; -IRModule InlineParams(IRModule mod, const String& entry_name) { +IRModule InlineParams(IRModule mod, const ffi::String& entry_name) { return ParamsInliner(mod, entry_name).Bind(); } namespace transform { -Pass InlineParams(const String& entry_name) { +Pass InlineParams(const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::InlineParams(m, entry_name); }; return CreateModulePass(pass_func, 0, "InlineParams", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.InlineParams", InlineParams); -}); +} } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index a634b8e9e36a..a4f46dce7fe4 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -57,12 +57,12 @@ LayoutDecision LayoutUtils::InferLayoutDecisionAt(const Expr& expr, } bool LayoutUtils::LayoutInfered(const Expr& expr) { - const String& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + const ffi::String& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); return layout.size() > 0; } bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { - const String& saved_layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + const ffi::String& saved_layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); const auto& sinfo = GetStructInfo(expr); if (sinfo->IsInstance() || sinfo->IsInstance()) { if (!layout.IsLeaf()) { @@ -80,8 +80,8 @@ bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { if (layout.IsLeaf()) { return false; } - String layout_str; - Array nested_layouts = layout.NestedArray(); + ffi::String layout_str; + ffi::Array nested_layouts = layout.NestedArray(); for (size_t i = 0; i < nested_layouts.size(); i++) { if (!nested_layouts[i].IsLeaf()) { return false; @@ -109,7 +109,7 @@ const NLayout LayoutUtils::GetNLayout(const Expr& expr) { return LayoutDecision(SpanUtils::GetAttr(expr->span, msc_attr::kLayout)); } if (sinfo->IsInstance()) { - String layout_str = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + ffi::String layout_str = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); std::vector output_layout; for (const auto& l : StringUtils::Split(layout_str, ",")) { output_layout.push_back(LayoutDecision(l)); @@ -134,7 +134,7 @@ bool LayoutUtils::HasUnknownDimTensor(const NLayout& nlayout) { return find; } -bool LayoutUtils::HasUnknownDimTensor(const Array& args) { +bool LayoutUtils::HasUnknownDimTensor(const ffi::Array& args) { for (const auto& arg : args) { if (IsNestedTensor(arg)) { if (HasUnknownDimTensor(GetNLayout(arg))) { @@ -204,8 +204,8 @@ const LayoutDecision LayoutUtils::ReduceLayout(const LayoutDecision& src_layout, } const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, - const Array& axes) { - String layout_str; + const ffi::Array& axes) { + ffi::String layout_str; for (const auto& a : axes) { layout_str = layout_str + src_layout->layout[a->value].name(); } @@ -214,7 +214,7 @@ const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, const std::vector& axes) { - String layout_str; + ffi::String layout_str; for (const auto& a : axes) { layout_str = layout_str + src_layout->layout[a].name(); } diff --git a/src/contrib/msc/core/transform/layout_utils.h b/src/contrib/msc/core/transform/layout_utils.h index 787c73cc8404..88bcc5703589 100644 --- a/src/contrib/msc/core/transform/layout_utils.h +++ b/src/contrib/msc/core/transform/layout_utils.h @@ -100,7 +100,7 @@ class LayoutUtils { * \brief Check if the args has unknown dim tensor. * \return Whether the args has unknown dim tensor. */ - TVM_DLL static bool HasUnknownDimTensor(const Array& args); + TVM_DLL static bool HasUnknownDimTensor(const ffi::Array& args); /*! * \brief Insert axes to the Layout @@ -120,7 +120,7 @@ class LayoutUtils { * \return The new layout. */ TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, - const Array& axes); + const ffi::Array& axes); TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, const std::vector& axes); diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc b/src/contrib/msc/core/transform/rewrite_utils.cc index 9cbc7c1a8c51..a20e7d5ac3b0 100644 --- a/src/contrib/msc/core/transform/rewrite_utils.cc +++ b/src/contrib/msc/core/transform/rewrite_utils.cc @@ -29,27 +29,27 @@ namespace tvm { namespace contrib { namespace msc { -Var RewriteUtils::ReEmit(BlockBuilder builder, const String& name, const Expr& expr) { +Var RewriteUtils::ReEmit(BlockBuilder builder, const ffi::String& name, const Expr& expr) { expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); return builder->Emit(expr, name); } -Var RewriteUtils::MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, - Attrs attrs) { +Var RewriteUtils::MakeCall(BlockBuilder builder, const ffi::String& name, Expr op, + ffi::Array args, Attrs attrs) { const auto& call = Call(op, args, attrs); return ReEmit(builder, name, call); } -Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, double value, +Expr RewriteUtils::MakeConstant(BlockBuilder builder, const ffi::String& name, double value, const DataType& dtype, size_t ndim) { - const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); + const auto& data = support::FloatImmToTensor(FloatImm(dtype, value)); Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name); const auto& constant = Constant(data, std::nullopt, span); if (ndim == 0) { return constant; } static const Op& reshape_op = Op::Get("relax.reshape"); - Array exp_shape(ndim, Integer(1)); + ffi::Array exp_shape(ndim, Integer(1)); return MakeCall(builder, name + "_exp", reshape_op, {constant, ShapeExpr(exp_shape)}); } diff --git a/src/contrib/msc/core/transform/rewrite_utils.h b/src/contrib/msc/core/transform/rewrite_utils.h index 307581b274ec..b5dc5e4f2a64 100644 --- a/src/contrib/msc/core/transform/rewrite_utils.h +++ b/src/contrib/msc/core/transform/rewrite_utils.h @@ -49,20 +49,20 @@ class RewriteUtils { * \brief Emit call with span name. * \return The emitted var. */ - TVM_DLL static Var ReEmit(BlockBuilder builder, const String& name, const Expr& expr); + TVM_DLL static Var ReEmit(BlockBuilder builder, const ffi::String& name, const Expr& expr); /*! * \brief Make and emit a call binding with span. * \return The emitted var. */ - TVM_DLL static Var MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, - Attrs attrs = Attrs()); + TVM_DLL static Var MakeCall(BlockBuilder builder, const ffi::String& name, Expr op, + ffi::Array args, Attrs attrs = Attrs()); /*! * \brief Make and emit a (shaped)constant with span. * \return The constant/reshape. */ - TVM_DLL static Expr MakeConstant(BlockBuilder builder, const String& name, double value, + TVM_DLL static Expr MakeConstant(BlockBuilder builder, const ffi::String& name, double value, const DataType& dtype, size_t ndim = 0); }; diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index 85819ea58dc6..16ce44cede16 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -41,7 +41,8 @@ using namespace tvm::contrib::msc; */ class ByocNameSetter : public ExprMutator { public: - explicit ByocNameSetter(IRModule ctx_module, const String& target, const String& entry_name) + explicit ByocNameSetter(IRModule ctx_module, const ffi::String& target, + const ffi::String& entry_name) : ExprMutator(ctx_module) { mod_ = ctx_module; target_ = target; @@ -54,9 +55,9 @@ class ByocNameSetter : public ExprMutator { if (gv->name_hint == entry_name_) { continue; } - const auto& name_opt = func->GetAttr(attr::kCodegen); + const auto& name_opt = func->GetAttr(attr::kCodegen); if (name_opt.has_value() && name_opt.value() == target_) { - const String& func_name = target_ + "_" + std::to_string(func_cnt); + const ffi::String& func_name = target_ + "_" + std::to_string(func_cnt); const auto& new_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, WithAttr(new_func, msc_attr::kUnique, func_name)); func_cnt += 1; @@ -66,7 +67,7 @@ class ByocNameSetter : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); ExprMutator::VisitBinding_(binding, val); } @@ -74,7 +75,7 @@ class ByocNameSetter : public ExprMutator { ExprMutator::VisitBinding_(binding, val); if (val->op->IsInstance()) { ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; - const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); + const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); if (name_opt.has_value()) { val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value()); } @@ -83,29 +84,29 @@ class ByocNameSetter : public ExprMutator { private: IRModule mod_; - String target_; - String entry_name_; - Map new_funcs_; - Map local_funcs_; + ffi::String target_; + ffi::String entry_name_; + ffi::Map new_funcs_; + ffi::Map local_funcs_; }; -IRModule SetBYOCAttrs(IRModule mod, const String& target, const String& entry_name) { +IRModule SetBYOCAttrs(IRModule mod, const ffi::String& target, const ffi::String& entry_name) { return ByocNameSetter(mod, target, entry_name).SetNames(); } namespace transform { -Pass SetBYOCAttrs(const String& target, const String& entry_name) { +Pass SetBYOCAttrs(const ffi::String& target, const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::SetBYOCAttrs(m, target, entry_name); }; return CreateModulePass(pass_func, 0, "SetBYOCAttrs", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.SetBYOCAttrs", SetBYOCAttrs); -}); +} } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 59711a99188d..90dd47cb2d36 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -35,9 +35,9 @@ namespace relax { using namespace tvm::contrib::msc; -std::tuple AccumulateMatch(const Array& input_shape, - const Array& output_shape, size_t in_start, - size_t out_start) { +std::tuple AccumulateMatch(const ffi::Array& input_shape, + const ffi::Array& output_shape, + size_t in_start, size_t out_start) { // find input position in_pos and output position out_pos // cumsum(in_shape[in_start:in_pos])==cumsum(out_shape[out_start:out_pos]) std::vector in_shape, out_shape; @@ -84,7 +84,8 @@ std::tuple AccumulateMatch(const Array& input_shape, } std::tuple, std::vector> InferReshapeAxes( - const Array& input_shape, const Array& output_shape, int batch_dim) { + const ffi::Array& input_shape, const ffi::Array& output_shape, + int batch_dim) { std::vector expand_axes, reduce_axes; size_t in_start = 0; while (in_start < input_shape.size()) { @@ -120,11 +121,11 @@ std::tuple, std::vector> InferReshapeAxes( } // Forward and Backward infer -InferLayoutOutput MSCInferLayoutConv(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput MSCInferLayoutConv( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision data_layout, kernel_layout, out_layout; - const String& op_name = Downcast(call->op)->name; + const ffi::String& op_name = Downcast(call->op)->name; if (op_name == "relax.nn.conv1d") { const auto* attrs = call->attrs.as(); data_layout = LayoutDecision(attrs->data_layout); @@ -144,11 +145,11 @@ InferLayoutOutput MSCInferLayoutConv(const Call& call, return InferLayoutOutput({data_layout, kernel_layout}, {out_layout}, Attrs()); } -InferLayoutOutput MSCInferLayoutPool2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput MSCInferLayoutPool2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision layout, out_layout; - const String& op_name = Downcast(call->op)->name; + const ffi::String& op_name = Downcast(call->op)->name; if (op_name == "relax.nn.adaptive_avg_pool2d") { const auto* attrs = call->attrs.as(); layout = LayoutDecision(attrs->layout); @@ -161,9 +162,9 @@ InferLayoutOutput MSCInferLayoutPool2d(const Call& call, return InferLayoutOutput({layout}, {out_layout}, Attrs()); } -InferLayoutOutput MSCInferLayoutResize2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput MSCInferLayoutResize2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto* attrs = call->attrs.as(); const auto& data_layout = LayoutDecision(attrs->layout); const auto& shape_layout = LayoutDecision("O"); @@ -171,10 +172,10 @@ InferLayoutOutput MSCInferLayoutResize2d(const Call& call, } // Forward Infer -InferLayoutOutput ForwardInferLayoutCommon(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - Array input_layouts; +InferLayoutOutput ForwardInferLayoutCommon( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ffi::Array input_layouts; LayoutDecision layout_hint; for (const auto& arg : call->args) { const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); @@ -190,7 +191,7 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, if (sinfo->IsInstance()) { return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); } - Array output_layouts; + ffi::Array output_layouts; if (const auto* tuple_sinfo = sinfo.as()) { for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) { output_layouts.push_back(layout_hint); @@ -200,10 +201,10 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, return InferLayoutOutput(); } -InferLayoutOutput ForwardInferLayoutBroadcast(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - Array input_layouts; +InferLayoutOutput ForwardInferLayoutBroadcast( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ffi::Array input_layouts; LayoutDecision layout_hint; for (const auto& arg : call->args) { const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); @@ -224,15 +225,15 @@ InferLayoutOutput ForwardInferLayoutBroadcast(const Call& call, return InferLayoutOutput(); } -InferLayoutOutput ForwardInferLayoutInplace(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutInplace( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); } -InferLayoutOutput ForwardInferLayoutBinary(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutBinary( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& output = ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); if (!output.defined()) { return output; @@ -256,9 +257,9 @@ InferLayoutOutput ForwardInferLayoutBinary(const Call& call, return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); } -InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutArgMaxMin( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -280,9 +281,9 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutBatchNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); @@ -300,9 +301,9 @@ InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, {{in_layout, g_layout, g_layout}}, Attrs()); } -InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForkwardInferLayoutExpandDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -320,9 +321,9 @@ InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutNormalize( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); @@ -339,9 +340,9 @@ InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, return InferLayoutOutput({in_layout, g_layout, g_layout}, {in_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutMatmul( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& a_shape = ExprUtils::GetShape(call->args[0]); const auto& b_shape = ExprUtils::GetShape(call->args[1]); if (a_shape.size() == 0) { @@ -358,7 +359,7 @@ InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, } } size_t start = a_layout->layout.ndim() - b_shape.size(); - String pre_layout; + ffi::String pre_layout; for (size_t i = start; i < a_layout->layout.ndim() - 2; i++) { pre_layout = pre_layout + a_layout->layout[i].name(); } @@ -366,9 +367,9 @@ InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, return InferLayoutOutput({a_layout, b_layout}, {a_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutPermute(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutPermute( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -388,9 +389,9 @@ InferLayoutOutput ForwardInferLayoutPermute(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutReduceAxis( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -414,9 +415,9 @@ InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutReshape(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutReshape( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -444,9 +445,9 @@ InferLayoutOutput ForwardInferLayoutReshape(const Call& call, return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutSqueeze( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -475,9 +476,9 @@ InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutTake(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutTake( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); const auto& input_shape = ExprUtils::GetShape(call->args[0]); @@ -508,9 +509,9 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call, return InferLayoutOutput(); } -InferLayoutOutput ForwardInferLayoutPlugin(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutPlugin( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { if (!call->args[0]->IsInstance()) { return InferLayoutOutput(); } @@ -626,9 +627,9 @@ TVM_REGISTER_OP("relax.call_dps_packed") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutPlugin); // Backward Infer -InferLayoutOutput BackwardInferLayoutCommon(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutCommon( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { NLayout output_layout = LayoutUtils::InferNLayout(call, var_layout_map); LayoutDecision layout_hint; if (output_layout.IsLeaf()) { @@ -643,7 +644,7 @@ InferLayoutOutput BackwardInferLayoutCommon(const Call& call, if (!layout_hint->layout.defined()) { return InferLayoutOutput(); } - Array input_layouts; + ffi::Array input_layouts; for (const auto& arg : call->args) { const auto& saved_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); if (saved_layout->layout.defined()) { @@ -655,9 +656,9 @@ InferLayoutOutput BackwardInferLayoutCommon(const Call& call, return InferLayoutOutput(input_layouts, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutBinary(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutBinary( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& output = BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); if (!output.defined()) { return output; @@ -681,15 +682,15 @@ InferLayoutOutput BackwardInferLayoutBinary(const Call& call, return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); } -InferLayoutOutput BackwardInferLayoutInplace(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutInplace( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { return BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); } -InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutArgMaxMin( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -708,9 +709,9 @@ InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutBatchNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -720,9 +721,9 @@ InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, {{output_layout, g_layout, g_layout}}, Attrs()); } -InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutExpandDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -740,9 +741,9 @@ InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutNormalize( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -751,9 +752,9 @@ InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, return InferLayoutOutput({output_layout, g_layout, g_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutMatmul( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -763,7 +764,7 @@ InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, return InferLayoutOutput(); } size_t start = output_layout->layout.ndim() - b_shape.size(); - String pre_layout; + ffi::String pre_layout; for (size_t i = start; i < output_layout->layout.ndim() - 2; i++) { pre_layout = pre_layout + output_layout->layout[i].name(); } @@ -771,9 +772,9 @@ InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, return InferLayoutOutput({output_layout, b_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutPermute(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutPermute( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -802,9 +803,9 @@ InferLayoutOutput BackwardInferLayoutPermute(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutReduceAxis( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -825,9 +826,9 @@ InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutReshape(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutReshape( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -855,9 +856,9 @@ InferLayoutOutput BackwardInferLayoutReshape(const Call& call, return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutSqueeze( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -886,9 +887,9 @@ InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutTake(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutTake( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); @@ -912,9 +913,9 @@ InferLayoutOutput BackwardInferLayoutTake(const Call& call, return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutTupleInputs(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutTupleInputs( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -1091,16 +1092,17 @@ class LayoutInfer : public ExprVisitor { continue; } // Infer by op_node - Op op = Downcast(GetRef(op_node)); + Op op = Downcast(ffi::GetRef(op_node)); InferLayoutOutput infered_layout; const auto& msc_infer_map = Op::GetAttrMap("FMSCBackwardInferLayout"); try { if (msc_infer_map.count(op)) { FRelaxInferLayout f = msc_infer_map[op]; - infered_layout = f(call, Map>(), var_layout_map_); - } else { infered_layout = - BackwardInferLayoutCommon(call, Map>(), var_layout_map_); + f(call, ffi::Map>(), var_layout_map_); + } else { + infered_layout = BackwardInferLayoutCommon( + call, ffi::Map>(), var_layout_map_); } } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << err.what(); @@ -1118,7 +1120,7 @@ class LayoutInfer : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { ExprVisitor::VisitBinding_(binding, call_node); - const auto& call = GetRef(call_node); + const auto& call = ffi::GetRef(call_node); if (const auto* v_node = call->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); RecordExpr(binding->var, call); @@ -1143,7 +1145,7 @@ class LayoutInfer : public ExprVisitor { } if (infer_outputs) { // infer layouts - Op op = Downcast(GetRef(op_node)); + Op op = Downcast(ffi::GetRef(op_node)); InferLayoutOutput infered_layout; const auto& msc_infer_map = Op::GetAttrMap("FMSCForwardInferLayout"); const auto& relax_infer_map = Op::GetAttrMap("FRelaxInferLayout"); @@ -1151,14 +1153,16 @@ class LayoutInfer : public ExprVisitor { try { if (msc_infer_map.count(op)) { FRelaxInferLayout f = msc_infer_map[op]; - infered_layout = f(call, Map>(), var_layout_map_); - } else if (!relax_infer_map.count(op)) { infered_layout = - ForwardInferLayoutCommon(call, Map>(), var_layout_map_); + f(call, ffi::Map>(), var_layout_map_); + } else if (!relax_infer_map.count(op)) { + infered_layout = ForwardInferLayoutCommon( + call, ffi::Map>(), var_layout_map_); } if (relax_infer_map.count(op) && !infered_layout.defined()) { FRelaxInferLayout f = relax_infer_map[op]; - infered_layout = f(call, Map>(), var_layout_map_); + infered_layout = + f(call, ffi::Map>(), var_layout_map_); set_inputs = false; } } catch (runtime::InternalError& err) { @@ -1187,14 +1191,14 @@ class LayoutInfer : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, GetRef(val)); + RecordExpr(binding->var, ffi::GetRef(val)); if (IsNestedTensor(binding->var)) { - Array input_layouts; + ffi::Array input_layouts; for (const auto& field : val->fields) { input_layouts.push_back(LayoutUtils::InferLayoutDecision(field, var_layout_map_)); } @@ -1204,15 +1208,15 @@ class LayoutInfer : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, GetRef(val)); - const auto& out_layout = LayoutUtils::InferLayoutDecisionAt(GetRef(val)->tuple, - var_layout_map_, val->index); + RecordExpr(binding->var, ffi::GetRef(val)); + const auto& out_layout = LayoutUtils::InferLayoutDecisionAt( + ffi::GetRef(val)->tuple, var_layout_map_, val->index); SetExprLayout(binding->var, out_layout); } void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) final { ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, GetRef(val)); + RecordExpr(binding->var, ffi::GetRef(val)); SetExprLayout(binding->var, LayoutDecision("O")); } @@ -1252,7 +1256,7 @@ class LayoutInfer : public ExprVisitor { } } - void SetInputLayouts(const Call& call, const Array& input_layouts) { + void SetInputLayouts(const Call& call, const ffi::Array& input_layouts) { if (input_layouts.size() == call->args.size()) { for (size_t i = 0; i < input_layouts.size(); i++) { SetExprLayout(call->args[i], input_layouts[i]); @@ -1309,10 +1313,10 @@ class LayoutInfer : public ExprVisitor { IRModule ref_module_; bool infered_; - Map var_map_; - Array ordered_exprs_; + ffi::Map var_map_; + ffi::Array ordered_exprs_; std::unordered_map var_layout_map_; - Map local_funcs_; + ffi::Map local_funcs_; }; // class LayoutInfer class LayoutChecker : public ExprVisitor { @@ -1326,14 +1330,14 @@ class LayoutChecker : public ExprVisitor { void VisitExpr_(const CallNode* call) final { ExprVisitor::VisitExpr_(call); - if (!LayoutUtils::LayoutInfered(GetRef(call))) { + if (!LayoutUtils::LayoutInfered(ffi::GetRef(call))) { missing_num_++; } } void VisitExpr_(const ConstantNode* cn) final { ExprVisitor::VisitExpr_(cn); - if (!LayoutUtils::LayoutInfered(GetRef(cn))) { + if (!LayoutUtils::LayoutInfered(ffi::GetRef(cn))) { missing_num_++; } } @@ -1352,7 +1356,7 @@ void SetExprLayout(const IRModule& ref_module, const Expr& func, bool allow_miss namespace transform { -Pass SetExprLayout(bool allow_missing, const String& entry_name) { +Pass SetExprLayout(bool allow_missing, const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { relax::SetExprLayout(m, m->Lookup(entry_name), allow_missing); return m; @@ -1360,10 +1364,10 @@ Pass SetExprLayout(bool allow_missing, const String& entry_name) { return CreateModulePass(pass_func, 0, "SetExprLayout", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.SetExprLayout", SetExprLayout); -}); +} } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 14ea3ccfec7b..d0231afedba5 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -36,10 +36,10 @@ namespace relax { class FuncNameGetter : public ExprVisitor { public: - explicit FuncNameGetter(const Array& arg_names) : arg_names_(arg_names) {} + explicit FuncNameGetter(const ffi::Array& arg_names) : arg_names_(arg_names) {} - /*! \brief Get the attributes from prim value as Map*/ - String HintName(const Expr& expr) { + /*! \brief Get the attributes from prim value as ffi::Map*/ + ffi::String HintName(const Expr& expr) { name_ = ""; ExprVisitor::VisitExpr(expr); return name_; @@ -73,8 +73,8 @@ class FuncNameGetter : public ExprVisitor { } private: - String name_; - Array arg_names_; + ffi::String name_; + ffi::Array arg_names_; }; /*! @@ -82,16 +82,16 @@ class FuncNameGetter : public ExprVisitor { */ class RelaxExprNameSetter : public ExprVisitor { public: - explicit RelaxExprNameSetter(const IRModule& ref_module, const String& target, - const Map& var_names) + explicit RelaxExprNameSetter(const IRModule& ref_module, const ffi::String& target, + const ffi::Map& var_names) : ref_module_(ref_module), target_{target}, var_names_{var_names} {} void VisitBindingBlock(const BindingBlock& block) final { - String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); + ffi::String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); if (block_name.size() == 0) { block_name = "block"; } - const String& prefix = StringUtils::Join(block_stack_, "."); + const ffi::String& prefix = StringUtils::Join(block_stack_, "."); if (setted_blocks_.count(prefix + "." + block_name)) { int cnt = 1; while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { @@ -101,7 +101,7 @@ class RelaxExprNameSetter : public ExprVisitor { } setted_blocks_.insert(prefix + "." + block_name); block_stack_.push_back(block_name); - const String& unique_name = StringUtils::Join(block_stack_, "."); + const ffi::String& unique_name = StringUtils::Join(block_stack_, "."); block->span = SpanUtils::SetAttr(block->span, msc_attr::kName, unique_name); ExprVisitor::VisitBindingBlock(block); block_stack_.pop_back(); @@ -109,16 +109,16 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitExpr_(const ConstantNode* val) { ExprVisitor::VisitExpr_(val); - const String& unique_name = GetUniqueName(GetRef(val), "const"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "const"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } - expr_names_.Set(GetRef(val), unique_name); + expr_names_.Set(ffi::GetRef(val), unique_name); } void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& unique_name = GetUniqueName(GetRef(val), "const"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "const"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } @@ -127,7 +127,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& unique_name = GetUniqueName(GetRef(val), "shape"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "shape"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } @@ -136,7 +136,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& unique_name = GetUniqueName(GetRef(val), "tuple"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "tuple"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } @@ -145,7 +145,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { ExprVisitor::VisitBinding_(binding, val); - String unique_name; + ffi::String unique_name; if (expr_names_.count(val->tuple)) { unique_name = expr_names_[val->tuple] + "." + std::to_string(val->index); } else if (const auto* v_node = val->tuple.as()) { @@ -159,15 +159,15 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { ExprVisitor::VisitBinding_(binding, val); - const auto& name_opt = val->GetAttr(attr::kComposite); + const auto& name_opt = val->GetAttr(attr::kComposite); if (name_opt.has_value()) { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { ExprVisitor::VisitBinding_(binding, val); - String name_hint, optype; + ffi::String name_hint, optype; bool use_unique = true; if (var_names_.count(binding->var->name_hint())) { name_hint = var_names_[binding->var->name_hint()]; @@ -177,7 +177,7 @@ class RelaxExprNameSetter : public ExprVisitor { const auto& func = Downcast(val->args[0]); name_hint = func->global_symbol; optype = func->global_symbol; - const String& input_name = GetUniqueName(val->args[1], "plugin_inputs"); + const ffi::String& input_name = GetUniqueName(val->args[1], "plugin_inputs"); if (input_name != SpanUtils::GetAttr(val->args[1]->span, msc_attr::kName)) { val->args[1]->span = SpanUtils::SetAttr(val->args[1]->span, msc_attr::kName, input_name); } @@ -190,27 +190,28 @@ class RelaxExprNameSetter : public ExprVisitor { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); ExprVisitor::VisitExpr(func); optype = GetFuncType(func); - name_hint = GetFuncName(GetRef(val), func); + name_hint = GetFuncName(ffi::GetRef(val), func); use_unique = false; } else if (local_funcs_.count(val->op)) { ExprVisitor::VisitExpr(local_funcs_[val->op]); optype = GetFuncType(local_funcs_[val->op]); - name_hint = GetFuncName(GetRef(val), local_funcs_[val->op]); + name_hint = GetFuncName(ffi::GetRef(val), local_funcs_[val->op]); use_unique = false; } if (name_hint.size() > 0) { // set name - const String& unique_name = - use_unique ? GetUniqueName(GetRef(val), name_hint) : name_hint; + const ffi::String& unique_name = + use_unique ? GetUniqueName(ffi::GetRef(val), name_hint) : name_hint; if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } // set constant consumer && shared_ref - Array input_types; + ffi::Array input_types; try { input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true); } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(val) << " : " << err.what(); + LOG(WARNING) << "Failed to GetInputTypes for " << ffi::GetRef(val) << " : " + << err.what(); throw err; } for (size_t i = 0; i < input_types.size(); i++) { @@ -218,7 +219,7 @@ class RelaxExprNameSetter : public ExprVisitor { continue; } if (const auto* c_node = val->args[i].as()) { - const String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); + const ffi::String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); if (constant_consumers_.count(const_name)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kSharedRef, constant_consumers_[const_name]); @@ -232,8 +233,8 @@ class RelaxExprNameSetter : public ExprVisitor { } private: - const String GetUniqueName(const Expr& expr, const String& name_hint) { - String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); + const ffi::String GetUniqueName(const Expr& expr, const ffi::String& name_hint) { + ffi::String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); if (expr_name.size() == 0) { expr_name = name_hint; } @@ -256,10 +257,10 @@ class RelaxExprNameSetter : public ExprVisitor { return expr_name; } - const String GetFuncType(const Function& func) { - String optype; - const auto& comp_opt = func->GetAttr(attr::kComposite); - const auto& code_opt = func->GetAttr(attr::kCodegen); + const ffi::String GetFuncType(const Function& func) { + ffi::String optype; + const auto& comp_opt = func->GetAttr(attr::kComposite); + const auto& code_opt = func->GetAttr(attr::kCodegen); if (comp_opt.has_value()) { optype = comp_opt.value(); } else if (code_opt.has_value()) { @@ -273,15 +274,15 @@ class RelaxExprNameSetter : public ExprVisitor { return optype; } - const String GetFuncName(const Call& call, const Function& func) { - String name; + const ffi::String GetFuncName(const Call& call, const Function& func) { + ffi::String name; // get from unique - const auto& name_opt = func->GetAttr(msc_attr::kUnique); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); if (name_opt.has_value()) { return name_opt.value(); } // get from exprs in the func - Array arg_names; + ffi::Array arg_names; for (const auto& a : call->args) { arg_names.push_back(expr_names_.count(a) ? expr_names_[a] : ""); } @@ -298,26 +299,26 @@ class RelaxExprNameSetter : public ExprVisitor { return GetUniqueName(call, name); } - Map setted_names_; - Map constant_consumers_; - std::set setted_blocks_; - Array block_stack_; - Map expr_names_; - Map local_funcs_; + ffi::Map setted_names_; + ffi::Map constant_consumers_; + std::set setted_blocks_; + ffi::Array block_stack_; + ffi::Map expr_names_; + ffi::Map local_funcs_; IRModule ref_module_; - String target_; - Map var_names_; + ffi::String target_; + ffi::Map var_names_; }; // class ExprNameSetter -void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const String& target, - const Map& var_names) { +void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const ffi::String& target, + const ffi::Map& var_names) { RelaxExprNameSetter(ref_module, target, var_names).VisitExpr(e); } namespace transform { -Pass SetRelaxExprName(const String& entry_name, const String& target, - const Map& var_names) { +Pass SetRelaxExprName(const ffi::String& entry_name, const ffi::String& target, + const ffi::Map& var_names) { auto pass_func = [=](IRModule m, PassContext pc) { relax::SetRelaxExprName(m, m->Lookup(entry_name), target, var_names); return m; @@ -325,10 +326,10 @@ Pass SetRelaxExprName(const String& entry_name, const String& target, return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.SetRelaxExprName", SetRelaxExprName); -}); +} } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index f4a79602f506..bc70c809af7c 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -69,8 +69,8 @@ int CommonUtils::CompareVersion(const std::vector& given_version, return 0; } -int CommonUtils::CompareVersion(const Array& given_version, - const Array& target_version) { +int CommonUtils::CompareVersion(const ffi::Array& given_version, + const ffi::Array& target_version) { std::vector int_given_version; std::vector int_target_version; for (const auto& v : given_version) { @@ -82,7 +82,7 @@ int CommonUtils::CompareVersion(const Array& given_version, return CompareVersion(int_given_version, int_target_version); } -const String CommonUtils::ToAttrKey(const String& key) { +const ffi::String CommonUtils::ToAttrKey(const ffi::String& key) { if (key == "name") { return msc_attr::kName; } @@ -111,7 +111,7 @@ const String CommonUtils::ToAttrKey(const String& key) { TVM_FFI_UNREACHABLE(); } -bool StringUtils::Contains(const String& src_string, const String& sub_string) { +bool StringUtils::Contains(const ffi::String& src_string, const ffi::String& sub_string) { if (src_string.size() == 0) { return false; } @@ -125,7 +125,7 @@ bool StringUtils::Contains(const String& src_string, const String& sub_string) { return pos >= 0; } -bool StringUtils::StartsWith(const String& src_string, const String& sub_string) { +bool StringUtils::StartsWith(const ffi::String& src_string, const ffi::String& sub_string) { if (src_string.size() == 0) { return false; } @@ -138,7 +138,7 @@ bool StringUtils::StartsWith(const String& src_string, const String& sub_string) return pos == 0; } -bool StringUtils::EndsWith(const String& src_string, const String& sub_string) { +bool StringUtils::EndsWith(const ffi::String& src_string, const ffi::String& sub_string) { if (src_string.size() == 0) { return false; } @@ -154,8 +154,9 @@ bool StringUtils::EndsWith(const String& src_string, const String& sub_string) { return static_cast(pos) == src_cstring.size() - sub_cstring.size(); } -const Array StringUtils::Split(const String& src_string, const String& sep) { - Array sub_strings; +const ffi::Array StringUtils::Split(const ffi::String& src_string, + const ffi::String& sep) { + ffi::Array sub_strings; if (src_string.size() == 0) { return sub_strings; } @@ -175,26 +176,27 @@ const Array StringUtils::Split(const String& src_string, const String& s return sub_strings; } -const String StringUtils::Join(const Array& sub_strings, const String& joint) { - String join_str = ""; +const ffi::String StringUtils::Join(const ffi::Array& sub_strings, + const ffi::String& joint) { + ffi::String join_str = ""; for (size_t i = 0; i < sub_strings.size(); i++) { join_str = join_str + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : joint); } return join_str; } -const String StringUtils::Join(const std::vector& sub_strings, - const std::string& joint) { - Array new_strings; +const ffi::String StringUtils::Join(const std::vector& sub_strings, + const std::string& joint) { + ffi::Array new_strings; for (const auto& s : sub_strings) { new_strings.push_back(s); } return Join(new_strings, joint); } -const String StringUtils::Replace(const String& src_string, const String& old_str, - const String& new_str) { - String new_string; +const ffi::String StringUtils::Replace(const ffi::String& src_string, const ffi::String& old_str, + const ffi::String& new_str) { + ffi::String new_string; const auto& sub_strings = Split(src_string, old_str); for (size_t i = 0; i < sub_strings.size(); i++) { new_string = new_string + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : new_str); @@ -202,10 +204,11 @@ const String StringUtils::Replace(const String& src_string, const String& old_st return new_string; } -const std::tuple StringUtils::SplitOnce(const String& src_string, const String& sep, - bool from_left) { +const std::tuple StringUtils::SplitOnce(const ffi::String& src_string, + const ffi::String& sep, + bool from_left) { if (src_string.size() == 0) { - return std::make_tuple(String(), String()); + return std::make_tuple(ffi::String(), ffi::String()); } std::string src_cstring = src_string; const std::string& csep = sep; @@ -213,17 +216,18 @@ const std::tuple StringUtils::SplitOnce(const String& src_string if (pos >= 0) { return std::make_tuple(src_cstring.substr(0, pos), src_cstring.substr(pos + csep.size())); } - return std::make_tuple(src_string, String()); + return std::make_tuple(src_string, ffi::String()); } -const Array StringUtils::GetClosures(const String& src_string, const String& left, - const String& right) { - Array tokens; +const ffi::Array StringUtils::GetClosures(const ffi::String& src_string, + const ffi::String& left, + const ffi::String& right) { + ffi::Array tokens; if (src_string.size() == 0) { return tokens; } - String token = "start"; - String left_str = src_string; + ffi::String token = "start"; + ffi::String left_str = src_string; while (token.size() > 0) { std::tie(token, left_str) = StringUtils::SplitOnce(left_str, left); if (left_str.size() > 0) { @@ -238,35 +242,36 @@ const Array StringUtils::GetClosures(const String& src_string, const Str return tokens; } -const String StringUtils::GetClosureOnce(const String& src_string, const String& left, - const String& right, bool from_left) { +const ffi::String StringUtils::GetClosureOnce(const ffi::String& src_string, + const ffi::String& left, const ffi::String& right, + bool from_left) { if (src_string.size() == 0) { return ""; } - String val = std::get<1>(SplitOnce(src_string, left, from_left)); + ffi::String val = std::get<1>(SplitOnce(src_string, left, from_left)); if (val.size() > 0) { val = std::get<0>(StringUtils::SplitOnce(val, right, from_left)); } return val; } -const String StringUtils::Upper(const String& src_string) { +const ffi::String StringUtils::Upper(const ffi::String& src_string) { std::string str = std::string(src_string); std::transform(str.begin(), str.end(), str.begin(), ::toupper); return str; } -const String StringUtils::Lower(const String& src_string) { +const ffi::String StringUtils::Lower(const ffi::String& src_string) { std::string str = std::string(src_string); std::transform(str.begin(), str.end(), str.begin(), ::tolower); return str; } -const String StringUtils::ToString(const ffi::Any& obj) { - String obj_string; +const ffi::String StringUtils::ToString(const ffi::Any& obj) { + ffi::String obj_string; if (obj == nullptr) { obj_string = ""; - } else if (auto opt_str = obj.as()) { + } else if (auto opt_str = obj.as()) { obj_string = *opt_str; } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); @@ -291,7 +296,8 @@ const String StringUtils::ToString(const ffi::Any& obj) { return obj_string; } -bool ArrayUtils::CompareArrays(const Array& left, const Array& right, int size) { +bool ArrayUtils::CompareArrays(const ffi::Array& left, + const ffi::Array& right, int size) { if (left.size() == right.size() && left.size() == 0) { return true; } @@ -314,7 +320,7 @@ bool ArrayUtils::CompareArrays(const Array& left, const Array& r return true; } -PrimExpr ArrayUtils::Accumulate(const Array& array, int pos) { +PrimExpr ArrayUtils::Accumulate(const ffi::Array& array, int pos) { size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos; PrimExpr accumulate = Integer(1); for (size_t i = 0; i < t_pos; i++) { @@ -323,7 +329,7 @@ PrimExpr ArrayUtils::Accumulate(const Array& array, int pos) { return accumulate; } -bool ArrayUtils::Broadcastable(const Array& lhs, const Array& rhs) { +bool ArrayUtils::Broadcastable(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -345,16 +351,16 @@ bool ArrayUtils::Broadcastable(const Array& lhs, const Array return true; } -const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& value) { +const Span SpanUtils::SetAttr(const Span& span, const ffi::String& key, const ffi::String& value) { if (value.size() == 0) { return span; } - String new_source; - Array tokens{"<" + key + ">", ""}; + ffi::String new_source; + ffi::Array tokens{"<" + key + ">", ""}; if (span.defined() && span->source_name.defined()) { - const String& source_str = span->source_name->name; - String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0])); - String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1])); + const ffi::String& source_str = span->source_name->name; + ffi::String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0])); + ffi::String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1])); if (StringUtils::Contains(source_str, tokens[0]) && StringUtils::Contains(source_str, tokens[1])) { new_source = left + tokens[0] + value + tokens[1] + right; @@ -371,29 +377,29 @@ const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& return Span(SourceName::Get(new_source), 0, 0, 0, 0); } -String SpanUtils::GetAttr(const Span& span, const String& key) { +ffi::String SpanUtils::GetAttr(const Span& span, const ffi::String& key) { if (span.defined() && span->source_name.defined()) { - Array tokens{"<" + key + ">", ""}; + ffi::Array tokens{"<" + key + ">", ""}; return StringUtils::GetClosureOnce(span->source_name->name, tokens[0], tokens[1]); } return ""; } -const Map SpanUtils::GetAttrs(const Span& span) { - Map attrs; +const ffi::Map SpanUtils::GetAttrs(const Span& span) { + ffi::Map attrs; for (const auto& key : StringUtils::GetClosures(span->source_name->name, "")) { attrs.Set(key, GetAttr(span, key)); } return attrs; } -const Span SpanUtils::CreateWithAttr(const String& key, const String& value) { +const Span SpanUtils::CreateWithAttr(const ffi::String& key, const ffi::String& value) { return SetAttr(Span(), key, value); } -const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs_num, - bool as_relax) { - Array input_types; +const ffi::Array ExprUtils::GetInputTypes(const ffi::String& optype, size_t inputs_num, + bool as_relax) { + ffi::Array input_types; if (as_relax && (optype == "broadcast_to" || optype == "reshape")) { input_types.push_back("input"); if (inputs_num > 1) { @@ -490,12 +496,12 @@ const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs return input_types; } -const Array ExprUtils::GetInputTypes(const Call& call) { - const String& optype = StringUtils::Replace(Downcast(call->op)->name, "relax.", ""); +const ffi::Array ExprUtils::GetInputTypes(const Call& call) { + const ffi::String& optype = StringUtils::Replace(Downcast(call->op)->name, "relax.", ""); return GetInputTypes(optype, call->args.size(), true); } -const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { +const ffi::String ExprUtils::GetSpanName(const Expr& expr, const ffi::String& suffix) { const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName); if (suffix.size() > 0) { return name + "_" + suffix; @@ -503,13 +509,13 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { return name; } -const Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as_int) { +const ffi::Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as_int) { const auto& shape_opt = sinfo->GetShape(); if (!shape_opt.defined()) { - return Array(); + return ffi::Array(); } if (as_int) { - Array shape; + ffi::Array shape; for (const auto& s : shape_opt.value()) { shape.push_back(s->IsInstance() ? s : Integer(-1)); } @@ -518,7 +524,7 @@ const Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as return shape_opt.value(); } -const Array ExprUtils::GetShape(const Expr& expr, bool as_int) { +const ffi::Array ExprUtils::GetShape(const Expr& expr, bool as_int) { return GetShape(Downcast(GetStructInfo(expr)), as_int); } @@ -526,27 +532,27 @@ const DataType ExprUtils::GetDataType(const Expr& expr) { return Downcast(GetStructInfo(expr))->dtype; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.SpanGetAttr", SpanUtils::GetAttr) .def("msc.core.SpanGetAttrs", SpanUtils::GetAttrs) .def("msc.core.SpanCreateWithAttr", - [](const String& key, const String& value) -> Span { + [](const ffi::String& key, const ffi::String& value) -> Span { return SpanUtils::CreateWithAttr(key, value); }) .def("msc.core.SpanSetAttr", - [](const Span& span, const String& key, const String& value) -> Span { + [](const Span& span, const ffi::String& key, const ffi::String& value) -> Span { return SpanUtils::SetAttr(span, key, value); }) - .def( - "msc.core.CompareVersion", - [](const Array& given_version, const Array& target_version) -> Integer { - return Integer(CommonUtils::CompareVersion(given_version, target_version)); - }) + .def("msc.core.CompareVersion", + [](const ffi::Array& given_version, + const ffi::Array& target_version) -> Integer { + return Integer(CommonUtils::CompareVersion(given_version, target_version)); + }) .def("msc.core.ToAttrKey", - [](const String& key) -> String { return CommonUtils::ToAttrKey(key); }); -}); + [](const ffi::String& key) -> ffi::String { return CommonUtils::ToAttrKey(key); }); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index aeb7f9eb88fd..a0732d5848ac 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -82,13 +82,13 @@ class CommonUtils { */ TVM_DLL static int CompareVersion(const std::vector& given_version, const std::vector& target_version); - TVM_DLL static int CompareVersion(const Array& given_version, - const Array& target_version); + TVM_DLL static int CompareVersion(const ffi::Array& given_version, + const ffi::Array& target_version); /*! * \brief Get attr key. * \return The attr key. */ - TVM_DLL static const String ToAttrKey(const String& key); + TVM_DLL static const ffi::String ToAttrKey(const ffi::String& key); }; /*! @@ -97,83 +97,87 @@ class CommonUtils { class StringUtils { public: /*! - * \brief Check if the String contains a substring. + * \brief Check if the ffi::String contains a substring. * \return Whether substring is contained. */ - TVM_DLL static bool Contains(const String& src_string, const String& sub_string); + TVM_DLL static bool Contains(const ffi::String& src_string, const ffi::String& sub_string); /*! - * \brief Check if the String starts with a substring. + * \brief Check if the ffi::String starts with a substring. * \return Whether string starts with substring. */ - TVM_DLL static bool StartsWith(const String& src_string, const String& sub_string); + TVM_DLL static bool StartsWith(const ffi::String& src_string, const ffi::String& sub_string); /*! - * \brief Check if the String ens with a substring. + * \brief Check if the ffi::String ens with a substring. * \return Whether string endswith substring. */ - TVM_DLL static bool EndsWith(const String& src_string, const String& sub_string); + TVM_DLL static bool EndsWith(const ffi::String& src_string, const ffi::String& sub_string); /*! - * \brief Split the String into sub Strings. + * \brief Split the ffi::String into sub Strings. * \return The SubStrings. */ - TVM_DLL static const Array Split(const String& src_string, const String& sep); + TVM_DLL static const ffi::Array Split(const ffi::String& src_string, + const ffi::String& sep); /*! * \brief Join the SubStrings into String. * \return The String. */ - TVM_DLL static const String Join(const Array& sub_strings, const String& joint); - TVM_DLL static const String Join(const std::vector& sub_strings, - const std::string& joint); + TVM_DLL static const ffi::String Join(const ffi::Array& sub_strings, + const ffi::String& joint); + TVM_DLL static const ffi::String Join(const std::vector& sub_strings, + const std::string& joint); /*! * \brief Replace the substring old to new in String. * \return The replaced String. */ - TVM_DLL static const String Replace(const String& src_string, const String& old_str, - const String& new_str); + TVM_DLL static const ffi::String Replace(const ffi::String& src_string, + const ffi::String& old_str, const ffi::String& new_str); /*! - * \brief Split the String into two sub Strings, only split by the frist seq. + * \brief Split the ffi::String into two sub Strings, only split by the frist seq. * \return The SubStrings. */ - TVM_DLL static const std::tuple SplitOnce(const String& src_string, - const String& sep, - bool from_left = false); + TVM_DLL static const std::tuple SplitOnce(const ffi::String& src_string, + const ffi::String& sep, + bool from_left = false); /*! * \brief Get the tokens between left and right. * \return The Tokens. */ - TVM_DLL static const Array GetClosures(const String& src_string, const String& left, - const String& right); + TVM_DLL static const ffi::Array GetClosures(const ffi::String& src_string, + const ffi::String& left, + const ffi::String& right); /*! * \brief Get the first token between left and right. * \return The Token. */ - TVM_DLL static const String GetClosureOnce(const String& src_string, const String& left, - const String& right, bool from_left = true); + TVM_DLL static const ffi::String GetClosureOnce(const ffi::String& src_string, + const ffi::String& left, const ffi::String& right, + bool from_left = true); /*! * \brief Change string to upper. * \return The String. */ - TVM_DLL static const String Upper(const String& src_string); + TVM_DLL static const ffi::String Upper(const ffi::String& src_string); /*! * \brief Change string to lower. * \return The String. */ - TVM_DLL static const String Lower(const String& src_string); + TVM_DLL static const ffi::String Lower(const ffi::String& src_string); /*! * \brief Change Object to String. * \return The String. */ - TVM_DLL static const String ToString(const ffi::Any& obj); + TVM_DLL static const ffi::String ToString(const ffi::Any& obj); }; /*! @@ -186,9 +190,9 @@ class ArrayUtils { * \return The replaced Array. */ template - TVM_DLL static const Array Replace(const Array& src_array, const T& old_ele, - const T& new_ele) { - Array new_array; + TVM_DLL static const ffi::Array Replace(const ffi::Array& src_array, const T& old_ele, + const T& new_ele) { + ffi::Array new_array; for (const auto& a : src_array) { if (a == old_ele) { new_array.push_back(new_ele); @@ -218,8 +222,8 @@ class ArrayUtils { * \return The downcasted array */ template - TVM_DLL static const Array Cast(const Array& src_array) { - Array new_array; + TVM_DLL static const ffi::Array Cast(const ffi::Array& src_array) { + ffi::Array new_array; for (const auto& s : src_array) { new_array.push_back(Downcast(s)); } @@ -231,21 +235,21 @@ class ArrayUtils { * \return The producted array */ template - TVM_DLL static const Array> Product(const Array>& arrays) { - Array> p_arrays; + TVM_DLL static const ffi::Array> Product(const ffi::Array>& arrays) { + ffi::Array> p_arrays; if (arrays.size() == 1) { for (const auto& a : arrays[0]) { - p_arrays.push_back(Array{a}); + p_arrays.push_back(ffi::Array{a}); } return p_arrays; } - Array> sub_arrays; + ffi::Array> sub_arrays; for (size_t i = 0; i < arrays.size() - 1; i++) { sub_arrays.push_back(arrays[i]); } for (const auto& p_array : Product(sub_arrays)) { for (const auto& a : arrays[arrays.size() - 1]) { - Array sub_array = p_array; + ffi::Array sub_array = p_array; sub_array.push_back(a); p_arrays.push_back(sub_array); } @@ -254,22 +258,23 @@ class ArrayUtils { } /*! - * \brief Compare String arrays. + * \brief Compare ffi::String arrays. * \return Whether two array are same. */ - TVM_DLL static bool CompareArrays(const Array& left, const Array& right, - int size = -1); + TVM_DLL static bool CompareArrays(const ffi::Array& left, + const ffi::Array& right, int size = -1); /*! * \brief Accumulate array. * \return The accumulate result */ - TVM_DLL static PrimExpr Accumulate(const Array& array, int pos = -1); + TVM_DLL static PrimExpr Accumulate(const ffi::Array& array, int pos = -1); /*! * \brief Check if lhs array is broadcastable to rhs. * \return broadcastable */ - TVM_DLL static bool Broadcastable(const Array& lhs, const Array& rhs); + TVM_DLL static bool Broadcastable(const ffi::Array& lhs, + const ffi::Array& rhs); }; /*! @@ -281,25 +286,26 @@ class SpanUtils { * \brief Set value to the Span. * \return The new Span. */ - TVM_DLL static const Span SetAttr(const Span& span, const String& key, const String& value); + TVM_DLL static const Span SetAttr(const Span& span, const ffi::String& key, + const ffi::String& value); /*! * \brief Get the value in value from the Span. * \return The value String. */ - TVM_DLL static String GetAttr(const Span& span, const String& key); + TVM_DLL static ffi::String GetAttr(const Span& span, const ffi::String& key); /*! * \brief Get all the key:value in format value from the Span. * \return The Attrs Map. */ - TVM_DLL static const Map GetAttrs(const Span& span); + TVM_DLL static const ffi::Map GetAttrs(const Span& span); /*! * \brief Create a span with value. * \return The created Span. */ - TVM_DLL static const Span CreateWithAttr(const String& key, const String& value); + TVM_DLL static const Span CreateWithAttr(const ffi::String& key, const ffi::String& value); }; /*! @@ -311,21 +317,21 @@ class ExprUtils { * \brief Get the input types of call. * \return The input types. */ - TVM_DLL static const Array GetInputTypes(const String& optype, size_t inputs_num, - bool as_relax); + TVM_DLL static const ffi::Array GetInputTypes(const ffi::String& optype, + size_t inputs_num, bool as_relax); /*! * \brief Get the input types of call. * \return The input types. */ - TVM_DLL static const Array GetInputTypes(const Call& call); + TVM_DLL static const ffi::Array GetInputTypes(const Call& call); /*! * \brief Get the scalar value of ndarray. * \return The scalar value. */ template - TVM_DLL static const T GetScalar(const runtime::NDArray& array, size_t i = 0) { + TVM_DLL static const T GetScalar(const runtime::Tensor& array, size_t i = 0) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { return T(reinterpret_cast(array->data)[i]); @@ -371,14 +377,15 @@ class ExprUtils { * \brief Get name in span. * \return The name. */ - TVM_DLL static const String GetSpanName(const Expr& expr, const String& suffix = ""); + TVM_DLL static const ffi::String GetSpanName(const Expr& expr, const ffi::String& suffix = ""); /*! * \brief Get shape of expr. * \return The shape. */ - TVM_DLL static const Array GetShape(const TensorStructInfo& sinfo, bool as_int = true); - TVM_DLL static const Array GetShape(const Expr& expr, bool as_int = true); + TVM_DLL static const ffi::Array GetShape(const TensorStructInfo& sinfo, + bool as_int = true); + TVM_DLL static const ffi::Array GetShape(const Expr& expr, bool as_int = true); /*! * \brief Get dtype of expr. diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 1a5bdfeacb33..30488fcc9af0 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -40,7 +40,7 @@ void TensorflowCodeGen::CodeGenHelper() { .func_arg("name", "str") .func_arg("shape", "List[int]") .func_arg("dtype", "str") - .func_arg("weights", "Dict[str, tvm.nd.array]") + .func_arg("weights", "Dict[str, tvm.runtime.Tensor]") .func_start() .cond_if("name in weights") .func_call("tf_v1.get_variable", "var") @@ -63,7 +63,7 @@ void TensorflowCodeGen::CodeGenGraph() { const auto& pair = graph()->FindProducerAndIdx(i); stack_.func_arg(IdxOutputBase(pair.first, pair.second), "tf_v1.Tensor"); } - stack_.func_arg("weights", "Dict[str, tvm.nd.array]").func_start(); + stack_.func_arg("weights", "Dict[str, tvm.runtime.Tensor]").func_start(); // define weights stack_.comment("Define the weights"); for (const auto& n : graph()->node_names) { @@ -88,7 +88,7 @@ void TensorflowCodeGen::CodeGenGraph() { } CodeGenNode(node, config()->use_tools); } - Array idx_outputs; + ffi::Array idx_outputs; for (const auto& o : graph()->GetOutputs()) { const auto& pair = graph()->FindProducerAndIdx(o); idx_outputs.push_back(IdxOutputBase(pair.first, pair.second)); @@ -139,7 +139,7 @@ void TensorflowCodeGen::CodeGenInference() { .scope_end(); } -const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTFV1OpCodes(); auto it = ops_map->find(node->optype); ICHECK(it != ops_map->end()) << "Unsupported tensorflow op(" << node->optype << "): " << node; @@ -152,16 +152,16 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.framework.tensorflow.GetTensorflowSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); }); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorflow/codegen.h b/src/contrib/msc/framework/tensorflow/codegen.h index af2579980a39..5052c11004d2 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.h +++ b/src/contrib/msc/framework/tensorflow/codegen.h @@ -59,10 +59,10 @@ class TensorflowCodeGen : public PyCodeGen GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get tensor type of the framework*/ - const String TensorType() const final { return "tf_v1.Tensor"; } + const ffi::String TensorType() const final { return "tf_v1.Tensor"; } }; } // namespace msc diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc index 570088ee35c2..d47021d84da5 100644 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc +++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc @@ -29,17 +29,16 @@ namespace tvm { namespace contrib { namespace msc { -const Array TFV1OpCode::GetDocs() { +const ffi::Array TFV1OpCode::GetDocs() { stack_.Config(this); CodeGenBuild(); return stack_.GetDocs(); } -const std::pair> TFV1OpCode::GetPadding(const String& strides_key, - const String& kernel_key, - const String& padding_key) { - String pad_mod = ""; - Array padding; +const std::pair> TFV1OpCode::GetPadding( + const ffi::String& strides_key, const ffi::String& kernel_key, const ffi::String& padding_key) { + ffi::String pad_mod = ""; + ffi::Array padding; std::vector kernel_size; if (node()->optype == "nn.conv2d" || node()->optype == "msc.conv2d_bias") { const auto& weight = node()->WeightAt("weight"); @@ -98,7 +97,7 @@ const std::pair> TFV1OpCode::GetPadding(const String& stri #define TFV1_OP_CODEGEN_METHODS(TypeName) \ public: \ - TypeName(const String& func_name) : TFV1OpCode(func_name) {} + TypeName(const ffi::String& func_name) : TFV1OpCode(func_name) {} class TFV1ArgMaxMinCodeGen : public TFV1OpCode { TFV1_OP_CODEGEN_METHODS(TFV1ArgMaxMinCodeGen) @@ -128,23 +127,25 @@ class TFV1AstypeCodeGen : public TFV1OpCode { class TFV1AxesCodeGen : public TFV1OpCode { public: - TFV1AxesCodeGen(const String& func_name, const String& attr_name) : TFV1OpCode(func_name) { + TFV1AxesCodeGen(const ffi::String& func_name, const ffi::String& attr_name) + : TFV1OpCode(func_name) { attr_name_ = attr_name; } protected: void CodeGenBuild() final { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_input_arg().op_list_arg(key, attr_name_).op_name_arg(); } private: - String attr_name_; + ffi::String attr_name_; }; class TFV1AxisCodeGen : public TFV1OpCode { public: - TFV1AxisCodeGen(const String& func_name, const String& attr_name) : TFV1OpCode(func_name) { + TFV1AxisCodeGen(const ffi::String& func_name, const ffi::String& attr_name) + : TFV1OpCode(func_name) { attr_name_ = attr_name; } @@ -154,7 +155,7 @@ class TFV1AxisCodeGen : public TFV1OpCode { } private: - String attr_name_; + ffi::String attr_name_; }; class TFV1BatchnormCodeGen : public TFV1OpCode { @@ -168,8 +169,8 @@ class TFV1BatchnormCodeGen : public TFV1OpCode { .op_arg("center") .op_arg("momentum") .op_arg("epsilon"); - Array weight_names{"gamma", "beta", "mean", "var"}; - Array init_names{"gamma", "beta", "moving_mean", "moving_variance"}; + ffi::Array weight_names{"gamma", "beta", "mean", "var"}; + ffi::Array init_names{"gamma", "beta", "moving_mean", "moving_variance"}; for (size_t i = 0; i < weight_names.size(); i++) { const auto& w_doc = DocUtils::ToStr(node()->WeightAt(weight_names[i])->name); stack_.inplace_start("tf_v1.constant_initializer", init_names[i] + "_initializer") @@ -219,7 +220,7 @@ class TFV1ConstantCodeGen : public TFV1OpCode { class TFV1ConvCodeGen : public TFV1OpCode { public: - TFV1ConvCodeGen(const String& func_name, bool use_bias) : TFV1OpCode(func_name) { + TFV1ConvCodeGen(const ffi::String& func_name, bool use_bias) : TFV1OpCode(func_name) { use_bias_ = use_bias; } @@ -318,19 +319,19 @@ class TFV1PadCodeGen : public TFV1OpCode { protected: void CodeGenBuild() final { - String mode; + ffi::String mode; const auto& attr_mode = node()->GetTypeAttr("pad_mode"); if (attr_mode == "constant") { mode = "CONSTANT"; } else { LOG_FATAL << "Unexpected pad mode " << node(); } - Array pad_width; + ffi::Array pad_width; const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); ICHECK(attr_pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); for (size_t i = 0; i < attr_pad_width.size(); i += 2) { - const String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + - std::to_string(attr_pad_width[i + 1]) + "]"; + const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + + std::to_string(attr_pad_width[i + 1]) + "]"; pad_width.push_back(cur_pad); } const auto& val_producer = node()->ProducerOf(1); @@ -349,7 +350,7 @@ class TFV1Pool2dCodeGen : public TFV1OpCode { protected: void CodeGenBuild() final { - String pooling_type; + ffi::String pooling_type; if (node()->optype == "nn.avg_pool2d") { pooling_type = "AVG"; } else if (node()->optype == "nn.max_pool2d") { @@ -413,7 +414,7 @@ class TFV1Resize2dCodeGen : public TFV1OpCode { protected: void CodeGenBuild() final { - String func_name; + ffi::String func_name; const auto& method = node()->GetTypeAttr("method"); const auto& coordinate_transformation_mode = node()->GetTypeAttr("coordinate_transformation_mode"); @@ -502,8 +503,10 @@ class TFV1TupleCodeGen : public TFV1OpCode { void CodeGenBuild() final { stack_.op_call().op_inputs_arg(); } }; -const std::shared_ptr>> GetTFV1OpCodes() { - static auto map = std::make_shared>>(); +const std::shared_ptr>> +GetTFV1OpCodes() { + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // binary && unary ops map->emplace("abs", std::make_shared("tf_v1.abs")); diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h index bda7e6e99336..a744ffc701e4 100644 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h +++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h @@ -50,14 +50,14 @@ class TFV1OpCode : public BaseOpCode * \param func_name the function name for the node. * \param config the config json for the node. */ - explicit TFV1OpCode(const String& func_name) + explicit TFV1OpCode(const ffi::String& func_name) : BaseOpCode(func_name) {} /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; /*! \brief Get dtype string*/ - const String DType(const DataType& dtype) final { + const ffi::String DType(const DataType& dtype) final { return "tf_v1." + BaseOpCode::DType(dtype); } @@ -68,16 +68,17 @@ class TFV1OpCode : public BaseOpCode virtual void CodeGenBuild() = 0; /*! \brief Get padding mode or array*/ - const std::pair> GetPadding(const String& strides_key, - const String& kernel_key = "", - const String& padding_key = "padding"); + const std::pair> GetPadding( + const ffi::String& strides_key, const ffi::String& kernel_key = "", + const ffi::String& padding_key = "padding"); }; /*! * \brief Get the map of available TFV1OpCode, use optype as key * \return Map of */ -const std::shared_ptr>> GetTFV1OpCodes(); +const std::shared_ptr>> +GetTFV1OpCodes(); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 7acd0f215502..1be8cf0836c9 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -48,7 +48,7 @@ void TensorRTCodeGen::CodeGenClassDeclare() { } // plugin headers if (config()->use_plugin) { - std::set plugins; + std::set plugins; for (const auto& n : graph()->node_names) { const auto& node = graph()->FindNode(n); if (IsPlugin(node->optype) && !plugins.count(node->optype)) { @@ -95,7 +95,7 @@ void TensorRTCodeGen::CodeGenClassDeclare() { void TensorRTCodeGen::CodeGenClassDefine() { auto malloc_buffer = [this](const MSCTensor& tensor) { - const String& idx_var = "idx_" + IdxTensor(tensor); + const ffi::String& idx_var = "idx_" + IdxTensor(tensor); this->stack_ .func_call("getBindingIndex", DocUtils::ToDeclare("int", idx_var), DocUtils::ToPtr("engine")) @@ -121,8 +121,8 @@ void TensorRTCodeGen::CodeGenClassDefine() { // save codegen before build if (config()->use_tools) { const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step"); - before_build_codes_ = - pf(GetStepCtx(), "before_build", graph()->name, config()->tools_tag).cast>(); + before_build_codes_ = pf(GetStepCtx(), "before_build", graph()->name, config()->tools_tag) + .cast>(); } if (graph()->weight_holders.size() > 0) { stack_.func_call("TRTUtils::LoadWeights", "mWeights") @@ -144,7 +144,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.comment("Mark batch size"); stack_.func_call("createOptimizationProfile", DocUtils::ToDeclare("auto", "profile"), DocUtils::ToPtr("builder")); - Array batch_flags{"MIN", "MAX", "OPT"}; + ffi::Array batch_flags{"MIN", "MAX", "OPT"}; for (const auto& i : graph()->GetInputs()) { for (const auto& f : batch_flags) { stack_.func_call("setDimensions", std::nullopt, DocUtils::ToPtr("profile")) @@ -207,8 +207,8 @@ void TensorRTCodeGen::CodeGenClassDefine() { // save codegen after build if (config()->use_tools) { const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step"); - after_build_codes_ = - pf(GetStepCtx(), "after_build", graph()->name, config()->tools_tag).cast>(); + after_build_codes_ = pf(GetStepCtx(), "after_build", graph()->name, config()->tools_tag) + .cast>(); } // end define build method stack_.func_end("true"); @@ -470,7 +470,7 @@ void TensorRTCodeGen::CodeGenCmake() { if (config()->use_plugin) { stack_.line("add_definitions(-DPLUGIN_SUPPORT_TENSORRT)").line(); } - String link_libs = " ${TRT_LIBS}"; + ffi::String link_libs = " ${TRT_LIBS}"; if (config()->extern_libs.size() > 0) { stack_.line("set(EXTERN_LIBS " + StringUtils::Join(config()->extern_libs, " ") + ")"); link_libs = link_libs + " ${EXTERN_LIBS}"; @@ -481,17 +481,18 @@ void TensorRTCodeGen::CodeGenCmake() { .line("target_link_libraries(" + graph()->name + link_libs + ")"); } -const String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { +const ffi::String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { const auto& pair = graph()->FindProducerAndIdx(tensor); - const String& prefix = "tensor_" + std::to_string(pair.first->index); + const ffi::String& prefix = "tensor_" + std::to_string(pair.first->index); if (pair.first->outputs.size() > 1) { return prefix + "_" + std::to_string(pair.second); } return prefix; } -const String TensorRTCodeGen::CppDType(const DataType& dtype) { - const String& dtype_name = CppCodeGen::DType(dtype); +const ffi::String TensorRTCodeGen::CppDType(const DataType& dtype) { + const ffi::String& dtype_name = + CppCodeGen::DType(dtype); if (dtype_name == "int32") { return "int"; } @@ -507,11 +508,11 @@ const String TensorRTCodeGen::CppDType(const DataType& dtype) { return dtype_name; } -const String TensorRTCodeGen::GetTensorBytes(const MSCTensor& tensor) { +const ffi::String TensorRTCodeGen::GetTensorBytes(const MSCTensor& tensor) { return std::to_string(tensor->GetSize()->value) + " * sizeof(" + CppDType(tensor->dtype) + ")"; } -void TensorRTCodeGen::ReturnOnFail(const String& flag, const String& err) { +void TensorRTCodeGen::ReturnOnFail(const ffi::String& flag, const ffi::String& err) { stack_.cond_if("!" + flag) .func_call("logger.log") .call_arg("ILogger::Severity::kERROR") @@ -521,11 +522,11 @@ void TensorRTCodeGen::ReturnOnFail(const String& flag, const String& err) { } template -const String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) { +const ffi::String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) { if (dims.size() == 2 && !use_ndim) { return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; } - String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; + ffi::String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; for (size_t i = 0; i < dims.size(); i++) { dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); } @@ -533,7 +534,7 @@ const String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) return dims_str; } -const String TensorRTCodeGen::ToDims(const Array& dims, bool use_ndim) { +const ffi::String TensorRTCodeGen::ToDims(const ffi::Array& dims, bool use_ndim) { std::vector int_dims; for (const auto& d : dims) { int_dims.push_back(d->value); @@ -541,7 +542,7 @@ const String TensorRTCodeGen::ToDims(const Array& dims, bool use_ndim) return ToDims(int_dims, use_ndim); } -const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTensorRTOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported tensorrt op(" << node->optype << "): " << node; @@ -554,8 +555,8 @@ const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { } } -const Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) { - Map tensor_ctx; +const ffi::Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) { + ffi::Map tensor_ctx; tensor_ctx.Set("ctx", "network"); for (const auto& pair : CppCodeGen::GetTensorCtx(tensor)) { @@ -564,8 +565,8 @@ const Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) return tensor_ctx; } -const Map TensorRTCodeGen::GetStepCtx() { - Map step_ctx; +const ffi::Map TensorRTCodeGen::GetStepCtx() { + ffi::Map step_ctx; step_ctx.Set("network", "network"); step_ctx.Set("config", "config"); step_ctx.Set("builder", "builder"); @@ -575,42 +576,42 @@ const Map TensorRTCodeGen::GetStepCtx() { return step_ctx; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.framework.tensorrt.GetTensorRTSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); }) - .def("msc.framework.tensorrt.GetTensorRTRoot", []() -> String { + .def("msc.framework.tensorrt.GetTensorRTRoot", []() -> ffi::String { #ifdef TENSORRT_ROOT_DIR return TENSORRT_ROOT_DIR; #else return ""; #endif }); -}); +} /*! * \brief Create runtime modules for MSC TensorRT. * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array MSCTensorRTCompiler(Array functions, - Map target_option, - Map constant_names) { - Array compiled_functions; +ffi::Array MSCTensorRTCompiler(ffi::Array functions, + ffi::Map target_option, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); ICHECK(name_opt.has_value()) << "Can not find " << msc_attr::kUnique << " from attrs"; const auto& name = name_opt.value(); std::string func_name = GetExtSymbol(func); ICHECK(target_option.count(name)) << "Can not find target option for " << name; - const auto& options = Downcast(target_option[name]); + const auto& options = Downcast(target_option[name]); MSCJSONSerializer serializer(constant_names, options); serializer.serialize(func); std::string graph_json = serializer.GetJSON(); @@ -622,10 +623,10 @@ Array MSCTensorRTCompiler(Array functions, return compiled_functions; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.ext.msc_tensorrt", MSCTensorRTCompiler); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorrt/codegen.h b/src/contrib/msc/framework/tensorrt/codegen.h index ea06a17f7c2b..87b4c330e40b 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.h +++ b/src/contrib/msc/framework/tensorrt/codegen.h @@ -60,34 +60,34 @@ class TensorRTCodeGen : public CppCodeGen GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get the tensor context for codegen_tensor*/ - const Map GetTensorCtx(const MSCTensor& tensor) final; + const ffi::Map GetTensorCtx(const MSCTensor& tensor) final; /*! \brief Get the step context for codegen_step*/ - const Map GetStepCtx() final; + const ffi::Map GetStepCtx() final; /*! \brief Generate return on fail codes*/ - void ReturnOnFail(const String& flag, const String& err); + void ReturnOnFail(const ffi::String& flag, const ffi::String& err); /*! \brief Get the index tensor*/ - const String IdxTensor(const MSCTensor& tensor); + const ffi::String IdxTensor(const MSCTensor& tensor); /*! \brief Get the dtype from the datatype*/ - const String CppDType(const DataType& dtype); + const ffi::String CppDType(const DataType& dtype); /*! \brief Generate describe for tensor bytes*/ - const String GetTensorBytes(const MSCTensor& tensor); + const ffi::String GetTensorBytes(const MSCTensor& tensor); /*! \brief Get the tensorrt dims from dims*/ template - const String ToDims(const std::vector& dims, bool use_ndim = true); - const String ToDims(const Array& dims, bool use_ndim = true); + const ffi::String ToDims(const std::vector& dims, bool use_ndim = true); + const ffi::String ToDims(const ffi::Array& dims, bool use_ndim = true); private: - Array before_build_codes_; - Array after_build_codes_; + ffi::Array before_build_codes_; + ffi::Array after_build_codes_; }; } // namespace msc diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h b/src/contrib/msc/framework/tensorrt/codegen_utils.h index f006b21b816e..3a16e668fe96 100644 --- a/src/contrib/msc/framework/tensorrt/codegen_utils.h +++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h @@ -40,8 +40,8 @@ namespace msc { class TensorRTCodeGenHelper : public BaseCodeGenHelper { public: /*! \brief Get describe for default node input*/ - const String IdxInputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool process = false) final { + const ffi::String IdxInputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, + const ffi::String& suffix = "", bool process = false) final { const auto& pair = node->ProducerAndIdxOf(idx); if (pair.first->optype == "input") { return "*" + IdxNodeBase(pair.first, prefix, suffix); @@ -53,8 +53,8 @@ class TensorRTCodeGenHelper : public BaseCodeGenHelper { } /*! \brief Get describe for default node output*/ - const String IdxOutputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool mark_exit = false) final { + const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, + const ffi::String& suffix = "", bool mark_exit = false) final { if (node->optype == "argmax" || node->optype == "argmin") { ICHECK_EQ(idx, 0) << "argmax and argmin only has 1 output, get " << idx; return IdxNodeBase(node, prefix, suffix) + "->getOutput(1)"; @@ -70,8 +70,8 @@ class TensorRTCodeGenHelper : public BaseCodeGenHelper { } /*! \brief Get describe for default node weight*/ - const String IdxWeightBase(const MSCJoint& node, const String& wtype, const String& suffix = "", - bool process = false) final { + const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix = "", bool process = false) final { return "mWeights[\"" + node->WeightAt(wtype)->name + "\"]"; } }; diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index 5a63ecbc7d06..4fde2bf8bc2e 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -31,7 +31,7 @@ namespace tvm { namespace contrib { namespace msc { -const Array TensorRTOpCode::GetDocs() { +const ffi::Array TensorRTOpCode::GetDocs() { stack_.Config(this); CodeGenBuild(); if (node()->optype == "tuple") { @@ -52,7 +52,7 @@ const Array TensorRTOpCode::GetDocs() { return stack_.GetDocs(); } -void TensorRTOpCode::SetPadding(const String& key) { +void TensorRTOpCode::SetPadding(const ffi::String& key) { const auto& padding = node()->GetTypeArrayAttr("padding"); if (padding.size() == 1) { SetLayerByDimsValue("Padding", std::vector{padding[0], padding[0]}, false); @@ -67,8 +67,8 @@ void TensorRTOpCode::SetPadding(const String& key) { } } -const String TensorRTOpCode::DeclareInputs(bool simplify) { - const String& inputs_ref = "inputs_" + std::to_string(node()->index); +const ffi::String TensorRTOpCode::DeclareInputs(bool simplify) { + const ffi::String& inputs_ref = "inputs_" + std::to_string(node()->index); if (node()->parents.size() == 1 && simplify) { const auto& idx_input = StringUtils::Replace(IdxInput(), "*", ""); stack_.declare("std::vector", inputs_ref + "_vec") @@ -85,9 +85,10 @@ const String TensorRTOpCode::DeclareInputs(bool simplify) { return inputs_ref; } -const String TensorRTOpCode::DType(const DataType& dtype) { - const String& dtype_name = BaseOpCode::DType(dtype); - String dtype_enum; +const ffi::String TensorRTOpCode::DType(const DataType& dtype) { + const ffi::String& dtype_name = + BaseOpCode::DType(dtype); + ffi::String dtype_enum; if (dtype_name == "int8") { dtype_enum = "DataType::kINT8"; } else if (dtype_name == "int32") { @@ -105,11 +106,11 @@ const String TensorRTOpCode::DType(const DataType& dtype) { } template -const String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { +const ffi::String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { if (dims.size() == 2 && !use_ndim) { return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; } - String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; + ffi::String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; for (size_t i = 0; i < dims.size(); i++) { dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); } @@ -117,7 +118,7 @@ const String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { return dims_str; } -const String TensorRTOpCode::ToDims(const Array& dims, bool use_ndim) { +const ffi::String TensorRTOpCode::ToDims(const ffi::Array& dims, bool use_ndim) { std::vector int_dims; for (const auto& d : dims) { int_dims.push_back(d->value); @@ -125,7 +126,7 @@ const String TensorRTOpCode::ToDims(const Array& dims, bool use_ndim) { return ToDims(int_dims, use_ndim); } -const String TensorRTOpCode::AttrToDims(const String& key, bool use_ndim) { +const ffi::String TensorRTOpCode::AttrToDims(const ffi::String& key, bool use_ndim) { const auto& dims = node()->GetTypeArrayAttr(key); return ToDims(dims, use_ndim); } @@ -139,7 +140,7 @@ const size_t TensorRTOpCode::ToReduceAxis(const std::vector& axes, size_t n return reduce_axis; } -const size_t TensorRTOpCode::AttrToReduceAxis(const String& key, size_t ndim) { +const size_t TensorRTOpCode::AttrToReduceAxis(const ffi::String& key, size_t ndim) { std::vector axes; if (node()->GetAttr(key, &axes)) { return ToReduceAxis(axes, ndim); @@ -149,56 +150,57 @@ const size_t TensorRTOpCode::AttrToReduceAxis(const String& key, size_t ndim) { return ToReduceAxis(std::vector{axis}, ndim); } -const size_t TensorRTOpCode::AttrToAxis(const String& key, size_t ndim) { +const size_t TensorRTOpCode::AttrToAxis(const ffi::String& key, size_t ndim) { size_t valid_ndim = ndim == 0 ? node()->InputAt(0)->Ndim() : ndim; int axis = node()->GetTypeAttr(key); return CommonUtils::GetIndex(axis, valid_ndim); } template -void TensorRTOpCode::SetLayerByAttr(const String& method, const String& key) { +void TensorRTOpCode::SetLayerByAttr(const ffi::String& method, const ffi::String& key) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).op_arg(key, ""); } template -void TensorRTOpCode::SetLayerByValue(const String& method, const T& value) { +void TensorRTOpCode::SetLayerByValue(const ffi::String& method, const T& value) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).call_arg(value); } -void TensorRTOpCode::SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim) { +void TensorRTOpCode::SetLayerByDimsAttr(const ffi::String& method, const ffi::String& key, + bool use_ndim) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(AttrToDims(key, use_ndim)); } template -void TensorRTOpCode::SetLayerByDimsValue(const String& method, const std::vector& value, +void TensorRTOpCode::SetLayerByDimsValue(const ffi::String& method, const std::vector& value, bool use_ndim) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } -void TensorRTOpCode::SetLayerByDimsValue(const String& method, const Array& value, - bool use_ndim) { +void TensorRTOpCode::SetLayerByDimsValue(const ffi::String& method, + const ffi::Array& value, bool use_ndim) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } #define TENSORRT_OP_CODEGEN_METHODS(TypeName) \ public: \ - TypeName(const String& func_name) : TensorRTOpCode(func_name) {} + TypeName(const ffi::String& func_name) : TensorRTOpCode(func_name) {} -#define TENSORRT_FLAG_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const String& func_name, const String& symbol) : TensorRTOpCode(func_name) { \ - symbol_ = symbol; \ - } \ - \ - private: \ - String symbol_; +#define TENSORRT_FLAG_OP_CODEGEN_METHODS(TypeName) \ + public: \ + TypeName(const ffi::String& func_name, const ffi::String& symbol) : TensorRTOpCode(func_name) { \ + symbol_ = symbol; \ + } \ + \ + private: \ + ffi::String symbol_; class TensorRTActivationCodeGen : public TensorRTOpCode { public: - explicit TensorRTActivationCodeGen(const String& symbol) : TensorRTOpCode("Activation") { + explicit TensorRTActivationCodeGen(const ffi::String& symbol) : TensorRTOpCode("Activation") { symbol_ = symbol; } @@ -214,7 +216,7 @@ class TensorRTActivationCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { @@ -232,7 +234,7 @@ class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { stride.push_back(in_sizes[i] / out_sizes[i]); kernel.push_back((in_sizes[i] - (out_sizes[i] - 1) * stride[i])); } - const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; stack_.op_call() .op_input_arg() .call_arg("PoolingType::k" + symbol_) @@ -243,7 +245,7 @@ class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { class TensorRTArgmaxminCodeGen : public TensorRTOpCode { public: - explicit TensorRTArgmaxminCodeGen(const String& symbol) : TensorRTOpCode("TopK") { + explicit TensorRTArgmaxminCodeGen(const ffi::String& symbol) : TensorRTOpCode("TopK") { symbol_ = symbol; } @@ -258,7 +260,7 @@ class TensorRTArgmaxminCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTAstypeCodeGen : public TensorRTOpCode { @@ -318,7 +320,7 @@ class TensorRTConstantCodeGen : public TensorRTOpCode { class TensorRTConvCodeGen : public TensorRTOpCode { public: - TensorRTConvCodeGen(const String& func_name, bool use_bias) : TensorRTOpCode(func_name) { + TensorRTConvCodeGen(const ffi::String& func_name, bool use_bias) : TensorRTOpCode(func_name) { use_bias_ = use_bias; } @@ -342,7 +344,7 @@ class TensorRTConvCodeGen : public TensorRTOpCode { } else { stack_.call_arg("mWeights[\"" + node()->name + ".bias\"]"); } - const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; SetLayerByDimsAttr("Stride" + suffix, "strides", false); SetLayerByDimsAttr("Dilation" + suffix, "dilation", false); SetLayerByAttr("NbGroups", "groups"); @@ -355,7 +357,7 @@ class TensorRTConvCodeGen : public TensorRTOpCode { class TensorRTElemwiseCodeGen : public TensorRTOpCode { public: - explicit TensorRTElemwiseCodeGen(const String& symbol) : TensorRTOpCode("ElementWise") { + explicit TensorRTElemwiseCodeGen(const ffi::String& symbol) : TensorRTOpCode("ElementWise") { symbol_ = symbol; } @@ -365,7 +367,7 @@ class TensorRTElemwiseCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTGetItemCodeGen : public TensorRTOpCode { @@ -396,7 +398,7 @@ class TensorRTInputCodeGen : public TensorRTOpCode { class TensorRTLinearCodeGen : public TensorRTOpCode { public: - TensorRTLinearCodeGen(const String& func_name, bool use_bias) : TensorRTOpCode(func_name) { + TensorRTLinearCodeGen(const ffi::String& func_name, bool use_bias) : TensorRTOpCode(func_name) { use_bias_ = use_bias; } @@ -464,7 +466,7 @@ class TensorRTPermuteDimsCodeGen : public TensorRTOpCode { axes.push_back(i - 1); } } - const String& perm_ref = "perm_" + std::to_string(node()->index); + const ffi::String& perm_ref = "perm_" + std::to_string(node()->index); stack_.op_call().op_input_arg().declare("Permutation", perm_ref); for (size_t i = 0; i < axes.size(); i++) { stack_.assign(perm_ref + ".order[" + std::to_string(i) + "]", @@ -476,7 +478,7 @@ class TensorRTPermuteDimsCodeGen : public TensorRTOpCode { class TensorRTPool2dCodeGen : public TensorRTOpCode { public: - explicit TensorRTPool2dCodeGen(const String& symbol) : TensorRTOpCode("PoolingNd") { + explicit TensorRTPool2dCodeGen(const ffi::String& symbol) : TensorRTOpCode("PoolingNd") { symbol_ = symbol; } @@ -486,7 +488,7 @@ class TensorRTPool2dCodeGen : public TensorRTOpCode { .op_input_arg() .call_arg("PoolingType::k" + symbol_) .call_arg(AttrToDims("pool_size", false)); - const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; SetLayerByDimsAttr("Stride" + suffix, "strides", false); if (node()->GetTypeAttr("ceil_mode")) { SetLayerByValue("PaddingMode", "PaddingMode::kEXPLICIT_ROUND_UP"); @@ -498,12 +500,12 @@ class TensorRTPool2dCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTReduceCodeGen : public TensorRTOpCode { public: - explicit TensorRTReduceCodeGen(const String& symbol) : TensorRTOpCode("Reduce") { + explicit TensorRTReduceCodeGen(const ffi::String& symbol) : TensorRTOpCode("Reduce") { symbol_ = symbol; } @@ -517,7 +519,7 @@ class TensorRTReduceCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTReshapeCodeGen : public TensorRTOpCode { @@ -540,7 +542,7 @@ class TensorRTResize2dCodeGen : public TensorRTOpCode { void CodeGenBuild() final { stack_.op_call().op_input_arg(); const auto& method = node()->GetTypeAttr("method"); - String resize_mode; + ffi::String resize_mode; if (method == "linear") { resize_mode = "LINEAR"; } else if (method == "nearest_neighbor") { @@ -663,7 +665,7 @@ class TensorRTTopkCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { - const String& symbol = node()->GetTypeAttr("largest") ? "MAX" : "MIN"; + const ffi::String& symbol = node()->GetTypeAttr("largest") ? "MAX" : "MIN"; stack_.op_call() .op_input_arg() .call_arg("TopKOperation::k" + symbol) @@ -685,7 +687,7 @@ class TensorRTTupleCodeGen : public TensorRTOpCode { class TensorRTUnaryCodeGen : public TensorRTOpCode { public: - explicit TensorRTUnaryCodeGen(const String& symbol) : TensorRTOpCode("Unary") { + explicit TensorRTUnaryCodeGen(const ffi::String& symbol) : TensorRTOpCode("Unary") { symbol_ = symbol; } @@ -695,7 +697,7 @@ class TensorRTUnaryCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTWhereCodeGen : public TensorRTOpCode { @@ -718,9 +720,9 @@ class TensorRTPluginOpCodeGen : public TensorRTOpCode { const auto& plugin = GetPlugin(node()->optype); const auto& input_ref = "inputs_" + std::to_string(producer->index); - const String& func_name = "plugin::" + node()->optype + "DynamicPlugin"; - const String& plugin_ref = "plugin_" + std::to_string(node()->index); - const String& layouts_ref = "layouts_" + std::to_string(node()->index); + const ffi::String& func_name = "plugin::" + node()->optype + "DynamicPlugin"; + const ffi::String& plugin_ref = "plugin_" + std::to_string(node()->index); + const ffi::String& layouts_ref = "layouts_" + std::to_string(node()->index); stack_.declare("std::vector", layouts_ref, 0, false); for (const auto& i : node()->GetInputs()) { stack_.declare_arg(DocUtils::ToStr(i->layout.name())); @@ -735,9 +737,10 @@ class TensorRTPluginOpCodeGen : public TensorRTOpCode { } }; -const std::shared_ptr>> +const std::shared_ptr>> GetTensorRTOpCodes() { - static auto map = std::make_shared>>(); + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // unary ops map->emplace("abs", std::make_shared("ABS")); diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h index 2d9bcb6acfa2..ddf7fb1522be 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h @@ -49,22 +49,22 @@ class TensorRTOpCode : public BaseOpCode(func_name) {} /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; /*! \brief Get func_name for the default node*/ - const String callee_name() final { + const ffi::String callee_name() final { return "network->add" + BaseOpCode::callee_name(); } /*! \brief Get valid return name for the default node*/ - const String ret_name() final { return "auto " + IdxNode(); } + const ffi::String ret_name() final { return "auto " + IdxNode(); } /*! \brief Get the dtype from the datatype*/ - const String DType(const DataType& dtype) final; + const ffi::String DType(const DataType& dtype) final; protected: TensorRTOpCodeStack stack_; @@ -73,50 +73,52 @@ class TensorRTOpCode : public BaseOpCode - const String ToDims(const std::vector& dims, bool use_ndim = true); - const String ToDims(const Array& dims, bool use_ndim = true); + const ffi::String ToDims(const std::vector& dims, bool use_ndim = true); + const ffi::String ToDims(const ffi::Array& dims, bool use_ndim = true); /*! \brief Get the tensorrt dims from attribute*/ - const String AttrToDims(const String& key, bool use_ndim = true); + const ffi::String AttrToDims(const ffi::String& key, bool use_ndim = true); /*! \brief Get the tensorrt reduce axis from dims*/ const size_t ToReduceAxis(const std::vector& axes, size_t ndim = 0); /*! \brief Get the tensorrt reduce axis from attribute*/ - const size_t AttrToReduceAxis(const String& key = "axis", size_t ndim = 0); + const size_t AttrToReduceAxis(const ffi::String& key = "axis", size_t ndim = 0); /*! \brief Get the attribute axis from attribute*/ - const size_t AttrToAxis(const String& key = "axis", size_t ndim = 0); + const size_t AttrToAxis(const ffi::String& key = "axis", size_t ndim = 0); /*! \brief Set layer by attribute*/ template - void SetLayerByAttr(const String& method, const String& key); + void SetLayerByAttr(const ffi::String& method, const ffi::String& key); /*! \brief Set layer by value*/ template - void SetLayerByValue(const String& method, const T& value); + void SetLayerByValue(const ffi::String& method, const T& value); /*! \brief Set layer by dims attribute*/ - void SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim = true); + void SetLayerByDimsAttr(const ffi::String& method, const ffi::String& key, bool use_ndim = true); /*! \brief Set layer by dims value*/ template - void SetLayerByDimsValue(const String& method, const std::vector& value, bool use_ndim = true); - void SetLayerByDimsValue(const String& method, const Array& value, bool use_ndim = true); + void SetLayerByDimsValue(const ffi::String& method, const std::vector& value, + bool use_ndim = true); + void SetLayerByDimsValue(const ffi::String& method, const ffi::Array& value, + bool use_ndim = true); }; /*! * \brief Get the map of available TensorRTOpCode, use optype as key * \return Map of */ -const std::shared_ptr>> +const std::shared_ptr>> GetTensorRTOpCodes(); } // namespace msc diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 3d43c74958ec..e3579ec7ef77 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -58,7 +58,7 @@ struct TensorRTTransConfig { } }; -const TensorRTTransConfig ParseConfig(const String& config_str) { +const TensorRTTransConfig ParseConfig(const ffi::String& config_str) { TensorRTTransConfig config; if (config_str.size() > 0) { std::istringstream is(config_str); @@ -70,12 +70,12 @@ const TensorRTTransConfig ParseConfig(const String& config_str) { using FRewriteTensorRT = ffi::TypedFunction& new_calls, const String& config)>; + const ffi::Map& new_calls, const ffi::String& config)>; -const Array BroadcastShape(const Array& src_shape, - const Array& out_shape) { +const ffi::Array BroadcastShape(const ffi::Array& src_shape, + const ffi::Array& out_shape) { size_t diff = out_shape.size() - src_shape.size(); - Array leading_shape, tailing_shape; + ffi::Array leading_shape, tailing_shape; for (size_t i = 0; i < diff; i++) { leading_shape.push_back(Integer(1)); } @@ -95,7 +95,7 @@ const Array BroadcastShape(const Array& src_shape, } Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& shape_a = ExprUtils::GetShape(call->args[0]); const auto& shape_b = ExprUtils::GetShape(call->args[1]); @@ -118,7 +118,7 @@ Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; if (new_calls.count(call->args[0]) && new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) { @@ -135,7 +135,7 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, const auto* conv_attrs = conv2d->attrs.as(); if (conv_attrs->data_layout == "NCHW") { // expand bias reshape - Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; + ffi::Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); const auto& exp_bias = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_bias"), reshape_op, @@ -155,14 +155,14 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& out_dtype = ExprUtils::GetDataType(var); const auto* src_attrs = src_call->attrs.as(); ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64)) << "Unexpected out dtype " << out_dtype; static const Op& topk_op = Op::Get("relax.topk"); - auto topk_attrs = make_object(); + auto topk_attrs = ffi::make_object(); topk_attrs->k = 1; if (src_attrs->axis.has_value()) { topk_attrs->axis = src_attrs->axis.value(); @@ -187,7 +187,7 @@ Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); @@ -218,8 +218,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call static const Op& exp_op = Op::Get("relax.exp"); // prepare q,k,v - auto permute_attrs = make_object(); - Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; + auto permute_attrs = ffi::make_object(); + ffi::Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; permute_attrs->axes = axes; const auto& q_trans = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), permute_dims_op, @@ -230,17 +230,17 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call const auto& v_trans = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), permute_dims_op, {call->args[2]}, Attrs(permute_attrs)); - Array q_shape({batch_size * num_head, seq_len, head_dim}); + ffi::Array q_shape({batch_size * num_head, seq_len, head_dim}); const auto& q_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_reshape"), reshape_op, {q_trans, ShapeExpr(q_shape)}); - Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); + ffi::Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); const auto& k_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape"), reshape_op, {k_trans, ShapeExpr(k_shape)}); - Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); + ffi::Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); const auto& v_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_reshape"), reshape_op, {v_trans, ShapeExpr(v_shape)}); - auto reduce_permute_attrs = make_object(); - Array v_axes{Integer(0), Integer(2), Integer(1)}; + auto reduce_permute_attrs = ffi::make_object(); + ffi::Array v_axes{Integer(0), Integer(2), Integer(1)}; reduce_permute_attrs->axes = v_axes; // transpose for batch_matmul const auto& k_reshape_trans = @@ -248,7 +248,7 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call permute_dims_op, {k_reshape}, Attrs(reduce_permute_attrs)); // calculate product - auto matmul_attrs = make_object(); + auto matmul_attrs = ffi::make_object(); matmul_attrs->out_dtype = in_dtype; const auto& qk_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), matmul_op, @@ -273,8 +273,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call // bias Expr prod = p_scale; if (call->args.size() == 4) { - Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; - Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; + ffi::Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; + ffi::Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; const auto& prod_exp = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_exp"), reshape_op, {prod, ShapeExpr(exp_shape)}); const auto& prod_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_add"), @@ -286,7 +286,7 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call // causal_mask Expr s_value; if (!src_attrs->causal_mask.has_value()) { - auto softmax_attrs = make_object(); + auto softmax_attrs = ffi::make_object(); softmax_attrs->axis = 2; s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op, {prod}, Attrs(softmax_attrs)); @@ -302,8 +302,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call } const auto& p_masked = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked"), tril_op, {prod, tril_k}); - auto reduce_attrs = make_object(); - Array axis{Integer(2)}; + auto reduce_attrs = ffi::make_object(); + ffi::Array axis{Integer(2)}; reduce_attrs->axis = axis; reduce_attrs->keepdims = true; const auto& p_max = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_max"), @@ -324,18 +324,18 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call // final calculation const auto& o_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "o_prod"), matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); - Array o_shape{batch_size, num_head, seq_len, head_dim_v}; + ffi::Array o_shape{batch_size, num_head, seq_len, head_dim_v}; return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); // define expand shape - Array exp_shape(input_shape.size(), Integer(1)); + ffi::Array exp_shape(input_shape.size(), Integer(1)); exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]); // create eps constant @@ -380,11 +380,11 @@ Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, {res, exp_offset}); } - return Tuple(Array{res}, call->span); + return Tuple(ffi::Array{res}, call->span); } Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& output_shape = ExprUtils::GetShape(var); @@ -394,8 +394,8 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_ca int64_t in_dim = Downcast(input_shape[i])->value; int64_t out_dim = Downcast(output_shape[i])->value; if (in_dim != out_dim) { - Array concat_inputs(out_dim / in_dim, concat_input); - auto concat_attrs = make_object(); + ffi::Array concat_inputs(out_dim / in_dim, concat_input); + auto concat_attrs = ffi::make_object(); concat_attrs->axis = i; concat_input = RewriteUtils::MakeCall( builder, ExprUtils::GetSpanName(call, "concat_" + std::to_string(i)), concat_op, @@ -406,17 +406,19 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_ca } Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto* src_attrs = src_call->attrs.as(); const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& weight_shape = ExprUtils::GetShape(call->args[1]); const auto& output_shape = ExprUtils::GetShape(var); if (src_attrs->data_layout == "NCW") { - Array new_args; + ffi::Array new_args; // expand inputs - Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), input_shape[2]}; - Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), weight_shape[2]}; + ffi::Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), + input_shape[2]}; + ffi::Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), + weight_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_input"), reshape_op, @@ -426,11 +428,11 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, {call->args[1], ShapeExpr(exp_weight_shape)})); // change to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); - auto conv_attrs = make_object(); - conv_attrs->strides = Array{src_attrs->strides[0], Integer(1)}; + auto conv_attrs = ffi::make_object(); + conv_attrs->strides = ffi::Array{src_attrs->strides[0], Integer(1)}; conv_attrs->padding = - Array{Integer(0), src_attrs->padding[0], Integer(0), src_attrs->padding[1]}; - conv_attrs->dilation = Array{src_attrs->dilation[0], Integer(1)}; + ffi::Array{Integer(0), src_attrs->padding[0], Integer(0), src_attrs->padding[1]}; + conv_attrs->dilation = ffi::Array{src_attrs->dilation[0], Integer(1)}; conv_attrs->groups = src_attrs->groups; conv_attrs->data_layout = "NCHW"; conv_attrs->kernel_layout = "OIHW"; @@ -448,7 +450,7 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { // 0.5 * x * (1 + erf(sqrt(0.5) * x)) const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); @@ -476,7 +478,7 @@ Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { // 0.5 * x * (1 + tanh(sqrt(2/pi) * (0.044715F * pow(x, 3) + x))) const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); @@ -517,13 +519,13 @@ Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); - Array group_shape = input_shape; - Array exp_shape(input_shape.size(), Integer(1)); + ffi::Array group_shape = input_shape; + ffi::Array exp_shape(input_shape.size(), Integer(1)); size_t axis = CommonUtils::GetIndex(src_attrs->channel_axis, input_shape.size()); int64_t channel_dim = Downcast(input_shape[axis])->value * Downcast(input_shape[axis + 1])->value / src_attrs->num_groups; @@ -551,7 +553,7 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call {call->args[0], ShapeExpr(group_shape)}); // mean(input) - auto mean_attrs = make_object(); + auto mean_attrs = ffi::make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, @@ -566,7 +568,7 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) - Array exp_eps_shape(input_shape.size(), Integer(1)); + ffi::Array exp_eps_shape(input_shape.size(), Integer(1)); const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), reshape_op, {eps, ShapeExpr(exp_eps_shape)}); const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), @@ -599,12 +601,12 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); - Array exp_shape(input_shape.size(), Integer(1)); + ffi::Array exp_shape(input_shape.size(), Integer(1)); for (const auto& a : src_attrs->axes) { size_t index = CommonUtils::GetIndex(static_cast(a->value), input_shape.size()); exp_shape.Set(index, input_shape[index]); @@ -624,7 +626,7 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call static const Op& subtract_op = Op::Get("relax.subtract"); // mean(input) - auto mean_attrs = make_object(); + auto mean_attrs = ffi::make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, @@ -639,7 +641,7 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) - Array exp_eps_shape(input_shape.size(), Integer(1)); + ffi::Array exp_eps_shape(input_shape.size(), Integer(1)); const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), reshape_op, {eps, ShapeExpr(exp_eps_shape)}); const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), @@ -676,7 +678,7 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& trt_config = ParseConfig(config); const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& shape_a = ExprUtils::GetShape(call->args[0]); @@ -686,27 +688,27 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, trt_config.linear_to_conv) { const auto& out_shape = ExprUtils::GetShape(var); PrimExpr accumulate = ArrayUtils::Accumulate(shape_a, shape_a.size() - 1); - Array exp_shape{accumulate, shape_a[shape_a.size() - 1], Integer(1), Integer(1)}; + ffi::Array exp_shape{accumulate, shape_a[shape_a.size() - 1], Integer(1), Integer(1)}; const auto& exp_in = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_in"), reshape_op, {call->args[0], ShapeExpr(exp_shape)}); // transpose and expand weight to OIHW static const Op& permute_dims_op = Op::Get("relax.permute_dims"); - auto permute_attrs = make_object(); - Array axes{Integer(1), Integer(0)}; + auto permute_attrs = ffi::make_object(); + ffi::Array axes{Integer(1), Integer(0)}; permute_attrs->axes = axes; const auto& trans_weight = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "trans_weight"), permute_dims_op, {call->args[1]}, Attrs(permute_attrs)); - Array weight_shape{shape_b[1], shape_b[0], Integer(1), Integer(1)}; + ffi::Array weight_shape{shape_b[1], shape_b[0], Integer(1), Integer(1)}; const auto& exp_weight = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), reshape_op, {trans_weight, ShapeExpr(weight_shape)}); // to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); - auto conv_attrs = make_object(); - conv_attrs->strides = Array{Integer(1), Integer(1)}; - conv_attrs->padding = Array{Integer(0), Integer(0), Integer(0), Integer(0)}; - conv_attrs->dilation = Array{Integer(1), Integer(1)}; + auto conv_attrs = ffi::make_object(); + conv_attrs->strides = ffi::Array{Integer(1), Integer(1)}; + conv_attrs->padding = ffi::Array{Integer(0), Integer(0), Integer(0), Integer(0)}; + conv_attrs->dilation = ffi::Array{Integer(1), Integer(1)}; conv_attrs->groups = 1; conv_attrs->data_layout = "NCHW"; conv_attrs->kernel_layout = "OIHW"; @@ -717,7 +719,7 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, return Call(reshape_op, {conv2d, ShapeExpr(out_shape)}, Attrs(), call->sinfo_args, call->span); } if (shape_a.size() > shape_b.size()) { - Array exp_shape(shape_a.size(), Integer(1)); + ffi::Array exp_shape(shape_a.size(), Integer(1)); size_t diff = shape_a.size() - shape_b.size(); for (size_t i = diff; i < shape_a.size(); i++) { exp_shape.Set(i, shape_b[i - diff]); @@ -728,7 +730,7 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); } if (shape_a.size() < shape_b.size()) { - Array exp_shape(shape_b.size(), Integer(1)); + ffi::Array exp_shape(shape_b.size(), Integer(1)); size_t diff = shape_b.size() - shape_a.size(); for (size_t i = diff; i < shape_b.size(); i++) { exp_shape.Set(i, shape_a[i - diff]); @@ -742,7 +744,7 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); @@ -761,7 +763,7 @@ Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; // create ops static const Op& multiply_op = Op::Get("relax.multiply"); @@ -773,7 +775,7 @@ Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& output_shape = ExprUtils::GetShape(var); static const Op& reshape_op = Op::Get("relax.reshape"); @@ -782,7 +784,7 @@ Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); @@ -797,7 +799,7 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, split_ends.push_back(i * size + size); } } else if (src_attrs->indices_or_sections->IsInstance()) { - const auto& indices = Downcast>(src_attrs->indices_or_sections); + const auto& indices = Downcast>(src_attrs->indices_or_sections); int64_t last_index = 0; for (size_t i = 0; i < indices.size(); ++i) { split_begins.push_back(last_index); @@ -811,14 +813,15 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, << src_attrs->indices_or_sections->GetTypeKey() << ")"; } // create strided_slices - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < split_begins.size(); i++) { static const Op& strided_slice_op = Op::Get("relax.strided_slice"); - const auto& axes = Tuple(Array{PrimValue(IntImm(DataType::Int(64), axis))}); - const auto& begin = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_begins[i]))}); - const auto& end = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_ends[i]))}); - const auto& strides = Tuple(Array{PrimValue(IntImm(DataType::Int(64), 1))}); - auto attrs = make_object(); + const auto& axes = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), axis))}); + const auto& begin = + Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), split_begins[i]))}); + const auto& end = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), split_ends[i]))}); + const auto& strides = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), 1))}); + auto attrs = ffi::make_object(); attrs->assume_inbound = true; const auto& slice = RewriteUtils::MakeCall( builder, ExprUtils::GetSpanName(call, "slice_" + std::to_string(i)), strided_slice_op, @@ -872,17 +875,17 @@ TVM_REGISTER_OP("relax.split").set_attr("FRewriteTensorRT", Re class TensorRTTransformer : public ExprMutator { public: - explicit TensorRTTransformer(IRModule ctx_module, const String& config) + explicit TensorRTTransformer(IRModule ctx_module, const ffi::String& config) : ExprMutator(ctx_module) { config_ = config; } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { if (const auto* op_node = call_node->op.as()) { - const auto& op = Downcast(GetRef(op_node)); + const auto& op = Downcast(ffi::GetRef(op_node)); const auto& rewrite_map = Op::GetAttrMap("FRewriteTensorRT"); if (rewrite_map.count(op)) { - const auto& call = GetRef(call_node); + const auto& call = ffi::GetRef(call_node); FRewriteTensorRT f = rewrite_map[op]; const auto& new_call = f(builder_, binding->var, call, new_calls_, config_); if (new_call != call) { @@ -897,27 +900,28 @@ class TensorRTTransformer : public ExprMutator { } private: - Map new_calls_; - String config_; + ffi::Map new_calls_; + ffi::String config_; }; -Function TransformTensorRT(const Function& func, const IRModule& module, const String& config) { +Function TransformTensorRT(const Function& func, const IRModule& module, + const ffi::String& config) { return Downcast(TensorRTTransformer(module, config).VisitExpr(func)); } namespace transform { -Pass TransformTensorRT(const String& config) { +Pass TransformTensorRT(const ffi::String& config) { auto pass_func = [=](Function f, IRModule m, PassContext pc) { return relax::TransformTensorRT(f, m, config); }; return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.TransformTensorRT", TransformTensorRT); -}); +} } // namespace transform } // namespace relax diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 68c55bb9cbce..c81646f8b267 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -92,7 +92,7 @@ void TorchCodeGen::CodeGenGraph() { } CodeGenNode(node, config()->use_tools); } - Array idx_outputs; + ffi::Array idx_outputs; for (const auto& o : graph()->GetOutputs()) { const auto& pair = graph()->FindProducerAndIdx(o); idx_outputs.push_back(IdxOutputBase(pair.first, pair.second, true)); @@ -140,7 +140,7 @@ void TorchCodeGen::CodeGenInference() { } } -const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTorchOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; @@ -153,16 +153,16 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.framework.torch.GetTorchSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); }); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/torch/codegen.h b/src/contrib/msc/framework/torch/codegen.h index 0ee860bb55c8..1e5032309cb6 100644 --- a/src/contrib/msc/framework/torch/codegen.h +++ b/src/contrib/msc/framework/torch/codegen.h @@ -56,10 +56,10 @@ class TorchCodeGen : public PyCodeGen { void CodeGenInference() final; /*! \brief Get the docs for the op*/ - const Array GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get tensor type of the framework*/ - const String TensorType() const final { return "torch.Tensor"; } + const ffi::String TensorType() const final { return "torch.Tensor"; } private: bool is_init_; diff --git a/src/contrib/msc/framework/torch/codegen_utils.h b/src/contrib/msc/framework/torch/codegen_utils.h index c63de27519e0..13dee306e942 100644 --- a/src/contrib/msc/framework/torch/codegen_utils.h +++ b/src/contrib/msc/framework/torch/codegen_utils.h @@ -39,8 +39,8 @@ namespace msc { class TorchCodeGenHelper : public BaseCodeGenHelper { public: /*! \brief Get describe for default node input*/ - const String IdxOutputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool mark_exit = false) final { + const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, + const ffi::String& suffix = "", bool mark_exit = false) final { if ((node->optype == "max" || node->optype == "min") && node->OutputAt(0)->Ndim() > 0) { ICHECK(idx == 0) << "max and min op only support 1 outputs, get " << node; return IdxNodeBase(node, prefix, suffix) + ".values"; diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index 9e3652f04118..8f649469855e 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -30,7 +30,7 @@ namespace tvm { namespace contrib { namespace msc { -const Array TorchOpCode::GetDocs() { +const ffi::Array TorchOpCode::GetDocs() { stack_.Config(this); if (is_init()) { CodeGenInit(); @@ -50,7 +50,7 @@ void TorchOpCode::CodeGenInit() { void TorchOpCode::CodeGenForward() { stack_.op_call().op_inputs_arg(false); } -const StrictListDoc TorchOpCode::GetPadding(const String& key) { +const StrictListDoc TorchOpCode::GetPadding(const ffi::String& key) { std::vector padding, src_padding; ICHECK(node()->GetAttr(key, &src_padding)); if (node()->optype == "nn.conv1d" || node()->optype == "msc.conv1d_bias") { @@ -76,9 +76,9 @@ const StrictListDoc TorchOpCode::GetPadding(const String& key) { return DocUtils::ToList(padding); } -#define TORCH_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const String& module_name, const String& func_name) \ +#define TORCH_OP_CODEGEN_METHODS(TypeName) \ + public: \ + TypeName(const ffi::String& module_name, const ffi::String& func_name) \ : TorchOpCode(module_name, func_name) {} class TorchAdaptivePoolCodeGen : public TorchOpCode { @@ -118,7 +118,7 @@ class TorchAxesCodeGen : public TorchOpCode { protected: void CodeGenInit() final { if (module_name().size() > 0) { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_list_arg(key, ""); } else { TorchOpCode::CodeGenInit(); @@ -129,7 +129,7 @@ class TorchAxesCodeGen : public TorchOpCode { if (module_name().size() > 0) { TorchOpCode::CodeGenForward(); } else { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_input_arg().op_list_arg(key, ""); } } @@ -268,7 +268,7 @@ class TorchConstantCodeGen : public TorchOpCode { class TorchConvCodeGen : public TorchOpCode { public: - TorchConvCodeGen(const String& module_name, const String& func_name, bool use_bias) + TorchConvCodeGen(const ffi::String& module_name, const ffi::String& func_name, bool use_bias) : TorchOpCode(module_name, func_name), use_bias_(use_bias) {} protected: @@ -343,9 +343,9 @@ class TorchExpandDimsCodeGen : public TorchOpCode { protected: void CodeGenForward() final { const auto& axes = node()->GetTypeArrayAttr("axis"); - String idx_input = IdxInput(); + ffi::String idx_input = IdxInput(); for (size_t i = 0; i < axes.size(); i++) { - String idx_out = IdxNode(); + ffi::String idx_out = IdxNode(); if (i < axes.size() - 1) { idx_out = idx_out + "_" + std::to_string(i); } @@ -400,7 +400,7 @@ class TorchLayerNormCodeGen : public TorchOpCode { << "Only support center and scale batchnorm, get " << node(); const auto& axes = CommonUtils::GetIndices(node()->GetTypeArrayAttr("axes"), node()->InputAt(0)->Ndim()); - Array normalized_shape; + ffi::Array normalized_shape; for (const auto& a : axes) { normalized_shape.push_back(node()->InputAt(0)->DimAt(a)); } @@ -412,7 +412,7 @@ class TorchLayerNormCodeGen : public TorchOpCode { class TorchLinearCodeGen : public TorchOpCode { public: - TorchLinearCodeGen(const String& module_name, const String& func_name, bool use_bias) + TorchLinearCodeGen(const ffi::String& module_name, const ffi::String& func_name, bool use_bias) : TorchOpCode(module_name, func_name), use_bias_(use_bias) {} protected: @@ -546,7 +546,7 @@ class TorchReshapeCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - Array shape = node()->OutputAt(0)->shape; + ffi::Array shape = node()->OutputAt(0)->shape; const auto& out_layout = node()->OutputAt(0)->layout; if (out_layout.defined()) { int32_t batch_dim = out_layout.IndexOf(tvm::tir::LayoutAxis::Get("N")); @@ -564,7 +564,7 @@ class TorchResize2dCodeGen : public TorchOpCode { protected: void CodeGenForward() final { const auto& method = node()->GetTypeAttr("method"); - String v_method; + ffi::String v_method; if (method == "nearest_neighbor") { v_method = "nearest"; } else { @@ -657,7 +657,7 @@ class TorchStridedSliceCodeGen : public TorchOpCode { for (size_t i = 0; i < axes.size(); i++) { axes_map[axes[i]] = i; } - Array slice; + ffi::Array slice; for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { if (axes_map.count(i)) { size_t idx = axes_map[i]; @@ -712,8 +712,10 @@ class TorchPluginOpCodeGen : public TorchOpCode { void CodeGenForward() final { stack_.op_call().op_inputs_arg(false); } }; -const std::shared_ptr>> GetTorchOpCodes() { - static auto map = std::make_shared>>(); +const std::shared_ptr>> +GetTorchOpCodes() { + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // simple ops diff --git a/src/contrib/msc/framework/torch/torch_opcode.h b/src/contrib/msc/framework/torch/torch_opcode.h index 80b7f5c60d1d..e732e502ce31 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.h +++ b/src/contrib/msc/framework/torch/torch_opcode.h @@ -49,31 +49,31 @@ class TorchOpCode : public BaseOpCode { * \param func_name the function name for the node. * \param config the config json for the node. */ - explicit TorchOpCode(const String& module_name, const String& func_name) + explicit TorchOpCode(const ffi::String& module_name, const ffi::String& func_name) : BaseOpCode(func_name) { module_name_ = module_name; } /*! \brief Config the TorchOpCode*/ void Config(const MSCJoint& node, const std::shared_ptr config, bool is_init, - const Map& prims) { + const ffi::Map& prims) { BaseOpCode::Config(node, config, prims); is_init_ = is_init; module_ref_ = "self." + StringUtils::Replace(node->name, ".", "_"); } /*! \brief Get return describe for default node*/ - const String IdxNode() final { + const ffi::String IdxNode() final { return is_init_ ? module_ref_ : BaseOpCode::IdxNode(); }; /*! \brief Get dtype string*/ - const String DType(const DataType& dtype) final { + const ffi::String DType(const DataType& dtype) final { return "torch." + BaseOpCode::DType(dtype); } /*! \brief Get func_name for the default node*/ - const String callee_name() final { + const ffi::String callee_name() final { if (is_init_) { return module_name_; } @@ -84,7 +84,7 @@ class TorchOpCode : public BaseOpCode { } /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; protected: TorchOpCodeStack stack_; @@ -96,28 +96,29 @@ class TorchOpCode : public BaseOpCode { virtual void CodeGenForward(); /*! \brief Get the padding from op*/ - const StrictListDoc GetPadding(const String& key = "padding"); + const StrictListDoc GetPadding(const ffi::String& key = "padding"); /*! \brief Get the is_init_ of codegen*/ bool is_init() { return is_init_; } /*! \brief Get the module_name of codegen*/ - const String module_name() { return module_name_; } + const ffi::String module_name() { return module_name_; } /*! \brief Get the module_ref of codegen*/ - const String module_ref() { return module_ref_; } + const ffi::String module_ref() { return module_ref_; } private: bool is_init_; - String module_name_; - String module_ref_; + ffi::String module_name_; + ffi::String module_ref_; }; /*! * \brief Get the map of available TorchOpCode, use optype as key * \return Map of */ -const std::shared_ptr>> GetTorchOpCodes(); +const std::shared_ptr>> +GetTorchOpCodes(); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 7c42ba8d142a..29445ed7ccc3 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -35,7 +35,7 @@ void RelaxCodeGen::CodeGenHeader() { void RelaxCodeGen::CodeGenGraph() { stack_.func_def(graph()->name, "tvm.IRModule"); - Array idx_inputs; + ffi::Array idx_inputs; for (const auto& i : graph()->GetInputs()) { const auto& pair = graph()->FindProducerAndIdx(i); const auto& idx_input = IdxOutputBase(pair.first, pair.second); @@ -89,13 +89,13 @@ void RelaxCodeGen::CodeGenGraph() { } // mark outputs stack_.comment("Emit the outputs"); - Array idx_exits; + ffi::Array idx_exits; for (const auto& e : graph()->GetExits()) { const auto& idx_exit = IdxNodeBase(e) + (config()->use_tools ? "_exit" : ""); if (config()->use_tools) { if (e->outputs.size() > 1) { - Array tuple_outputs; + ffi::Array tuple_outputs; for (size_t o_idx = 0; o_idx < e->outputs.size(); o_idx++) { const auto& t_output = IdxOutputBase(e, o_idx, true); tuple_outputs.push_back(t_output); @@ -151,7 +151,7 @@ void RelaxCodeGen::CodeGenInference() { const auto& producer = graph()->FindProducer(i); stack_.call_arg(IdxNodeBase(producer)); } - String target, device; + ffi::String target, device; if (config()->test_device == "cpu") { target = "llvm"; device = "tvm.cpu()"; @@ -189,7 +189,7 @@ void RelaxCodeGen::CodeGenInference() { } } -const String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { +const ffi::String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { if (prim->optype == "shape") { const auto& producer = graph()->FindNode(prim->GetTypeAttr("producer")); int out_idx = prim->GetTypeAttr("out_idx"); @@ -199,7 +199,7 @@ const String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { return PyCodeGen::DescribePrim(prim); } -const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetRelaxOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; @@ -212,16 +212,16 @@ const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.framework.tvm.GetRelaxSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); }); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tvm/codegen.h b/src/contrib/msc/framework/tvm/codegen.h index 249105b5a50b..0874e21acd4d 100644 --- a/src/contrib/msc/framework/tvm/codegen.h +++ b/src/contrib/msc/framework/tvm/codegen.h @@ -56,13 +56,13 @@ class RelaxCodeGen : public PyCodeGen { void CodeGenInference() final; /*! \brief Describe the prim*/ - const String DescribePrim(const MSCPrim& prim) final; + const ffi::String DescribePrim(const MSCPrim& prim) final; /*! \brief Get the docs for the op*/ - const Array GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get tensor type of the framework*/ - const String TensorType() const final { return "relax.Expr"; } + const ffi::String TensorType() const final { return "relax.Expr"; } }; } // namespace msc diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index a4be884858dc..da2cdfba5914 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -29,7 +29,7 @@ namespace tvm { namespace contrib { namespace msc { -const Array RelaxOpCode::GetDocs() { +const ffi::Array RelaxOpCode::GetDocs() { stack_.Config(this); CodeGenBuild(); bool emit_var = true; @@ -43,14 +43,14 @@ const Array RelaxOpCode::GetDocs() { return stack_.GetDocs(); } -void RelaxOpCode::BuilderEmit(const String& ret, const String& name) { +void RelaxOpCode::BuilderEmit(const ffi::String& ret, const ffi::String& name) { stack_.func_call("block_builder.emit", ret).call_arg(ret); if (name.size() > 0) { stack_.call_arg(DocUtils::ToStr(name), "name_hint"); } } -const ExprDoc RelaxOpCode::GetOutDtype(const String& key, int input_idx) { +const ExprDoc RelaxOpCode::GetOutDtype(const ffi::String& key, int input_idx) { if (config()->use_tools && input_idx >= 0 && node()->inputs.size() > static_cast(input_idx)) { return DocUtils::ToDoc(IdxInput(input_idx) + ".struct_info.dtype"); @@ -62,7 +62,7 @@ const ExprDoc RelaxOpCode::GetOutDtype(const String& key, int input_idx) { return DocUtils::ToStr(out_dtype); } -const std::vector RelaxOpCode::GetAxes(const String& key) { +const std::vector RelaxOpCode::GetAxes(const ffi::String& key) { std::vector axes; int axis; if (!node()->GetAttr(key, &axes) && node()->GetAttr(key, &axis)) { @@ -73,7 +73,7 @@ const std::vector RelaxOpCode::GetAxes(const String& key) { #define RELAX_OP_CODEGEN_METHODS(TypeName) \ public: \ - TypeName(const String& func_name) : RelaxOpCode(func_name) {} + TypeName(const ffi::String& func_name) : RelaxOpCode(func_name) {} class RelaxAdaptivePool2dCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxAdaptivePool2dCodeGen) @@ -101,7 +101,7 @@ class RelaxAttentionCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { for (size_t i = 0; i < 3; i++) { - const String& axes_key = i == 0 ? "axes" : "axes_" + std::to_string(i); + const ffi::String& axes_key = i == 0 ? "axes" : "axes_" + std::to_string(i); stack_.op_call("relax.op.permute_dims", IdxInput(i)) .op_input_arg(i) .op_list_arg(axes_key, "axes"); @@ -129,7 +129,7 @@ class RelaxAxesCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(GetAxes(key)), key); } }; @@ -210,7 +210,7 @@ class RelaxBiasAddCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { int axis = CommonUtils::GetIndex(node()->GetTypeAttr("axis"), node()->OutputAt(0)->Ndim()); - Array expand_shape; + ffi::Array expand_shape; for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { if (i == static_cast(axis)) { expand_shape.push_back(node()->InputAt(0)->DimAt(i)); @@ -263,7 +263,7 @@ class RelaxConstantCodeGen : public RelaxOpCode { class RelaxConvCodeGen : public RelaxOpCode { public: - RelaxConvCodeGen(const String& func_name, bool use_bias) + RelaxConvCodeGen(const ffi::String& func_name, bool use_bias) : RelaxOpCode(func_name), use_bias_(use_bias) {} protected: @@ -286,7 +286,7 @@ class RelaxConvCodeGen : public RelaxOpCode { << "out_layout or data_layout should be given, get " << node(); } const auto& out_layout = tir::Layout(out_layout_str); - Array expand_shape; + ffi::Array expand_shape; for (size_t i = 0; i < node()->OutputAt(0)->Ndim(); i++) { if (out_layout[i].name() == "C") { expand_shape.push_back(node()->OutputAt(0)->DimAt(i)); @@ -335,7 +335,7 @@ class RelaxEinsumCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - const String& key = config()->from_relay ? "equation" : "subscripts"; + const ffi::String& key = config()->from_relay ? "equation" : "subscripts"; stack_.op_call().op_inputs_arg().op_str_arg(key, "subscripts"); } }; @@ -480,12 +480,12 @@ class RelaxPadCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - Array pad_width; + ffi::Array pad_width; const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); ICHECK(attr_pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); for (size_t i = 0; i < attr_pad_width.size(); i += 2) { - const String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + - std::to_string(attr_pad_width[i + 1]) + "]"; + const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + + std::to_string(attr_pad_width[i + 1]) + "]"; pad_width.push_back(cur_pad); } stack_.op_call() @@ -507,6 +507,7 @@ class RelaxPool2dCodeGen : public RelaxOpCode { .op_list_arg("strides") .op_list_arg("padding") .op_list_arg("dilation") + .op_arg("count_include_pad") .op_arg("ceil_mode") .op_str_arg("layout") .op_str_arg("out_layout"); @@ -530,7 +531,7 @@ class RelaxPermuteDimsCodeGen : public RelaxOpCode { class RelaxReduceAxisCodeGen : public RelaxOpCode { public: - RelaxReduceAxisCodeGen(const String& func_name, bool as_list) + RelaxReduceAxisCodeGen(const ffi::String& func_name, bool as_list) : RelaxOpCode(func_name), as_list_(as_list) {} protected: @@ -602,7 +603,7 @@ class RelaxResize2dCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { // roi has forced to be float list - Array roi_list; + ffi::Array roi_list; std::vector roi = node()->GetTypeArrayAttr("roi"); for (const auto& r : roi) { roi_list.push_back("float(" + std::to_string(r) + ")"); @@ -680,7 +681,7 @@ class RelaxTileCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - const String& key = config()->from_relay ? "reps" : "repeats"; + const ffi::String& key = config()->from_relay ? "reps" : "repeats"; stack_.op_call().op_input_arg().op_list_arg(key, "repeats"); } }; @@ -698,7 +699,7 @@ class RelaxTriCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { if (node()->optype == "trilu") { - const String& func_name = + const ffi::String& func_name = node()->GetTypeAttr("upper") ? "relax.op.triu" : "relax.op.tril"; stack_.op_call(func_name).op_input_arg().op_arg("k"); } else { @@ -720,8 +721,10 @@ class RelaxPluginOpCodeGen : public RelaxOpCode { } }; -const std::shared_ptr>> GetRelaxOpCodes() { - static auto map = std::make_shared>>(); +const std::shared_ptr>> +GetRelaxOpCodes() { + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // binary && unary ops map->emplace("abs", std::make_shared("relax.op.abs")); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.h b/src/contrib/msc/framework/tvm/relax_opcode.h index e5914149184e..bbbee44d822d 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.h +++ b/src/contrib/msc/framework/tvm/relax_opcode.h @@ -49,11 +49,11 @@ class RelaxOpCode : public BaseOpCode { * \param func_name the function name for the node. * \param config the config json for the node. */ - explicit RelaxOpCode(const String& func_name) + explicit RelaxOpCode(const ffi::String& func_name) : BaseOpCode(func_name) {} /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; protected: RelaxOpCodeStack stack_; @@ -62,20 +62,21 @@ class RelaxOpCode : public BaseOpCode { virtual void CodeGenBuild() = 0; /*! \brief coda stack emit docs*/ - void BuilderEmit(const String& ret, const String& name = ""); + void BuilderEmit(const ffi::String& ret, const ffi::String& name = ""); /*! \brief Get the out_dtype attribute*/ - const ExprDoc GetOutDtype(const String& key = "out_dtype", int input_idx = 0); + const ExprDoc GetOutDtype(const ffi::String& key = "out_dtype", int input_idx = 0); /*! \brief Get the axes attribute*/ - const std::vector GetAxes(const String& key = "axes"); + const std::vector GetAxes(const ffi::String& key = "axes"); }; /*! * \brief Get the map of available RelaxOpCode, use optype as key * \return Map of */ -const std::shared_ptr>> GetRelaxOpCodes(); +const std::shared_ptr>> +GetRelaxOpCodes(); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/base_codegen.h b/src/contrib/msc/plugin/base_codegen.h index cd5f03ff7716..fcb1f3982f79 100644 --- a/src/contrib/msc/plugin/base_codegen.h +++ b/src/contrib/msc/plugin/base_codegen.h @@ -66,13 +66,15 @@ class BasePluginCodeGen { virtual ~BasePluginCodeGen() = default; /*! \brief Get plugin sources*/ - virtual const Map GetBuildSources(const std::string& print_options = "") { - Map sources; + virtual const ffi::Map GetBuildSources( + const std::string& print_options = "") { + ffi::Map sources; // plugin sources for (const auto& name : ListPluginNames()) { const auto& plugin = GetPlugin(name); // attr declare - const String& attr_macro = "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_ATTR_H_"; + const ffi::String& attr_macro = + "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_ATTR_H_"; this->stack_.line("#ifndef " + attr_macro) .line("#define " + attr_macro) .line() @@ -90,7 +92,8 @@ class BasePluginCodeGen { EndNamespace(); sources.Set(plugin->name + "_attr.cc", ToCppSource(print_options)); // op decalre - const String& op_macro = "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_OP_H_"; + const ffi::String& op_macro = + "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_OP_H_"; this->stack_.line("#ifndef " + op_macro).line("#define " + op_macro).line(); CodeGenOpHeader(plugin); StartNamespace(); @@ -114,7 +117,7 @@ class BasePluginCodeGen { } } // cmakelists - std::set devices; + std::set devices; for (const auto& name : ListPluginNames()) { const auto& plugin = GetPlugin(name); for (const auto& pair : plugin->externs) { @@ -129,8 +132,9 @@ class BasePluginCodeGen { } /*! \brief Get manager sources*/ - virtual const Map GetManagerSources(const std::string& print_options = "") { - Map sources; + virtual const ffi::Map GetManagerSources( + const std::string& print_options = "") { + ffi::Map sources; CodeGenManagerDepends(); this->stack_.class_def("PluginManager(object)").class_start(); CodeGenManagerMethods(); @@ -138,7 +142,7 @@ class BasePluginCodeGen { CodeGenOpBuilder(GetPlugin(name)); } if (this->config()->need_convert) { - Map symbols; + ffi::Map symbols; this->stack_.func_def("get_convert_map") .func_decorator("classmethod") .func_arg("cls", "object") @@ -165,7 +169,7 @@ class BasePluginCodeGen { /*! \brief Header of plugin files*/ virtual void CodeGenOpHeader(const Plugin& plugin) { this->stack_.line("#include \"" + plugin->name + "_attr.h\""); - std::set include_headers; + std::set include_headers; for (const auto& pair : plugin->externs) { if (pair.second->header.size() > 0 && !include_headers.count(pair.second->header)) { this->stack_.line("#include \"" + pair.second->header + "\""); @@ -194,7 +198,8 @@ class BasePluginCodeGen { /*! \brief Codegen safe call extern*/ void CodeGenSafeCall(const PluginExtern& extern_func, - const Array& call_args = Array(), const String& ret = "") { + const ffi::Array& call_args = ffi::Array(), + const ffi::String& ret = "") { this->stack_.scope_start("try {").func_call(extern_func->name, ret); for (const auto& arg : call_args) { this->stack_.call_arg(arg); @@ -244,14 +249,15 @@ class BasePluginCodeGen { virtual void CodeGenOpRuntime(const Plugin& plugin) {} /*! \brief Codegen cmake file*/ - virtual void CodeGenCmake(const std::set& devices) { + virtual void CodeGenCmake(const std::set& devices) { CodeGenPreCmake(devices); CodeGenPostCmake(devices); } /*! \brief Codegen cmake start*/ - void CodeGenPreCmake(const std::set& devices, - const Map& extra_flags = Map()) { + void CodeGenPreCmake(const std::set& devices, + const ffi::Map& extra_flags = + ffi::Map()) { const auto& p_name = this->config()->project_name; stack_.line("cmake_minimum_required(VERSION " + this->config()->cmake_version + " FATAL_ERROR)") .line("project(" + p_name + ")"); @@ -277,9 +283,9 @@ class BasePluginCodeGen { } /*! \brief Codegen cmake end*/ - void CodeGenPostCmake(const std::set& devices, - const Array& extra_includes = Array(), - const Array& extra_libs = Array()) { + void CodeGenPostCmake(const std::set& devices, + const ffi::Array& extra_includes = ffi::Array(), + const ffi::Array& extra_libs = ffi::Array()) { const auto& p_name = this->config()->project_name; stack_.line() .line("file(GLOB_RECURSE PLUGIN_HEADERS src/*.h)") @@ -293,7 +299,7 @@ class BasePluginCodeGen { stack_.line("add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS})"); } // define includes - String includes = StringUtils::Join(extra_includes, " "); + ffi::String includes = StringUtils::Join(extra_includes, " "); if (this->config()->includes.size() > 0) { includes = includes + " " + StringUtils::Join(this->config()->includes, " "); } @@ -301,7 +307,7 @@ class BasePluginCodeGen { stack_.line("target_include_directories(" + p_name + " PUBLIC " + includes + ")"); } // define libs - String link_libs = StringUtils::Join(extra_libs, " "); + ffi::String link_libs = StringUtils::Join(extra_libs, " "); const auto& libs = StringUtils::Join(this->config()->libs, " "); if (libs.size() > 0) { link_libs = link_libs + " " + libs; @@ -496,10 +502,10 @@ class BasePluginCodeGen { } /*! \brief Codegen convert function for plugin*/ - virtual const String CodeGenOpConvert(const Plugin& plugin) { return plugin->name; } + virtual const ffi::String CodeGenOpConvert(const Plugin& plugin) { return plugin->name; } /*! \brief Change code stack to cpp source*/ - const String ToCppSource(const std::string& print_options = "") { + const ffi::String ToCppSource(const std::string& print_options = "") { CppPrinter printer(print_options); for (const auto& d : this->stack_.GetDocs()) { printer.Append(d); @@ -509,7 +515,7 @@ class BasePluginCodeGen { } /*! \brief Change code stack to python source*/ - const String ToPySource(const std::string& print_options = "") { + const ffi::String ToPySource(const std::string& print_options = "") { PythonPrinter printer(print_options); for (const auto& d : this->stack_.GetDocs()) { printer.Append(d); @@ -518,23 +524,23 @@ class BasePluginCodeGen { return printer.GetString(); } - std::vector> GetDtypeMatrix(const Plugin& plugin) { - std::vector> matrix; + std::vector> GetDtypeMatrix(const Plugin& plugin) { + std::vector> matrix; if (plugin->support_dtypes.size() == 0) { - std::unordered_map dtypes; + std::unordered_map dtypes; for (size_t i = 0; i < plugin->inputs.size(); i++) { dtypes[i] = plugin->inputs[i]->dtype; } matrix.push_back(dtypes); } else { - Array templates; - Array> condidates; + ffi::Array templates; + ffi::Array> condidates; for (const auto& pair : plugin->support_dtypes) { templates.push_back(pair.first); condidates.push_back(pair.second); } for (const auto& t_dtypes : ArrayUtils::Product(condidates)) { - std::unordered_map dtypes; + std::unordered_map dtypes; for (size_t i = 0; i < templates.size(); i++) { for (size_t in_idx = 0; in_idx < plugin->inputs.size(); in_idx++) { if (plugin->inputs[in_idx]->dtype == templates[i]) { @@ -554,11 +560,11 @@ class BasePluginCodeGen { return matrix; } - const Map GetTensorDtypes(const Plugin& plugin, - const std::unordered_map& dtypes) { - Map tensor_dtypes; + const ffi::Map GetTensorDtypes( + const Plugin& plugin, const std::unordered_map& dtypes) { + ffi::Map tensor_dtypes; for (const auto& pair : dtypes) { - const String& ref_dtype = plugin->inputs[pair.first]->dtype; + const ffi::String& ref_dtype = plugin->inputs[pair.first]->dtype; for (const auto& t : plugin->inputs) { if (t->dtype == ref_dtype) { tensor_dtypes.Set(t->name, pair.second); @@ -579,8 +585,8 @@ class BasePluginCodeGen { } /*! \brief Change plugin comment in python*/ - const String GetPyComment(const Plugin& plugin) { - String comment = "Python wrapper for " + plugin->name + "\nInputs\n------"; + const ffi::String GetPyComment(const Plugin& plugin) { + ffi::String comment = "Python wrapper for " + plugin->name + "\nInputs\n------"; for (const auto& t : plugin->inputs) { comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe; } @@ -598,16 +604,16 @@ class BasePluginCodeGen { } /*! \brief Get class name for meta attrs*/ - const String MetaAttrCls(const Plugin& plugin) const { return plugin->name + "MetaAttr"; } + const ffi::String MetaAttrCls(const Plugin& plugin) const { return plugin->name + "MetaAttr"; } /*! \brief Get converter name for plugin*/ - const String ConverterName(const Plugin& plugin) const { return plugin->name + "Converter"; } + const ffi::String ConverterName(const Plugin& plugin) const { return plugin->name + "Converter"; } /*! \brief Check if the type is list type. */ - bool IsListType(const String& type) { return StringUtils::StartsWith(type, "list"); } + bool IsListType(const ffi::String& type) { return StringUtils::StartsWith(type, "list"); } /*! \brief Get type of element. */ - const String GetEleType(const String& type) { + const ffi::String GetEleType(const ffi::String& type) { if (!IsListType(type)) { return ""; } @@ -615,7 +621,7 @@ class BasePluginCodeGen { } /*! \brief Type name in cpp*/ - virtual const String ToCppType(const String& type) { + virtual const ffi::String ToCppType(const ffi::String& type) { if (IsListType(type)) { const auto& ele_type = GetEleType(type); return "std::vector<" + ToCppType(ele_type) + ">"; @@ -636,7 +642,7 @@ class BasePluginCodeGen { } /*! \brief Type name in python*/ - virtual const String ToPyType(const String& type) { + virtual const ffi::String ToPyType(const ffi::String& type) { if (IsListType(type)) { const auto& ele_type = GetEleType(type); return "List[" + ToPyType(ele_type) + "]"; diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc index f1ab676b707f..890b9a6df7b3 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.cc +++ b/src/contrib/msc/plugin/tensorrt_codegen.cc @@ -120,7 +120,7 @@ void TensorRTPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { .for_start("i", 0, plugin->attrs.size()); for (size_t i = 0; i < plugin->attrs.size(); i++) { const auto& attr = plugin->attrs[i]; - const String& cond = "strcmp(fields[i].name, \"" + attr->name + "\") == 0"; + const ffi::String& cond = "strcmp(fields[i].name, \"" + attr->name + "\") == 0"; if (i == 0) { stack_.switch_start(cond); } else { @@ -275,7 +275,7 @@ void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .declare("bool", "support"); size_t cnt = 0; for (const auto& dtypes : GetDtypeMatrix(plugin)) { - const String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; + const ffi::String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; if (cnt == 0) { stack_.switch_start(cond); } else { @@ -374,7 +374,7 @@ void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .declare("bool", "support"); size_t cnt = 0; for (const auto& dtypes : GetDtypeMatrix(plugin)) { - String cond; + ffi::String cond; for (size_t i = 0; i < plugin->inputs.size(); i++) { cond = cond + "io_desc[" + std::to_string(i) + "].type == TRTUtils::ToDataType(\"" + dtypes.at(i) + "\")"; @@ -419,8 +419,8 @@ void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { CodegenCreator(plugin, true, false); } -void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { - Map flags; +void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { + ffi::Map flags; flags.Set("PLUGIN_SUPPORT_TENSORRT", ""); flags.Set("TRT_MAJOR", std::to_string(config()->version[0])); flags.Set("TRT_MINOR", std::to_string(config()->version[1])); @@ -432,7 +432,7 @@ void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { .line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root + " PATH_SUFFIXES lib)") .line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-terminate\")"); - Array includes, libs; + ffi::Array includes, libs; includes.push_back("${TRT_INCLUDE_DIR}"); libs.push_back("${TRT_LIBS}"); CodeGenPostCmake(devices, includes, libs); @@ -454,7 +454,7 @@ void TensorRTPluginCodeGen::CodeGenManagerMethods() { void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dynamic, bool in_declare) { const auto& op_cls = OpCls(plugin, dynamic); - const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; + const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; if (in_declare) { stack_.comment("common methods for " + op_cls); stack_.constructor_def(op_cls).constructor_arg("name", "const std::string&"); @@ -567,7 +567,7 @@ void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dy .line("assert(char_buf == (start_buf + getSerializationSize()));") .func_end(); // getPluginType - const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); + const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); stack_.func_def(op_cls + "::getPluginType", "const char*") .func_decorator("const noexcept") .func_start() @@ -644,7 +644,7 @@ void TensorRTPluginCodeGen::CodegenOpMembers(const Plugin& plugin, bool dynamic) void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare) { const auto& creator_cls = CreatorCls(plugin, dynamic); - const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; + const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; if (in_declare) { stack_.class_def(creator_cls + " : public IPluginCreator") .class_start() @@ -679,7 +679,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b .line() .class_end(); } else { - const String& attr_name = MetaAttrCls(plugin); + const ffi::String& attr_name = MetaAttrCls(plugin); // static members stack_.comment("static members and register for " + plugin->name) .declare("PluginFieldCollection", creator_cls + "::collection_") @@ -705,7 +705,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b .func_call("data", fields_doc, DocUtils::ToDoc("fields_")) .constructor_end(); // getPluginName - const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); + const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); stack_.func_def(creator_cls + "::getPluginName", "const char*") .func_decorator("const noexcept") .func_start() @@ -753,7 +753,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b .for_start("i", plugin->attrs.size(), fields_size); for (size_t i = 0; i < plugin->inputs.size(); i++) { const auto& tensor = plugin->inputs[i]; - const String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0"; + const ffi::String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0"; if (i == 0) { stack_.switch_start(cond); } else { @@ -794,7 +794,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b } void TensorRTPluginCodeGen::CodegenOutputInfer(const Plugin& plugin, bool as_desc) { - Array infer_args{"input_metas_", "meta_attr_", "false"}; + ffi::Array infer_args{"input_metas_", "meta_attr_", "false"}; stack_.line("assert(n_inputs == " + std::to_string(plugin->inputs.size()) + ");") .func_call("resize", "", "input_metas_") .call_arg(plugin->inputs.size()) @@ -810,7 +810,7 @@ void TensorRTPluginCodeGen::CodegenOutputInfer(const Plugin& plugin, bool as_des } void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) { - Array infer_args{"input_metas_", "meta_attr_", "false"}; + ffi::Array infer_args{"input_metas_", "meta_attr_", "false"}; CodeGenSafeCall(plugin->externs["infer_buffer"], infer_args, "buffer_metas_"); stack_.for_start("b", "buffer_metas_") .assign("size", "size + max_batch * b.size(false)") @@ -820,12 +820,12 @@ void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) { void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { ICHECK(plugin->externs.count("cuda_compute")) << "cuda_compute is needed fo TensorRT plugin"; auto prepare_tensor = [this, &dynamic](const PluginTensor& tensor, - const Map& dtypes, size_t idx, - const String& collect) { - const String& t_name = "d_" + tensor->name; - const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const String& tensor_type = "DataTensor<" + t_dtype + ">"; - const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; + const ffi::Map& dtypes, + size_t idx, const ffi::String& collect) { + const ffi::String& t_name = "d_" + tensor->name; + const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; + const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; stack_.func_call("TRTUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)); const auto& t_meta = DocUtils::ToIndex(collect + "_metas_", idx); if (dynamic) { @@ -844,8 +844,8 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { }; for (const auto& dtypes : GetDtypeMatrix(plugin)) { const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - Array compute_args; - String dtype_cond = ""; + ffi::Array compute_args; + ffi::String dtype_cond = ""; if (dynamic) { for (size_t i = 0; i < plugin->inputs.size(); i++) { dtype_cond = dtype_cond + "input_descs[" + std::to_string(i) + @@ -858,19 +858,19 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { // prepare compute datas stack_.cond_if(dtype_cond).comment("prepare compute datas"); for (size_t i = 0; i < plugin->inputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->outputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); compute_args.push_back(t_name); } if (plugin->buffers.size() > 0) { stack_.assign("offset", 0, "size_t"); for (size_t i = 0; i < plugin->buffers.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "buffer"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "buffer"); compute_args.push_back(t_name); - const String& size_name = "size_" + plugin->buffers[i]->name; + const ffi::String& size_name = "size_" + plugin->buffers[i]->name; stack_ .func_call("size", DocUtils::ToDeclare("size_t", size_name), DocUtils::ToIndex("buffer_metas_", i)) @@ -885,11 +885,11 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.plugin.GetTensorRTPluginSources", - [](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { + [](const ffi::String& codegen_config, const ffi::String& print_config, + const ffi::String& codegen_type) -> ffi::Map { TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); if (codegen_type == "build") { return codegen.GetBuildSources(print_config); @@ -897,9 +897,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (codegen_type == "manager") { return codegen.GetManagerSources(print_config); } - return Map(); + return ffi::Map(); }); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/tensorrt_codegen.h b/src/contrib/msc/plugin/tensorrt_codegen.h index 24fb4e5dfca2..c5b0e585a139 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.h +++ b/src/contrib/msc/plugin/tensorrt_codegen.h @@ -79,25 +79,25 @@ class TensorRTPluginCodeGen : public BasePluginCodeGen& devices) final; + void CodeGenCmake(const std::set& devices) final; /*! \brief Codegen manager methods*/ void CodeGenManagerMethods() final; private: /*! \brief Op class name of plugin*/ - const String OpCls(const Plugin& plugin, bool dynamic) const { + const ffi::String OpCls(const Plugin& plugin, bool dynamic) const { return plugin->name + (dynamic ? "DynamicPlugin" : "Plugin"); } /*! \brief Creator class name of plugin*/ - const String CreatorCls(const Plugin& plugin, bool dynamic) const { + const ffi::String CreatorCls(const Plugin& plugin, bool dynamic) const { return plugin->name + (dynamic ? "DynamicCreator" : "Creator"); } bool IsMixPrecision(const Plugin& plugin) { for (const auto& dtypes : GetDtypeMatrix(plugin)) { - String ref_dtype = ""; + ffi::String ref_dtype = ""; for (const auto& pair : dtypes) { if (ref_dtype.size() == 0) { ref_dtype = pair.second; diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc index 63d068acab34..d5a2b5353de4 100644 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -153,7 +153,7 @@ void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { CodeGenMalloc(plugin, plugin->buffers, "buffer"); } // do the compute - String device_cond = ""; + ffi::String device_cond = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { device_cond = device_cond + "input_tensors[" + std::to_string(i) + "].is_cuda()"; @@ -216,15 +216,15 @@ void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .func_end(); } -void TorchPluginCodeGen::CodeGenCmake(const std::set& devices) { - Map flags; +void TorchPluginCodeGen::CodeGenCmake(const std::set& devices) { + ffi::Map flags; flags.Set("PLUGIN_SUPPORT_TORCH", ""); CodeGenPreCmake(devices, flags); stack_.line() .line("set(CMAKE_CXX_STANDARD 17)") .line("list(APPEND CMAKE_PREFIX_PATH \"" + config()->torch_prefix + "\")") .line("find_package(Torch REQUIRED)"); - Array includes, libs; + ffi::Array includes, libs; libs.push_back("${TORCH_LIBRARIES}"); CodeGenPostCmake(devices, includes, libs); } @@ -366,14 +366,14 @@ void TorchPluginCodeGen::CodeGenConvertDepends() { .line(); } -const String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { +const ffi::String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { stack_.func_def(ConverterName(plugin), "relax.Var") .func_arg("node", "fx.node.Node") .func_arg("ctx", "TorchFXImporter") .func_start() .func_call("retrieve_args", "args", "ctx") .call_arg("node"); - Array args; + ffi::Array args; for (size_t i = 0; i < plugin->inputs.size(); i++) { const auto& tensor = plugin->inputs[i]; stack_.assign(tensor->name, DocUtils::ToIndex("args", i + 1)); @@ -407,9 +407,9 @@ const String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { .call_arg("op") .call_arg("name"); if (plugin->outputs.size() == 1) { - stack_.func_end(DocUtils::ToList(Array{"var"})); + stack_.func_end(DocUtils::ToList(ffi::Array{"var"})); } else { - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < plugin->outputs.size(); i++) { const auto& tensor = plugin->outputs[i]; stack_.func_call("relax.TupleGetItem", tensor->name).call_arg("var").call_arg(i); @@ -420,9 +420,10 @@ const String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { return EntryName(plugin); } -void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, const Array& tensors, - const String& collect) { - Array call_args{"input_metas", "meta_attr_", "true"}; +void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, + const ffi::Array& tensors, + const ffi::String& collect) { + ffi::Array call_args{"input_metas", "meta_attr_", "true"}; stack_.line().comment("malloc " + collect).declare("std::vector", collect + "_metas"); CodeGenSafeCall(plugin->externs["infer_" + collect], call_args, collect + "_metas"); for (size_t i = 0; i < tensors.size(); i++) { @@ -442,13 +443,14 @@ void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, const Array& dtypes, - size_t idx, const String& collect) { - const String& t_name = "d_" + tensor->name; - const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const String& tensor_type = "DataTensor<" + t_dtype + ">"; - const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; +void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) { + auto prepare_tensor = [this](const PluginTensor& tensor, + const ffi::Map& dtypes, size_t idx, + const ffi::String& collect) { + const ffi::String& t_name = "d_" + tensor->name; + const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; + const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; stack_.func_call("TorchUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) .call_arg(DocUtils::ToIndex(collect + "_tensors", idx)) .call_arg(DocUtils::ToIndex(collect + "_metas", idx)) @@ -459,8 +461,8 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi if (plugin->externs.count(device + "_compute")) { for (const auto& dtypes : GetDtypeMatrix(plugin)) { const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - Array compute_args; - String dtype_cond = ""; + ffi::Array compute_args; + ffi::String dtype_cond = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { dtype_cond = dtype_cond + "input_metas[" + std::to_string(i) + "].data_type() == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")"; @@ -469,15 +471,15 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi // prepare compute datas stack_.cond_if(dtype_cond).comment("prepare compute datas"); for (size_t i = 0; i < plugin->inputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->outputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->buffers.size(); i++) { - const String& t_name = prepare_tensor(plugin->buffers[i], tensor_dtypes, i, "buffer"); + const ffi::String& t_name = prepare_tensor(plugin->buffers[i], tensor_dtypes, i, "buffer"); compute_args.push_back(t_name); } compute_args.push_back("meta_attr_"); @@ -494,11 +496,11 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.plugin.GetTorchPluginSources", - [](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { + [](const ffi::String& codegen_config, const ffi::String& print_config, + const ffi::String& codegen_type) -> ffi::Map { TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); if (codegen_type == "build") { return codegen.GetBuildSources(print_config); @@ -506,9 +508,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (codegen_type == "manager") { return codegen.GetManagerSources(print_config); } - return Map(); + return ffi::Map(); }); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/torch_codegen.h b/src/contrib/msc/plugin/torch_codegen.h index 4452650e2271..1dae9134e704 100644 --- a/src/contrib/msc/plugin/torch_codegen.h +++ b/src/contrib/msc/plugin/torch_codegen.h @@ -79,7 +79,7 @@ class TorchPluginCodeGen : public BasePluginCodeGen { void CodeGenOpDefine(const Plugin& plugin) final; /*! \brief Codegen cmake file*/ - void CodeGenCmake(const std::set& devices) final; + void CodeGenCmake(const std::set& devices) final; /*! \brief Codegen manager depends*/ void CodeGenManagerDepends() final; @@ -94,18 +94,18 @@ class TorchPluginCodeGen : public BasePluginCodeGen { void CodeGenConvertDepends() final; /*! \brief Codegen convert function for plugin*/ - const String CodeGenOpConvert(const Plugin& plugin) final; + const ffi::String CodeGenOpConvert(const Plugin& plugin) final; private: /*! \brief Codegen malloc for outputs/buffers*/ - void CodeGenMalloc(const Plugin& plugin, const Array& tensors, - const String& collect); + void CodeGenMalloc(const Plugin& plugin, const ffi::Array& tensors, + const ffi::String& collect); /*! \brief Codegen compute*/ - void CodeGenCompute(const Plugin& plugin, const String& device); + void CodeGenCompute(const Plugin& plugin, const ffi::String& device); /*! \brief Entry name of torch function*/ - const String EntryName(const Plugin& plugin) { + const ffi::String EntryName(const Plugin& plugin) { std::string lower_name; const std::string& name = std::string(plugin->name); for (size_t i = 0; i < name.size(); i++) { @@ -119,7 +119,7 @@ class TorchPluginCodeGen : public BasePluginCodeGen { } /*! \brief Type name in torch*/ - const String ToTorchType(const String& type) { + const ffi::String ToTorchType(const ffi::String& type) { if (type == "float") { return "double"; } diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 7410867aaf25..7a109a147280 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -35,7 +35,7 @@ void TVMPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { stack_.comment("convert exprs to meta attrs method") .func_def(attr_name + "_from_exprs", "const " + attr_name); for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } // args to meta_attr @@ -50,12 +50,12 @@ void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { // exprs to meta_attr stack_.func_def(attr_name + "_from_exprs", "const " + attr_name); for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } stack_.func_start().declare(attr_name, "meta_attr"); for (const auto& a : plugin->attrs) { - const String& convert = IsListType(a->type) ? "AttrFromPrims" : "AttrFromPrim"; + const ffi::String& convert = IsListType(a->type) ? "AttrFromPrims" : "AttrFromPrim"; stack_.func_call("TVMUtils::" + convert) .call_arg(a->name) .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); @@ -92,30 +92,30 @@ void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { void TVMPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { // infer struct info - stack_.func_def("InferStructInfo" + plugin->name, "Array"); + stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array"); for (const auto& t : plugin->inputs) { stack_.func_arg(t->name, "const Expr&"); } for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } // infer layout stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") - .func_arg("inputs", "const Array&") + .func_arg("inputs", "const ffi::Array&") .func_arg("var_layout_map", "const VarLayoutMap&"); } void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { const auto& attr_name = MetaAttrCls(plugin); // infer struct info - Array infer_args{"input_metas", "meta_attr", "false"}; - stack_.func_def("InferStructInfo" + plugin->name, "Array"); + ffi::Array infer_args{"input_metas", "meta_attr", "false"}; + stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array"); for (const auto& t : plugin->inputs) { stack_.func_arg(t->name, "const Expr&"); } for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } stack_.func_start() @@ -133,7 +133,7 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { } stack_.declare("std::vector", "output_metas"); CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas"); - stack_.declare("Array", "output_sinfo"); + stack_.declare("ffi::Array", "output_sinfo"); for (size_t i = 0; i < plugin->outputs.size(); i++) { stack_.func_call("push_back", "", "output_sinfo") .inplace_start("TVMUtils::ToTensorStructInfo") @@ -152,20 +152,20 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { // infer layout stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") - .func_arg("inputs", "const Array&") + .func_arg("inputs", "const ffi::Array&") .func_arg("var_layout_map", "const VarLayoutMap&") .func_start() .comment("define attrs"); for (size_t i = 0; i < plugin->attrs.size(); i++) { const auto& attr = plugin->attrs[i]; - const String& anno = IsListType(attr->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(attr->type) ? "Tuple" : "PrimValue"; stack_ .func_call("Downcast<" + anno + ">", DocUtils::ToDeclare("const auto&", "attr_" + attr->name)) .call_arg(DocUtils::ToIndex("inputs", i + plugin->inputs.size())); } - stack_.declare("Array", "arg_layouts") - .declare("Array", "output_layouts") + stack_.declare("ffi::Array", "arg_layouts") + .declare("ffi::Array", "output_layouts") .comment("extract meta attrs") .func_call(attr_name + "_from_exprs", "const " + attr_name + "& meta_attr"); for (const auto& a : plugin->attrs) { @@ -201,7 +201,7 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("output_metas", "i"), "layout_name()")) .inplace_end() .for_end() - .declare("Array", "input_layouts") + .declare("ffi::Array", "input_layouts") .func_call("push_back", "", "input_layouts") .inplace_start("LayoutDecision") .call_arg(DocUtils::ToStr("")) @@ -229,10 +229,10 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { ICHECK(!plugin->externs.count("infer_buffer")) << "infer_buffer is not supported for tvm runtime"; const auto& attr_name = MetaAttrCls(plugin); const auto& func_name = ComputeName(plugin); - String device_cond = ""; - String device_index = ""; + ffi::String device_cond = ""; + ffi::String device_index = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { - String device_type = ""; + ffi::String device_type = ""; if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { device_type = "DLDeviceType::kDLCUDA"; } else { @@ -267,8 +267,8 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { .line(); } -void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { - Map flags; +void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { + ffi::Map flags; flags.Set("PLUGIN_SUPPORT_TVM", ""); CodeGenPreCmake(devices, flags); stack_.line("set(CMAKE_CXX_STANDARD 17)") @@ -276,7 +276,7 @@ void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { .line() .line("set(TVM_ROOT " + config()->tvm_root + ")") .line("find_library(TVM_LIB NAMES tvm HINTS ${TVM_ROOT}/build NO_DEFAULT_PATH)"); - Array includes, libs; + ffi::Array includes, libs; includes.push_back("${TVM_ROOT}/include"); includes.push_back("${TVM_ROOT}/3rdparty/dmlc-core/include"); includes.push_back("${TVM_ROOT}/3rdparty/dlpack/include"); @@ -318,7 +318,7 @@ void TVMPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { stack_.func_arg(attr->name, ToPyType(attr->type), attr->default_value); } stack_.func_arg("name", "str", "\"" + plugin->name + "\"").func_start(); - Array args; + ffi::Array args; for (const auto& t : plugin->inputs) { args.push_back(t->name); } @@ -345,15 +345,17 @@ void TVMPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { stack_.func_end("op").comment(GetPyComment(plugin), true); } -void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device) { +void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) { if (plugin->externs.count(device + "_compute")) { // compute with dtype - auto prepare_tensor = [this](const PluginTensor& tensor, const Map& dtypes, - size_t idx, const String& collect) { - const String& t_name = "d_" + tensor->name; - const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const String& tensor_type = "DataTensor<" + t_dtype + ">"; - const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; + auto prepare_tensor = [this](const PluginTensor& tensor, + const ffi::Map& dtypes, size_t idx, + const ffi::String& collect) { + const ffi::String& t_name = "d_" + tensor->name; + const ffi::String& t_dtype = + dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; + const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; stack_.func_call("TVMUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) .call_arg(tensor->name) .call_arg(collect == "input"); @@ -361,8 +363,8 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device }; for (const auto& dtypes : GetDtypeMatrix(plugin)) { const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - Array compute_args; - String dtype_cond = ""; + ffi::Array compute_args; + ffi::String dtype_cond = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { const auto& t_name = plugin->inputs[i]->name; dtype_cond = dtype_cond + "TVMUtils::ToMetaType(" + t_name + @@ -372,18 +374,18 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device // prepare compute datas stack_.cond_if(dtype_cond).comment("prepare compute datas"); for (size_t i = 0; i < plugin->inputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->outputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); compute_args.push_back(t_name); } ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm"; compute_args.push_back("meta_attr"); if (device == "cuda") { // TODO(tvm-team): update to support get stream from device id - stack_.assign("stream", "TVMFFIEnvGetCurrentStream(kDLCUDA, 0)", "auto"); + stack_.assign("stream", "TVMFFIEnvGetStream(kDLCUDA, 0)", "auto"); compute_args.push_back("stream"); } CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); @@ -394,11 +396,11 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.plugin.GetTVMPluginSources", - [](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { + [](const ffi::String& codegen_config, const ffi::String& print_config, + const ffi::String& codegen_type) -> ffi::Map { TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); if (codegen_type == "build") { return codegen.GetBuildSources(print_config); @@ -406,9 +408,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (codegen_type == "manager") { return codegen.GetManagerSources(print_config); } - return Map(); + return ffi::Map(); }); -}); +} } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/tvm_codegen.h b/src/contrib/msc/plugin/tvm_codegen.h index 520e35de95c6..926c5162005a 100644 --- a/src/contrib/msc/plugin/tvm_codegen.h +++ b/src/contrib/msc/plugin/tvm_codegen.h @@ -82,7 +82,7 @@ class TVMPluginCodeGen : public BasePluginCodeGen { void CodeGenOpRuntime(const Plugin& plugin) final; /*! \brief Codegen cmake file*/ - void CodeGenCmake(const std::set& devices) final; + void CodeGenCmake(const std::set& devices) final; /*! \brief Codegen manager depends*/ void CodeGenManagerDepends() final; @@ -95,13 +95,13 @@ class TVMPluginCodeGen : public BasePluginCodeGen { private: /*! \brief Func name of compute*/ - const String ComputeName(const Plugin& plugin) { return plugin->name + "_compute"; } + const ffi::String ComputeName(const Plugin& plugin) { return plugin->name + "_compute"; } /*! \brief Codegen compute*/ - void CodeGenCompute(const Plugin& plugin, const String& device); + void CodeGenCompute(const Plugin& plugin, const ffi::String& device); /*! \brief Type name in tvm*/ - const String ToTVMType(const String& type) { + const ffi::String ToTVMType(const ffi::String& type) { if (type == "string") { return "StringImm"; } diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc index 41c75c875b78..81d6bb7e5891 100644 --- a/src/ir/analysis.cc +++ b/src/ir/analysis.cc @@ -29,25 +29,25 @@ namespace tvm { namespace ir { -Map> CollectCallMap(const IRModule& mod) { +ffi::Map> CollectCallMap(const IRModule& mod) { struct CalleeCollectorImpl : CalleeCollector { void Mark(GlobalVar gvar) override { gvars.push_back(gvar); } support::OrderedSet gvars; }; - Map> call_map; + ffi::Map> call_map; for (const auto& [gvar, base_func] : mod->functions) { CalleeCollectorImpl collector; CalleeCollector::vtable()(base_func, &collector); - call_map.Set(gvar, Array{collector.gvars.begin(), collector.gvars.end()}); + call_map.Set(gvar, ffi::Array{collector.gvars.begin(), collector.gvars.end()}); } return call_map; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.analysis.CollectCallMap", CollectCallMap); -}); +} } // namespace ir } // namespace tvm diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index 3436d49b02ee..3dd7c6a5ff8f 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -56,7 +56,7 @@ BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) { } } // namespace -Pass ApplyPassToFunction(Pass pass, String func_name_regex, +Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex, bool error_if_no_function_matches_regex) { auto pass_name = static_cast(std::stringstream() << "ApplyPassTo" << func_name_regex) @@ -65,15 +65,15 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex]( IRModule mod, PassContext) -> IRModule { bool at_least_one_function_matched_regex = false; - std::unordered_set keep_original_version; - std::unordered_set internal_functions; + std::unordered_set keep_original_version; + std::unordered_set internal_functions; IRModule subset; for (auto [gvar, func] : mod->functions) { std::string name = gvar->name_hint; if (tvm::runtime::regex_match(name, func_name_regex)) { at_least_one_function_matched_regex = true; - if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { // Function may be mutated, but is an internal function. Mark // it as externally-exposed, so that any call-tracing internal // transforms do not remove this function, in case it its @@ -97,7 +97,7 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, if (error_if_no_function_matches_regex) { CHECK(at_least_one_function_matched_regex) << "No function matched regex '" << func_name_regex << "', out of functions " << [&]() { - Array function_names; + ffi::Array function_names; for (const auto& [gvar, func] : mod->functions) { function_names.push_back(gvar->name_hint); } @@ -130,10 +130,10 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, return CreateModulePass(pass_func, 0, pass_name, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("transform.ApplyPassToFunction", ApplyPassToFunction); -}); +} } // namespace transform } // namespace tvm diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 66a43f93c7d5..748f4bf5c93f 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -28,12 +28,12 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { AttrFieldInfoNode::RegisterReflection(); DictAttrsNode::RegisterReflection(); -}); +} -DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { +DictAttrs WithAttrs(DictAttrs attrs, ffi::Map new_attrs) { if (new_attrs.empty()) { return attrs; } @@ -45,7 +45,7 @@ DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { return attrs; } -DictAttrs WithAttr(DictAttrs attrs, String key, ffi::Any value) { +DictAttrs WithAttr(DictAttrs attrs, ffi::String key, ffi::Any value) { attrs.CopyOnWrite()->dict.Set(key, value); return attrs; } @@ -57,23 +57,23 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { void DictAttrsNode::InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { - String key = args[i].cast(); + ffi::String key = args[i].cast(); ffi::AnyView val = args[i + 1]; dict.Set(key, val); } } -DictAttrs::DictAttrs(Map dict) { - ObjectPtr n = make_object(); +DictAttrs::DictAttrs(ffi::Map dict) { + ObjectPtr n = ffi::make_object(); n->dict = std::move(dict); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ tvm::ffi::reflection::ObjectDef(); }); +TVM_FFI_STATIC_INIT_BLOCK() { tvm::ffi::reflection::ObjectDef(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.DictAttrsGetDict", [](DictAttrs attrs) { return attrs->dict; }); -}); +} } // namespace tvm diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index fa48ceba288b..e20c6b8e1715 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -29,25 +29,25 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { DiagnosticNode::RegisterReflection(); DiagnosticRendererNode::RegisterReflection(); DiagnosticContextNode::RegisterReflection(); -}); +} // failed to check to argument arg0.dims[0] != 0 /* Diagnostic */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("diagnostics.Diagnostic", [](int level, Span span, String message) { + refl::GlobalDef().def("diagnostics.Diagnostic", [](int level, Span span, ffi::String message) { return Diagnostic(static_cast(level), span, message); }); -}); +} Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& message) { - auto n = make_object(); + auto n = ffi::make_object(); n->level = level; n->span = span; n->message = message; @@ -94,13 +94,15 @@ DiagnosticBuilder Diagnostic::Help(ObjectRef loc) { return DiagnosticBuilder(DiagnosticLevel::kHelp, loc); } -DiagnosticBuilder Diagnostic::Bug(const Object* loc) { return Bug(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Bug(const Object* loc) { return Bug(ffi::GetRef(loc)); } -DiagnosticBuilder Diagnostic::Error(const Object* loc) { return Error(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Error(const Object* loc) { + return Error(ffi::GetRef(loc)); +} -DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(ffi::GetRef(loc)); } -DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(ffi::GetRef(loc)); } /* Diagnostic Renderer */ @@ -108,18 +110,18 @@ void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->rendere TVM_DLL DiagnosticRenderer::DiagnosticRenderer( ffi::TypedFunction renderer) { - auto n = make_object(); + auto n = ffi::make_object(); n->renderer = renderer; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("diagnostics.DiagnosticRenderer", [](ffi::TypedFunction renderer) { return DiagnosticRenderer(renderer); }); -}); +} /* Diagnostic Context */ @@ -143,42 +145,42 @@ void DiagnosticContext::Render() { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "diagnostics.DiagnosticRendererRender", [](DiagnosticRenderer renderer, DiagnosticContext ctx) { renderer.Render(ctx); }); -}); +} DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) { CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function"; - auto n = make_object(); + auto n = ffi::make_object(); n->module = module; n->renderer = renderer; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("diagnostics.DiagnosticContext", [](const IRModule& module, const DiagnosticRenderer& renderer) { return DiagnosticContext(module, renderer); }); -}); +} /*! \brief Emit a diagnostic. */ void DiagnosticContext::Emit(const Diagnostic& diagnostic) { (*this)->diagnostics.push_back(diagnostic); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("diagnostics.Emit", [](DiagnosticContext ctx, const Diagnostic& diagnostic) { return ctx.Emit(diagnostic); }) .def("diagnostics.DiagnosticContextRender", [](DiagnosticContext context) { return context.Render(); }); -}); +} /*! \brief Emit a diagnostic. */ void DiagnosticContext::EmitFatal(const Diagnostic& diagnostic) { @@ -210,11 +212,11 @@ DiagnosticContext DiagnosticContext::Default(const IRModule& module) { return DiagnosticContext(module, renderer); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("diagnostics.Default", [](const IRModule& module) { return DiagnosticContext::Default(module); }); -}); +} std::ostream& EmitDiagnosticHeader(std::ostream& out, const Span& span, DiagnosticLevel level, std::string msg) { @@ -328,13 +330,13 @@ DiagnosticRenderer TerminalRenderer(std::ostream& out) { }); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def(DEFAULT_RENDERER, []() { return TerminalRenderer(std::cerr); }) .def("diagnostics.GetRenderer", []() { return GetRenderer(); }) .def("diagnostics.ClearRenderer", []() { tvm::ffi::Function::RemoveGlobal(OVERRIDE_RENDERER); }); -}); +} } // namespace tvm diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index bc91db0ce45d..5a6e2c662b61 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -27,7 +27,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ EnvFuncNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { EnvFuncNode::RegisterReflection(); } using ffi::Any; using ffi::Function; @@ -42,15 +42,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ObjectPtr CreateEnvNode(const std::string& name) { auto f = tvm::ffi::Function::GetGlobal(name); ICHECK(f.has_value()) << "Cannot find global function \'" << name << '\''; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->func = *f; n->name = name; return n; } -EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } +EnvFunc EnvFunc::Get(const ffi::String& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.EnvFuncGet", EnvFunc::Get) @@ -69,5 +69,5 @@ TVM_FFI_STATIC_INIT_BLOCK({ return node->name; }) .def("__data_from_json__", EnvFunc::Get); -}); +} } // namespace tvm diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 43112335988f..b856854a5d8f 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -33,7 +33,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { BaseExprNode::RegisterReflection(); PrimExprNode::RegisterReflection(); RelaxExprNode::RegisterReflection(); @@ -42,19 +42,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ IntImmNode::RegisterReflection(); FloatImmNode::RegisterReflection(); RangeNode::RegisterReflection(); -}); +} PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr PrimExpr::ConvertFallbackValue(String value) { return tir::StringImm(value); } +PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tir::StringImm(value); } IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype << " was supplied."; - ICHECK(dtype.is_int() || dtype.is_uint()) - << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied."; + ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool()) + << "ValueError: IntImm supports only int or uint or bool type, but " << dtype + << " was supplied."; if (dtype.is_uint()) { ICHECK_GE(value, 0U) << "ValueError: Literal value " << value << " is negative for unsigned integer type " << dtype; @@ -62,7 +63,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK_LT(value, 1LL << dtype.bits()) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } - } else if (dtype.bits() == 1) { + } else if (dtype.bits() == 1 || dtype.is_bool()) { // int(1) ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype; } else if (dtype.bits() < 64) { @@ -71,19 +72,19 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK_LT(value, 1LL << (dtype.bits() - 1)) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = dtype; node->value = value; node->span = span; data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.IntImm", [](DataType dtype, int64_t value, Span span) { return IntImm(dtype, value, span); }); -}); +} FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; @@ -174,56 +175,56 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { << dtype; } } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = dtype; node->value = value; node->span = span; data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.FloatImm", [](DataType dtype, double value, Span span) { return FloatImm(dtype, value, span); }); -}); +} Range::Range(PrimExpr begin, PrimExpr end, Span span) - : Range(make_object(begin, tir::is_zero(begin) ? end : (end - begin), span)) {} + : Range(ffi::make_object(begin, tir::is_zero(begin) ? end : (end - begin), span)) {} Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { - return Range(make_object(min, extent, span)); + return Range(ffi::make_object(min, extent, span)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.Range_from_min_extent", Range::FromMinExtent) - .def("ir.Range", [](PrimExpr begin, Optional end, Span span) -> Range { + .def("ir.Range", [](PrimExpr begin, ffi::Optional end, Span span) -> Range { if (end.defined()) { return Range(begin, end.value(), span); } else { return Range(IntImm(begin->dtype, 0), begin, span); } }); -}); +} -GlobalVar::GlobalVar(String name_hint, Span span) { - ObjectPtr n = make_object(); +GlobalVar::GlobalVar(ffi::String name_hint, Span span) { + ObjectPtr n = ffi::make_object(); n->name_hint = std::move(name_hint); n->span = std::move(span); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("ir.GlobalVar", [](String name) { return GlobalVar(name); }) + .def("ir.GlobalVar", [](ffi::String name) { return GlobalVar(name); }) .def("ir.DebugPrint", [](ObjectRef ref) { std::stringstream ss; ss << ref; return ss.str(); }); -}); +} } // namespace tvm diff --git a/src/ir/function.cc b/src/ir/function.cc index 6cf0cd35ceee..de14d57b3ef8 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -30,47 +30,54 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.BaseFunc_Attrs", [](BaseFunc func) { return func->attrs; }) .def("ir.BaseFuncCopy", [](BaseFunc func) { return func; }) .def("ir.BaseFuncWithAttr", - [](ffi::RValueRef func_ref, String key, Any value) -> BaseFunc { + [](ffi::RValueRef func_ref, ffi::String key, Any value) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); } else if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); } else { LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); } }) .def("ir.BaseFuncWithAttrs", - [](ffi::RValueRef func_ref, Map attr_map) -> BaseFunc { + [](ffi::RValueRef func_ref, + ffi::Map attr_map) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { return WithAttrs(Downcast(std::move(func)), attr_map); } if (const auto f = tvm::ffi::Function::GetGlobal("relax.FuncWithAttrs")) { - if (auto ret = (*f)(func, attr_map).cast>()) { + if (auto ret = (*f)(func, attr_map).cast>()) { return ret.value(); } } + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); TVM_FFI_UNREACHABLE(); }) - .def("ir.BaseFuncWithoutAttr", [](ffi::RValueRef func_ref, String key) -> BaseFunc { - BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); - } else if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - TVM_FFI_UNREACHABLE(); - } - }); -}); + .def("ir.BaseFuncWithoutAttr", + [](ffi::RValueRef func_ref, ffi::String key) -> BaseFunc { + BaseFunc func = *std::move(func_ref); + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_UNREACHABLE(); + } + }); +} } // namespace tvm diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 566702f5dd63..151387d3c25a 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -26,31 +26,31 @@ #include namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { VDeviceNode::RegisterReflection(); DummyGlobalInfoNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.DummyGlobalInfo", []() { - auto n = DummyGlobalInfo(make_object()); + auto n = DummyGlobalInfo(ffi::make_object()); return n; }); -}); +} VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->target = std::move(tgt); n->vdevice_id = std::move(dev_id); n->memory_scope = std::move(mem_scope); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.VDevice", [](Target tgt, int dev_id, MemoryScope mem_scope) { return VDevice(tgt, dev_id, mem_scope); }); -}); +} } // namespace tvm diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index 9d4e66bfa466..115eba152948 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -32,19 +32,19 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ GlobalVarSupplyNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { GlobalVarSupplyNode::RegisterReflection(); } GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, std::unordered_map name_to_var_map) { - auto n = make_object(name_supply, name_to_var_map); + auto n = ffi::make_object(name_supply, name_to_var_map); data_ = std::move(n); } std::string GetModuleName(const IRModule& module) { - return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); + return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); } -GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupply() { +GlobalVarSupply::GlobalVarSupply(const ffi::Array& modules) : GlobalVarSupply() { if (!modules.empty()) { IRModule first_mod = modules.front(); this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); @@ -57,7 +57,7 @@ GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupp } GlobalVarSupply::GlobalVarSupply(const IRModule module) - : GlobalVarSupply(Array{module}) {} + : GlobalVarSupply(ffi::Array{module}) {} void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) { name_supply_->ReserveName(var->name_hint, false); @@ -72,8 +72,8 @@ GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply, std::unordered_map name_to_var_map) : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {} -GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_prefix) { - String final_name = name_supply_->ReserveName(name, add_prefix); +GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const ffi::String& name, bool add_prefix) { + ffi::String final_name = name_supply_->ReserveName(name, add_prefix); auto it = name_to_var_map_.find(final_name); if (it != name_to_var_map_.end()) { @@ -85,8 +85,8 @@ GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_pref } } -GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { - String final_name = name_supply_->FreshName(name, add_prefix); +GlobalVar GlobalVarSupplyNode::FreshGlobal(ffi::String name, bool add_prefix) { + ffi::String final_name = name_supply_->FreshName(name, add_prefix); ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) << "GlobalVar already exists for name " << final_name; GlobalVar var = GlobalVar(final_name); @@ -94,7 +94,7 @@ GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { return var; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.GlobalVarSupply_NameSupply", @@ -102,10 +102,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("ir.GlobalVarSupply_IRModule", [](IRModule mod) { return GlobalVarSupply(std::move(mod)); }) .def("ir.GlobalVarSupply_IRModules", - [](const Array& mods) { return GlobalVarSupply(mods); }) + [](const ffi::Array& mods) { return GlobalVarSupply(mods); }) .def_method("ir.GlobalVarSupply_FreshGlobal", &GlobalVarSupplyNode::FreshGlobal) .def_method("ir.GlobalVarSupply_UniqueGlobalFor", &GlobalVarSupplyNode::UniqueGlobalFor) .def_method("ir.GlobalVarSupply_ReserveGlobalVar", &GlobalVarSupplyNode::ReserveGlobalVar); -}); +} } // namespace tvm diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index 74176cb373cc..011968d105c5 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -33,7 +33,7 @@ namespace tvm { namespace instrument { -TVM_FFI_STATIC_INIT_BLOCK({ PassInstrumentNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PassInstrumentNode::RegisterReflection(); } /*! * \brief Base PassInstrument implementation @@ -83,9 +83,8 @@ class BasePassInstrumentNode : public PassInstrumentNode { * \param info The pass information. */ void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final; - - static constexpr const char* _type_key = "instrument.PassInstrument"; - TVM_DECLARE_FINAL_OBJECT_INFO(BasePassInstrumentNode, PassInstrumentNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("instrument.PassInstrument", BasePassInstrumentNode, + PassInstrumentNode); }; /*! @@ -110,7 +109,7 @@ class BasePassInstrument : public PassInstrument { * \param run_after_pass_callback Callback to call after a pass run. */ TVM_DLL BasePassInstrument( - String name, ffi::TypedFunction enter_pass_ctx_callback, + ffi::String name, ffi::TypedFunction enter_pass_ctx_callback, ffi::TypedFunction exit_pass_ctx_callback, ffi::TypedFunction should_run_callback, ffi::TypedFunction @@ -118,16 +117,17 @@ class BasePassInstrument : public PassInstrument { ffi::TypedFunction run_after_pass_callback); - TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BasePassInstrument, PassInstrument, + BasePassInstrumentNode); }; BasePassInstrument::BasePassInstrument( - String name, ffi::TypedFunction enter_pass_ctx_callback, + ffi::String name, ffi::TypedFunction enter_pass_ctx_callback, ffi::TypedFunction exit_pass_ctx_callback, ffi::TypedFunction should_run_callback, ffi::TypedFunction run_before_pass_callback, ffi::TypedFunction run_after_pass_callback) { - auto pi = make_object(); + auto pi = ffi::make_object(); pi->name = std::move(name); pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback); @@ -176,11 +176,11 @@ void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "instrument.PassInstrument", - [](String name, ffi::TypedFunction enter_pass_ctx, + [](ffi::String name, ffi::TypedFunction enter_pass_ctx, ffi::TypedFunction exit_pass_ctx, ffi::TypedFunction should_run, ffi::TypedFunction run_before_pass, @@ -188,7 +188,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, run_before_pass, run_after_pass); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -204,7 +204,7 @@ struct PassProfile { using Time = std::chrono::time_point; /*! \brief The name of the pass being profiled. */ - String name; + ffi::String name; /*! \brief The time when the pass was entered. */ Time start; /*! \brief The time when the pass completed. */ @@ -214,13 +214,13 @@ struct PassProfile { /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ std::vector children; - explicit PassProfile(String name) + explicit PassProfile(ffi::String name) : name(name), start(Clock::now()), end(Clock::now()), children() {} /*! \brief Gets the PassProfile of the currently executing pass. */ static PassProfile* Current(); /*! \brief Pushes a new PassProfile with the given pass name. */ - static void EnterPass(String name); + static void EnterPass(ffi::String name); /*! \brief Pops the current PassProfile. */ static void ExitPass(); }; @@ -237,7 +237,7 @@ struct PassProfileThreadLocalEntry { /*! \brief Thread local store to hold the pass profiling data. */ typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; -void PassProfile::EnterPass(String name) { +void PassProfile::EnterPass(ffi::String name) { PassProfile* cur = PassProfile::Current(); cur->children.emplace_back(name); PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); @@ -260,13 +260,13 @@ PassProfile* PassProfile::Current() { } } -String RenderPassProfiles() { +ffi::String RenderPassProfiles() { PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; if (entry->root.children.empty()) { LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; - return String(); + return ffi::String(); } // (depth, parent_duration, pass) @@ -312,7 +312,7 @@ String RenderPassProfiles() { return os.str(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("instrument.RenderTimePassProfiles", RenderPassProfiles) @@ -332,7 +332,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr, run_before_pass, run_after_pass); }); -}); +} } // namespace instrument } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 3ca4457b9871..b0104ba14d17 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -36,11 +36,11 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ IRModuleNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { IRModuleNode::RegisterReflection(); } -IRModule::IRModule(tvm::Map functions, SourceMap source_map, DictAttrs attrs, - Map> global_infos) { - auto n = make_object(); +IRModule::IRModule(tvm::ffi::Map functions, SourceMap source_map, + DictAttrs attrs, ffi::Map> global_infos) { + auto n = ffi::make_object(); n->functions = std::move(functions); n->global_var_map_ = {}; n->source_map = source_map; @@ -109,11 +109,11 @@ uint64_t IRModuleNode::SHash(uint64_t init_hash, return hash_value; } -bool IRModuleNode::ContainGlobalVar(const String& name) const { +bool IRModuleNode::ContainGlobalVar(const ffi::String& name) const { return global_var_map_.find(name) != global_var_map_.end(); } -GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { +GlobalVar IRModuleNode::GetGlobalVar(const ffi::String& name) const { auto it = global_var_map_.find(name); if (it == global_var_map_.end()) { std::ostringstream msg; @@ -132,7 +132,7 @@ GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { return (*it).second; } -tvm::Array IRModuleNode::GetGlobalVars() const { +tvm::ffi::Array IRModuleNode::GetGlobalVars() const { std::vector global_vars; for (const auto& pair : global_var_map_) { global_vars.push_back(pair.second); @@ -140,7 +140,7 @@ tvm::Array IRModuleNode::GetGlobalVars() const { std::sort(global_vars.begin(), global_vars.end(), [](const GlobalVar& lhs, const GlobalVar& rhs) { return lhs->name_hint < rhs->name_hint; }); - return tvm::Array(global_vars); + return tvm::ffi::Array(global_vars); } void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { @@ -165,7 +165,7 @@ void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) { this->Add(var, func, true); } -void IRModuleNode::UpdateGlobalInfo(const String& name, const Array& info) { +void IRModuleNode::UpdateGlobalInfo(const ffi::String& name, const ffi::Array& info) { this->global_infos.Set(name, info); } @@ -182,7 +182,7 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { return (*it).second; } -BaseFunc IRModuleNode::Lookup(const String& name) const { +BaseFunc IRModuleNode::Lookup(const ffi::String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } @@ -199,15 +199,15 @@ IRModule IRModuleNode::ShallowCopy() { } IRModule IRModule::FromExpr(const RelaxExpr& expr, - const tvm::Map& global_funcs) { + const tvm::ffi::Map& global_funcs) { auto mod = IRModule(global_funcs); - String gv_name; + ffi::String gv_name; // All global definitions must be functions. BaseFunc func; if (auto func_node = expr.as()) { func = func_node.value(); - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { // Function literal has been annotated with it's required global symbol. gv_name = opt.value(); } @@ -225,22 +225,22 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr, return mod; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.IRModule", - [](tvm::Map funcs, tvm::ObjectRef attrs, - Map> global_infos) { + [](tvm::ffi::Map funcs, tvm::ObjectRef attrs, + ffi::Map> global_infos) { auto dict_attrs = [&attrs]() { if (!attrs.defined()) { return DictAttrs(); } else if (auto* as_dict_attrs = attrs.as()) { - return GetRef(as_dict_attrs); + return ffi::GetRef(as_dict_attrs); } else if (attrs.as()) { - return tvm::DictAttrs(Downcast>(attrs)); + return tvm::DictAttrs(Downcast>(attrs)); } else { - LOG(FATAL) - << "Expected attrs argument to be either DictAttrs or Map"; + LOG(FATAL) << "Expected attrs argument to be either DictAttrs or " + "ffi::Map"; } }(); @@ -259,11 +259,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ return mod; }) .def("ir.Module_Remove", - [](IRModule mod, Variant var) -> IRModule { + [](IRModule mod, ffi::Variant var) -> IRModule { GlobalVar gvar = [&]() { if (auto opt = var.as()) { return opt.value(); - } else if (auto opt = var.as()) { + } else if (auto opt = var.as()) { return mod->GetGlobalVar(opt.value()); } else { LOG(FATAL) << "InternalError: " @@ -274,10 +274,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ return mod; }) .def("ir.Module_Contains", - [](IRModule mod, Variant var) -> bool { + [](IRModule mod, ffi::Variant var) -> bool { if (auto opt = var.as()) { return mod->functions.count(opt.value()); - } else if (auto opt = var.as()) { + } else if (auto opt = var.as()) { return mod->global_var_map_.count(opt.value()); } else { LOG(FATAL) << "InternalError: " @@ -288,30 +288,30 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("ir.Module_GetGlobalVars", &IRModuleNode::GetGlobalVars) .def_method("ir.Module_ContainGlobalVar", &IRModuleNode::ContainGlobalVar) .def("ir.Module_Lookup", [](IRModule mod, GlobalVar var) { return mod->Lookup(var); }) - .def("ir.Module_Lookup_str", [](IRModule mod, String var) { return mod->Lookup(var); }) + .def("ir.Module_Lookup_str", [](IRModule mod, ffi::String var) { return mod->Lookup(var); }) .def("ir.Module_FromExpr", &IRModule::FromExpr) .def("ir.Module_Update", [](IRModule mod, IRModule from) { mod->Update(from); }) .def("ir.Module_UpdateFunction", [](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }) .def("ir.Module_UpdateGlobalInfo", - [](IRModule mod, String name, Array global_info) { + [](IRModule mod, ffi::String name, ffi::Array global_info) { mod->UpdateGlobalInfo(name, global_info); }) .def("ir.Module_GetAttrs", [](IRModule mod) -> ObjectRef { return mod->GetAttrs(); }) .def("ir.Module_WithAttr", - [](ffi::RValueRef mod, String key, ffi::Any value) -> IRModule { + [](ffi::RValueRef mod, ffi::String key, ffi::Any value) -> IRModule { return WithAttr(*std::move(mod), key, value); }) .def("ir.Module_WithoutAttr", - [](ffi::RValueRef mod, String key) -> IRModule { + [](ffi::RValueRef mod, ffi::String key) -> IRModule { return WithoutAttr(*std::move(mod), key); }) .def("ir.Module_WithAttrs", - [](ffi::RValueRef mod, Map attr_map) -> IRModule { + [](ffi::RValueRef mod, ffi::Map attr_map) -> IRModule { return WithAttrs(*std::move(mod), attr_map); }) .def("ir.Module_GetAttr", - [](IRModule mod, String key) -> ObjectRef { return mod->GetAttr(key); }); -}); + [](IRModule mod, ffi::String key) -> ObjectRef { return mod->GetAttr(key); }); +} } // namespace tvm diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index 24b5e72735a0..e5b94dff5a06 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -30,13 +30,13 @@ namespace tvm { -NameSupply::NameSupply(const String& prefix, std::unordered_map name_map) { - auto n = make_object(prefix, std::move(name_map)); +NameSupply::NameSupply(const ffi::String& prefix, std::unordered_map name_map) { + auto n = ffi::make_object(prefix, std::move(name_map)); data_ = std::move(n); } -String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { - String final_name = name; +ffi::String NameSupplyNode::ReserveName(const ffi::String& name, bool add_prefix) { + ffi::String final_name = name; if (add_prefix) { final_name = add_prefix_to_name(name); } @@ -44,8 +44,9 @@ String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { return final_name; } -String NameSupplyNode::FreshName(const String& name, bool add_prefix, bool add_underscore) { - String unique_name = name; +ffi::String NameSupplyNode::FreshName(const ffi::String& name, bool add_prefix, + bool add_underscore) { + ffi::String unique_name = name; if (add_prefix) { unique_name = add_prefix_to_name(name); } @@ -53,8 +54,8 @@ String NameSupplyNode::FreshName(const String& name, bool add_prefix, bool add_u return unique_name; } -bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { - String unique_name = name; +bool NameSupplyNode::ContainsName(const ffi::String& name, bool add_prefix) { + ffi::String unique_name = name; if (add_prefix) { unique_name = add_prefix_to_name(name); } @@ -62,7 +63,7 @@ bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { return name_map.count(unique_name); } -String NameSupplyNode::add_prefix_to_name(const String& name) { +ffi::String NameSupplyNode::add_prefix_to_name(const ffi::String& name) { if (prefix_.empty()) { return name; } @@ -90,13 +91,14 @@ std::string NameSupplyNode::GetUniqueName(std::string name, bool add_underscore) return name; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + NameSupplyNode::RegisterReflection(); refl::GlobalDef() - .def("ir.NameSupply", [](String prefix) { return NameSupply(prefix); }) + .def("ir.NameSupply", [](ffi::String prefix) { return NameSupply(prefix); }) .def_method("ir.NameSupply_FreshName", &NameSupplyNode::FreshName) .def_method("ir.NameSupply_ReserveName", &NameSupplyNode::ReserveName) .def_method("ir.NameSupply_ContainsName", &NameSupplyNode::ContainsName); -}); +} } // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index 1bb0e7007b28..514b45c65ad0 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -34,7 +34,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ OpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { OpNode::RegisterReflection(); } using ffi::Any; using ffi::Function; @@ -44,47 +44,49 @@ using tir::FLowerIntrinsic; using OpRegistry = AttrRegistry; // find operator by name -const Op& Op::Get(const String& name) { +const Op& Op::Get(const ffi::String& name) { const OpRegEntry* reg = OpRegistry::Global()->Get(name); ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered"; return reg->op(); } OpRegEntry::OpRegEntry(uint32_t reg_index) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->index_ = reg_index; op_ = Op(n); } -OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) { +OpRegEntry& OpRegEntry::RegisterOrGet(const ffi::String& name) { return OpRegistry::Global()->RegisterOrGet(name); } // Get attribute map by key -const AttrRegistryMapContainerMap& Op::GetAttrMapContainer(const String& attr_name) { +const AttrRegistryMapContainerMap& Op::GetAttrMapContainer(const ffi::String& attr_name) { return OpRegistry::Global()->GetAttrMap(attr_name); } // Check if a key is present in the registry. -bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); } +bool Op::HasAttrMap(const ffi::String& attr_name) { + return OpRegistry::Global()->HasAttrMap(attr_name); +} // Resets attr of the OpAttrMap. void OpRegEntry::reset_attr(const std::string& attr_name) { OpRegistry::Global()->ResetAttr(attr_name, op_); } -void OpRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { +void OpRegEntry::UpdateAttr(const ffi::String& key, ffi::Any value, int plevel) { OpRegistry::Global()->UpdateAttr(key, op_, value, plevel); } // Frontend APIs -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.ListOpNames", []() { return OpRegistry::Global()->ListAllNames(); }) - .def("ir.GetOp", [](String name) -> Op { return Op::Get(name); }) + .def("ir.GetOp", [](ffi::String name) -> Op { return Op::Get(name); }) .def("ir.OpGetAttr", - [](Op op, String attr_name) -> ffi::Any { + [](Op op, ffi::String attr_name) -> ffi::Any { auto op_map = Op::GetAttrMap(attr_name); ffi::Any rv; if (op_map.count(op)) { @@ -93,19 +95,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ return rv; }) .def("ir.OpHasAttr", - [](Op op, String attr_name) -> bool { return Op::HasAttrMap(attr_name); }) + [](Op op, ffi::String attr_name) -> bool { return Op::HasAttrMap(attr_name); }) .def("ir.OpSetAttr", - [](Op op, String attr_name, ffi::AnyView value, int plevel) { + [](Op op, ffi::String attr_name, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attr(attr_name, value, plevel); }) .def("ir.OpResetAttr", - [](Op op, String attr_name) { + [](Op op, ffi::String attr_name) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); reg.reset_attr(attr_name); }) .def("ir.RegisterOp", - [](String op_name, String descr) { + [](ffi::String op_name, ffi::String descr) { const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; @@ -113,7 +115,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ op.describe(descr); }) .def("ir.OpAddArgument", - [](Op op, String name, String type, String description) { + [](Op op, ffi::String name, ffi::String type, ffi::String description) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.add_argument(name, type, description); }) @@ -128,12 +130,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ reg.set_num_inputs(n); }) .def("ir.OpSetAttrsTypeKey", - [](Op op, String key) { + [](Op op, ffi::String key) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attrs_type_key(key); }) .def("ir.RegisterOpAttr", - [](String op_name, String attr_key, ffi::AnyView value, int plevel) { + [](ffi::String op_name, ffi::String attr_key, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); // enable resgiteration and override of certain properties if (attr_key == "num_inputs" && plevel > 128) { @@ -145,19 +147,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ } }) .def("ir.RegisterOpLowerIntrinsic", - [](String name, ffi::Function f, String target, int plevel) { + [](ffi::String name, ffi::Function f, ffi::String target, int plevel) { tvm::OpRegEntry::RegisterOrGet(name).set_attr( target + ".FLowerIntrinsic", f, plevel); }); // override OpNode to use name as the repr refl::TypeAttrDef() .def("__data_to_json__", - [](const OpNode* node) -> String { + [](const OpNode* node) -> ffi::String { // simply save as the string return node->name; }) - .def("__data_from_json__", [](const String& name) -> Op { return Op::Get(name); }); -}); + .def("__data_from_json__", [](const ffi::String& name) -> Op { return Op::Get(name); }); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc index 9887a111f958..98b5b74c42cd 100644 --- a/src/ir/replace_global_vars.cc +++ b/src/ir/replace_global_vars.cc @@ -31,7 +31,7 @@ namespace tvm { namespace transform { -IRModule ReplaceGlobalVars(IRModule mod, Map replacements) { +IRModule ReplaceGlobalVars(IRModule mod, ffi::Map replacements) { if (replacements.empty()) { return mod; } @@ -63,32 +63,36 @@ IRModule ReplaceGlobalVars(IRModule mod, Map replacements) return mod; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("transform.ReplaceGlobalVars", ReplaceGlobalVars); -}); +} IRModule ModuleReplaceGlobalVars( - IRModule mod, Map, Variant> replacements) { - Map gvar_replacements; + IRModule mod, + ffi::Map, ffi::Variant> + replacements) { + ffi::Map gvar_replacements; for (const auto& [before, after] : replacements) { GlobalVar gvar_before; if (auto gvar = before.as()) { gvar_before = gvar.value(); - } else if (auto str = before.as()) { + } else if (auto str = before.as()) { gvar_before = mod->GetGlobalVar(str.value()); } else { - LOG(FATAL) << "Variant must contain either String or GlobalVar"; + LOG(FATAL) + << "ffi::Variant must contain either ffi::String or GlobalVar"; } GlobalVar gvar_after; if (auto gvar = after.as()) { gvar_after = gvar.value(); - } else if (auto str = after.as()) { + } else if (auto str = after.as()) { gvar_after = gvar_before; gvar_after.CopyOnWrite()->name_hint = str.value(); } else { - LOG(FATAL) << "Variant must contain either String or GlobalVar"; + LOG(FATAL) + << "ffi::Variant must contain either ffi::String or GlobalVar"; } gvar_replacements.Set(gvar_before, gvar_after); @@ -97,10 +101,10 @@ IRModule ModuleReplaceGlobalVars( return ReplaceGlobalVars(mod, gvar_replacements); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.Module_ReplaceGlobalVars", ModuleReplaceGlobalVars); -}); +} } // namespace transform } // namespace tvm diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 588efe9c6a4e..521b02db44b5 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -29,7 +29,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; SourceNameNode::RegisterReflection(); SpanNode::RegisterReflection(); @@ -44,16 +44,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ return node->name; }) .def("__data_from_json__", SourceName::Get); -}); +} -ObjectPtr GetSourceNameNode(const String& name) { +ObjectPtr GetSourceNameNode(const ffi::String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr - static std::unordered_map> source_map; + static std::unordered_map> source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); source_map[name] = n; n->name = std::move(name); return n; @@ -62,16 +62,16 @@ ObjectPtr GetSourceNameNode(const String& name) { } } -ObjectPtr GetSourceNameNodeByStr(const std::string& name) { +ObjectPtr GetSourceNameNodeByStr(const std::string& name) { return GetSourceNameNode(name); } -SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } +SourceName SourceName::Get(const ffi::String& name) { return SourceName(GetSourceNameNode(name)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.SourceName", SourceName::Get); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); Span::Span(SourceName source_name, int line, int end_line, int column, int end_column) { - auto n = make_object(); + auto n = ffi::make_object(); n->source_name = std::move(source_name); n->line = line; n->end_line = end_line; @@ -99,9 +99,9 @@ Span Span::Merge(const Span& other) const { std::max((*this)->end_column, other->end_column)); } -SequentialSpan::SequentialSpan(tvm::Array spans) { - auto n = make_object(); - tvm::Array tmp_spans; +SequentialSpan::SequentialSpan(tvm::ffi::Array spans) { + auto n = ffi::make_object(); + tvm::ffi::Array tmp_spans; for (const Span& s : spans) { if (const SequentialSpanNode* seq_s = s.as()) { tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end()); @@ -120,9 +120,9 @@ SequentialSpan::SequentialSpan(tvm::Array spans) { } SequentialSpan::SequentialSpan(std::initializer_list init) { - auto n = make_object(); - tvm::Array spans = tvm::Array(init); - tvm::Array tmp_spans; + auto n = ffi::make_object(); + tvm::ffi::Array spans = tvm::ffi::Array(init); + tvm::ffi::Array tmp_spans; for (const Span& s : spans) { if (const SequentialSpanNode* seq_s = s.as()) { tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end()); @@ -140,15 +140,15 @@ SequentialSpan::SequentialSpan(std::initializer_list init) { data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.Span", [](SourceName source_name, int line, int end_line, int column, int end_column) { return Span(source_name, line, end_line, column, end_column); }) - .def("ir.SequentialSpan", [](tvm::Array spans) { return SequentialSpan(spans); }); -}); + .def("ir.SequentialSpan", [](tvm::ffi::Array spans) { return SequentialSpan(spans); }); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -172,7 +172,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /*! \brief Construct a source from a string. */ Source::Source(SourceName src_name, std::string source) { - auto n = make_object(); + auto n = ffi::make_object(); n->source_name = std::move(src_name); n->source = std::move(source); @@ -201,7 +201,7 @@ Source::Source(SourceName src_name, std::string source) { data_ = n; } -tvm::String Source::GetLine(int line) { +tvm::ffi::String Source::GetLine(int line) { VLOG(1) << "Source::GetLine: line=" << line; ICHECK(line - 1 < static_cast((*this)->line_map.size())) << "requested line: " << line << "at index: " << (line - 1) @@ -212,28 +212,28 @@ tvm::String Source::GetLine(int line) { int line_start = range.first; int line_length = range.second; VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; - // TODO(@jroesch): expose substring on tvm::String. + // TODO(@jroesch): expose substring on tvm::ffi::String. auto line_text = std::string((*this)->source).substr(line_start, line_length); VLOG(1) << "Source::GetLine: line_text=" << line_text; return line_text; } -SourceMap::SourceMap(Map source_map) { - auto n = make_object(); +SourceMap::SourceMap(ffi::Map source_map) { + auto n = ffi::make_object(); n->source_map = std::move(source_map); data_ = std::move(n); } void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("SourceMapAdd", [](SourceMap map, String name, String content) { + refl::GlobalDef().def("SourceMapAdd", [](SourceMap map, ffi::String name, ffi::String content) { auto src_name = SourceName::Get(name); Source source(src_name, content); map.Add(source); return src_name; }); -}); +} } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index d82f02f3dfb9..3cbf8a629fc3 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -31,19 +31,13 @@ #include #include -#include -#include #include -#include - -#include "../runtime/regex.h" namespace tvm { namespace transform { using tvm::ReprPrinter; using tvm::ffi::Any; -using tvm::ffi::PackedArgs; TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool); @@ -54,21 +48,23 @@ struct PassContextThreadLocalEntry { /*! \brief The current pass context. */ std::stack context_stack; - PassContextThreadLocalEntry() { default_context = PassContext(make_object()); } + PassContextThreadLocalEntry() { + default_context = PassContext(ffi::make_object()); + } }; /*! \brief Thread local store to hold the pass context. */ -typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; +typedef dmlc::ThreadLocalStore PassContextThreadLocalStore; void PassContext::EnterWithScope() { InstrumentEnterPassContext(); - PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = PassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } void PassContext::ExitWithScope() { - PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = PassContextThreadLocalStore::Get(); ICHECK(!entry->context_stack.empty()); ICHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); @@ -77,7 +73,7 @@ void PassContext::ExitWithScope() { } PassContext PassContext::Current() { - PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = PassContextThreadLocalStore::Get(); if (!entry->context_stack.empty()) { return entry->context_stack.top(); } else { @@ -86,7 +82,7 @@ PassContext PassContext::Current() { } // linearly scan the pass array to match pass_name -bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { +bool PassArrayContains(const ffi::Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { if (x == pass_name) return true; } @@ -107,7 +103,7 @@ bool PassContext::PassEnabled(const PassInfo& info) const { class PassConfigManager { public: - void Register(std::string key, String value_type_str, + void Register(std::string key, ffi::String value_type_str, std::function legalization) { ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; @@ -117,7 +113,7 @@ class PassConfigManager { } // Trying to validate and legalize a config. - void Legalize(Map* config) { + void Legalize(ffi::Map* config) { std::vector> update; for (auto [key, value] : *config) { auto it = key2vtype_.find(key); @@ -149,10 +145,10 @@ class PassConfigManager { } } - Map> ListConfigs() { - Map> configs; + ffi::Map> ListConfigs() { + ffi::Map> configs; for (const auto& kv : key2vtype_) { - Map metadata; + ffi::Map metadata; metadata.Set("type", kv.second.type_str); configs.Set(kv.first, metadata); } @@ -173,20 +169,20 @@ class PassConfigManager { std::unordered_map key2vtype_; }; -void PassContext::RegisterConfigOption(const char* key, String value_type_str, +void PassContext::RegisterConfigOption(const char* key, ffi::String value_type_str, std::function legalization) { PassConfigManager::Global()->Register(key, value_type_str, legalization); } -Map> PassContext::ListConfigs() { +ffi::Map> PassContext::ListConfigs() { return PassConfigManager::Global()->ListConfigs(); } -PassContext PassContext::Create() { return PassContext(make_object()); } +PassContext PassContext::Create() { return PassContext(ffi::make_object()); } namespace { struct ClearOnError { - Array* instruments{nullptr}; + ffi::Array* instruments{nullptr}; ~ClearOnError() { if (instruments) { @@ -244,7 +240,7 @@ struct ExitPassSuccesses { bool all_initialized{false}; std::vector successes; - Array* instruments{nullptr}; + ffi::Array* instruments{nullptr}; }; } // namespace @@ -366,20 +362,19 @@ class ModulePassNode : public PassNode { * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } - - static constexpr const char* _type_key = "transform.ModulePass"; - TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("transform.ModulePass", ModulePassNode, PassNode); }; class ModulePass : public Pass { public: ModulePass(std::function pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ModulePass, Pass, ModulePassNode); }; -PassInfo::PassInfo(int opt_level, String name, tvm::Array required, bool traceable) { - auto pass_info = make_object(); +PassInfo::PassInfo(int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { + auto pass_info = ffi::make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); @@ -389,7 +384,7 @@ PassInfo::PassInfo(int opt_level, String name, tvm::Array required, bool ModulePass::ModulePass(std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -429,15 +424,15 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c return mod; } -Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { - auto n = make_object(); +Sequential::Sequential(tvm::ffi::Array passes, PassInfo pass_info) { + auto n = ffi::make_object(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); data_ = std::move(n); } -Sequential::Sequential(tvm::Array passes, String name) { - auto n = make_object(); +Sequential::Sequential(tvm::ffi::Array passes, ffi::String name) { + auto n = ffi::make_object(); n->passes = std::move(passes); PassInfo pass_info = PassInfo(0, std::move(name), {}, /* traceable */ false); n->pass_info = std::move(pass_info); @@ -457,7 +452,7 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { << "\n"; } -Pass GetPass(const String& pass_name) { +Pass GetPass(const ffi::String& pass_name) { std::optional f; if (pass_name.operator std::string().find("transform.") != std::string::npos) { f = tvm::ffi::Function::GetGlobal(pass_name); @@ -492,23 +487,22 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c } Pass CreateModulePass(std::function pass_func, int opt_level, - String name, tvm::Array required, bool traceable) { + ffi::String name, tvm::ffi::Array required, bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return ModulePass(std::move(pass_func), pass_info); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("transform.PassInfo", - [](int opt_level, String name, tvm::Array required, bool traceable) { - return PassInfo(opt_level, name, required, traceable); - }) + [](int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { return PassInfo(opt_level, name, required, traceable); }) .def_packed("transform.Info", [](ffi::PackedArgs args, ffi::Any* ret) { Pass pass = args[0].cast(); *ret = pass->Info(); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { @@ -528,14 +522,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PassContextNode::RegisterReflection(); PassInfoNode::RegisterReflection(); SequentialNode::RegisterReflection(); ModulePassNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("transform.MakeModulePass", @@ -548,7 +542,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("transform.RunPass", [](Pass pass, ffi::RValueRef mod) { return pass(*std::move(mod)); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -558,18 +552,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << info->opt_level; }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("transform.Sequential", [](ffi::PackedArgs args, ffi::Any* ret) { - auto passes = args[0].cast>(); + auto passes = args[0].cast>(); int opt_level = args[1].cast(); std::string name = args[2].cast(); - auto required = args[3].cast>(); + auto required = args[3].cast>(); bool traceable = args[4].cast(); PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); *ret = Sequential(passes, pass_info); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -585,12 +579,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "]"; }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "transform.PassContext", - [](int opt_level, Array required, Array disabled, - Array instruments, Optional> config) { + [](int opt_level, ffi::Array required, ffi::Array disabled, + ffi::Array instruments, + ffi::Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; @@ -604,7 +599,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ PassConfigManager::Global()->Legalize(&(pctx->config)); return pctx; }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -627,34 +622,34 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("transform.GetCurrentPassContext", PassContext::Current) .def("transform.EnterPassContext", PassContext::Internal::EnterScope) .def("transform.ExitPassContext", PassContext::Internal::ExitScope) .def("transform.OverrideInstruments", - [](PassContext pass_ctx, Array instruments) { + [](PassContext pass_ctx, ffi::Array instruments) { pass_ctx.InstrumentExitPassContext(); pass_ctx->instruments = instruments; pass_ctx.InstrumentEnterPassContext(); }); -}); +} -Pass PrintIR(String header, bool show_meta_data) { - auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { +Pass PrintIR(ffi::String header) { + auto pass_func = [header](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" << mod; return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("transform.PrintIR", PrintIR) .def("transform.ListConfigs", PassContext::ListConfigs); -}); +} } // namespace transform } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 4afa785aaedd..b28e20a78f89 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -26,29 +26,29 @@ #include namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { TypeNode::RegisterReflection(); PrimTypeNode::RegisterReflection(); PointerTypeNode::RegisterReflection(); TupleTypeNode::RegisterReflection(); FuncTypeNode::RegisterReflection(); TensorMapTypeNode::RegisterReflection(); -}); +} PrimType::PrimType(runtime::DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->span = std::move(span); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.PrimType", [](runtime::DataType dtype) { return PrimType(dtype); }); -}); +} -PointerType::PointerType(Type element_type, String storage_scope) { - ObjectPtr n = make_object(); +PointerType::PointerType(Type element_type, ffi::String storage_scope) { + ObjectPtr n = ffi::make_object(); if (storage_scope.empty()) { n->storage_scope = "global"; } else { @@ -58,46 +58,46 @@ PointerType::PointerType(Type element_type, String storage_scope) { data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.PointerType", [](Type element_type, String storage_scope = "") { + refl::GlobalDef().def("ir.PointerType", [](Type element_type, ffi::String storage_scope = "") { return PointerType(element_type, storage_scope); }); -}); +} -FuncType::FuncType(tvm::Array arg_types, Type ret_type, Span span) { - ObjectPtr n = make_object(); +FuncType::FuncType(tvm::ffi::Array arg_types, Type ret_type, Span span) { + ObjectPtr n = ffi::make_object(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); n->span = std::move(span); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.FuncType", [](tvm::Array arg_types, Type ret_type) { + refl::GlobalDef().def("ir.FuncType", [](tvm::ffi::Array arg_types, Type ret_type) { return FuncType(arg_types, ret_type); }); -}); +} -TupleType::TupleType(Array fields, Span span) { - ObjectPtr n = make_object(); +TupleType::TupleType(ffi::Array fields, Span span) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = std::move(span); data_ = std::move(n); } -TupleType TupleType::Empty() { return TupleType(Array()); } +TupleType TupleType::Empty() { return TupleType(ffi::Array()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("ir.TupleType", [](Array fields) { return TupleType(fields); }) + .def("ir.TupleType", [](ffi::Array fields) { return TupleType(fields); }) .def("ir.TensorMapType", [](Span span) { return TensorMapType(span); }); -}); +} TensorMapType::TensorMapType(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = std::move(span); data_ = std::move(n); } diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index 774c9d8f245f..3c81ca107eab 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -49,7 +49,7 @@ Type TypeMutator::VisitType(const Type& t) { } // Type Mutator. -Array TypeMutator::MutateArray(Array arr) { +ffi::Array TypeMutator::MutateArray(ffi::Array arr) { // The array will do copy on write // If no changes are made, the original array will be returned. return arr.Map([this](const Type& ty) { return VisitType(ty); }); @@ -58,32 +58,32 @@ Array TypeMutator::MutateArray(Array arr) { Type TypeMutator::VisitType_(const FuncTypeNode* op) { bool changed = false; - Array new_args = MutateArray(op->arg_types); + ffi::Array new_args = MutateArray(op->arg_types); changed = changed || !new_args.same_as(op->arg_types); Type new_ret_type = VisitType(op->ret_type); changed = changed || !new_ret_type.same_as(op->ret_type); - if (!changed) return GetRef(op); + if (!changed) return ffi::GetRef(op); return FuncType(new_args, new_ret_type); } Type TypeMutator::VisitType_(const TupleTypeNode* op) { - Array new_fields = MutateArray(op->fields); + ffi::Array new_fields = MutateArray(op->fields); if (new_fields.same_as(op->fields)) { - return GetRef(op); + return ffi::GetRef(op); } else { return TupleType(new_fields); } } -Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef(op); } +Type TypeMutator::VisitType_(const PrimTypeNode* op) { return ffi::GetRef(op); } Type TypeMutator::VisitType_(const PointerTypeNode* op) { Type element_type = VisitType(op->element_type); if (element_type.same_as(op->element_type)) { - return GetRef(op); + return ffi::GetRef(op); } else { return PointerType(element_type, op->storage_scope); } diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 9c2ba084ad41..5c00a9bdbc4e 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -40,7 +40,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { if (const auto* func = base_func.as()) { last_func = func; if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - return GetRef(func); + return ffi::GetRef(func); } if (gv->name_hint == "main") { main_func = func; @@ -50,7 +50,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { } // Priority 2: PrimFunc whose name is `main` if (main_func != nullptr) { - return GetRef(main_func); + return ffi::GetRef(main_func); } // Priority 3: The only PrimFunc in the IRModule if (num_prim_func == 0) { @@ -61,7 +61,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" << mod; } - return GetRef(last_func); + return ffi::GetRef(last_func); } /******** ArgInfo ********/ @@ -69,11 +69,11 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { // The JSON object is always an array whose first element is a tag. For example: // `['TENSOR', 'float32', [1, 224, 224, 3]] // Step 1. Extract the tag - Optional tag{std::nullopt}; + ffi::Optional tag{std::nullopt}; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() >= 1); - tag = json_array->at(0).cast(); + tag = json_array->at(0).cast(); } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj << "\nThe error is: " << e.what(); @@ -86,12 +86,12 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { throw; } -Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { +ffi::Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { using support::AsVector; - Array result; + ffi::Array result; result.reserve(func->params.size()); for (const tir::Var& arg : func->params) { - if (Optional _buffer = func->buffer_map.Get(arg)) { + if (ffi::Optional _buffer = func->buffer_map.Get(arg)) { tir::Buffer buffer = _buffer.value(); result.push_back(TensorInfo(/*dtype=*/buffer->dtype, /*shape=*/AsVector(buffer->shape))); @@ -102,10 +102,10 @@ Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { return result; } -Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) { +ffi::Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) { if (remove_preproc) { IRModule new_mod = - tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true)(mod); + tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_tensor_rewrite*/ true)(mod); return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod)); } return ArgInfo::FromPrimFunc(FindEntryFunc(mod)); @@ -114,28 +114,28 @@ Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) /******** TensorInfo ********/ TensorInfo::TensorInfo(runtime::DataType dtype, ffi::Shape shape) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->shape = shape; this->data_ = std::move(n); } ObjectRef TensorInfoNode::AsJSON() const { - static String tag = "TENSOR"; - String dtype = DLDataTypeToString(this->dtype); - Array shape = support::AsArray(this->shape); - return Array{tag, dtype, shape}; + static ffi::String tag = "TENSOR"; + ffi::String dtype = DLDataTypeToString(this->dtype); + ffi::Array shape = support::AsArray(this->shape); + return ffi::Array{tag, dtype, shape}; } TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { DLDataType dtype; - Array shape; + ffi::Array shape; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 3); // Load json[1] => dtype { - String dtype_str = json_array->at(1).cast(); + ffi::String dtype_str = json_array->at(1).cast(); dtype = StringToDLDataType(dtype_str); } // Load json[2] => shape @@ -160,9 +160,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /******** FFI ********/ -TVM_FFI_STATIC_INIT_BLOCK({ TensorInfoNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TensorInfoNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.ArgInfoAsJSON", &ArgInfoNode::AsJSON) @@ -172,7 +172,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("meta_schedule.TensorInfo", [](runtime::DataType dtype, ffi::Shape shape) -> TensorInfo { return TensorInfo(dtype, shape); }); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc index 062e32e58e83..195547bee764 100644 --- a/src/meta_schedule/builder/builder.cc +++ b/src/meta_schedule/builder/builder.cc @@ -26,48 +26,50 @@ namespace meta_schedule { /******** Constructors ********/ BuilderInput::BuilderInput(IRModule mod, Target target, - Optional> params) { - ObjectPtr n = make_object(); + ffi::Optional> params) { + ObjectPtr n = ffi::make_object(); n->mod = std::move(mod); n->target = std::move(target); n->params = std::move(params); data_ = std::move(n); } -BuilderResult::BuilderResult(Optional artifact_path, Optional error_msg) { - ObjectPtr n = make_object(); +BuilderResult::BuilderResult(ffi::Optional artifact_path, + ffi::Optional error_msg) { + ObjectPtr n = ffi::make_object(); n->artifact_path = std::move(artifact_path); n->error_msg = std::move(error_msg); data_ = std::move(n); } Builder Builder::PyBuilder(BuilderNode::FBuild f_build) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_build = std::move(f_build); return Builder(std::move(n)); } /******** FFI ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { BuilderInputNode::RegisterReflection(); BuilderResultNode::RegisterReflection(); PyBuilderNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.BuilderInput", - [](IRModule mod, Target target, Optional> params) - -> BuilderInput { return BuilderInput(mod, target, params); }) - .def("meta_schedule.BuilderResult", - [](Optional artifact_path, Optional error_msg) -> BuilderResult { - return BuilderResult(artifact_path, error_msg); + [](IRModule mod, Target target, + ffi::Optional> params) -> BuilderInput { + return BuilderInput(mod, target, params); }) + .def("meta_schedule.BuilderResult", + [](ffi::Optional artifact_path, ffi::Optional error_msg) + -> BuilderResult { return BuilderResult(artifact_path, error_msg); }) .def_method("meta_schedule.BuilderBuild", &BuilderNode::Build) .def("meta_schedule.BuilderPyBuilder", Builder::PyBuilder); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index 242939802885..4cc13787ae96 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -23,24 +23,25 @@ namespace tvm { namespace meta_schedule { -void PyCostModelNode::Load(const String& path) { +void PyCostModelNode::Load(const ffi::String& path) { ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; f_load(path); } -void PyCostModelNode::Save(const String& path) { +void PyCostModelNode::Save(const ffi::String& path) { ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; f_save(path); } -void PyCostModelNode::Update(const TuneContext& context, const Array& candidates, - const Array& results) { +void PyCostModelNode::Update(const TuneContext& context, + const ffi::Array& candidates, + const ffi::Array& results) { ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; f_update(context, candidates, results); } std::vector PyCostModelNode::Predict(const TuneContext& context, - const Array& candidates) { + const ffi::Array& candidates) { ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; std::vector result(candidates.size(), 0.0); f_predict(context, candidates, result.data()); @@ -52,7 +53,7 @@ CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // PyCostModelNode::FUpdate f_update, // PyCostModelNode::FPredict f_predict, // PyCostModelNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_load = std::move(f_load); n->f_save = std::move(f_save); n->f_update = std::move(f_update); @@ -70,22 +71,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.CostModelLoad", &CostModelNode::Load) .def_method("meta_schedule.CostModelSave", &CostModelNode::Save) .def_method("meta_schedule.CostModelUpdate", &CostModelNode::Update) .def("meta_schedule.CostModelPredict", - [](CostModel model, // - const TuneContext& context, // - Array candidates, // + [](CostModel model, // + const TuneContext& context, // + ffi::Array candidates, // void* p_addr) -> void { std::vector result = model->Predict(context, candidates); std::copy(result.begin(), result.end(), static_cast(p_addr)); }) .def("meta_schedule.CostModelPyCostModel", CostModel::PyCostModel); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 3b96ed0ca8b0..a7548c95b6cb 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -46,20 +46,20 @@ ObjectRef WorkloadNode::AsJSON() const { // Dump the JSON string to base64 std::string b64_mod = Base64Encode(json_mod); // Output - return Array{SHash2Str(this->shash), String(b64_mod)}; + return ffi::Array{SHash2Str(this->shash), ffi::String(b64_mod)}; } Workload Workload::FromJSON(const ObjectRef& json_obj) { - IRModule mod{nullptr}; + IRModule mod{ffi::UnsafeInit()}; THashCode shash = 0; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 2); // Load json[0] => shash - String str_shash = json_array->at(0).cast(); + ffi::String str_shash = json_array->at(0).cast(); // Load json[1] => mod { - String b64_mod = json_array->at(1).cast(); + ffi::String b64_mod = json_array->at(1).cast(); std::string json_mod = Base64Decode(b64_mod); mod = LoadJSON(json_mod).cast(); std::stringstream(str_shash) >> shash; @@ -73,9 +73,11 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { /******** TuningRecord ********/ -TuningRecord::TuningRecord(tir::Trace trace, Workload workload, Optional> run_secs, - Optional target, Optional> args_info) { - ObjectPtr n = make_object(); +TuningRecord::TuningRecord(tir::Trace trace, Workload workload, + ffi::Optional> run_secs, + ffi::Optional target, + ffi::Optional> args_info) { + ObjectPtr n = ffi::make_object(); n->trace = trace; n->workload = workload; n->run_secs = run_secs; @@ -96,10 +98,10 @@ MeasureCandidate TuningRecordNode::AsMeasureCandidate() const { } ObjectRef TuningRecordNode::AsJSON() const { - Optional> json_args_info; - Optional json_target; + ffi::Optional> json_args_info; + ffi::Optional json_target; if (args_info.defined()) { - Array info; + ffi::Array info; info.reserve(args_info.value().size()); for (const ArgInfo& arg_info : args_info.value()) { info.push_back(arg_info->AsJSON()); @@ -109,10 +111,10 @@ ObjectRef TuningRecordNode::AsJSON() const { if (target.defined()) { json_target = target.value()->Export(); } - return Array{trace->AsJSON(false), // - run_secs, // - json_target, // - json_args_info}; + return ffi::Array{trace->AsJSON(false), // + run_secs, // + json_target, // + json_args_info}; } bool TuningRecordNode::IsValid() const { @@ -131,10 +133,10 @@ bool TuningRecordNode::IsValid() const { } TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { - tir::Trace trace{nullptr}; - Optional> run_secs; - Optional target; - Optional> args_info; + tir::Trace trace{ffi::UnsafeInit()}; + ffi::Optional> run_secs; + ffi::Optional target; + ffi::Optional> args_info; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 4); @@ -144,12 +146,12 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w } // Load json[2] => target if (json_array->at(2) != nullptr) { - target = Target(json_array->at(2).cast>()); + target = Target(json_array->at(2).cast>()); } // Load json[3] => args_info if (json_array->at(3) != nullptr) { const ffi::ArrayObj* json_args_info = json_array->at(3).cast(); - Array info; + ffi::Array info; info.reserve(json_args_info->size()); for (Any json_arg_info : *json_args_info) { info.push_back(ArgInfo::FromJSON(json_arg_info.cast())); @@ -173,15 +175,18 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w } /******** Database ********/ -DatabaseNode::DatabaseNode(String mod_eq_name) { mod_eq_ = ModuleEquality::Create(mod_eq_name); } +DatabaseNode::DatabaseNode(ffi::String mod_eq_name) { + mod_eq_ = ModuleEquality::Create(mod_eq_name); +} DatabaseNode::~DatabaseNode() = default; -Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name) { +ffi::Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, + const Target& target, + const ffi::String& workload_name) { if (!this->HasWorkload(mod)) { return std::nullopt; } - Array records = this->GetTopK(this->CommitWorkload(mod), 1); + ffi::Array records = this->GetTopK(this->CommitWorkload(mod), 1); if (records.empty()) { return std::nullopt; } @@ -189,9 +194,10 @@ Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, cons return records[0]; } -Optional DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name) { - if (Optional opt_record = this->QueryTuningRecord(mod, target, workload_name)) { +ffi::Optional DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) { + if (ffi::Optional opt_record = + this->QueryTuningRecord(mod, target, workload_name)) { TuningRecord record = opt_record.value(); tir::Schedule sch = tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, @@ -203,9 +209,9 @@ Optional DatabaseNode::QuerySchedule(const IRModule& mod, const T } } -Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target, - const String& workload_name) { - if (Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { +ffi::Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) { + if (ffi::Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { return opt_sch.value()->mod(); } else { return std::nullopt; @@ -244,7 +250,7 @@ void Database::EnterWithScope() { ThreadLocalDatabases()->push_back(*this); } void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); } -Optional Database::Current() { +ffi::Optional Database::Current() { std::vector* tls = ThreadLocalDatabases(); if (tls->empty()) { return std::nullopt; @@ -254,7 +260,7 @@ Optional Database::Current() { } /******** PyDatabase ********/ -PyDatabaseNode::PyDatabaseNode(String mod_eq_name) : DatabaseNode(mod_eq_name) {} +PyDatabaseNode::PyDatabaseNode(ffi::String mod_eq_name) : DatabaseNode(mod_eq_name) {} Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, @@ -264,8 +270,8 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, - PyDatabaseNode::FSize f_size, String mod_eq_name) { - ObjectPtr n = make_object(mod_eq_name); + PyDatabaseNode::FSize f_size, ffi::String mod_eq_name) { + ObjectPtr n = ffi::make_object(mod_eq_name); n->f_has_workload = f_has_workload; n->f_commit_workload = f_commit_workload; n->f_commit_tuning_record = f_commit_tuning_record; @@ -280,21 +286,21 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, /******** FFI ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { WorkloadNode::RegisterReflection(); TuningRecordNode::RegisterReflection(); PyDatabaseNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.Workload", [](IRModule mod) { return Workload(mod); }) .def_method("meta_schedule.WorkloadAsJSON", &WorkloadNode::AsJSON) .def("meta_schedule.WorkloadFromJSON", &Workload::FromJSON) .def("meta_schedule.TuningRecord", - [](tir::Trace trace, Workload workload, Optional> run_secs, - Optional target, Optional> args_info) { + [](tir::Trace trace, Workload workload, ffi::Optional> run_secs, + ffi::Optional target, ffi::Optional> args_info) { return TuningRecord(trace, workload, run_secs, target, args_info); }) .def_method("meta_schedule.TuningRecordAsMeasureCandidate", @@ -315,7 +321,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.DatabaseQueryIRModule", &DatabaseNode::QueryIRModule) .def_method("meta_schedule.DatabaseDumpPruned", &DatabaseNode::DumpPruned) .def("meta_schedule.DatabasePyDatabase", Database::PyDatabase); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index fd24072aae8f..10274fd2f792 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -57,10 +57,10 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { os << "]"; } else if (const auto* dict = json_obj.as()) { int n = dict->size(); - std::vector> key_values; + std::vector> key_values; key_values.reserve(n); for (const auto& kv : *dict) { - if (auto key = kv.first.try_cast()) { + if (auto key = kv.first.try_cast()) { key_values.emplace_back(key.value(), kv.second); } else { LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: " @@ -81,7 +81,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { } os << "}"; } else if (json_obj.as()) { - JSONDumps(String(SaveJSON(json_obj)), os); + JSONDumps(ffi::String(SaveJSON(json_obj)), os); } else { LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj.GetTypeKey(); } @@ -241,7 +241,7 @@ class JSONTokenizer { LOG(FATAL) << "ValueError: Unexpected end of string"; } ++cur_; - *token = Token{TokenType::kString, String(str)}; + *token = Token{TokenType::kString, ffi::String(str)}; return true; } @@ -315,9 +315,9 @@ class JSONParser { } } - Array ParseArray() { + ffi::Array ParseArray() { bool is_first = true; - Array results; + ffi::Array results; for (;;) { Token token; if (is_first) { @@ -347,9 +347,9 @@ class JSONParser { return results; } - Map ParseDict() { + ffi::Map ParseDict() { bool is_first = true; - Map results; + ffi::Map results; for (;;) { Token token; if (is_first) { @@ -376,7 +376,7 @@ class JSONParser { CHECK(token.type == TokenType::kColon) << "ValueError: Unexpected token before: " << tokenizer_.cur_; Any value = ParseObject(tokenizer_.Next()); - results.Set(Downcast(key), value); + results.Set(Downcast(key), value); continue; } else { LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index aeae22f4ca41..862d0fd05a10 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -35,10 +35,10 @@ namespace meta_schedule { * \param allow_missing Whether to create new file when the given path is not found. * \return An array containing lines read from the json file. */ -std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing) { +std::vector JSONFileReadLines(const ffi::String& path, int num_threads, bool allow_missing) { std::ifstream is(path); if (is.good()) { - std::vector json_strs; + std::vector json_strs; for (std::string str; std::getline(is, str);) { json_strs.push_back(str); } @@ -61,7 +61,7 @@ std::vector JSONFileReadLines(const String& path, int num_threads, bool all * \param path The path to the json file. * \param line The line to append. */ -void JSONFileAppendLine(const String& path, const std::string& line) { +void JSONFileAppendLine(const ffi::String& path, const std::string& line) { std::ofstream os(path, std::ofstream::app); CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; os << line << std::endl; @@ -70,14 +70,14 @@ void JSONFileAppendLine(const String& path, const std::string& line) { /*! \brief The default database implementation, which mimics two database tables with two files. */ class JSONDatabaseNode : public DatabaseNode { public: - explicit JSONDatabaseNode(String mod_eq_name = "structural") + explicit JSONDatabaseNode(ffi::String mod_eq_name = "structural") : DatabaseNode(mod_eq_name), workloads2idx_(/*bucket_count*/ 0, WorkloadHash(), WorkloadEqual(GetModuleEquality())) {} /*! \brief The path to the workload table */ - String path_workload; + ffi::String path_workload; /*! \brief The path to the tuning record table */ - String path_tuning_record; + ffi::String path_tuning_record; /*! \brief All the workloads in the database */ std::unordered_map workloads2idx_; /*! \brief All the tuning records in the database */ @@ -89,9 +89,7 @@ class JSONDatabaseNode : public DatabaseNode { .def_ro("path_workload", &JSONDatabaseNode::path_workload) .def_ro("path_tuning_record", &JSONDatabaseNode::path_tuning_record); } - - static constexpr const char* _type_key = "meta_schedule.JSONDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.JSONDatabase", JSONDatabaseNode, DatabaseNode); public: bool HasWorkload(const IRModule& mod) { @@ -115,18 +113,18 @@ class JSONDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) { this->tuning_records_.insert(record); JSONFileAppendLine(this->path_tuning_record, - JSONDumps(Array{ + JSONDumps(ffi::Array{ /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)), /*tuning_record=*/record->AsJSON() // })); } - Array GetTopK(const Workload& workload, int top_k) { + ffi::Array GetTopK(const Workload& workload, int top_k) { CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; if (top_k == 0) { return {}; } - Array results; + ffi::Array results; results.reserve(top_k); for (const TuningRecord& record : this->tuning_records_) { auto run_secs = record->run_secs; @@ -144,8 +142,8 @@ class JSONDatabaseNode : public DatabaseNode { return results; } - Array GetAllTuningRecords() { - Array results; + ffi::Array GetAllTuningRecords() { + ffi::Array results; results.reserve(Size()); for (const TuningRecord& record : this->tuning_records_) { results.push_back(record); @@ -156,10 +154,10 @@ class JSONDatabaseNode : public DatabaseNode { int64_t Size() { return tuning_records_.size(); } }; -Database Database::JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing, - String mod_eq_name) { +Database Database::JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record, + bool allow_missing, ffi::String mod_eq_name) { int num_threads = std::thread::hardware_concurrency(); - ObjectPtr n = make_object(mod_eq_name); + ObjectPtr n = ffi::make_object(mod_eq_name); // Load `n->workloads2idx_` from `path_workload` std::vector workloads; { @@ -173,7 +171,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, // Todo(tvm-team): re-enable the shash check when we get environment // independent structural hash values. if (recalc_hash != workload->shash) { - ObjectPtr wkl = make_object(*workload.get()); + ObjectPtr wkl = ffi::make_object(*workload.get()); wkl->shash = recalc_hash; workload = Workload(wkl); } @@ -185,11 +183,11 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, { std::vector json_objs = JSONFileReadLines(path_tuning_record, num_threads, allow_missing); std::vector records; - records.resize(json_objs.size(), TuningRecord{nullptr}); + records.resize(json_objs.size(), TuningRecord{ffi::UnsafeInit()}); support::parallel_for_dynamic( 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { auto json_obj = json_objs[task_id].cast(); - Workload workload{nullptr}; + Workload workload{ffi::UnsafeInit()}; try { const ffi::ArrayObj* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); @@ -215,12 +213,12 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, return Database(n); } -TVM_FFI_STATIC_INIT_BLOCK({ JSONDatabaseNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { JSONDatabaseNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.DatabaseJSONDatabase", Database::JSONDatabase); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index ec08fd62a232..ef144e47631c 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -26,10 +26,10 @@ namespace meta_schedule { class MemoryDatabaseNode : public DatabaseNode { public: - explicit MemoryDatabaseNode(String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} + explicit MemoryDatabaseNode(ffi::String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} - Array records; - Array workloads; + ffi::Array records; + ffi::Array workloads; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -37,9 +37,8 @@ class MemoryDatabaseNode : public DatabaseNode { .def_ro("records", &MemoryDatabaseNode::records) .def_ro("workloads", &MemoryDatabaseNode::workloads); } - - static constexpr const char* _type_key = "meta_schedule.MemoryDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(MemoryDatabaseNode, DatabaseNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MemoryDatabase", MemoryDatabaseNode, + DatabaseNode); public: bool HasWorkload(const IRModule& mod) final { @@ -64,7 +63,7 @@ class MemoryDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; if (top_k == 0) { return {}; @@ -88,24 +87,24 @@ class MemoryDatabaseNode : public DatabaseNode { } } - Array GetAllTuningRecords() final { return records; } + ffi::Array GetAllTuningRecords() final { return records; } int64_t Size() final { return records.size(); } }; -Database Database::MemoryDatabase(String mod_eq_name) { - ObjectPtr n = make_object(mod_eq_name); +Database Database::MemoryDatabase(ffi::String mod_eq_name) { + ObjectPtr n = ffi::make_object(mod_eq_name); n->records.clear(); n->workloads.clear(); return Database(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.DatabaseMemoryDatabase", Database::MemoryDatabase); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ MemoryDatabaseNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MemoryDatabaseNode::RegisterReflection(); } } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc index 07526fbc45ab..ddb38af9d581 100644 --- a/src/meta_schedule/database/ordered_union_database.cc +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -25,22 +25,21 @@ namespace meta_schedule { class OrderedUnionDatabaseNode : public DatabaseNode { public: - Array databases; + ffi::Array databases; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("databases", &OrderedUnionDatabaseNode::databases); } - - static constexpr const char* _type_key = "meta_schedule.OrderedUnionDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(OrderedUnionDatabaseNode, DatabaseNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.OrderedUnionDatabase", OrderedUnionDatabaseNode, + DatabaseNode); public: - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& task_name) final { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& task_name) final { for (const Database& db : databases) { - if (Optional record = db->QueryTuningRecord(mod, target, task_name)) { + if (ffi::Optional record = db->QueryTuningRecord(mod, target, task_name)) { return record; } } @@ -62,12 +61,12 @@ class OrderedUnionDatabaseNode : public DatabaseNode { throw; } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetTopK"; throw; } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetAllTuningRecords"; throw; } @@ -78,19 +77,19 @@ class OrderedUnionDatabaseNode : public DatabaseNode { } }; -Database Database::OrderedUnionDatabase(Array databases) { - ObjectPtr n = make_object(); +Database Database::OrderedUnionDatabase(ffi::Array databases) { + ObjectPtr n = ffi::make_object(); n->databases = std::move(databases); return Database(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.DatabaseOrderedUnionDatabase", Database::OrderedUnionDatabase); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ OrderedUnionDatabaseNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { OrderedUnionDatabaseNode::RegisterReflection(); } } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index 1f85654cfa0c..5825b6834b8f 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -25,7 +25,8 @@ namespace meta_schedule { class ScheduleFnDatabaseNode : public DatabaseNode { public: - explicit ScheduleFnDatabaseNode(String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} + explicit ScheduleFnDatabaseNode(ffi::String mod_eq_name = "structural") + : DatabaseNode(mod_eq_name) {} ffi::TypedFunction schedule_fn; @@ -34,14 +35,13 @@ class ScheduleFnDatabaseNode : public DatabaseNode { refl::ObjectDef().def_ro("schedule_fn", &ScheduleFnDatabaseNode::schedule_fn); } - - static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.ScheduleFnDatabase", ScheduleFnDatabaseNode, + DatabaseNode); public: - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name) final { - if (Optional sch = this->QuerySchedule(mod, target, workload_name)) { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { + if (ffi::Optional sch = this->QuerySchedule(mod, target, workload_name)) { return TuningRecord(sch.value()->trace().value(), /*workload=*/Workload(mod, 0), // /*run_secs=*/std::nullopt, // @@ -51,8 +51,8 @@ class ScheduleFnDatabaseNode : public DatabaseNode { return std::nullopt; } - Optional QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { tir::Schedule sch = tir::Schedule::Traced(WithAttr(mod, "task_name", workload_name), /*rand_state=*/-1, @@ -79,12 +79,12 @@ class ScheduleFnDatabaseNode : public DatabaseNode { throw; } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK"; throw; } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetAllTuningRecords"; throw; } @@ -96,18 +96,18 @@ class ScheduleFnDatabaseNode : public DatabaseNode { }; Database Database::ScheduleFnDatabase(ffi::TypedFunction schedule_fn, - String mod_eq_name) { - ObjectPtr n = make_object(mod_eq_name); + ffi::String mod_eq_name) { + ObjectPtr n = ffi::make_object(mod_eq_name); n->schedule_fn = std::move(schedule_fn); return Database(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.DatabaseScheduleFnDatabase", Database::ScheduleFnDatabase); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ ScheduleFnDatabaseNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ScheduleFnDatabaseNode::RegisterReflection(); } } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc index 38864a5fcc03..125bcb7ac45f 100644 --- a/src/meta_schedule/database/union_database.cc +++ b/src/meta_schedule/database/union_database.cc @@ -25,28 +25,26 @@ namespace meta_schedule { class UnionDatabaseNode : public DatabaseNode { public: - Array databases; + ffi::Array databases; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("databases", &UnionDatabaseNode::databases); } - - static constexpr const char* _type_key = "meta_schedule.UnionDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(UnionDatabaseNode, DatabaseNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.UnionDatabase", UnionDatabaseNode, DatabaseNode); public: - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& task_name) final { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& task_name) final { std::vector results; results.reserve(databases.size()); for (const Database& db : databases) { - if (Optional record = db->QueryTuningRecord(mod, target, task_name)) { + if (ffi::Optional record = db->QueryTuningRecord(mod, target, task_name)) { results.push_back(record.value()); } } std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); - return results.empty() ? Optional(std::nullopt) : results[0]; + return results.empty() ? ffi::Optional(std::nullopt) : results[0]; } bool HasWorkload(const IRModule& mod) final { @@ -64,12 +62,12 @@ class UnionDatabaseNode : public DatabaseNode { throw; } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { LOG(FATAL) << "NotImplementedError: UnionDatabase.GetTopK"; throw; } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { LOG(FATAL) << "NotImplementedError: UnionDatabase.GetAllTuningRecords"; throw; } @@ -80,18 +78,18 @@ class UnionDatabaseNode : public DatabaseNode { } }; -Database Database::UnionDatabase(Array databases) { - ObjectPtr n = make_object(); +Database Database::UnionDatabase(ffi::Array databases) { + ObjectPtr n = ffi::make_object(); n->databases = std::move(databases); return Database(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.DatabaseUnionDatabase", Database::UnionDatabase); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ UnionDatabaseNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { UnionDatabaseNode::RegisterReflection(); } } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index 41980adc0034..6410a50c133e 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -28,9 +28,9 @@ namespace tvm { namespace meta_schedule { -ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, - Array dispatched, int weight) { - ObjectPtr n = make_object(); +ExtractedTask::ExtractedTask(ffi::String task_name, IRModule mod, Target target, + ffi::Array dispatched, int weight) { + ObjectPtr n = ffi::make_object(); n->task_name = task_name; n->mod = mod; n->target = target; @@ -39,16 +39,16 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, data_ = n; } -TVM_FFI_STATIC_INIT_BLOCK({ ExtractedTaskNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ExtractedTaskNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ExtractedTask", - [](String task_name, IRModule mod, Target target, - Array dispatched, int weight) -> ExtractedTask { + [](ffi::String task_name, IRModule mod, Target target, + ffi::Array dispatched, int weight) -> ExtractedTask { return ExtractedTask(task_name, mod, target, dispatched, weight); }); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc index 1f0668a84922..978ba658020c 100644 --- a/src/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -23,8 +23,8 @@ namespace tvm { namespace meta_schedule { -Array PyFeatureExtractorNode::ExtractFrom( - const TuneContext& context, const Array& candidates) { +ffi::Array PyFeatureExtractorNode::ExtractFrom( + const TuneContext& context, const ffi::Array& candidates) { ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; return f_extract_from(context, candidates); } @@ -32,7 +32,7 @@ Array PyFeatureExtractorNode::ExtractFrom( FeatureExtractor FeatureExtractor::PyFeatureExtractor( PyFeatureExtractorNode::FExtractFrom f_extract_from, // PyFeatureExtractorNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_extract_from = std::move(f_extract_from); n->f_as_string = std::move(f_as_string); return FeatureExtractor(n); @@ -47,18 +47,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { FeatureExtractorNode::RegisterReflection(); PyFeatureExtractorNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.FeatureExtractorExtractFrom", &FeatureExtractorNode::ExtractFrom) .def("meta_schedule.FeatureExtractorPyFeatureExtractor", FeatureExtractor::PyFeatureExtractor); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index d99fe6cc7847..9072ccf62a94 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -84,7 +84,8 @@ std::vector GetBufferShape(const Buffer& buffer, arith::Analyzer* analy * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if it does not exist */ int64_t GetPragmaAutoUnroll(const ForNode* loop) { - if (Optional auto_unroll = GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { + if (ffi::Optional auto_unroll = + GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { return auto_unroll.value()->value; } return -1; @@ -216,18 +217,18 @@ int64_t GetVarStride(const std::vector& multi_indices, const IntVec& } /*! - * \brief Converts a 2-dimensional STL vector to a TVM NDArray + * \brief Converts a 2-dimensional STL vector to a TVM Tensor * \param src The source 2-dimensional STL vector * \param second_dim_size The length of the second dimension. When the first dim of src is 0, - * second_dim_size must be specified, and in such case the shape of the result NDArray is + * second_dim_size must be specified, and in such case the shape of the result Tensor is * (0, second_dim_size). - * \return The converted TVM NDArray + * \return The converted TVM Tensor */ -runtime::NDArray AsNDArray(const std::vector>& src, int second_dim_size = -1) { +runtime::Tensor AsTensor(const std::vector>& src, int second_dim_size = -1) { int n = src.size(); ICHECK(!src.empty() || second_dim_size != -1); int m = src.empty() ? second_dim_size : src[0].size(); - runtime::NDArray tgt = runtime::NDArray::Empty( + runtime::Tensor tgt = runtime::Tensor::Empty( /*shape=*/{n, m}, /*dtype=*/DLDataType{kDLFloat, 64, 1}, /*ctx=*/DLDevice{kDLCPU, 0}); @@ -267,16 +268,16 @@ Pass SimplifyForFeatureExtraction() { PrimExpr VisitExpr_(const SelectNode* node) final { if (HasBufferLoad(node->true_value) || HasBufferLoad(node->false_value) || HasBufferLoad(node->condition)) { - return GetRef(node); } return make_const(node->dtype, 1.0); } PrimExpr VisitExpr_(const VarNode* var) final { - if (unit_vars_.count(GetRef(var))) { + if (unit_vars_.count(ffi::GetRef(var))) { return make_const(var->dtype, 0.0); } - return GetRef(var); + return ffi::GetRef(var); } Stmt VisitStmt_(const ForNode* loop) final { @@ -308,7 +309,7 @@ Pass SimplifyForFeatureExtraction() { */ Sequential PassListForPerStoreFeature() { return Sequential({ - tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true), + tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_tensor_rewrite*/ true), tir::transform::SimplifyForFeatureExtraction(), tir::transform::LowerCrossThreadReduction(), tir::transform::LowerInitBlock(), @@ -859,7 +860,7 @@ void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::Analyzer* // For each buffer, we find the loop stride on it const BufferNode* buffer = this->buffer; int ndim = this->buffer->shape.size(); - IntVec buffer_shape = utils::GetBufferShape(GetRef(buffer), analyzer); + IntVec buffer_shape = utils::GetBufferShape(ffi::GetRef(buffer), analyzer); // Calculate the buffer's stride from its shape IntVec buffer_stride(ndim); if (ndim >= 1) { @@ -1398,11 +1399,11 @@ class PerStoreFeatureNode : public FeatureExtractorNode { } } - Array ExtractFrom(const TuneContext& tune_context, - const Array& candidates) { + ffi::Array ExtractFrom(const TuneContext& tune_context, + const ffi::Array& candidates) { auto& target_keys = tune_context->target.value()->keys; bool is_gpu = std::find(target_keys.begin(), target_keys.end(), "gpu") != target_keys.end(); - std::vector results; + std::vector results; results.resize(candidates.size()); std::unique_ptr feature_group6 = nullptr; if (extract_workload) { @@ -1417,20 +1418,19 @@ class PerStoreFeatureNode : public FeatureExtractorNode { feature_group6->Export(&feature); } } - results[task_id] = tir::utils::AsNDArray(features, this->feature_vector_length); + results[task_id] = tir::utils::AsTensor(features, this->feature_vector_length); }; support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f); return results; } - - static constexpr const char* _type_key = "meta_schedule.PerStoreFeature"; - TVM_DECLARE_FINAL_OBJECT_INFO(PerStoreFeatureNode, FeatureExtractorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PerStoreFeature", PerStoreFeatureNode, + FeatureExtractorNode); }; FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, int arith_intensity_curve_num_samples, int cache_line_bytes, bool extract_workload) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->buffers_per_store = buffers_per_store; n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples; n->cache_line_bytes = cache_line_bytes; @@ -1446,13 +1446,13 @@ FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, return FeatureExtractor(n); } -TVM_FFI_STATIC_INIT_BLOCK({ PerStoreFeatureNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PerStoreFeatureNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.FeatureExtractorPerStoreFeature", FeatureExtractor::PerStoreFeature); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index 89b2934fe28e..a7b455eec782 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -26,9 +26,9 @@ namespace meta_schedule { class AddToDatabaseNode : public MeasureCallbackNode { public: void Apply(const TaskScheduler& task_scheduler, int task_id, - const Array& measure_candidates, - const Array& builder_results, - const Array& runner_results) final { + const ffi::Array& measure_candidates, + const ffi::Array& builder_results, + const ffi::Array& runner_results) final { if (!task_scheduler->database_.defined()) { return; } @@ -42,11 +42,11 @@ class AddToDatabaseNode : public MeasureCallbackNode { for (int i = 0; i < n; ++i) { RunnerResult result = runner_results[i]; MeasureCandidate candidate = measure_candidates[i]; - Array run_secs{nullptr}; + ffi::Array run_secs{nullptr}; if (result->run_secs.defined()) { run_secs = result->run_secs.value(); } else { - run_secs = Array{FloatImm(DataType::Float(32), 1e10)}; + run_secs = ffi::Array{FloatImm(DataType::Float(32), 1e10)}; } database->CommitTuningRecord(TuningRecord( /*trace=*/candidate->sch->trace().value(), @@ -57,20 +57,26 @@ class AddToDatabaseNode : public MeasureCallbackNode { } } - static constexpr const char* _type_key = "meta_schedule.AddToDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(AddToDatabaseNode, MeasureCallbackNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.AddToDatabase", AddToDatabaseNode, + MeasureCallbackNode); }; MeasureCallback MeasureCallback::AddToDatabase() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return MeasureCallback(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); refl::GlobalDef().def("meta_schedule.MeasureCallbackAddToDatabase", MeasureCallback::AddToDatabase); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 08feaf354eee..bf5172349b13 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -23,11 +23,11 @@ namespace tvm { namespace meta_schedule { -void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, // - int task_id, // - const Array& measure_candidates, // - const Array& builds, // - const Array& results) { +void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const ffi::Array& measure_candidates, // + const ffi::Array& builds, // + const ffi::Array& results) { ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; auto _ = Profiler::TimedScope("MeasureCallback/" + this->f_as_string()); return f_apply(task_scheduler, task_id, measure_candidates, builds, results); @@ -35,13 +35,13 @@ void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // PyMeasureCallbackNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_apply = std::move(f_apply); n->f_as_string = std::move(f_as_string); return MeasureCallback(n); } -Array MeasureCallback::Default() { +ffi::Array MeasureCallback::Default() { return { MeasureCallback::AddToDatabase(), MeasureCallback::RemoveBuildArtifact(), @@ -58,18 +58,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { MeasureCallbackNode::RegisterReflection(); PyMeasureCallbackNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.MeasureCallbackApply", &MeasureCallbackNode::Apply) .def("meta_schedule.MeasureCallbackPyMeasureCallback", MeasureCallback::PyMeasureCallback) .def("meta_schedule.MeasureCallbackDefault", MeasureCallback::Default); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc index 69fcd186f3c4..18f00efab5fc 100644 --- a/src/meta_schedule/measure_callback/remove_build_artifact.cc +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -26,32 +26,38 @@ namespace meta_schedule { class RemoveBuildArtifactNode : public MeasureCallbackNode { public: void Apply(const TaskScheduler& task_scheduler, int task_id, - const Array& measure_candidates, - const Array& builder_results, - const Array& runner_results) final { + const ffi::Array& measure_candidates, + const ffi::Array& builder_results, + const ffi::Array& runner_results) final { static auto f_rm = tvm::ffi::Function::GetGlobalRequired("meta_schedule.remove_build_dir"); auto _ = Profiler::TimedScope("MeasureCallback/RemoveBuildArtifact"); for (const BuilderResult& build_result : builder_results) { - if (Optional path = build_result->artifact_path) { + if (ffi::Optional path = build_result->artifact_path) { f_rm(path.value()); } } } - static constexpr const char* _type_key = "meta_schedule.RemoveBuildArtifact"; - TVM_DECLARE_FINAL_OBJECT_INFO(RemoveBuildArtifactNode, MeasureCallbackNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RemoveBuildArtifact", RemoveBuildArtifactNode, + MeasureCallbackNode); }; MeasureCallback MeasureCallback::RemoveBuildArtifact() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return MeasureCallback(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + RemoveBuildArtifactNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.MeasureCallbackRemoveBuildArtifact", MeasureCallback::RemoveBuildArtifact); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 1db62d5e5068..845e14e1e7ea 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -26,9 +26,9 @@ namespace meta_schedule { class UpdateCostModelNode : public MeasureCallbackNode { public: void Apply(const TaskScheduler& task_scheduler, int task_id, - const Array& measure_candidates, - const Array& builder_results, - const Array& runner_results) final { + const ffi::Array& measure_candidates, + const ffi::Array& builder_results, + const ffi::Array& runner_results) final { auto _ = Profiler::TimedScope("MeasureCallback/UpdateCostModel"); const TaskRecord& task = task_scheduler->tasks_[task_id]; if (!task_scheduler->cost_model_.defined()) { @@ -39,8 +39,8 @@ class UpdateCostModelNode : public MeasureCallbackNode { ICHECK_EQ(measure_candidates.size(), builder_results.size()); ICHECK_EQ(runner_results.size(), builder_results.size()); int n = builder_results.size(); - Array pruned_candidate; - Array pruned_runner_result; + ffi::Array pruned_candidate; + ffi::Array pruned_runner_result; pruned_candidate.reserve(n); pruned_runner_result.reserve(n); for (int i = 0; i < n; i++) { @@ -55,20 +55,26 @@ class UpdateCostModelNode : public MeasureCallbackNode { cost_model->Update(task->ctx, pruned_candidate, pruned_runner_result); } - static constexpr const char* _type_key = "meta_schedule.UpdateCostModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(UpdateCostModelNode, MeasureCallbackNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.UpdateCostModel", UpdateCostModelNode, + MeasureCallbackNode); }; MeasureCallback MeasureCallback::UpdateCostModel() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return MeasureCallback(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + UpdateCostModelNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.MeasureCallbackUpdateCostModel", MeasureCallback::UpdateCostModel); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index df8c45b5e697..8eb1f46b0b22 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -34,53 +34,53 @@ class ModuleEqualityStructural : public ModuleEquality { public: size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); } bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); } - String GetName() const { return "structural"; } + ffi::String GetName() const { return "structural"; } }; -class ModuleEqualityIgnoreNDArray : public ModuleEquality { +class ModuleEqualityIgnoreTensor : public ModuleEquality { public: size_t Hash(IRModule mod) const { return tvm::ffi::StructuralHash::Hash(mod, /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + /*skip_tensor_content=*/true); } bool Equal(IRModule lhs, IRModule rhs) const { return tvm::ffi::StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + /*skip_tensor_content=*/true); } - String GetName() const { return "ignore-ndarray"; } + ffi::String GetName() const { return "ignore-tensor"; } }; -// The NDArray-ignoring variant of structural equal / hash is used for the module equality +// The Tensor-ignoring variant of structural equal / hash is used for the module equality // on the extracted anchor blocks. class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { auto anchor_block = tir::FindAnchorBlock(mod); if (anchor_block) { - return ffi::StructuralHash::Hash(GetRef(anchor_block), + return ffi::StructuralHash::Hash(ffi::GetRef(anchor_block), /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + /*skip_tensor_content=*/true); } - return ModuleEqualityIgnoreNDArray().Hash(mod); + return ModuleEqualityIgnoreTensor().Hash(mod); } bool Equal(IRModule lhs, IRModule rhs) const { auto anchor_block_lhs = tir::FindAnchorBlock(lhs); auto anchor_block_rhs = tir::FindAnchorBlock(rhs); if (anchor_block_lhs && anchor_block_rhs) { - return tvm::ffi::StructuralEqual::Equal(GetRef(anchor_block_lhs), - GetRef(anchor_block_rhs), + return tvm::ffi::StructuralEqual::Equal(ffi::GetRef(anchor_block_lhs), + ffi::GetRef(anchor_block_rhs), /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + /*skip_tensor_content=*/true); } - return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs); + return ModuleEqualityIgnoreTensor().Equal(lhs, rhs); } - String GetName() const { return "anchor-block"; } + ffi::String GetName() const { return "anchor-block"; } }; std::unique_ptr ModuleEquality::Create(const std::string& mod_eq_name) { if (mod_eq_name == "structural") { return std::make_unique(); - } else if (mod_eq_name == "ignore-ndarray") { - return std::make_unique(); + } else if (mod_eq_name == "ignore-tensor") { + return std::make_unique(); } else if (mod_eq_name == "anchor-block") { return std::make_unique(); } diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index 7aa3944a4048..f9546438157d 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -34,17 +34,17 @@ class ModuleEquality { virtual size_t Hash(IRModule mod) const = 0; virtual bool Equal(IRModule lhs, IRModule rhs) const = 0; - virtual String GetName() const = 0; + virtual ffi::String GetName() const = 0; /*! * \brief Create a ModuleEquality instance * \param mod_eq_name A string to specify the module equality testing and hashing method. * It must be one of the followings: * - "structural": Use StructuralEqual/Hash - * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + * - "ignore-tensor": Same as "structural", but ignore tensor raw data during * equality testing and hashing. * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a - * given module. The "ignore-ndarray" varint is used for the extracted blocks + * given module. The "ignore-tensor" varint is used for the extracted blocks * or in case no anchor block is found. * For the definition of the anchor block, see tvm/tir/analysis.h. * \return An owning pointer to the created instance diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 7825e8909429..4ad979648aca 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -37,9 +37,8 @@ class MutateComputeLocationNode : public MutatorNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation"; - TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MutateComputeLocation", + MutateComputeLocationNode, MutatorNode); public: // Inherit from `MutatorNode` @@ -47,10 +46,10 @@ class MutateComputeLocationNode : public MutatorNode { this->json_mod_ = SaveJSON(context->mod.value()); } // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } @@ -86,9 +85,9 @@ std::vector MutateComputeLocationNode::Fin InstructionKind::Get("SampleComputeLocation"); std::vector candidates; - auto f_decision_provider = [&](const tir::Instruction& inst, // - const Array& inputs, // - const Array& attrs, // + auto f_decision_provider = [&](const tir::Instruction& inst, // + const ffi::Array& inputs, // + const ffi::Array& attrs, // const Any& decision) -> Any { if (inst->kind.same_as(inst_sample_compute_location)) { // Step 1. Extract the instruction input and the old decision. @@ -118,7 +117,7 @@ std::vector MutateComputeLocationNode::Fin return candidates; } -Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector candidates = FindCandidates(trace, rand_state); if (candidates.empty()) { return std::nullopt; @@ -129,16 +128,16 @@ Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* } Mutator Mutator::MutateComputeLocation() { - return Mutator(make_object()); + return Mutator(ffi::make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ MutateComputeLocationNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MutateComputeLocationNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.MutatorMutateComputeLocation", Mutator::MutateComputeLocation); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index b7c532ae5b0f..66266dd2a539 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -37,7 +37,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) { return false; } ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == attr::meta_schedule_parallel; } @@ -79,13 +79,13 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { * \return The parallel structure */ std::vector> AnalyzeParallel(const ScheduleState& self, - const String& block_name, const String& func_name, - int64_t limit) { - Array block_srefs = + const ffi::String& block_name, + const ffi::String& func_name, int64_t limit) { + ffi::Array block_srefs = tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); ICHECK_EQ(block_srefs.size(), 1); const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]); - ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); + ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(ffi::GetRef(block)); std::vector> results; results.reserve(info.realizes.size()); for (const BlockRealize& realize : info.realizes) { @@ -176,9 +176,8 @@ class MutateParallelNode : public MutatorNode { refl::ObjectDef().def_ro("max_jobs_per_core", &MutateParallelNode::max_jobs_per_core); } - - static constexpr const char* _type_key = "meta_schedule.MutateParallel"; - TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MutateParallel", MutateParallelNode, + MutatorNode); public: struct Candidate; @@ -189,10 +188,10 @@ class MutateParallelNode : public MutatorNode { this->json_mod_ = SaveJSON(context->mod.value()); } // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } }; @@ -204,9 +203,9 @@ struct MutateParallelNode::Candidate { /*! \brief The current parallel extent */ int64_t parallel_extent; /*! \brief The name of the root block */ - String block_name; + ffi::String block_name; /*! \brief The name of the PrimFunc */ - String func_name; + ffi::String func_name; }; /*! @@ -241,14 +240,14 @@ bool FindParallelDecision(const Trace& trace, TRandState* rand_state, const InstructionNode* get_block_inst = get_block_insts.at(Downcast(ann_inst->inputs[0]).get()); ICHECK_EQ(get_block_inst->attrs.size(), 2); - candidate->inst = GetRef(ann_inst); + candidate->inst = ffi::GetRef(ann_inst); candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; - candidate->block_name = Downcast(get_block_inst->attrs[0]); - candidate->func_name = Downcast(get_block_inst->attrs[1]); + candidate->block_name = Downcast(get_block_inst->attrs[0]); + candidate->func_name = Downcast(get_block_inst->attrs[1]); return true; } -Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { // Step 1. Find a parallel decision. Candidate candidate; if (!FindParallelDecision(trace, rand_state, &candidate)) { @@ -293,7 +292,7 @@ Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_s } int64_t limit = it->second; // Step 6. Assemble a new trace - Array insts; + ffi::Array insts; insts.reserve(trace->insts.size()); for (const Instruction& inst : trace->insts) { if (inst.same_as(candidate.inst)) { @@ -308,17 +307,17 @@ Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_s } Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_jobs_per_core = max_jobs_per_core; return Mutator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ MutateParallelNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MutateParallelNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.MutatorMutateParallel", Mutator::MutateParallel); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 26e3a4709a91..ef9c30729485 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -37,9 +37,8 @@ class MutateThreadBindingNode : public MutatorNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.MutateThreadBinding"; - TVM_DECLARE_FINAL_OBJECT_INFO(MutateThreadBindingNode, MutatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MutateThreadBinding", MutateThreadBindingNode, + MutatorNode); public: // Inherit from `MutatorNode` @@ -47,10 +46,10 @@ class MutateThreadBindingNode : public MutatorNode { this->json_mod_ = SaveJSON(context->mod.value()); } // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } @@ -111,7 +110,7 @@ std::vector MutateThreadBindingNode::FindCan } ICHECK_EQ(inst->inputs.size(), 1); ICHECK_EQ(inst->attrs.size(), 1); - if (Downcast(inst->attrs[0]) != "threadIdx.x") return false; + if (Downcast(inst->attrs[0]) != "threadIdx.x") return false; return sampled_split_insts.find(Downcast(inst->inputs[0]).get()) != sampled_split_insts.end(); @@ -143,17 +142,17 @@ std::vector MutateThreadBindingNode::FindCan ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; - int decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + int decision = Downcast(trace->decisions[ffi::GetRef(sample_inst)])->value; std::vector probs = - support::AsVector(Downcast>(sample_inst->attrs[1])); + support::AsVector(Downcast>(sample_inst->attrs[1])); - candidates.emplace_back(GetRef(sample_inst), probs, decision); + candidates.emplace_back(ffi::GetRef(sample_inst), probs, decision); } return candidates; } -Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector candidates = FindCandidates(trace, rand_state); if (candidates.empty()) { return std::nullopt; @@ -168,14 +167,16 @@ Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* r return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); } -Mutator Mutator::MutateThreadBinding() { return Mutator(make_object()); } +Mutator Mutator::MutateThreadBinding() { + return Mutator(ffi::make_object()); +} -TVM_FFI_STATIC_INIT_BLOCK({ MutateThreadBindingNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MutateThreadBindingNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.MutateThreadBinding", Mutator::MutateThreadBinding); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index fc56feedfba8..e2f3689d2854 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -37,7 +37,7 @@ using tir::Trace; */ std::vector DowncastTilingDecision(const ObjectRef& decision) { const auto* arr = TVM_TYPE_AS(decision, ffi::ArrayObj); - return support::AsVector(GetRef>(arr)); + return support::AsVector(ffi::GetRef>(arr)); } /*! @@ -60,18 +60,17 @@ class MutateTileSizeNode : public MutatorNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.MutateTileSize"; - TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MutateTileSize", MutateTileSizeNode, + MutatorNode); public: // Inherit from `MutatorNode` void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } }; @@ -119,7 +118,7 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, if (inst->kind.same_as(inst_annotate)) { ICHECK_EQ(inst->attrs.size(), 1); ICHECK_EQ(inst->inputs.size(), 2); - if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { + if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { const auto* ann_val = inst->inputs[1].as(); ICHECK(ann_val); annotated.insert(ann_val); @@ -134,7 +133,7 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, if (annotated.count(inst->outputs[0].as())) { ICHECK_EQ(inst->attrs.size(), 2); std::vector probs = - support::AsVector(Downcast>(inst->attrs[1])); + support::AsVector(Downcast>(inst->attrs[1])); if (probs.size() == 1) { // Skip mutating the sampling instructions who have only single candidate. continue; @@ -191,8 +190,8 @@ struct FactorMemo { std::mutex mutex_; }; -Optional MutateSampleTileSize(const Trace& trace, Instruction inst, - std::vector tiles, TRandState* rand_state) { +ffi::Optional MutateSampleTileSize(const Trace& trace, Instruction inst, + std::vector tiles, TRandState* rand_state) { int n_splits = tiles.size(); // Step 1. Choose two loops, `x` and `y` int x, y; @@ -235,11 +234,11 @@ Optional MutateSampleTileSize(const Trace& trace, Instruction inst, } } -Optional MutateSampleVectorize(const Trace& trace, Instruction inst, - int64_t original_decision, TRandState* rand_state) { +ffi::Optional MutateSampleVectorize(const Trace& trace, Instruction inst, + int64_t original_decision, TRandState* rand_state) { ICHECK_EQ(inst->attrs.size(), 2); std::vector probs = - support::AsVector(Downcast>(inst->attrs[1])); + support::AsVector(Downcast>(inst->attrs[1])); probs.erase(probs.begin() + original_decision); int result = tir::MakeMultinomialSampler(rand_state, probs)(); if (result >= original_decision) { @@ -248,7 +247,7 @@ Optional MutateSampleVectorize(const Trace& trace, Instruction inst, return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true); } -Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector sample_perfect_tile_insts; std::vector sample_vectorize_insts; std::vector> sample_perfect_tile_tiles; @@ -271,14 +270,14 @@ Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s } } -Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } +Mutator Mutator::MutateTileSize() { return Mutator(ffi::make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ MutateTileSizeNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MutateTileSizeNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.MutatorMutateTileSize", Mutator::MutateTileSize); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 74b3cae05d52..dab987708238 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -35,7 +35,7 @@ bool IsAnnotateWithUnroll(const Instruction& inst) { return false; } ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == attr::meta_schedule_unroll_explicit || ann_key == attr::meta_schedule_unroll_implicit; } @@ -56,19 +56,17 @@ class MutateUnrollNode : public MutatorNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.MutateUnroll"; - TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MutateUnroll", MutateUnrollNode, MutatorNode); public: struct Candidate; // Inherit from `MutatorNode` void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } }; @@ -118,14 +116,15 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, ICHECK(sample_insts.count(var_rv)); const InstructionNode* sample_inst = sample_insts.at(var_rv); ICHECK_EQ(sample_inst->attrs.size(), 2); - candidate->inst = GetRef(sample_inst); - candidate->decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->inst = ffi::GetRef(sample_inst); + candidate->decision = + Downcast(trace->decisions[ffi::GetRef(sample_inst)])->value; candidate->probs = - support::AsVector(Downcast>(sample_inst->attrs[1])); + support::AsVector(Downcast>(sample_inst->attrs[1])); return true; } -Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { Candidate candidate; if (!FindUnrollDecision(trace, rand_state, &candidate)) { return std::nullopt; @@ -141,14 +140,14 @@ Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_sta return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); } -Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } +Mutator Mutator::MutateUnroll() { return Mutator(ffi::make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ MutateUnrollNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MutateUnrollNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.MutatorMutateUnroll", Mutator::MutateUnroll); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 50ab81f95f27..fd8fe45bf185 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -29,7 +29,7 @@ void PyMutatorNode::InitializeWithTuneContext(const TuneContext& context) { f_initialize_with_tune_context(context); } -Optional PyMutatorNode::Apply( +ffi::Optional PyMutatorNode::Apply( const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) { ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; return f_apply(trace, *rand_state); @@ -45,7 +45,7 @@ Mutator Mutator::PyMutator( PyMutatorNode::FApply f_apply, // PyMutatorNode::FClone f_clone, // PyMutatorNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); n->f_clone = std::move(f_clone); @@ -53,25 +53,25 @@ Mutator Mutator::PyMutator( return Mutator(n); } -Map Mutator::DefaultLLVM() { - return Map{ +ffi::Map Mutator::DefaultLLVM() { + return ffi::Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)}, {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}, {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; } -Map Mutator::DefaultCUDA() { - return Map{ +ffi::Map Mutator::DefaultCUDA() { + return ffi::Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.08)}, {Mutator::MutateThreadBinding(), FloatImm(DataType::Float(64), 0.02)}}; } -Map Mutator::DefaultCUDATensorCore() { return Mutator::DefaultCUDA(); } +ffi::Map Mutator::DefaultCUDATensorCore() { return Mutator::DefaultCUDA(); } -Map Mutator::DefaultHexagon() { - return Map{ +ffi::Map Mutator::DefaultHexagon() { + return ffi::Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)}, {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}, @@ -87,18 +87,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { MutatorNode::RegisterReflection(); PyMutatorNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.MutatorInitializeWithTuneContext", &MutatorNode::InitializeWithTuneContext) .def("meta_schedule.MutatorApply", - [](Mutator self, tir::Trace trace, TRandState seed) -> Optional { + [](Mutator self, tir::Trace trace, TRandState seed) -> ffi::Optional { TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); return self->Apply(trace, &seed_); @@ -109,7 +109,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("meta_schedule.MutatorDefaultCUDA", Mutator::DefaultCUDA) .def("meta_schedule.MutatorDefaultCUDATensorCore", Mutator::DefaultCUDATensorCore) .def("meta_schedule.MutatorDefaultHexagon", Mutator::DefaultHexagon); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 0aef44c58bcf..9f59404de5ef 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -83,7 +83,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { } // map loop variable to zero for the store index & simplify - Array store_index = bufferstorenode->indices; + ffi::Array store_index = bufferstorenode->indices; // Use DetectIterMap to detect whether store index is non-contiguous. arith::Analyzer analyzer; @@ -94,7 +94,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { } // map loop variable to zero for the load index & simplify - Array load_index = bufferloadnode->indices; + ffi::Array load_index = bufferloadnode->indices; // Use DetectIterMap to detect whether load index is non-contiguous. auto load_iter_map = DetectIterMap(load_index, input_iters, 1, @@ -110,7 +110,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { } bool found_ = false; - Map input_iters = Map(); + ffi::Map input_iters = ffi::Map(); }; } // namespace tir @@ -133,9 +133,9 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { - IRModule lowered{nullptr}; + IRModule lowered{ffi::UnsafeInit()}; try { - auto pass_list = Array(); + auto pass_list = ffi::Array(); pass_list.push_back(tir::transform::BindTarget(this->target)); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); @@ -152,9 +152,10 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::VectorizeLoop(true)); pass_list.push_back(tir::transform::StorageRewrite()); - tir::PrimFunc f = - WithAttr(GetRef(prim_func), "global_symbol", String(g_var->name_hint)); - IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); + tir::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", + ffi::String(g_var->name_hint)); + IRModule mod = + IRModule(ffi::Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); } catch (const dmlc::Error& e) { return false; @@ -169,27 +170,34 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { // Inherited from PostprocNode Postproc Clone() const { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy"; - TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.DisallowAsyncStridedMemCopy", + DisallowAsyncStridedMemCopyNode, PostprocNode); private: tvm::Target target; }; Postproc Postproc::DisallowAsyncStridedMemCopy() { - ObjectPtr n = make_object(); + ObjectPtr n = + ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + DisallowAsyncStridedMemCopyNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocDisallowAsyncStridedMemCopy", Postproc::DisallowAsyncStridedMemCopy); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index 47588c42a0a5..df7344455e6d 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -71,23 +71,29 @@ class DisallowDynamicLoopNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } // Inherited from PostprocNode Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; - TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.DisallowDynamicLoop", DisallowDynamicLoopNode, + PostprocNode); }; Postproc Postproc::DisallowDynamicLoop() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + DisallowDynamicLoopNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocDisallowDynamicLoop", Postproc::DisallowDynamicLoop); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index ccf280860d80..41557830afb6 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -44,7 +44,7 @@ Postproc Postproc::PyPostproc( PyPostprocNode::FApply f_apply, // PyPostprocNode::FClone f_clone, // PyPostprocNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); n->f_clone = std::move(f_clone); @@ -52,8 +52,8 @@ Postproc Postproc::PyPostproc( return Postproc(n); } -Array Postproc::DefaultLLVM() { - return Array{ +ffi::Array Postproc::DefaultLLVM() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), @@ -61,16 +61,24 @@ Array Postproc::DefaultLLVM() { }; } -Array Postproc::DefaultCPUTensorization() { - return Array{ +ffi::Array Postproc::DefaultCPUTensorization() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/true), Postproc::RewriteLayout(), }; } -Array Postproc::DefaultCUDA() { - return Array{ +ffi::Array Postproc::DefaultRISCV() { + return ffi::Array{ + Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), + Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/false), + Postproc::RewriteLayout(), + }; +} + +ffi::Array Postproc::DefaultCUDA() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteCooperativeFetch(), Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256), @@ -80,8 +88,8 @@ Array Postproc::DefaultCUDA() { }; } -Array Postproc::DefaultCUDATensorCore() { - return Array{ +ffi::Array Postproc::DefaultCUDATensorCore() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteCooperativeFetch(), Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256), @@ -94,8 +102,8 @@ Array Postproc::DefaultCUDATensorCore() { }; } -Array Postproc::DefaultHexagon() { - return Array{ +ffi::Array Postproc::DefaultHexagon() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), Postproc::RewriteLayout(), Postproc::VerifyVTCMLimit(), @@ -111,12 +119,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PostprocNode::RegisterReflection(); PyPostprocNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.PostprocInitializeWithTuneContext", @@ -128,7 +136,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("meta_schedule.PostprocDefaultCUDA", Postproc::DefaultCUDA) .def("meta_schedule.PostprocDefaultCUDATensorCore", Postproc::DefaultCUDATensorCore) .def("meta_schedule.PostprocDefaultHexagon", Postproc::DefaultHexagon); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index d7009c0596f5..ae7b693efd94 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -30,14 +30,15 @@ namespace tir { * \param axis The axis name expected * \return std::nullopt if parsing fails; Otherwise, the extent of thread axis */ -Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) { +ffi::Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst, + ffi::String axis) { static InstructionKind inst_kind_bind = InstructionKind::Get("Bind"); if (!inst->kind.same_as(inst_kind_bind)) { return std::nullopt; } ICHECK_EQ(inst->inputs.size(), 1); ICHECK_EQ(inst->attrs.size(), 1); - String thread_axis = Downcast(inst->attrs[0]); + ffi::String thread_axis = Downcast(inst->attrs[0]); if (thread_axis != axis) { return std::nullopt; } @@ -51,15 +52,15 @@ Optional ParseThreadBinding(const Schedule& sch, const Instruction& ins * \param vector_lane The number of vector lane in vectorized cooperative fetching * \return std::nullopt if parsing fails; Otherwise, the annotated block */ -Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, - int64_t* vector_lane) { +ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, + int64_t* vector_lane) { static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); if (!inst->kind.same_as(inst_kind_annotate)) { return std::nullopt; } ICHECK_EQ(inst->inputs.size(), 2); ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); if (ann_key != attr::meta_schedule_cooperative_fetch) { return std::nullopt; } @@ -80,7 +81,7 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { } ICHECK_EQ(inst->inputs.size(), 2); ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == attr::warp_execution; } @@ -124,7 +125,7 @@ class RewriteCooperativeFetchNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { - if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + if (ffi::Optional v = context->target.value()->GetAttr("thread_warp_size")) { this->thread_warp_size_ = v.value()->value; } else { TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; @@ -135,12 +136,12 @@ class RewriteCooperativeFetchNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; - TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RewriteCooperativeFetch", + RewriteCooperativeFetchNode, PostprocNode); private: int thread_warp_size_ = -1; @@ -153,11 +154,13 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { int64_t vector_lane = 1; std::vector> tasks; for (const tir::Instruction& inst : trace->insts) { - if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { + if (ffi::Optional new_thread_extent = + tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { thread_extent_x = new_thread_extent.value()->value; continue; } - if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { + if (ffi::Optional new_thread_extent = + tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { thread_extent_y = new_thread_extent.value()->value; continue; } @@ -165,7 +168,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { thread_extent_x = thread_warp_size_; continue; } - Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); + ffi::Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); if (!opt_block_rv.defined()) { continue; } @@ -191,29 +194,30 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { } if (thread_extent_y != -1) { if (vector_lane > 1) { - Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x), // - Integer(vector_lane)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_y), // + Integer(thread_extent_x), // + Integer(vector_lane)}); sch->Vectorize(split[3]); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } else { - Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_y), // + Integer(thread_extent_x)}); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } } else { if (vector_lane > 1) { - Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_x), // - Integer(vector_lane)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_x), // + Integer(vector_lane)}); sch->Vectorize(split[2]); sch->Bind(split[1], "threadIdx.x"); } else { - Array split = sch->Split(fused, {std::nullopt, Integer(thread_extent_x)}); + ffi::Array split = + sch->Split(fused, {std::nullopt, Integer(thread_extent_x)}); sch->Bind(split[1], "threadIdx.x"); } } @@ -227,17 +231,17 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteCooperativeFetch() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ RewriteCooperativeFetchNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RewriteCooperativeFetchNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.PostprocRewriteCooperativeFetch", Postproc::RewriteCooperativeFetch); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 0d645fcf8b21..17acdcc9bf2f 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -36,17 +36,17 @@ class BufferReadPosCollector : public StmtExprVisitor { const std::pair& GetBufferLocation() const { return buffer_loc_; } - const Optional GetBufferIndexMap() const { return buffer_index_map_; } + const ffi::Optional GetBufferIndexMap() const { return buffer_index_map_; } private: void VisitStmt_(const ForNode* op) final { - loop_stack_.push_back(GetRef(op)); + loop_stack_.push_back(ffi::GetRef(op)); StmtVisitor::VisitStmt_(op); loop_stack_.pop_back(); } void VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize outer_block_realize = GetRef(op); + BlockRealize outer_block_realize = ffi::GetRef(op); std::swap(outer_block_realize, cur_realize_); StmtVisitor::VisitStmt_(op); std::swap(cur_realize_, outer_block_realize); @@ -57,13 +57,13 @@ class BufferReadPosCollector : public StmtExprVisitor { const Buffer& buffer = op->buffer; if (buffer_ == buffer.get()) { - Map subst_map; + ffi::Map subst_map; for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) { const Var& var = cur_realize_->block->iter_vars[i]->var; const PrimExpr& value = cur_realize_->iter_values[i]; subst_map.Set(var, value); } - Array subst_indices; + ffi::Array subst_indices; for (const PrimExpr& e : op->indices) { subst_indices.push_back(Substitute(e, subst_map)); } @@ -93,10 +93,10 @@ class BufferReadPosCollector : public StmtExprVisitor { /*! \brief The block that consumes the buffer and the corresponding read index. */ std::pair buffer_loc_; /*! \brief The proposed IndexMap. */ - Optional buffer_index_map_; + ffi::Optional buffer_index_map_; /*! \brief Loop stack for calculating IndexMap. */ - Array loop_stack_; + ffi::Array loop_stack_; /*! \brief Arithmetic analyzer. */ arith::Analyzer analyzer_; /*! \brief Current BlockRealize scope, used in recursive visit */ @@ -108,7 +108,7 @@ class LayoutFreeBufferCollector : public StmtVisitor { void VisitStmt_(const BlockNode* block) final { StmtVisitor::VisitStmt_(block); if (auto ann = block->annotations.Get("layout_free_placeholders")) { - for (Buffer buffer : Downcast>(ann.value())) { + for (Buffer buffer : Downcast>(ann.value())) { buffers.insert(buffer); } } @@ -117,12 +117,12 @@ class LayoutFreeBufferCollector : public StmtVisitor { std::unordered_set buffers; }; -Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { +ffi::Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { // Only rewrite PrimFuncs with attr "layout_free_buffers" - Array layout_free_buffer_index = - func->GetAttr(attr::layout_free_buffers, Array()).value(); + ffi::Array layout_free_buffer_index = + func->GetAttr(attr::layout_free_buffers, ffi::Array()).value(); - Array layout_free_buffers; + ffi::Array layout_free_buffers; for (const Integer& index : layout_free_buffer_index) { ICHECK(static_cast(index->value) < func->params.size()); const Var& param = func->params[index->value]; @@ -182,14 +182,14 @@ std::vector GetCacheReadChain(const Buffer& buf, const PrimFuncNode } bool RewriteLayout(const Schedule& sch) { - std::vector> results; + std::vector> results; auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) { BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, true); }; for (const auto& [g_var, base_func] : sch->mod()->functions) { - const String& func_name = g_var->name_hint; + const ffi::String& func_name = g_var->name_hint; const auto* prim_func = base_func.as(); // Only consider PrimFunc if (prim_func == nullptr) { @@ -261,23 +261,28 @@ class RewriteLayoutNode : public PostprocNode { } Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.RewriteLayout"; - TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RewriteLayout", RewriteLayoutNode, PostprocNode); }; Postproc Postproc::RewriteLayout() { - auto n = make_object(); + auto n = ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + RewriteLayoutNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocRewriteLayout", Postproc::RewriteLayout); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 945b9adbc948..d833af614221 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -146,7 +146,7 @@ void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedA } } -int CalculateNumRewritableLoops(const Array& loop_srefs, +int CalculateNumRewritableLoops(const ffi::Array& loop_srefs, const std::vector& loop_types) { int rw_loops_num = 0; ICHECK_EQ(loop_srefs.size(), loop_types.size()); @@ -174,7 +174,7 @@ int CalculateNumRewritableLoops(const Array& loop_srefs, } void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, - const Array& loop_rvs, ParsedAnnotation* parsed) { + const ffi::Array& loop_rvs, ParsedAnnotation* parsed) { StmtSRef block_sref = sch->GetSRef(block_rv); if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { return; @@ -186,7 +186,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, return; } // Extract loop_srefs, and calculate the iterator types - Array loop_srefs; + ffi::Array loop_srefs; std::vector loop_types; { loop_srefs.reserve(n_loops); @@ -198,7 +198,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, } // check the maximal number of axes that are vectorizable (contiguous memory access) BlockRealize realize = GetBlockRealize(sch->state(), block_sref); - Array buffer_access(realize->block->reads); + ffi::Array buffer_access(realize->block->reads); buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), realize->block->writes.end()); std::unordered_map binding_map; @@ -357,10 +357,11 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block return false; } -void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_rvs, int vec_len) { +void RewriteFuseSplitParallelVectorize(const Schedule& sch, ffi::Array* loop_rvs, + int vec_len) { size_t n_loops = loop_rvs->size(); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); - Array split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); + ffi::Array split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); ICHECK_EQ(split.size(), 2); const LoopRV& outer = split[0]; const LoopRV& inner = split[1]; @@ -372,7 +373,7 @@ void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_ loop_rvs->Set(n_loops - 1, inner); } -void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { +void RewriteParallel(const Schedule& sch, size_t n, ffi::Array* loop_rvs) { ICHECK_LE(n, loop_rvs->size()); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); sch->Parallel(fused); @@ -381,7 +382,7 @@ void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { } } -void RewriteVectorize(const Schedule& sch, size_t n, Array* loop_rvs) { +void RewriteVectorize(const Schedule& sch, size_t n, ffi::Array* loop_rvs) { size_t n_loops = loop_rvs->size(); ICHECK_LE(n, n_loops); LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); @@ -414,10 +415,10 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { bool Apply(const Schedule& sch) final { tir::ParsedAnnotation parsed_root; - tir::BlockRV root_rv{nullptr}; + tir::BlockRV root_rv{ffi::UnsafeInit()}; while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); if (loop_rvs.empty()) { continue; } @@ -451,25 +452,31 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { Postproc Clone() const { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; - TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RewriteParallelVectorizeUnroll", + RewriteParallelVectorizeUnrollNode, PostprocNode); }; Postproc Postproc::RewriteParallelVectorizeUnroll() { ObjectPtr n = - make_object(); + ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + RewriteParallelVectorizeUnrollNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocRewriteParallelVectorizeUnroll", Postproc::RewriteParallelVectorizeUnroll); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index bd78855d8684..fffef8ba6856 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -27,8 +27,8 @@ namespace tir { struct ReductionBlockFinder : private StmtVisitor { public: /*! \brief Find all the reduction blocks that should be decomposed */ - static std::vector> Find(const ScheduleState& self) { - std::vector> results; + static std::vector> Find(const ScheduleState& self) { + std::vector> results; for (const auto& kv : self->mod->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; @@ -92,7 +92,7 @@ struct ReductionBlockFinder : private StmtVisitor { * or -1 if the `init` does not need to be decomposed. */ int FindDecomposePoint(const StmtSRef& block_sref) { - Array loop_srefs = GetLoops(block_sref); + ffi::Array loop_srefs = GetLoops(block_sref); int n = loop_srefs.size(); for (int i = 0; i < n; ++i) { if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { @@ -122,36 +122,37 @@ class RewriteReductionBlockNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; - TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RewriteReductionBlock", + RewriteReductionBlockNode, PostprocNode); }; bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { for (;;) { - std::vector> results = + std::vector> results = tir::ReductionBlockFinder::Find(sch->state()); int rewritten = 0; for (const auto& kv : results) { const tir::StmtSRef& block_sref = kv.first; - const String& global_var_name = kv.second; + const ffi::String& global_var_name = kv.second; int decompose_point = tir::FindDecomposePoint(block_sref); if (decompose_point == -1) { continue; } tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); // Rewrite auto tensorization related annotations - if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize).has_value()) { + if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize) + .has_value()) { // Remove tensorization annotation as it shouldn't be propagated to the init block. sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); - Optional tensorize_init = - tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); + ffi::Optional tensorize_init = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); // The annotation of tensorization of the init statement should be moved to the init block // after 'DecomposeReduction'. // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is std::nullopt. @@ -172,17 +173,16 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteReductionBlock() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + RewriteReductionBlockNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocRewriteReductionBlock", Postproc::RewriteReductionBlock); -}); - -TVM_FFI_STATIC_INIT_BLOCK({ RewriteReductionBlockNode::RegisterReflection(); }); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 596bc7cb1f24..473731b5a7b5 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -30,15 +30,15 @@ using tir::BlockRV; using tir::LoopRV; void CollectTensorizationJobs( - const tir::Schedule& sch, const String& func_name, const tir::PrimFuncNode* func, + const tir::Schedule& sch, const ffi::String& func_name, const tir::PrimFuncNode* func, bool vectorize_init_loop, - std::vector>>* jobs) { + std::vector>>* jobs) { tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { if (const auto* block = obj.as()) { tir::StmtSRef block_sref = sch->GetSRef(block); std::string block_name = block_sref->StmtAs()->name_hint; - if (Optional intrin_name = - tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + if (ffi::Optional intrin_name = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { if (intrin_name.value() != "") { jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { try { @@ -49,9 +49,9 @@ void CollectTensorizationJobs( }); } else if (block_name.find("init") && vectorize_init_loop) { jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { - Array child_blocks = sch->GetChildBlocks(block); + ffi::Array child_blocks = sch->GetChildBlocks(block); ICHECK(child_blocks.size() == 1); - Array init_loops = sch->GetLoops(child_blocks[0]); + ffi::Array init_loops = sch->GetLoops(child_blocks[0]); ICHECK(init_loops.size() == 1); sch->Vectorize(init_loops[0]); }); @@ -73,19 +73,19 @@ class RewriteTensorizeNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } bool vectorize_init_loop = false; - static constexpr const char* _type_key = "meta_schedule.RewriteTensorize"; - TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RewriteTensorize", RewriteTensorizeNode, + PostprocNode); }; bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { // The rewriting jobs, 3-tuple (block_name, func_name, job_func) - std::vector>> jobs; + std::vector>> jobs; for (const auto& kv : sch->mod()->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; @@ -94,8 +94,8 @@ bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { } } for (const auto& job : jobs) { - const String& block_name = std::get<0>(job); - const String& func_name = std::get<1>(job); + const ffi::String& block_name = std::get<0>(job); + const ffi::String& func_name = std::get<1>(job); const auto& job_func = std::get<2>(job); BlockRV block = sch->GetBlock(block_name, func_name); sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); @@ -105,17 +105,16 @@ bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->vectorize_init_loop = vectorize_init_loop; return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ RewriteTensorizeNode::RegisterReflection(); }); - -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + RewriteTensorizeNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocRewriteTensorize", Postproc::RewriteTensorize); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index acebeb71cdf7..98e3db2522f1 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -27,7 +27,7 @@ namespace tir { /*! \brief Find all the blocks that are not bound */ class UnboundBlockFinder : private StmtVisitor { public: - static std::vector> Find(const ScheduleState& self) { + static std::vector> Find(const ScheduleState& self) { UnboundBlockFinder finder(self); for (const auto& kv : self->mod->functions) { GlobalVar g_var = kv.first; @@ -68,13 +68,13 @@ class UnboundBlockFinder : private StmtVisitor { /*! \brief The schedule state */ const ScheduleState& self_; /*! \brief The list of unbound blocks */ - std::vector> blocks_; + std::vector> blocks_; /*! \brief The number of blockIdx above the current stmt */ int n_block_idx_; /*! \brief The number of threadIdx above the current stmt */ int n_thread_idx_; /*! \brief The name of the global var */ - String global_var_name_; + ffi::String global_var_name_; }; } // namespace tir @@ -89,7 +89,7 @@ class RewriteUnboundBlockNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { CHECK(context->target.defined()) << "ValueError: target is not defined"; - Optional max_threads_per_block = + ffi::Optional max_threads_per_block = context->target.value()->GetAttr("max_threads_per_block"); CHECK(max_threads_per_block.defined()) << "ValueError: missing attribute `max_threads_per_block` in the target"; @@ -100,7 +100,7 @@ class RewriteUnboundBlockNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -114,9 +114,8 @@ class RewriteUnboundBlockNode : public PostprocNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock"; - TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RewriteUnboundBlock", RewriteUnboundBlockNode, + PostprocNode); }; bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { @@ -128,11 +127,11 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { return Integer(std::min(t, max_extent)); }; - std::vector> unbound_blocks = + std::vector> unbound_blocks = tir::UnboundBlockFinder::Find(sch->state()); for (const auto& kv : unbound_blocks) { tir::StmtSRef block_sref = kv.first; - String global_var_name = kv.second; + ffi::String global_var_name = kv.second; BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); } @@ -140,18 +139,18 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_threadblocks_ = max_threadblocks; n->max_threads_per_block_ = -1; return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ RewriteUnboundBlockNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RewriteUnboundBlockNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.PostprocRewriteUnboundBlock", Postproc::RewriteUnboundBlock); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 20cd0735431d..04a9cf2ea79b 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -73,9 +73,9 @@ class ThreadExtentChecker : private StmtVisitor { if (block->annotations.count(attr::warp_execution)) { thread_idx_x = thread_warp_size_; } - if (Optional low_inclusive = + if (ffi::Optional low_inclusive = GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { - if (Optional high_inclusive = + if (ffi::Optional high_inclusive = GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { int64_t low = low_inclusive.value()->value; int64_t high = high_inclusive.value()->value; @@ -104,7 +104,7 @@ namespace meta_schedule { /*! \brief Extract attribute from a target. */ Integer Extract(const Target& target, const char* name) { ICHECK(target.defined()); - if (Optional v = target->GetAttr(name)) { + if (ffi::Optional v = target->GetAttr(name)) { return v.value(); } LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; @@ -114,14 +114,14 @@ Integer Extract(const Target& target, const char* name) { /*! \brief Verify the correctness of the generated GPU code. */ class VerifyGPUCodeNode : public PostprocNode { public: - Target target_{nullptr}; - Map target_constraints_{nullptr}; + Target target_{ffi::UnsafeInit()}; + ffi::Map target_constraints_{ffi::UnsafeInit()}; int thread_warp_size_ = -1; void InitializeWithTuneContext(const TuneContext& context) final { ICHECK(context->target.defined()); this->target_ = context->target.value(); - this->target_constraints_ = Map{ + this->target_constraints_ = ffi::Map{ {"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")}, {"max_threads_per_block", Extract(this->target_, "max_threads_per_block")}, {"max_vthread", Integer(8)}, @@ -150,9 +150,9 @@ class VerifyGPUCodeNode : public PostprocNode { if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { return false; } - IRModule lowered{nullptr}; + IRModule lowered{ffi::UnsafeInit()}; try { - auto pass_list = Array(); + auto pass_list = ffi::Array(); // Phase 1 pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); @@ -180,14 +180,15 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::LowerIntrin()); // Convert Function to IRModule transform::PassContext pass_ctx = transform::PassContext::Current(); - tir::PrimFunc f = - WithAttr(GetRef(prim_func), "global_symbol", String(g_var->name_hint)); + tir::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", + ffi::String(g_var->name_hint)); f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin bool noalias = pass_ctx->GetConfig("tir.noalias", true).value(); if (noalias) { f = WithAttr(std::move(f), "tir.noalias", true); } - IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); + IRModule mod = + IRModule(ffi::Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); } catch (const std::exception&) { return false; @@ -201,24 +202,29 @@ class VerifyGPUCodeNode : public PostprocNode { } Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); n->target_constraints_ = this->target_constraints_; return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; - TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.VerifyGPUCode", VerifyGPUCodeNode, PostprocNode); }; Postproc Postproc::VerifyGPUCode() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + VerifyGPUCodeNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocVerifyGPUCode", Postproc::VerifyGPUCode); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index ee9394f16b17..f0fe8be1c1c9 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -56,23 +56,29 @@ class VerifyVTCMLimitNode : public PostprocNode { } Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } - static constexpr const char* _type_key = "meta_schedule.VerifyVTCMLimit"; - TVM_DECLARE_FINAL_OBJECT_INFO(VerifyVTCMLimitNode, PostprocNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.VerifyVTCMLimit", VerifyVTCMLimitNode, + PostprocNode); }; Postproc Postproc::VerifyVTCMLimit() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + VerifyVTCMLimitNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.PostprocVerifyVTCMLimit", Postproc::VerifyVTCMLimit); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index d133e67eadef..e0bbc904c2c1 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -28,22 +28,22 @@ namespace meta_schedule { /**************** Profiler ****************/ -Map ProfilerNode::Get() const { - Map ret; +ffi::Map ProfilerNode::Get() const { + ffi::Map ret; for (const auto& kv : stats_sec) { ret.Set(kv.first, FloatImm(DataType::Float(64), kv.second)); } return ret; } -String ProfilerNode::Table() const { +ffi::String ProfilerNode::Table() const { CHECK(!stats_sec.empty()) << "ValueError: The stats are empty. Please run the profiler first."; CHECK(stats_sec.count("Total")) << "ValueError: The total time is not recorded. This method should be called only after " "exiting the profiler's with scope."; double total = stats_sec.at("Total"); struct Entry { - String name; + ffi::String name; double minutes; double percentage; bool operator<(const Entry& other) const { return percentage > other.percentage; } @@ -71,14 +71,14 @@ String ProfilerNode::Table() const { } Profiler::Profiler() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stats_sec.clear(); n->total_timer = nullptr; data_ = n; } -ffi::Function ProfilerTimedScope(String name) { - if (Optional opt_profiler = Profiler::Current()) { +ffi::Function ProfilerTimedScope(ffi::String name) { + if (ffi::Optional opt_profiler = Profiler::Current()) { return ffi::TypedFunction([profiler = opt_profiler.value(), // tik = std::chrono::high_resolution_clock::now(), // name = std::move(name)]() { @@ -91,7 +91,7 @@ ffi::Function ProfilerTimedScope(String name) { return nullptr; } -ScopedTimer Profiler::TimedScope(String name) { return ScopedTimer(ProfilerTimedScope(name)); } +ScopedTimer Profiler::TimedScope(ffi::String name) { return ScopedTimer(ProfilerTimedScope(name)); } /**************** Context Manager ****************/ @@ -113,7 +113,7 @@ void Profiler::ExitWithScope() { } } -Optional Profiler::Current() { +ffi::Optional Profiler::Current() { std::vector* profilers = ThreadLocalProfilers(); if (profilers->empty()) { return std::nullopt; @@ -122,9 +122,9 @@ Optional Profiler::Current() { } } -TVM_FFI_STATIC_INIT_BLOCK({ ProfilerNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ProfilerNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.Profiler", []() -> Profiler { return Profiler(); }) @@ -134,7 +134,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.ProfilerGet", &ProfilerNode::Get) .def_method("meta_schedule.ProfilerTable", &ProfilerNode::Table) .def("meta_schedule.ProfilerTimedScope", ProfilerTimedScope); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 08ecb7aaa22d..1b9a3ea9a9c5 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -23,54 +23,55 @@ namespace tvm { namespace meta_schedule { -RunnerInput::RunnerInput(String artifact_path, String device_type, Array args_info) { - ObjectPtr n = make_object(); +RunnerInput::RunnerInput(ffi::String artifact_path, ffi::String device_type, + ffi::Array args_info) { + ObjectPtr n = ffi::make_object(); n->artifact_path = artifact_path; n->device_type = device_type; n->args_info = args_info; this->data_ = n; } -RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { - ObjectPtr n = make_object(); +RunnerResult::RunnerResult(ffi::Optional> run_secs, + ffi::Optional error_msg) { + ObjectPtr n = ffi::make_object(); n->run_secs = run_secs; n->error_msg = error_msg; this->data_ = n; } RunnerFuture::RunnerFuture(RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_done = f_done; n->f_result = f_result; this->data_ = n; } Runner Runner::PyRunner(Runner::FRun f_run) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_run = f_run; return Runner(n); } /******** FFI ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { + RunnerNode::RegisterReflection(); RunnerInputNode::RegisterReflection(); RunnerResultNode::RegisterReflection(); RunnerFutureNode::RegisterReflection(); PyRunnerNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.RunnerInput", - [](String artifact_path, String device_type, Array args_info) -> RunnerInput { - return RunnerInput(artifact_path, device_type, args_info); - }) + [](ffi::String artifact_path, ffi::String device_type, ffi::Array args_info) + -> RunnerInput { return RunnerInput(artifact_path, device_type, args_info); }) .def("meta_schedule.RunnerResult", - [](Optional> run_secs, Optional error_msg) -> RunnerResult { - return RunnerResult(run_secs, error_msg); - }) + [](ffi::Optional> run_secs, ffi::Optional error_msg) + -> RunnerResult { return RunnerResult(run_secs, error_msg); }) .def("meta_schedule.RunnerFuture", [](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { return RunnerFuture(f_done, f_result); @@ -79,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.RunnerFutureResult", &RunnerFutureNode::Result) .def_method("meta_schedule.RunnerRun", &RunnerNode::Run) .def("meta_schedule.RunnerPyRunner", Runner::PyRunner); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index 9d2cdaedbde3..c3fd12e282b3 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -26,21 +26,21 @@ namespace meta_schedule { using namespace tvm::tir; -static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, - std::vector tiled, std::vector unrolled) { +static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, + std::vector tiled, std::vector unrolled) { using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - Array factors{nullptr}; - Array loops = sch->GetLoops(block); + ffi::Array factors{ffi::UnsafeInit()}; + ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); - Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); + ffi::Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); ICHECK_EQ(t0.size(), 2); factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); - Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); + ffi::Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); ICHECK_EQ(t1.size(), 2); sch->Unroll(loops[unrolled[0]]); @@ -60,11 +60,11 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); @@ -75,13 +75,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cpu.conv2d_nhwc_winograd_inverse", - [](Schedule sch, BlockRV block) -> Array { + [](Schedule sch, BlockRV block) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); return {sch}; }) .def("meta_schedule.cpu.conv2d_nchw_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); @@ -92,12 +92,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cpu.conv2d_nchw_winograd_inverse", - [](Schedule sch, BlockRV block) -> Array { + [](Schedule sch, BlockRV block) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); return {sch}; }); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index 287f764a4640..2a042553d6b9 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -31,10 +31,10 @@ namespace meta_schedule { using namespace tvm::tir; -std::function MakeFactorSampler(Schedule sch, Array thread_extents) { +std::function MakeFactorSampler(Schedule sch, ffi::Array thread_extents) { return [sch = std::move(sch), thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { - Array extents; + ffi::Array extents; extents.reserve(thread_extents.size()); for (const Integer extent : thread_extents) { if (extent->value <= max_extent) { @@ -48,14 +48,14 @@ std::function MakeFactorSampler(Schedule sch, Array th if (n == 1) { return Integer(extents[0]); } - Array probs(n, FloatImm(DataType::Float(32), 1.0 / n)); + ffi::Array probs(n, FloatImm(DataType::Float(32), 1.0 / n)); return sch->SampleCategorical(extents, probs); }; } -Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, - int64_t max_threads_per_block, - std::function get_factor) { +ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, + int64_t max_threads_per_block, + std::function get_factor) { int64_t extent = -1; if (const int64_t* e = as_const_int(sch->Get(loop)->extent)) { extent = *e; @@ -67,15 +67,15 @@ Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblock get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); } ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); - Array splits = sch->Split(loop, {std::nullopt, factor}); + ffi::Array splits = sch->Split(loop, {std::nullopt, factor}); ICHECK_EQ(splits.size(), 2); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); return {splits[0], splits[1]}; } else { - Array splits = sch->Split(loop, {std::nullopt, - Integer(max_threadblocks), // - Integer(max_threads_per_block)}); + ffi::Array splits = sch->Split(loop, {std::nullopt, + Integer(max_threadblocks), // + Integer(max_threads_per_block)}); ICHECK_EQ(splits.size(), 3); sch->Reorder({splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); @@ -95,7 +95,7 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // if (tir::HasBeenMultiLevelTiled(block_sref)) { return; } - Array loops = tir::GetLoops(block_sref); + ffi::Array loops = tir::GetLoops(block_sref); int n = loops.size(); int i_block_idx = -1; int i_thread_idx = -1; @@ -141,11 +141,11 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; throw; } - LoopRV loop_rv{nullptr}; + LoopRV loop_rv{ffi::UnsafeInit()}; { - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); if (i_spatial_loop == -1) { - LoopRV spatial_loop_rv{nullptr}; + LoopRV spatial_loop_rv{ffi::UnsafeInit()}; if (loop_rvs.empty()) { spatial_loop_rv = sch->AddUnitLoop(block_rv); } else { @@ -165,7 +165,7 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // } if (i_block_idx == -1 && i_thread_idx != -1) { int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); sch->Bind(loop_rv, "blockIdx.x"); return; diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index ea7ee90e1408..74a70da58b36 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -29,22 +29,22 @@ namespace meta_schedule { using namespace tvm::tir; -static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, - std::vector tiled, std::vector unrolled) { +static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, + std::vector tiled, std::vector unrolled) { // This method is used for NHWC layout only. Will likely be refactored into a more schedule using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - Array factors{nullptr}; - Array loops = sch->GetLoops(block); + ffi::Array factors{ffi::UnsafeInit()}; + ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); - Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); + ffi::Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); ICHECK_EQ(t0.size(), 2); factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); - Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); + ffi::Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); ICHECK_EQ(t1.size(), 2); sch->Unroll(loops[unrolled[0]]); @@ -64,14 +64,14 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + ffi::Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); { BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true); @@ -84,7 +84,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; - Array loops = sch->GetLoops(data_pack); + ffi::Array loops = sch->GetLoops(data_pack); ICHECK_EQ(loops.size(), 8); BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, max_threads_per_block); @@ -92,26 +92,26 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cuda.conv2d_nhwc_winograd_inverse", - [](Schedule sch, BlockRV inverse) -> Array { + [](Schedule sch, BlockRV inverse) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, inverse); ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; - Array loops = sch->GetLoops(inverse); + ffi::Array loops = sch->GetLoops(inverse); ICHECK_EQ(loops.size(), 8); BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, max_threads_per_block); return {sch}; }) .def("meta_schedule.cuda.conv2d_nchw_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - LoopRV outer{nullptr}; + LoopRV outer{ffi::UnsafeInit()}; { - Array loops = sch->GetLoops(data_pack); + ffi::Array loops = sch->GetLoops(data_pack); ICHECK_EQ(loops.size(), 6); sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]}); sch->Unroll(loops[0]); @@ -134,25 +134,25 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cuda.conv2d_nchw_winograd_inverse", - [](Schedule sch, BlockRV inverse) -> Array { + [](Schedule sch, BlockRV inverse) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, inverse); // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] int64_t tile_size = Downcast(sch->Get(inverse)->writes[0]->buffer->shape[2])->value; - LoopRV outer{nullptr}; + LoopRV outer{ffi::UnsafeInit()}; { BlockRV output = sch->GetConsumers(inverse)[0]; - Array nchw = sch->GetLoops(output); + ffi::Array nchw = sch->GetLoops(output); ICHECK_EQ(nchw.size(), 4); - Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); - Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); + ffi::Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); + ffi::Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); outer = ws[0]; } { sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local"); - Array loops = sch->GetLoops(inverse); + ffi::Array loops = sch->GetLoops(inverse); ICHECK_EQ(loops.size(), 10); sch->Unroll(loops[6]); sch->Unroll(loops[7]); @@ -161,7 +161,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } return {sch}; }); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule/generic/winograd.cc b/src/meta_schedule/schedule/generic/winograd.cc index edb14667bcec..fe41e1e686f1 100644 --- a/src/meta_schedule/schedule/generic/winograd.cc +++ b/src/meta_schedule/schedule/generic/winograd.cc @@ -29,8 +29,8 @@ using namespace tvm::tir; * \return The only producer block. */ BlockRV GetWinogradProducerAndInlineConst(Schedule sch, BlockRV block) { - Array producers = sch->GetProducers(block); - Array results; + ffi::Array producers = sch->GetProducers(block); + ffi::Array results; for (const BlockRV& producer : producers) { if (sch->Get(producer)->reads.empty()) { sch->ComputeInline(producer); diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index c2f3a7208f64..fad3279eb792 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -36,11 +36,11 @@ class AddRFactorNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -64,14 +64,12 @@ class AddRFactorNode : public ScheduleRuleNode { .def_ro("max_jobs_per_core", &AddRFactorNode::max_jobs_per_core) .def_ro("max_innermost_factor", &AddRFactorNode::max_innermost_factor); } - - static constexpr const char* _type_key = "meta_schedule.AddRFactor"; - TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.AddRFactor", AddRFactorNode, ScheduleRuleNode); }; ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, - Optional max_innermost_factor) { - ObjectPtr n = make_object(); + ffi::Optional max_innermost_factor) { + ObjectPtr n = ffi::make_object(); n->max_jobs_per_core = max_jobs_per_core; n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; n->max_parallel_extent_ = -1; @@ -79,7 +77,8 @@ ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, return ScheduleRule(n); } -Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { +ffi::Array AddRFactorNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { tir::StmtSRef block_sref = sch->GetSRef(block_rv); if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, max_parallel_basic_)) { @@ -97,16 +96,18 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); // Split the fused reduction loop. - Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); - Array split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); + ffi::Array factors = + sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); + ffi::Array split_loops = + sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); - Array res; + ffi::Array res; for (const tir::LoopRV& split_loop : split_loops) { tir::Schedule sch_tmp = sch->Copy(); sch_tmp->Seed(sch->ForkSeed()); try { const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); - Array axes = sch_tmp->GetLoops(block_rf); + ffi::Array axes = sch_tmp->GetLoops(block_rf); ICHECK_GT(axes.size(), num_spatial_loops); // Annotate that the rfactor block, which is now the producer of the original block, needs to @@ -121,12 +122,12 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: return res; } -TVM_FFI_STATIC_INIT_BLOCK({ AddRFactorNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { AddRFactorNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleAddRFactor", ScheduleRule::AddRFactor); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 35752b8b73eb..927ce3656c2f 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -36,24 +36,25 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { CHECK(this->target_.defined()) << "ValueError: ApplyCustomRule is not initialized with TuneContext that has a Target."; - Array keys = this->target_.value()->keys; - if (Optional ann = tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { + ffi::Array keys = this->target_.value()->keys; + if (ffi::Optional ann = + tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { if (ann.value() != "None") { - for (const String& key : keys) { + for (const ffi::String& key : keys) { if (const auto custom_schedule_fn = tvm::ffi::Function::GetGlobal(GetCustomRuleName(ann.value(), key))) { - Array result = - (*custom_schedule_fn)(sch, block_rv).cast>(); + ffi::Array result = + (*custom_schedule_fn)(sch, block_rv).cast>(); return result; } } std::ostringstream os; os << "Unknown schedule rule \"" << ann.value() << "\" for target keys \"" << keys << "\". Checked ffi::Functions:"; - for (const String& key : keys) { + for (const ffi::String& key : keys) { os << "\n " << GetCustomRuleName(ann.value(), key); } LOG(WARNING) << os.str(); @@ -65,25 +66,24 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); n->target_ = target_; return ScheduleRule(n); } public: - Optional target_ = std::nullopt; + ffi::Optional target_ = std::nullopt; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("target_", &ApplyCustomRuleNode::target_); } - - static constexpr const char* _type_key = "meta_schedule.ApplyCustomRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(ApplyCustomRuleNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.ApplyCustomRule", ApplyCustomRuleNode, + ScheduleRuleNode); }; ScheduleRule ScheduleRule::ApplyCustomRule() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ScheduleRule(n); } @@ -91,12 +91,12 @@ bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) { return rule->IsInstance(); } -TVM_FFI_STATIC_INIT_BLOCK({ ApplyCustomRuleNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ApplyCustomRuleNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleApplyCustomRule", ScheduleRule::ApplyCustomRule); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 717ec0732575..1ab276c5bec7 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -32,7 +32,7 @@ class AutoBindNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { CHECK(context->target.defined()) << "ValueError: target is not defined"; - Optional max_threads_per_block = + ffi::Optional max_threads_per_block = context->target.value()->GetAttr("max_threads_per_block"); CHECK(max_threads_per_block.defined()) << "ValueError: missing attribute `max_threads_per_block` in the target"; @@ -40,11 +40,11 @@ class AutoBindNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -54,39 +54,38 @@ class AutoBindNode : public ScheduleRuleNode { /*! \brief The max number of threadblocks in the cuda device */ int64_t max_threadblocks_ = -1; /*! \brief thread_extents Candidates of thread axis extent. */ - Array thread_extents_; + ffi::Array thread_extents_; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.AutoBind"; - TVM_DECLARE_FINAL_OBJECT_INFO(AutoBindNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.AutoBind", AutoBindNode, ScheduleRuleNode); }; -Array AutoBindNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { +ffi::Array AutoBindNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = MakeFactorSampler(sch, this->thread_extents_); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); return {sch}; } -ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_extents, +ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, ffi::Array thread_extents, int max_threads_per_block) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_threadblocks_ = max_threadblocks; n->max_threads_per_block_ = max_threads_per_block; n->thread_extents_ = std::move(thread_extents); return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ AutoBindNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { AutoBindNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleAutoBind", ScheduleRule::AutoBind); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 7d0277880cf4..3d5fc8798c13 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -39,7 +39,7 @@ bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sr for (; sref->parent != nullptr; sref = sref->parent) { } ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance()); - return IsSpatialPrimFunc(GetRef(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr))); + return IsSpatialPrimFunc(ffi::GetRef(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr))); } /*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ @@ -52,7 +52,7 @@ class AutoInlineNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { InlineType inline_type = CheckInline(sch, block_rv); if (inline_type == InlineType::kInlineIntoConsumer) { sch->ComputeInline(block_rv); @@ -64,7 +64,7 @@ class AutoInlineNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -82,7 +82,7 @@ class AutoInlineNode : public ScheduleRuleNode { /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ bool require_ordered; /*! \brief The operators that are disallowed in auto inline */ - Array disallow_op; + ffi::Array disallow_op; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -95,9 +95,7 @@ class AutoInlineNode : public ScheduleRuleNode { .def_ro("require_ordered", &AutoInlineNode::require_ordered) .def_ro("disallow_op", &AutoInlineNode::disallow_op); } - - static constexpr const char* _type_key = "meta_schedule.AutoInline"; - TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.AutoInline", AutoInlineNode, ScheduleRuleNode); }; inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, @@ -114,7 +112,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, } // Cond 2. For a block that generates a constant tensor, ignore all other conditions if (inline_const_tensor && block->reads.empty()) { - Array consumer_srefs = GetConsumers(state, block_sref); + ffi::Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { return InlineType::kInlineIntoConsumer; } @@ -144,25 +142,26 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, } } // Cond 6. The block is disallowed for auto inline - if (Optional ann = - tir::GetAnn(block_sref, tir::attr::meta_schedule_inline_rule)) { + if (ffi::Optional ann = + tir::GetAnn(block_sref, tir::attr::meta_schedule_inline_rule)) { if (ann.value() == "disable") return InlineType::kNoInline; } // Last cond: Check inline into the consumers or the spatial producer tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, /*require_stage_pipeline=*/false); if (into_consumer) { - Array consumer_srefs = GetConsumers(state, block_sref); + ffi::Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { return InlineType::kInlineIntoConsumer; } } if (into_producer) { - Array producer_srefs = GetProducers(state, block_sref); + ffi::Array producer_srefs = GetProducers(state, block_sref); if (producer_srefs.size() == 1 && tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && CanReverseComputeInline(state, block_sref) && - !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).has_value()) { + !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize) + .has_value()) { return InlineType::kInlineIntoProducer; } } @@ -175,8 +174,8 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // bool disallow_if_then_else, // bool require_injective, // bool require_ordered, // - Optional> disallow_op) { - ObjectPtr n = make_object(); + ffi::Optional> disallow_op) { + ObjectPtr n = ffi::make_object(); n->into_producer = into_producer; n->into_consumer = into_consumer; n->inline_const_tensor = inline_const_tensor; @@ -185,28 +184,28 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // n->require_ordered = require_ordered; n->disallow_op.clear(); if (disallow_op.defined()) { - Array op_names = disallow_op.value(); + ffi::Array op_names = disallow_op.value(); n->disallow_op.reserve(op_names.size()); - for (const String& op_name : op_names) { + for (const ffi::String& op_name : op_names) { n->disallow_op.push_back(Op::Get(op_name)); } } return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ AutoInlineNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { AutoInlineNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleAutoInline", ScheduleRule::AutoInline); -}); +} /*! \brief Inline blocks that produce a constant scalar. */ class InlineConstantScalarsNode : public ScheduleRuleNode { public: void InitializeWithTuneContext(const TuneContext& context) final {} - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { // Look for a block of the form // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) { // reads([]) @@ -225,7 +224,7 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { } ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -233,22 +232,21 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars"; - TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.InlineConstantScalars", + InlineConstantScalarsNode, ScheduleRuleNode); }; ScheduleRule ScheduleRule::InlineConstantScalars() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ InlineConstantScalarsNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { InlineConstantScalarsNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleInlineConstantScalars", ScheduleRule::InlineConstantScalars); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index ddf603db27ab..17e9552dcb60 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -30,8 +30,9 @@ class CrossThreadReductionNode : public ScheduleRuleNode { ICHECK(context->target.defined()); Target target = context->target.value(); - Optional opt_max_threads_per_block = target->GetAttr("max_threads_per_block"); - Optional opt_warp_size = target->GetAttr("thread_warp_size"); + ffi::Optional opt_max_threads_per_block = + target->GetAttr("max_threads_per_block"); + ffi::Optional opt_warp_size = target->GetAttr("thread_warp_size"); if (!opt_max_threads_per_block.defined()) { TVM_PY_LOG(WARNING, context->logger) @@ -48,7 +49,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { // Step 0. Check the conditions of this rule. if (max_threads_per_block == -1 || warp_size == -1) { return {sch}; @@ -75,7 +76,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 3. Try block fusion. int n_candidate = static_cast(thread_extents.size()); - Array probs(n_candidate, FloatImm(DataType::Float(32), 1.0 / n_candidate)); + ffi::Array probs(n_candidate, FloatImm(DataType::Float(32), 1.0 / n_candidate)); tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { ICHECK(target_block.defined()); @@ -87,7 +88,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // the loop before binding. // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. if (!InThreadScope(tmp_sch, target_block)) { - const Array& split_res = + const ffi::Array& split_res = tmp_sch->Split(tgt_block_innermost_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); if (tgt_block_innermost_loop.same_as(target_loop)) { @@ -108,7 +109,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { tir::LoopRV fused_reduce_loop; ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. - const Array& split_res = + const ffi::Array& split_res = tmp_sch->Split(fused_reduce_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); @@ -117,7 +118,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -130,7 +131,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \return A boolean indicating whether the block is in thread scope. */ bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { - const Array& axes = sch->GetLoops(block); + const ffi::Array& axes = sch->GetLoops(block); for (const tir::LoopRV& loop_rv : axes) { const tir::For& loop = sch->Get(loop_rv); runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); @@ -170,9 +171,9 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \return The extent of "threadIdx.x" in the input schedule */ tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { - tir::ExprRV extent{nullptr}; + tir::ExprRV extent{ffi::UnsafeInit()}; for (const tir::Instruction& inst : trace->insts) { - if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { + if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { return extent; } @@ -197,18 +198,18 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing // a tuple reduction, fusion is temporarily not supported. if (sch->Get(block_rv)->writes.size() != 1) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 1. Get all the consumers of the input block. - Array consumers = sch->GetConsumers(block_rv); + ffi::Array consumers = sch->GetConsumers(block_rv); // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is // not fusible. if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 3. Calculate the lowest common ancestor of all the consumers. @@ -220,18 +221,18 @@ class CrossThreadReductionNode : public ScheduleRuleNode { const tir::StmtSRef& lca_sref = tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } // Step 4. Get the outer loops of the target block, and get the compute-at position index. - Array tgt_block_loops = sch->GetLoops(consumers[0]); + ffi::Array tgt_block_loops = sch->GetLoops(consumers[0]); int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); // Step 5. A negative position index means not fusible, and vice-versa. if (pos < 0) { - return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, - tir::LoopRV{nullptr}); + return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()}, tir::BlockRV{ffi::UnsafeInit()}, + tir::LoopRV{ffi::UnsafeInit()}); } else { return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); } @@ -248,8 +249,9 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \param lca_sref The lowest common ancestor of all the consumers of the input block * \return The compute-at position index of the input block */ - int GetComputePosition(const tir::Schedule& sch, const Array& block_loops, - const Array& tgt_block_loops, const tir::StmtSRef& lca_sref) { + int GetComputePosition(const tir::Schedule& sch, const ffi::Array& block_loops, + const ffi::Array& tgt_block_loops, + const tir::StmtSRef& lca_sref) { int n_block_loop = static_cast(block_loops.size()); int n_tgt_block_loop = static_cast(tgt_block_loops.size()); @@ -271,7 +273,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { /*! \brief The number of threads per warp */ int warp_size; /*! \brief Candidates of thread axis extent (values are required to be positive). */ - Array thread_extents; + ffi::Array thread_extents; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -280,27 +282,26 @@ class CrossThreadReductionNode : public ScheduleRuleNode { .def_ro("warp_size", &CrossThreadReductionNode::warp_size) .def_ro("thread_extents", &CrossThreadReductionNode::thread_extents); } - - static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction"; - TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.CrossThreadReduction", CrossThreadReductionNode, + ScheduleRuleNode); }; -ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { +ScheduleRule ScheduleRule::CrossThreadReduction(ffi::Array thread_extents) { for (const auto& extent : thread_extents) { CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->thread_extents = std::move(thread_extents); return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ CrossThreadReductionNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { CrossThreadReductionNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleCrossThreadReduction", ScheduleRule::CrossThreadReduction); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 6a7c6ade45c1..ea78c4f6e3d3 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -55,10 +55,10 @@ using tir::IterVarType; using tir::LoopRV; using tir::Schedule; -TVM_FFI_STATIC_INIT_BLOCK({ MultiLevelTilingNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MultiLevelTilingNode::RegisterReflection(); } -State::State(tir::Schedule sch, tir::BlockRV block_rv, Array> tiles) { - ObjectPtr node = make_object(); +State::State(tir::Schedule sch, tir::BlockRV block_rv, ffi::Array> tiles) { + ObjectPtr node = ffi::make_object(); node->sch = std::move(sch); node->block_rv = std::move(block_rv); node->tiles = std::move(tiles); @@ -66,22 +66,23 @@ State::State(tir::Schedule sch, tir::BlockRV block_rv, Array> } State StateNode::Copy() const { - ObjectPtr node = make_object(*this); + ObjectPtr node = ffi::make_object(*this); node->sch = sch->Copy(); return State(node); } // Do nothing; Inherited from ScheduleRuleNode void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) { - if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { + if (ffi::Optional v = + context->target.value()->GetAttr("max_threads_per_block")) { this->max_threads_per_block_ = v.value()->value; - if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + if (ffi::Optional v = context->target.value()->GetAttr("thread_warp_size")) { this->thread_warp_size_ = v.value()->value; } else { TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; } } - if (Optional opt_sm = context->target.value()->GetAttr("arch")) { + if (ffi::Optional opt_sm = context->target.value()->GetAttr("arch")) { std::string sm = opt_sm.value(); if (support::StartsWith(sm, "sm_")) { sm = sm.substr(3); @@ -102,12 +103,12 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) } // Entry of the mega rule; Inherited from ScheduleRuleNode -Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) { +ffi::Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) { if ((filter_fn_ && filter_fn_.value()(sch, sch->GetSRef(block_rv)).cast()) || NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); - Array results; + ffi::Array results; for (auto&& state : ApplySubRules({State(sch, block_rv)})) { results.push_back(std::move(state->sch)); } @@ -118,7 +119,7 @@ Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& // Inherited from ScheduleRuleNode ScheduleRule MultiLevelTilingNode::Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -138,7 +139,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { } std::vector levels = config.levels; ReuseType req = config.req; - if (Optional> ann = tir::GetAnn>( + if (ffi::Optional> ann = tir::GetAnn>( state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) { req = ReuseType::kMustReuse; levels.clear(); @@ -148,7 +149,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { std::vector results; if (req == ReuseType::kMayReuse) { // Case 1. If the write cache is already there, we don't need to add another. - Array consumer_rvs = state->sch->GetConsumers(state->block_rv); + ffi::Array consumer_rvs = state->sch->GetConsumers(state->block_rv); if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) { for (int level : levels) { State new_state = state->Copy(); @@ -180,14 +181,14 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } -std::pair, Array> MultiLevelTilingNode::SplitLoop( +std::pair, ffi::Array> MultiLevelTilingNode::SplitLoop( const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const { - Array factors = sch->SamplePerfectTile( + ffi::Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, /*max_innermost_factor=*/max_innermost_factor); - Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } @@ -196,7 +197,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types - Array loops = sch->GetLoops(block_rv); + ffi::Array loops = sch->GetLoops(block_rv); std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it @@ -210,10 +211,10 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num; int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num; - Array skipped_outer_spatial_loops; - std::vector> tiles(s_indices_.size() + r_indices_.size()); + ffi::Array skipped_outer_spatial_loops; + std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); - std::vector> tile_factors; + std::vector> tile_factors; tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; @@ -268,7 +269,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, sch->Bind(fused, tile_binds[i]); tiles[i] = {fused}; } - state->tiles = Array>{tiles.begin(), tiles.end()}; + state->tiles = ffi::Array>{tiles.begin(), tiles.end()}; if (this->thread_warp_size_ != -1) { int64_t low_inclusive = 1; int64_t high_inclusive = this->max_threads_per_block_; @@ -308,9 +309,9 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { // Insert cache_read block to the proper place sch->ComputeAt(cache_read_block, loop_rv, true); // Fuse the iterators of the cache_read - Array buffer_loops = sch->GetLoops(cache_read_block); - sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // - buffer_loops.end()}); + ffi::Array buffer_loops = sch->GetLoops(cache_read_block); + sch->Fuse(ffi::Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); AnnotateCooperativeFetching(&sch, cache_read_block); new_state->read_reuse.emplace(i, cache_read_block); } @@ -330,7 +331,7 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { // therefore it matches the notation array size in the following code tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back()); const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref); - Array seq = Downcast(r_for_loop->body)->seq; + ffi::Array seq = Downcast(r_for_loop->body)->seq; if (seq.size() != 3) { return {state}; } @@ -346,11 +347,11 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { State new_state = state->Copy(); LoopRV r_loop_fused = new_state->sch->Fuse(new_state->tiles[r_indices_[0]]); new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage, - Array{0, 0, stage - 2}); + ffi::Array{0, 0, stage - 2}); new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order, - Array{0, 1, 2}); + ffi::Array{0, 1, 2}); new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages, - Array{0}); + ffi::Array{0}); ret.push_back(std::move(new_state)); } return ret; @@ -386,30 +387,31 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, double prob = 1.0 / n; tir::ExprRV vector_load_len = (*sch)->SampleCategorical(support::AsArray(valid_vector_lens), - Array(n, FloatImm(DataType::Float(32), prob))); + ffi::Array(n, FloatImm(DataType::Float(32), prob))); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); } } // Constructor -ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, - Optional max_innermost_factor, - Optional> vector_load_lens, - Optional> reuse_read, - Optional> reuse_write, - Optional filter_fn) { +ScheduleRule ScheduleRule::MultiLevelTiling( + ffi::String structure, ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write, + ffi::Optional filter_fn) { auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); node->filter_fn_ = filter_fn; return ScheduleRule(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleMultiLevelTiling", ScheduleRule::MultiLevelTiling); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 2b03d749f2b5..028d1aecbf45 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -64,7 +64,7 @@ enum class ReuseType : int32_t { * \param str The string to be converted. * \return The converted ReuseType. */ -inline ReuseType Str2ReuseType(const String& str) { +inline ReuseType Str2ReuseType(const ffi::String& str) { if (str == "no") { return ReuseType::kNoReuse; } else if (str == "may") { @@ -84,16 +84,16 @@ struct ReuseConfig { /*! \brief Which levels are caching stage inserted at */ std::vector levels; /*! \brief The storage scope */ - String scope; + ffi::String scope; /*! \brief Default constructor: no data reuse */ ReuseConfig() : req(ReuseType::kNoReuse) {} /*! \brief Construct from a configuration dictionary */ - explicit ReuseConfig(const Map& config) - : req(Str2ReuseType(Downcast(config.at("req")))), - levels(support::AsVector(Downcast>(config.at("levels")))), - scope(Downcast(config.at("scope"))) { + explicit ReuseConfig(const ffi::Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { ICHECK_EQ(config.size(), 3); } }; @@ -109,9 +109,9 @@ class StateNode : public Object { /*! \brief The block to be tiled */ tir::BlockRV block_rv; /*! \brief The loop tiles */ - Array> tiles; + ffi::Array> tiles; /*! \brief The factors of the loop tiles. */ - Array> tile_factors; + ffi::Array> tile_factors; /*! \brief The mapping from buffer index to read cache block. */ std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ @@ -123,16 +123,17 @@ class StateNode : public Object { */ virtual State Copy() const; - static constexpr const char* _type_key = "meta_schedule.State"; - TVM_DECLARE_BASE_OBJECT_INFO(StateNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.State", StateNode, Object); }; /*! \brief Managed reference to StateNode */ class State : public ObjectRef { public: /*! \brief Default constructor */ - explicit State(tir::Schedule sch, tir::BlockRV block_rv, Array> tiles = {}); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); + explicit State(tir::Schedule sch, tir::BlockRV block_rv, + ffi::Array> tiles = {}); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(State, ObjectRef, StateNode); }; /*! @@ -173,7 +174,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final; // Entry of the mega rule; Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; // Inherited from ScheduleRuleNode ScheduleRule Clone() const override; @@ -181,10 +182,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual std::pair, Array> SplitLoop(const tir::Schedule& sch, - tir::BlockRV block, - tir::LoopRV loop, - int n_tiles) const; + virtual std::pair, ffi::Array> SplitLoop( + const tir::Schedule& sch, tir::BlockRV block, tir::LoopRV loop, int n_tiles) const; // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; @@ -195,9 +194,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode { * - 'SSRSRS' on CPU * - 'SSSRRSRS' on GPU */ - String structure; + ffi::String structure; /*! \brief For each level of tiles, which thread axis it is bound to */ - Array tile_binds; + ffi::Array tile_binds; /*! \brief The maximum size of the innermost factor */ int max_innermost_factor; /*! \brief The length of vector lane in vectorized cooperative fetching */ @@ -219,7 +218,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { /*! \brief The logging function */ ffi::Function logger; /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */ - Optional filter_fn_; + ffi::Optional filter_fn_; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -228,18 +227,18 @@ class MultiLevelTilingNode : public ScheduleRuleNode { .def_ro("tile_binds", &MultiLevelTilingNode::tile_binds) .def_ro("max_innermost_factor", &MultiLevelTilingNode::max_innermost_factor); } - - static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; - TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO("meta_schedule.MultiLevelTiling", MultiLevelTilingNode, + ScheduleRuleNode); }; template -ObjectPtr MultiLevelTilingInitCommon(String structure, Optional> tile_binds, - Optional max_innermost_factor, - Optional> vector_load_lens, - Optional> reuse_read, - Optional> reuse_write) { - ObjectPtr n = make_object(); +ObjectPtr MultiLevelTilingInitCommon( + ffi::String structure, ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write) { + ObjectPtr n = ffi::make_object(); n->structure = structure; n->tile_binds = tile_binds.value_or({}); n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 22f9699c9180..c58e81dc3343 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -36,11 +36,11 @@ using tir::LoopRV; using tir::Schedule; struct TensorCoreIntrinGroup { - String init_intrin; - String load_a_intrin; - String load_b_intrin; - String compute_intrin; - String store_intrin; + ffi::String init_intrin; + ffi::String load_a_intrin; + ffi::String load_b_intrin; + ffi::String compute_intrin; + ffi::String store_intrin; /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the * following keys: @@ -52,11 +52,12 @@ struct TensorCoreIntrinGroup { * The values of the keys should be the names of the corresponding intrinsics and should be * registered via TensorIntrin.Register beforehand. */ - static TensorCoreIntrinGroup FromConfig(const Map& config); + static TensorCoreIntrinGroup FromConfig(const ffi::Map& config); }; -TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map& config) { - auto f_initialize_intrin = [&config](String key_name, String* intrin_name) { +TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig( + const ffi::Map& config) { + auto f_initialize_intrin = [&config](ffi::String key_name, ffi::String* intrin_name) { CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set."; *intrin_name = config.at(key_name); // Check the existence of the intrin @@ -76,7 +77,7 @@ class TensorCoreStateNode : public StateNode { /*! \brief The tensor core intrinsic group. */ TensorCoreIntrinGroup intrin_group; /*! \brief The auto tensorization maping info. */ - tir::AutoTensorizeMappingInfo mapping_info{nullptr}; + tir::AutoTensorizeMappingInfo mapping_info{ffi::UnsafeInit()}; /*! \brief The Tensor Core reindex block A for Tensor Core computation */ tir::BlockRV tensor_core_reindex_A; /*! \brief The Tensor Core reindex block B for Tensor Core computation */ @@ -90,23 +91,30 @@ class TensorCoreStateNode : public StateNode { State Copy() const final; - static constexpr const char* _type_key = "meta_schedule.TensorCoreState"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.TensorCoreState", TensorCoreStateNode, + StateNode); }; class TensorCoreState : public State { public: explicit TensorCoreState(TensorCoreIntrinGroup intrin_group, tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, - BlockRV block_rv, bool use_async, Array> tiles = {}); + BlockRV block_rv, bool use_async, + ffi::Array> tiles = {}); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TensorCoreState, State, TensorCoreStateNode); }; TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, - BlockRV block_rv, bool use_async, Array> tiles) { - ObjectPtr node = make_object(); + BlockRV block_rv, bool use_async, + ffi::Array> tiles) { + ObjectPtr node = ffi::make_object(); node->intrin_group = intrin_group; node->mapping_info = mapping_info; node->sch = std::move(sch); @@ -118,7 +126,7 @@ TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, } State TensorCoreStateNode::Copy() const { - ObjectPtr node = make_object(*this); + ObjectPtr node = ffi::make_object(*this); node->sch = sch->Copy(); return State(node); } @@ -145,11 +153,9 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { // Subrule: Add software pipeline inline std::vector AddSoftwarePipeline(TensorCoreState state) const; // Subrule: split loop for mma using sample partitioned tile - inline std::pair, Array> MMASplitLoop(const Schedule& sch, - BlockRV block, LoopRV loop, - int n_tiles, - int partition_pos, - int innerpart_factor) const; + inline std::pair, ffi::Array> MMASplitLoop( + const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles, int partition_pos, + int innerpart_factor) const; // Subrule: tile loop nest for mma // Basically same with MultiLevelTilingNode::TileLoopNest, but change SamplePerfectTile to // SamplePartitionedTile @@ -159,12 +165,12 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { std::vector ApplySubRules(std::vector states) final; // Override Apply to apply tensorization-specific analysis before applying sub-rules - Array Apply(const Schedule& sch, const BlockRV& block_rv) final; + ffi::Array Apply(const Schedule& sch, const BlockRV& block_rv) final; // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } @@ -174,31 +180,38 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { * \param intrin_name The name of the tensor intrin * \return The loop to be tensorized. std::nullopt if the workload can't be tensorized. */ - Optional TransformWithTensorIntrin(TensorCoreStateNode* state, - const String& intrin_name) const; + ffi::Optional TransformWithTensorIntrin(TensorCoreStateNode* state, + const ffi::String& intrin_name) const; /*! * \brief Tile, blockize and annotate for tensorization with the given intrin * \param block_rv The block to be tensorized * \param intrin_name The name of the tensor intrin */ - void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, const String& intrin_name, - const String& permuted_layout_annotate_value) const; + void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, + const ffi::String& intrin_name, + const ffi::String& permuted_layout_annotate_value) const; public: /*! \brief The candidate tensor core intrin groups to apply */ std::vector intrin_groups; /*! \brief Whether to use software pipeline */ bool use_software_pipeline = false; - static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore"; - TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MultiLevelTilingTensorCore", + MultiLevelTilingTensorCoreNode, MultiLevelTilingNode); private: }; // Entry of the mega rule; Inherited from ScheduleRuleNode -Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, - const BlockRV& block_rv) { +ffi::Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, + const BlockRV& block_rv) { if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { return {sch}; } @@ -206,7 +219,7 @@ Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, std::unordered_map intrin_group_to_mapping_info; for (int i = 0, n = intrin_groups.size(); i < n; ++i) { TensorCoreIntrinGroup intrin_group = intrin_groups[i]; - Optional mapping_info = tir::GetAutoTensorizeMappingInfo( + ffi::Optional mapping_info = tir::GetAutoTensorizeMappingInfo( sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); if (mapping_info.defined()) { @@ -231,7 +244,7 @@ Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv, true)); } - Array results; + ffi::Array results; for (auto&& state : ApplySubRules(initial_states)) { TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with " << state.as()->intrin_group.compute_intrin; @@ -273,9 +286,9 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); + Schedule* sch, const BlockRV& block_rv, const ffi::String& intrin_name, + const ffi::String& permuted_layout_annotate_value) const { + ffi::Optional loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); ICHECK(loop.defined()); BlockRV blockized_outer = (*sch)->Blockize(loop.value()); (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); @@ -308,8 +321,9 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta BlockRV cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope); new_state->read_reuse.emplace(i, cache_read_block); if (state->is_mma) { - new_state->sch->Annotate(cache_read_block, "permuted_layout", - String(std::string("g2s_") + std::string(i == 0 ? "A" : "B"))); + new_state->sch->Annotate( + cache_read_block, "permuted_layout", + ffi::String(std::string("g2s_") + std::string(i == 0 ? "A" : "B"))); } } results.push_back(std::move(new_state)); @@ -317,16 +331,17 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta return results; } -std::pair, Array> MultiLevelTilingTensorCoreNode::MMASplitLoop( - const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles, int partition_pos, - int innerpart_factor) const { - Array factors = sch->SamplePartitionedTile( +std::pair, ffi::Array> +MultiLevelTilingTensorCoreNode::MMASplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, + int n_tiles, int partition_pos, + int innerpart_factor) const { + ffi::Array factors = sch->SamplePartitionedTile( /*loop=*/loop, /*n=*/n_tiles, /*partition_pos=*/partition_pos, /*innerpart_factor=*/innerpart_factor); - Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } @@ -334,7 +349,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types - Array loops = sch->GetLoops(block_rv); + ffi::Array loops = sch->GetLoops(block_rv); if (!(loops.size() == 3 || !state->is_mma)) { LOG(DEBUG) << "The MMA tensor core only supports SSR loops now"; return {}; @@ -343,9 +358,9 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; - std::vector> tiles(s_indices_.size() + r_indices_.size()); + std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); - std::vector> tile_factors; + std::vector> tile_factors; tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; @@ -397,7 +412,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta sch->Bind(fused, tile_binds[i]); tiles[i] = {fused}; } - state->tiles = Array>{tiles.begin(), tiles.end()}; + state->tiles = ffi::Array>{tiles.begin(), tiles.end()}; if (this->thread_warp_size_ != -1) { int64_t low_inclusive = 1; int64_t high_inclusive = this->max_threads_per_block_; @@ -445,7 +460,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id. // `loop_idx` can be negative, in which case it is counted from the end. auto f_get_inner_tile_product = [&](int loop_idx) { - Array factors; + ffi::Array factors; for (int i = tile_index_warp_id + 1; i < static_cast(s_indices_.size()); ++i) { auto s_factors = state->tile_factors[s_indices_[i]]; if (loop_idx < 0) { @@ -479,8 +494,8 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // frag_shape_m and frag_shape_n are structural bindings that cannot // not be automatically captured until c++20 [&, frag_shape_m = frag_shape_m, - frag_shape_n = frag_shape_n](const Array& indices) { - Array result; + frag_shape_n = frag_shape_n](const ffi::Array& indices) { + ffi::Array result; result.reserve(indices.size() + 4); for (int i = 0; i < num_higher_dims; ++i) { result.push_back(indices[i]); @@ -547,7 +562,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( // Get the loops other than the innermost two loops (accum_m and accum_n). auto f_get_loops = [&](const BlockRV& block_rv) -> std::array { - Array buffer_loops = sch->GetLoops(block_rv); + ffi::Array buffer_loops = sch->GetLoops(block_rv); ICHECK_GT(buffer_loops.size(), 6); return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; @@ -571,24 +586,24 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize, state->intrin_group.store_intrin); - Array buffer_loops = sch->GetLoops(state->write_reuse[0]); + ffi::Array buffer_loops = sch->GetLoops(state->write_reuse[0]); ICHECK_GT(buffer_loops.size(), 5); - sch->Fuse(Array{buffer_loops.end() - 5, // The src shmem is always 2D - buffer_loops.end()}); + sch->Fuse(ffi::Array{buffer_loops.end() - 5, // The src shmem is always 2D + buffer_loops.end()}); AnnotateCooperativeFetching(&sch, state->write_reuse[0]); return {state}; } std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( TensorCoreState state) const { - const Array& r_tiles = state->tiles[r_indices_[1]]; + const ffi::Array& r_tiles = state->tiles[r_indices_[1]]; Schedule& sch = state->sch; ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block"; - auto f_tensorize_load = [&](int read_index, String scope, String intrin_name) { + auto f_tensorize_load = [&](int read_index, ffi::String scope, ffi::String intrin_name) { auto cache_read = sch->CacheRead(state->block_rv, read_index, scope); state->sch->ComputeAt(cache_read, r_tiles.back(), true); - String permuted_layout_annotate_value = + ffi::String permuted_layout_annotate_value = state->is_mma ? std::string("s2l_") + std::string(read_index == 0 ? "A" : "B") : ""; TileAndAnnotateTensorize(&sch, cache_read, intrin_name, permuted_layout_annotate_value); }; @@ -603,7 +618,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( sch->ComputeInline(sch->GetProducers(cache_read)[0]); const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer( - sch->state(), GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); + sch->state(), ffi::GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); const DataType& dtype = cache_read_buffer->dtype; if (dtype.is_float16()) { sch->StorageAlign(cache_read, 0, -2, 32, 8); @@ -631,7 +646,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( // Check reduction length after blockize. int64_t reduction_length = 1; for (int r_index : r_indices_) { - const Array& tiles = state->tiles[r_index]; + const ffi::Array& tiles = state->tiles[r_index]; for (const LoopRV& tile : tiles) { const auto* extent = sch->Get(tile)->extent.as(); ICHECK(extent != nullptr) << "Dynamic extent is not supported."; @@ -686,16 +701,16 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( // compute matmul with fragment K1 - 1 // sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_stage, - Array{0, 0, 1}); + ffi::Array{0, 0, 1}); sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_order, - Array{0, 1, 2}); + ffi::Array{0, 1, 2}); if (state->is_mma && state->use_async) { sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_async_stages, - Array{0}); + ffi::Array{0}); sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage, - Array{0, 0, 1, 2, 2}); + ffi::Array{0, 0, 1, 2, 2}); sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order, - Array{0, 1, 3, 2, 4}); + ffi::Array{0, 1, 3, 2, 4}); } else { // Outer software pipeline: Interleave the outer loop with the (pipelined) inner loop. // The prefetching stage of the inner pipeline is executed by one iteration in the outer loop. @@ -738,16 +753,16 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( // compute matmul with fragment K1 - 1 of tile K0 - 1 // sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage, - Array{0, 0, 0, 0, 0, 1, 1}); + ffi::Array{0, 0, 0, 0, 0, 1, 1}); sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order, - Array{0, 3, 1, 4, 5, 2, 6}); + ffi::Array{0, 3, 1, 4, 5, 2, 6}); } return {state}; } -Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( - TensorCoreStateNode* state, const String& intrin_name) const { +ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( + TensorCoreStateNode* state, const ffi::String& intrin_name) const { BlockRV block_rv = state->block_rv; const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); @@ -755,7 +770,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // Add reindex stages const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); // Hold the reference of the block before reindex - const tir::Block block_before_reindex = GetRef(block); + const tir::Block block_before_reindex = ffi::GetRef(block); if (block->reads.size() != 2 || block->writes.size() != 1) { // only matmul-like computation is allowed return std::nullopt; @@ -792,7 +807,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( for (int i = 0; i < offset; ++i) { const tir::VarNode* var_ptr = index_map->final_indices[i].as(); ICHECK(var_ptr != nullptr); - unmapped_index_map_src.insert(GetRef(var_ptr)); + unmapped_index_map_src.insert(ffi::GetRef(var_ptr)); } for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; @@ -806,7 +821,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( ICHECK(tir::is_one(range->extent)); const tir::VarNode* var_ptr = range->min.as(); ICHECK(var_ptr != nullptr); - const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef(var_ptr)]; + const tir::Var& lhs_representer = lhs_to_index_map_src[ffi::GetRef(var_ptr)]; sub_index_map_src.push_back(lhs_representer); if (unmapped_index_map_src.count(lhs_representer)) { sub_index_map_tgt.push_back(lhs_representer); @@ -815,15 +830,15 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); ICHECK(var != nullptr); - sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef(var)]); + sub_index_map_tgt.push_back(rhs_to_index_map_tgt[ffi::GetRef(var)]); } return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); }; std::unordered_set visited_buffers; - Map buffer_sub_index_map; // cache of the sub index map associated - // with each buffer + ffi::Map buffer_sub_index_map; // cache of the sub index map + // associated with each buffer auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) { const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer( @@ -835,7 +850,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // Refresh block pointer (block sref is not invalidated) block = TVM_SREF_TO_BLOCK(block_sref); const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( - state->sch->state(), GetRef(block), buffer_index, index_type); + state->sch->state(), ffi::GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, @@ -868,7 +883,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorization( TensorCoreState state) const { // Do reindex and layout transformations. - Optional transformed_loop_rv = + ffi::Optional transformed_loop_rv = TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin); if (!transformed_loop_rv.defined()) { // The workload can't be tensorized. @@ -888,12 +903,13 @@ inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorizat } ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( - Array> intrin_groups, String structure, Optional> tile_binds, - Optional max_innermost_factor, Optional> vector_load_lens, - Optional> reuse_read, Optional> reuse_write, - bool use_software_pipeline) { + ffi::Array> intrin_groups, ffi::String structure, + ffi::Optional> tile_binds, ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write, bool use_software_pipeline) { if (tile_binds.defined()) { - for (const String& tile_bind : tile_binds.value()) { + for (const ffi::String& tile_bind : tile_binds.value()) { CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core."; } } @@ -921,11 +937,13 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( return ScheduleRule(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + MultiLevelTilingTensorCoreNode::RegisterReflection(); + TensorCoreStateNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore", ScheduleRule::MultiLevelTilingTensorCore); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index a560248ee2b2..080e1c9c0fbf 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -40,22 +40,29 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { public: size_t vector_length_in_bits; - static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector"; - TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MultiLevelTilingWideVector", + MultiLevelTilingWideVectorNode, MultiLevelTilingNode); protected: ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } - std::pair, Array> SplitLoop(const Schedule& sch, BlockRV block, - LoopRV loop, int n_tiles) const; + std::pair, ffi::Array> SplitLoop(const Schedule& sch, + BlockRV block, LoopRV loop, + int n_tiles) const; }; -std::pair, Array> MultiLevelTilingWideVectorNode::SplitLoop( - const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const { +std::pair, ffi::Array> +MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, + int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv); const tir::BlockNode* block_node = block_sref->StmtAs(); @@ -93,43 +100,45 @@ std::pair, Array> MultiLevelTilingWideVectorNode // We split the innermost spatial loop in a way that always uses the maximum vector length. const int64_t* extent_int = tir::GetLoopIntExtent(loop); if (extent_int && *extent_int > vec_len) { - Array inner_splits = sch->Split(/*loop=*/loop_rv, - /*factors=*/{std::nullopt, PrimExpr(vec_len)}); - Array outer_factors = sch->SamplePerfectTile( + ffi::Array inner_splits = + sch->Split(/*loop=*/loop_rv, + /*factors=*/{std::nullopt, PrimExpr(vec_len)}); + ffi::Array outer_factors = sch->SamplePerfectTile( /*loop=*/inner_splits[0], /*n=*/n_tiles - 1, /*max_innermost_factor=*/max_innermost_factor); - Array outer_splits = sch->Split( + ffi::Array outer_splits = sch->Split( /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); outer_splits.push_back(inner_splits[1]); outer_factors.push_back(PrimExpr(vec_len)); return {outer_factors, outer_splits}; } else { - Array factors(n_tiles - 1, PrimExpr(1)); + ffi::Array factors(n_tiles - 1, PrimExpr(1)); factors.push_back(loop->extent); - Array splits = sch->Split(/*loop=*/loop_rv, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } } } -ScheduleRule ScheduleRule::MultiLevelTilingWideVector(String structure, - Integer vector_length_in_bits, - Optional max_innermost_factor, - Optional> reuse_read, - Optional> reuse_write) { +ScheduleRule ScheduleRule::MultiLevelTilingWideVector( + ffi::String structure, Integer vector_length_in_bits, + ffi::Optional max_innermost_factor, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write) { auto node = MultiLevelTilingInitCommon( structure, std::nullopt, max_innermost_factor, std::nullopt, reuse_read, reuse_write); node->vector_length_in_bits = vector_length_in_bits->value; return ScheduleRule(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + MultiLevelTilingWideVectorNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.ScheduleRuleMultiLevelTilingWideVector", ScheduleRule::MultiLevelTilingWideVector); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 85c9243e6bb1..4a375689e493 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -31,15 +31,15 @@ namespace meta_schedule { * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate * the tiled block for tensorization by postproc rewrite. */ -Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, - const std::string& intrin_name) { - Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); +ffi::Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, + const std::string& intrin_name) { + ffi::Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); if (!tiled_loop_rv) { return std::nullopt; } ICHECK(tiled_loop_rv.defined()); tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); - sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); + sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, ffi::String(intrin_name)); return outer_block; } @@ -48,7 +48,7 @@ Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, */ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { protected: - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; @@ -68,7 +68,7 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } @@ -86,19 +86,23 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { } public: - /*! \brief The name of a tensor intrinsic. */ - String intrin_name; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } - static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin"; - TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); + /*! \brief The name of a tensor intrinsic. */ + ffi::String intrin_name; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.MultiLevelTilingWithIntrin", + MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); }; -ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(String intrin_name, String structure, - Optional> tile_binds, - Optional max_innermost_factor, - Optional> vector_load_lens, - Optional> reuse_read, - Optional> reuse_write) { +ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( + ffi::String intrin_name, ffi::String structure, + ffi::Optional> tile_binds, ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write) { ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) << "Provided tensor intrinsic " << intrin_name << " is not registered."; auto node = MultiLevelTilingInitCommon( @@ -107,11 +111,12 @@ ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(String intrin_name, String return ScheduleRule(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + MultiLevelTilingWithIntrinNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin", ScheduleRule::MultiLevelTilingWithIntrin); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 28929d933762..9216c70e3328 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -30,7 +30,7 @@ bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) { bool CheckSpatialPrimFunc(const Schedule& sch, const BlockRV& root_block_rv) { return IsSpatialPrimFunc( - GetRef(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); + ffi::GetRef(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); } } // namespace tir @@ -51,7 +51,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { // Currently only mark the root block with annotations. if (!tir::IsRootBlock(sch, root_rv)) { return {sch}; @@ -70,7 +70,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; - Array probs(n, FloatImm(DataType::Float(32), prob)); + ffi::Array probs(n, FloatImm(DataType::Float(32), prob)); PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); if (unroll_explicit) { sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); @@ -84,7 +84,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } @@ -104,7 +104,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { * \brief The options of the maximum number of unroll steps to be done. * Use an empty array to disable unroll. */ - Array unroll_max_steps; + ffi::Array unroll_max_steps; /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */ bool unroll_explicit; /*! \brief The number of maximum available jobs in CPU. */ @@ -118,16 +118,15 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { .def_ro("unroll_max_steps", &ParallelizeVectorizeUnrollNode::unroll_max_steps) .def_ro("unroll_explicit", &ParallelizeVectorizeUnrollNode::unroll_explicit); } - - static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll"; - TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.ParallelizeVectorizeUnroll", + ParallelizeVectorizeUnrollNode, ScheduleRuleNode); }; ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, - Array unroll_max_steps, + ffi::Array unroll_max_steps, bool unroll_explicit) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_jobs_per_core = max_jobs_per_core; n->max_vectorize_extent = max_vectorize_extent; n->unroll_max_steps = unroll_max_steps; @@ -136,13 +135,13 @@ ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ ParallelizeVectorizeUnrollNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ParallelizeVectorizeUnrollNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll", ScheduleRule::ParallelizeVectorizeUnroll); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index a2bfa2644b1e..2c9975fcf916 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -29,7 +29,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { if (!CheckConditions(sch, block_rv)) { return {sch}; } @@ -40,7 +40,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer // access the input block. Hence we collect its producer ahead of time. // - Note that only single producer is allowed in this case. - Array producers{nullptr}; + ffi::Array producers{nullptr}; if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, true)) { producers = sch->GetProducers(block_rv); @@ -61,7 +61,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -82,7 +82,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { } // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child // block. - Array loop_srefs = tir::GetLoops(block_sref); + ffi::Array loop_srefs = tir::GetLoops(block_sref); if (loop_srefs.empty()) { return false; } @@ -117,21 +117,20 @@ class RandomComputeLocationNode : public ScheduleRuleNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation"; - TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RandomComputeLocation", + RandomComputeLocationNode, ScheduleRuleNode); }; ScheduleRule ScheduleRule::RandomComputeLocation() { - return ScheduleRule(make_object()); + return ScheduleRule(ffi::make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ RandomComputeLocationNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RandomComputeLocationNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleRuleRandomComputeLocation", ScheduleRule::RandomComputeLocation); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 9570c0d0f904..9eac4ad57b20 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "../utils.h" @@ -29,8 +30,8 @@ void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) { f_initialize_with_tune_context(context); } -Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, - const tir::BlockRV& block) { +ffi::Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block) { ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; return f_apply(sch, block); } @@ -45,7 +46,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( PyScheduleRuleNode::FApply f_apply, // PyScheduleRuleNode::FClone f_clone, // PyScheduleRuleNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); n->f_clone = std::move(f_clone); @@ -53,7 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( return ScheduleRule(n); } -Array ScheduleRule::DefaultLLVM() { +ffi::Array ScheduleRule::DefaultLLVM() { return { ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), @@ -64,7 +65,7 @@ Array ScheduleRule::DefaultLLVM() { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -75,21 +76,21 @@ Array ScheduleRule::DefaultLLVM() { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; } -Array ScheduleRule::DefaultX86(const String& type) { - static const Map intrins = {{"vnni", "dot_16x4_vnni"}, - {"avx512", "dot_16x4_avx512"}}; +ffi::Array ScheduleRule::DefaultX86(const ffi::String& type) { + static const ffi::Map intrins = {{"vnni", "dot_16x4_vnni"}, + {"avx512", "dot_16x4_avx512"}}; return { ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), @@ -100,7 +101,7 @@ Array ScheduleRule::DefaultX86(const String& type) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -112,9 +113,9 @@ Array ScheduleRule::DefaultX86(const String& type) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, @@ -122,34 +123,34 @@ Array ScheduleRule::DefaultX86(const String& type) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; } -Array ScheduleRule::DefaultCUDA() { +ffi::Array ScheduleRule::DefaultCUDA() { return { ScheduleRule::ApplyCustomRule(), ScheduleRule::MultiLevelTiling( /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, + /*tile_binds=*/ffi::Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, /*max_innermost_factor=*/Integer(64), - /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*vector_load_lens=*/ffi::Array{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ - Map{{"req", String("must")}, - {"levels", Array{4}}, // - {"scope", String("shared")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{4}}, // + {"scope", ffi::String("shared")}}, /*reuse_write=*/ - Map{{"req", String("must")}, - {"levels", Array{3}}, // - {"scope", String("local")}}), + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{3}}, // + {"scope", ffi::String("local")}}), ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/true, @@ -158,22 +159,22 @@ Array ScheduleRule::DefaultCUDA() { /*disallow_if_then_else=*/false, /*require_injective=*/false, /*require_ordered=*/false, - /*disallow_op=*/Array{}), + /*disallow_op=*/ffi::Array{}), ScheduleRule::CrossThreadReduction( - /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), + /*thread_extents=*/ffi::Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/-1, /*max_vectorize_extent=*/-1, - /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512, 1024}, /*unroll_explicit=*/true), ScheduleRule::AutoBind( /*max_threadblocks=*/256, - /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}), + /*thread_extents*/ ffi::Array{32, 64, 128, 256, 512, 1024}), }; } -Array ScheduleRule::DefaultCUDATensorCore() { - Array> wmma_intrin_groups = { +ffi::Array ScheduleRule::DefaultCUDATensorCore() { + ffi::Array> wmma_intrin_groups = { // Tensor Cores f32 += f16 * f16 { {"init", "wmma_fill_16x16x16_f32"}, @@ -220,7 +221,7 @@ Array ScheduleRule::DefaultCUDATensorCore() { {"store", "wmma_store_16x16x16_s32_shared_dyn"}, }, }; - Array> mma_intrin_groups = { + ffi::Array> mma_intrin_groups = { // Tensor Core MMA { {"init", "mma_init_m16n8k8_f16"}, @@ -237,45 +238,45 @@ Array ScheduleRule::DefaultCUDATensorCore() { {"store", "mma_store_m16n8k8_f32_global"}, }, }; - Array results{ + ffi::Array results{ ScheduleRule::ApplyCustomRule(), ScheduleRule::MultiLevelTilingTensorCore( /*intrin_groups=*/wmma_intrin_groups, /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, + /*tile_binds=*/ffi::Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, /*max_innermost_factor=*/Integer(4), - /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*vector_load_lens=*/ffi::Array{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ - Map{{"req", String("must")}, - {"levels", Array{4}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{4}}, // + {"scope", ffi::String("shared.dyn")}}, /*reuse_write=*/ - Map{{"req", String("must")}, - {"levels", Array{2}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{2}}, // + {"scope", ffi::String("shared.dyn")}}, /*use_software_pipeline=*/false), // ScheduleRule::MultiLevelTilingTensorCore( /*intrin_groups=*/mma_intrin_groups, /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, + /*tile_binds=*/ffi::Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, /*max_innermost_factor=*/Integer(4), - /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*vector_load_lens=*/ffi::Array{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ - Map{{"req", String("must")}, - {"levels", Array{4}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{4}}, // + {"scope", ffi::String("shared.dyn")}}, /*reuse_write=*/ - Map{{"req", String("no")}, - {"levels", Array{2}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("no")}, + {"levels", ffi::Array{2}}, // + {"scope", ffi::String("shared.dyn")}}, /*use_software_pipeline=*/true) // }; - Array append = ScheduleRule::DefaultCUDA(); + ffi::Array append = ScheduleRule::DefaultCUDA(); results.insert(results.end(), append.begin() + 1, append.end()); return results; } -Array ScheduleRule::DefaultHexagon() { +ffi::Array ScheduleRule::DefaultHexagon() { return { ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), @@ -286,80 +287,137 @@ Array ScheduleRule::DefaultHexagon() { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::MultiLevelTilingWideVector( /*structure=*/"SRSRS", /*vector_length_in_bits=*/1024, /*max_innermost_factor=*/Integer(128), /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/128, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, /*unroll_explicit=*/true), }; } -Array GetARMNeonSpecificRules() { +ffi::Array ScheduleRule::DefaultRISCV(const int vlen) { + ffi::Array rules; + rules.push_back(ScheduleRule::ApplyCustomRule()); + rules.push_back(ScheduleRule::InlineConstantScalars()); + rules.push_back(ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/ffi::Array{"tir.exp"})); + rules.push_back(ScheduleRule::AddRFactor( + /*max_jobs_per_core=*/16, + /*max_innermost_factor=*/Integer(64))); + auto current_target = tvm::Target::Current(); + const auto reg_rvv_intrinsics = + tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrinsics"); + const auto rvv_kernels_inventory = reg_rvv_intrinsics(current_target, /* inventory_only */ true) + .cast>(); + for (const auto& intrin : rvv_kernels_inventory) { + if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) { + // on demand intrinsic register + reg_rvv_intrinsics(current_target, /* inventory_only */ false); + } + rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/intrin.first, + /*structure=*/"SSRSRS", + /*tile_binds=*/std::nullopt, + /*max_innermost_factor=*/Integer(intrin.second), + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, + /*reuse_write=*/ + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}})); + } + rules.push_back(ScheduleRule::MultiLevelTiling( + /*structure=*/"SSRSRS", + /*tile_binds=*/std::nullopt, + /*max_innermost_factor=*/Integer(64), + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, + /*reuse_write=*/ + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}})); + rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/16, + /*max_vectorize_extent=*/64, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, + /*unroll_explicit=*/true)); + rules.push_back(ScheduleRule::RandomComputeLocation()); + + return rules; +} + +ffi::Array GetARMNeonSpecificRules() { return { ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_i8i8s32_neon"), + /*intrin_name=*/ffi::String("dot_4x4_i8i8s32_neon"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), }; } -Array GetARMDotprodSpecificRules() { +ffi::Array GetARMDotprodSpecificRules() { return { ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_i8i8s32_sdot"), + /*intrin_name=*/ffi::String("dot_4x4_i8i8s32_sdot"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_u8u8u32_udot"), + /*intrin_name=*/ffi::String("dot_4x4_u8u8u32_udot"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_u8u8i32_hdot"), + /*intrin_name=*/ffi::String("dot_4x4_u8u8i32_hdot"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), }; } -Array ScheduleRule::DefaultARM(const String& type) { - return Array::Agregate( +ffi::Array ScheduleRule::DefaultARM(const ffi::String& type) { + return ffi::Array::Agregate( ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/false, @@ -368,12 +426,12 @@ Array ScheduleRule::DefaultARM(const String& type) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/8, /*max_innermost_factor=*/Integer(32)), - "neon" == type ? GetARMNeonSpecificRules() : Array{}, - "dotprod" == type ? GetARMDotprodSpecificRules() : Array{}, + "neon" == type ? GetARMNeonSpecificRules() : ffi::Array{}, + "dotprod" == type ? GetARMDotprodSpecificRules() : ffi::Array{}, ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, @@ -381,13 +439,13 @@ Array ScheduleRule::DefaultARM(const String& type) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/8, /*max_vectorize_extent=*/32, - /*unroll_max_steps=*/Array{0, 8, 32, 256}, + /*unroll_max_steps=*/ffi::Array{0, 8, 32, 256}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation()); } @@ -401,12 +459,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { ScheduleRuleNode::RegisterReflection(); PyScheduleRuleNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.ScheduleRuleInitializeWithTuneContext", @@ -419,7 +477,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("meta_schedule.ScheduleRuleDefaultCUDATensorCore", ScheduleRule::DefaultCUDATensorCore) .def("meta_schedule.ScheduleRuleDefaultHexagon", ScheduleRule::DefaultHexagon) .def("meta_schedule.ScheduleRuleDefaultARM", ScheduleRule::DefaultARM); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 82c0dcb746c6..8aa5aca45059 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -112,10 +112,10 @@ class SizedHeap { }; struct PerThreadData { - IRModule mod{nullptr}; + IRModule mod{ffi::UnsafeInit()}; TRandState rand_state{-1}; std::function trace_sampler = nullptr; - std::function()> mutator_sampler = nullptr; + std::function()> mutator_sampler = nullptr; /*! * \brief Set the value for the trace and mutator samplers per thread. @@ -124,7 +124,7 @@ struct PerThreadData { * \param mutator_probs The probability of each mutator as a dict. */ void Set(const std::vector& scores, double genetic_mutate_prob, - const Map& mutator_probs) { + const ffi::Map& mutator_probs) { trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); } @@ -135,11 +135,11 @@ struct PerThreadData { * \param rand_state The random state for sampling * \return The sampler created */ - static std::function()> MakeMutatorSampler( - double genetic_mutate_prob, // - const Map& mutator_probs, // + static std::function()> MakeMutatorSampler( + double genetic_mutate_prob, // + const ffi::Map& mutator_probs, // TRandState* rand_state) { - std::vector> mutators; + std::vector> mutators; std::vector masses; mutators.push_back(std::nullopt); masses.push_back(1.0 - genetic_mutate_prob); @@ -165,7 +165,7 @@ struct PerThreadData { } } return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), - mutators = std::move(mutators)]() -> Optional { + mutators = std::move(mutators)]() -> ffi::Optional { int i = idx_sampler(); return mutators[i]; }; @@ -212,8 +212,8 @@ struct ConcurrentBitmask { * \param traces The picked candidate traces. * \return The assembled measure candidates. */ -Array AssembleCandidates(const std::vector& picks) { - Array measure_inputs; +ffi::Array AssembleCandidates(const std::vector& picks) { + ffi::Array measure_inputs; measure_inputs.reserve(picks.size()); for (const Schedule& sch : picks) { measure_inputs.push_back( @@ -261,7 +261,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The counter of returning empty results. */ int num_empty_iters; /*! \brief The design spaces. Decisions are not used so traces only. */ - Array design_spaces; + ffi::Array design_spaces; /*! \brief Pre thread data including module to be tuned and random state. */ std::vector per_thread_data_; /*! @@ -270,14 +270,15 @@ class EvolutionarySearchNode : public SearchStrategyNode { * */ IRModuleSet measured_workloads_; /*! \brief A Database for selecting useful candidates. */ - Database database_{nullptr}; + Database database_{ffi::UnsafeInit()}; /*! \brief A cost model helping to explore the search space */ - CostModel cost_model_{nullptr}; + CostModel cost_model_{ffi::UnsafeInit()}; /*! \brief The token registered for the given workload in database. */ - Workload token_{nullptr}; + Workload token_{ffi::UnsafeInit()}; explicit State(EvolutionarySearchNode* self, int max_trials, int num_trials_per_iter, - Array design_space_schedules, Database database, CostModel cost_model) + ffi::Array design_space_schedules, Database database, + CostModel cost_model) : self(self), max_trials(max_trials), num_trials_per_iter(num_trials_per_iter), @@ -331,10 +332,10 @@ class EvolutionarySearchNode : public SearchStrategyNode { inline std::vector PickWithEpsGreedy(const std::vector& inits, const std::vector& bests, int num); /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ - inline Optional> GenerateMeasureCandidates(); + inline ffi::Optional> GenerateMeasureCandidates(); /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ - inline void NotifyRunnerResults(const Array& measure_candidates, - const Array& results); + inline void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results); /*! * \brief Compute the hash for the given module. * \param mod The input TIR module. @@ -346,9 +347,9 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The tuning context of the evolutionary search strategy. */ const TuneContextNode* ctx_{nullptr}; /*! \brief The postprocessors */ - Array postprocs_; + ffi::Array postprocs_; /*! \brief The mutators and their probability. */ - Map mutator_probs_; + ffi::Map mutator_probs_; /*! \brief The random state. To be initialized with TuneContext. */ TRandState rand_state_; /*! \brief The state of the search strategy. */ @@ -394,9 +395,8 @@ class EvolutionarySearchNode : public SearchStrategyNode { .def_ro("genetic_max_fail_count", &EvolutionarySearchNode::genetic_max_fail_count) .def_ro("eps_greedy", &EvolutionarySearchNode::eps_greedy); } - - static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; - TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.EvolutionarySearch", EvolutionarySearchNode, + SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& ctx) final { CHECK(ctx->num_threads > 0) << "ValueError: `TuneContext.num_threads` must be > 0"; @@ -413,8 +413,9 @@ class EvolutionarySearchNode : public SearchStrategyNode { this->state_.reset(); } - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final { ICHECK(!design_spaces.empty()); CHECK(this->ctx_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; CHECK(database.defined()) @@ -439,19 +440,19 @@ class EvolutionarySearchNode : public SearchStrategyNode { this->state_.reset(); } - Optional> GenerateMeasureCandidates() final { + ffi::Optional> GenerateMeasureCandidates() final { ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) final { + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(measure_candidates, results); } SearchStrategy Clone() const final { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->population_size = this->population_size; n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop; n->init_measured_ratio = this->init_measured_ratio; @@ -472,7 +473,7 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu auto _ = Profiler::TimedScope("EvoSearch/PickBestFromDatabase"); std::vector measured_traces; measured_traces.reserve(num); - Array top_records = this->database_->GetTopK(this->token_, num); + ffi::Array top_records = this->database_->GetTopK(this->token_, num); for (TuningRecord record : top_records) { measured_traces.push_back(record->trace); } @@ -487,7 +488,7 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu tir::Trace trace = measured_traces.at(trace_id); Schedule& result = results.at(trace_id); ICHECK(!result.defined()); - if (Optional sch = pp.Apply(mod, trace, rand_state)) { + if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { result = sch.value(); } else { LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; @@ -514,7 +515,7 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu ICHECK(!result.defined()); int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); tir::Trace trace(design_spaces[design_space_index]->insts, {}); - if (Optional sch = pp.Apply(mod, trace, rand_state)) { + if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { result = sch.value(); } }; @@ -546,7 +547,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( for (int iter = 0;; ++iter) { // Predict normalized score with the cost model, std::vector scores = - PredictNormalizedScore(population, GetRef(self->ctx_), this->cost_model_); + PredictNormalizedScore(population, ffi::GetRef(self->ctx_), this->cost_model_); { auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc"); @@ -583,7 +584,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( TRandState* rand_state = &data.rand_state; const IRModule& mod = data.mod; std::function& trace_sampler = data.trace_sampler; - std::function()>& mutator_sampler = data.mutator_sampler; + std::function()>& mutator_sampler = data.mutator_sampler; Schedule& result = next_population.at(trace_id); int sampled_trace_id = -1; // Loop until success @@ -591,11 +592,11 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( sampled_trace_id = trace_sampler(); sampled_trace_id = sampled_trace_id % self->population_size; tir::Trace trace = population.at(sampled_trace_id)->trace().value(); - if (Optional opt_mutator = mutator_sampler()) { + if (ffi::Optional opt_mutator = mutator_sampler()) { // Decision: mutate Mutator mutator = opt_mutator.value(); - if (Optional new_trace = mutator->Apply(trace, rand_state)) { - if (Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { + if (ffi::Optional new_trace = mutator->Apply(trace, rand_state)) { + if (ffi::Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { // note that sch's trace is different from new_trace // because it contains post-processing information result = sch.value(); @@ -694,7 +695,8 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( return results; } -Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { +ffi::Optional> +EvolutionarySearchNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { return std::nullopt; } @@ -737,7 +739,8 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure } void EvolutionarySearchNode::State::NotifyRunnerResults( - const Array& measure_candidates, const Array& results) { + const ffi::Array& measure_candidates, + const ffi::Array& results) { st += results.size(); ed += results.size(); } @@ -757,7 +760,7 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, / TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->population_size = population_size; n->num_empty_iters_before_early_stop = 5; n->init_measured_ratio = init_measured_ratio; @@ -772,18 +775,19 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, / class EvolutionarySearch : public SearchStrategy { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(EvolutionarySearch, SearchStrategy, - EvolutionarySearchNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(EvolutionarySearch, SearchStrategy, + EvolutionarySearchNode); }; -Array EvolutionarySearchSampleInitPopulation(EvolutionarySearch self, int num) { +ffi::Array EvolutionarySearchSampleInitPopulation(EvolutionarySearch self, int num) { std::vector results = self->state_->SampleInitPopulation(num); - return Array(results.begin(), results.end()); + return ffi::Array(results.begin(), results.end()); } -Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, - Array population, int num) { - Array result; +ffi::Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, + ffi::Array population, + int num) { + ffi::Array result; std::vector population_vec = std::vector(population.begin(), population.end()); std::vector schs = self->state_->EvolveWithCostModel(population_vec, num); @@ -798,9 +802,9 @@ Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, return result; } -TVM_FFI_STATIC_INIT_BLOCK({ EvolutionarySearchNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { EvolutionarySearchNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.SearchStrategyEvolutionarySearch", SearchStrategy::EvolutionarySearch) @@ -808,7 +812,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ EvolutionarySearchSampleInitPopulation) .def("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel", EvolutionarySearchEvolveWithCostModel); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index c9a219777053..9082c6c3a90f 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -49,25 +49,24 @@ class ReplayFuncNode : public SearchStrategyNode { << "ValueError: The search strategy has not been initialized."; } - inline Optional> GenerateMeasureCandidates(); - inline void NotifyRunnerResults(const Array& results); + inline ffi::Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const ffi::Array& results); }; /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The IRModule to be scheduled from TuneContext. */ - Optional mod_ = std::nullopt; + ffi::Optional mod_ = std::nullopt; /*! \brief The space generator from TuneContext. */ - Optional space_generator_ = std::nullopt; + ffi::Optional space_generator_ = std::nullopt; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; static void RegisterReflection() { - // No fields to register + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - - static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.ReplayFunc", ReplayFuncNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& ctx) final { CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined"; @@ -85,8 +84,10 @@ class ReplayFuncNode : public SearchStrategyNode { this->state_.reset(); } - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final { CHECK(this->state_ == nullptr) << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; this->state_ = std::make_unique(this, max_trials, num_trials_per_iter); @@ -98,19 +99,19 @@ class ReplayFuncNode : public SearchStrategyNode { this->state_.reset(); } - Optional> GenerateMeasureCandidates() final { + ffi::Optional> GenerateMeasureCandidates() final { ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) final { + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } SearchStrategy Clone() const final { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->rand_state_ = -1; n->mod_ = std::nullopt; n->space_generator_ = std::nullopt; @@ -119,17 +120,18 @@ class ReplayFuncNode : public SearchStrategyNode { } }; -inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { +inline ffi::Optional> +ReplayFuncNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { return std::nullopt; } ed = std::min(ed, max_trials); - Array result; + ffi::Array result; IRModule mod = self->mod_.value(); - Array postprocs = self->space_generator_.value()->postprocs.value_or({}); + ffi::Array postprocs = self->space_generator_.value()->postprocs.value_or({}); for (int i = st; i < ed; i++) { for (;;) { - Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); + ffi::Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); tir::Schedule sch = schs[design_space_index]; sch->EnterPostproc(); @@ -141,7 +143,7 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC } } if (!failed) { - Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); + ffi::Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); result.push_back(MeasureCandidate(sch, args_info)); break; } @@ -150,22 +152,22 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC return result; } -inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { +inline void ReplayFuncNode::State::NotifyRunnerResults(const ffi::Array& results) { st += num_trials_per_iter; ed += num_trials_per_iter; } SearchStrategy SearchStrategy::ReplayFunc() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return SearchStrategy(n); } -TVM_FFI_STATIC_INIT_BLOCK({ ReplayFuncNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ReplayFuncNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.SearchStrategyReplayFunc", SearchStrategy::ReplayFunc); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 151d502ec078..7898b171d357 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -31,7 +31,7 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The search strategy itself */ ReplayTraceNode* self; /*! \brief The design spaces. */ - Array design_spaces; + ffi::Array design_spaces; /*! \brief The number of total trials. */ int max_trials; /*! \brief The number of trials per iteration. */ @@ -42,9 +42,9 @@ class ReplayTraceNode : public SearchStrategyNode { int ed; /*! \brief The module to be tuned. */ - Array per_thread_mod_{nullptr}; + ffi::Array per_thread_mod_{nullptr}; - explicit State(ReplayTraceNode* self, Array design_spaces, int max_trials, + explicit State(ReplayTraceNode* self, ffi::Array design_spaces, int max_trials, int num_trials_per_iter) : self(self), design_spaces(design_spaces), @@ -59,8 +59,8 @@ class ReplayTraceNode : public SearchStrategyNode { } } - inline Optional> GenerateMeasureCandidates(); - inline void NotifyRunnerResults(const Array& results); + inline ffi::Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const ffi::Array& results); }; /*! \brief The max number of failures during trace replaying. */ @@ -69,11 +69,11 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The IRModule to be scheduled from TuneContext. */ - Optional mod_ = std::nullopt; + ffi::Optional mod_ = std::nullopt; /*! \brief The number of threads to be used. */ int num_threads_ = -1; /*! \brief The postprocessors. */ - Array postprocs_ = {}; + ffi::Array postprocs_ = {}; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; @@ -81,9 +81,8 @@ class ReplayTraceNode : public SearchStrategyNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("max_fail_count", &ReplayTraceNode::max_fail_count); } - - static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.ReplayTrace", ReplayTraceNode, + SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& ctx) final { CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined"; @@ -102,12 +101,14 @@ class ReplayTraceNode : public SearchStrategyNode { this->state_.reset(); } - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final { ICHECK(!design_spaces.empty()); CHECK(this->state_ == nullptr) << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; - Array design_space_traces; + ffi::Array design_space_traces; design_space_traces.reserve(design_spaces.size()); for (const tir::Schedule& space : design_spaces) { design_space_traces.push_back(space->trace().value()->Simplified(true)); @@ -121,19 +122,19 @@ class ReplayTraceNode : public SearchStrategyNode { this->state_.reset(); } - Optional> GenerateMeasureCandidates() final { + ffi::Optional> GenerateMeasureCandidates() final { ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) final { + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } SearchStrategy Clone() const final { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_fail_count = this->max_fail_count; n->rand_state_ = this->rand_state_; n->state_ = nullptr; // cleared the state @@ -141,14 +142,15 @@ class ReplayTraceNode : public SearchStrategyNode { } }; -inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { +inline ffi::Optional> +ReplayTraceNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { return std::nullopt; } ed = std::min(ed, max_trials); ICHECK_LT(st, ed); std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); - Array> per_task_result(ed - st, std::nullopt); + ffi::Array> per_task_result(ed - st, std::nullopt); ThreadedTraceApply pp(self->postprocs_); auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id, int task_id) -> void { @@ -159,41 +161,41 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); tir::Trace trace = design_spaces[design_space_index]; tir::Trace new_trace = tir::Trace(trace->insts, {}); - if (Optional opt_sch = pp.Apply(mod, new_trace, &rand_state)) { + if (ffi::Optional opt_sch = pp.Apply(mod, new_trace, &rand_state)) { tir::Schedule sch = opt_sch.value(); - Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); + ffi::Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); per_task_result.Set(task_id, MeasureCandidate(sch, args_info)); break; } } }; support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); - Array filtered; + ffi::Array filtered; filtered.reserve(ed - st); - for (Optional result : per_task_result) + for (ffi::Optional result : per_task_result) if (result.has_value()) { filtered.push_back(*std::move(result)); } return filtered; } -inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { +inline void ReplayTraceNode::State::NotifyRunnerResults(const ffi::Array& results) { st += num_trials_per_iter; ed += num_trials_per_iter; } SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_fail_count = max_fail_count; return SearchStrategy(n); } -TVM_FFI_STATIC_INIT_BLOCK({ ReplayTraceNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ReplayTraceNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.SearchStrategyReplayTrace", SearchStrategy::ReplayTrace); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 66d063b2dcba..3273e70ac1b8 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -23,8 +23,8 @@ namespace tvm { namespace meta_schedule { -MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) { - ObjectPtr n = make_object(); +MeasureCandidate::MeasureCandidate(tir::Schedule sch, ffi::Array args_info) { + ObjectPtr n = ffi::make_object(); n->sch = sch; n->args_info = args_info; data_ = std::move(n); @@ -37,9 +37,9 @@ void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) } void PySearchStrategyNode::PreTuning(int max_trials, int num_trials_per_iter, - const Array& design_spaces, - const Optional& database, - const Optional& cost_model) { + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) { ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; f_pre_tuning(max_trials, num_trials_per_iter, design_spaces, database, cost_model); } @@ -49,14 +49,15 @@ void PySearchStrategyNode::PostTuning() { f_post_tuning(); } -Optional> PySearchStrategyNode::GenerateMeasureCandidates() { +ffi::Optional> PySearchStrategyNode::GenerateMeasureCandidates() { ICHECK(f_generate_measure_candidates != nullptr) << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; return f_generate_measure_candidates(); } -void PySearchStrategyNode::NotifyRunnerResults(const Array& measure_candidates, - const Array& results) { +void PySearchStrategyNode::NotifyRunnerResults( + const ffi::Array& measure_candidates, + const ffi::Array& results) { ICHECK(f_notify_runner_results != nullptr) << "PySearchStrategy's NotifyRunnerResults method not implemented!"; f_notify_runner_results(measure_candidates, results); @@ -74,7 +75,7 @@ SearchStrategy SearchStrategy::PySearchStrategy( PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results, // PySearchStrategyNode::FClone f_clone) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = f_initialize_with_tune_context; n->f_pre_tuning = f_pre_tuning; n->f_post_tuning = f_post_tuning; @@ -84,16 +85,16 @@ SearchStrategy SearchStrategy::PySearchStrategy( return SearchStrategy(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { MeasureCandidateNode::RegisterReflection(); PySearchStrategyNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.MeasureCandidate", - [](tir::Schedule sch, Optional> args_info) -> MeasureCandidate { + [](tir::Schedule sch, ffi::Optional> args_info) -> MeasureCandidate { return MeasureCandidate(sch, args_info.value_or({})); }) .def("meta_schedule.SearchStrategyPySearchStrategy", SearchStrategy::PySearchStrategy) @@ -106,7 +107,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.SearchStrategyNotifyRunnerResults", &SearchStrategyNode::NotifyRunnerResults) .def_method("meta_schedule.SearchStrategyClone", &SearchStrategyNode::Clone); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 86f21f43e817..e3786a4d6188 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -37,7 +37,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { TRandState rand_state_ = -1; static void RegisterReflection() { - // No fields to register + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } void InitializeWithTuneContext(const TuneContext& context) final { @@ -45,8 +46,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { this->rand_state_ = ForkSeed(&context->rand_state); } - Array GenerateDesignSpace(const IRModule& mod) final { - using ScheduleAndUnvisitedBlocks = std::pair>; + ffi::Array GenerateDesignSpace(const IRModule& mod) final { + using ScheduleAndUnvisitedBlocks = std::pair>; CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply"; tir::Schedule sch = tir::Schedule::Traced( /*mod=*/mod, @@ -55,8 +56,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; - Array result{sch}; - Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); + ffi::Array result{sch}; + ffi::Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); for (ScheduleRule sch_rule : sch_rules.value()) { for (const tir::Schedule& sch : result) { @@ -80,12 +81,12 @@ class PostOrderApplyNode : public SpaceGeneratorNode { continue; } if (!ScheduleRule::IsApplyCustomRule(sch_rule)) { - if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").has_value()) { + if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").has_value()) { stack.emplace_back(sch, blocks); continue; } } - Array applied = sch_rule->Apply(sch, /*block=*/block_rv); + ffi::Array applied = sch_rule->Apply(sch, /*block=*/block_rv); for (const tir::Schedule& sch : applied) { stack.emplace_back(sch, blocks); } @@ -95,19 +96,19 @@ class PostOrderApplyNode : public SpaceGeneratorNode { } SpaceGenerator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); CloneRules(this, n.get()); return SpaceGenerator(n); } - static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; - TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.PostOrderApply", PostOrderApplyNode, + SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs) { - ObjectPtr n = make_object(); +SpaceGenerator SpaceGenerator::PostOrderApply( + ffi::Function f_block_filter, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs) { + ObjectPtr n = ffi::make_object(); n->sch_rules = std::move(sch_rules); n->postprocs = std::move(postprocs); n->mutator_probs = std::move(mutator_probs); @@ -115,13 +116,13 @@ SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, return SpaceGenerator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ PostOrderApplyNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PostOrderApplyNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.SpaceGeneratorPostOrderApply", SpaceGenerator::PostOrderApply); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 1112aca88762..7d22635b76f2 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -33,6 +33,8 @@ class ScheduleFnNode : public SpaceGeneratorNode { static void RegisterReflection() { // `schedule_fn_` is not registered. + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } void InitializeWithTuneContext(const TuneContext& context) final { @@ -40,7 +42,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { this->rand_state_ = ForkSeed(&context->rand_state); } - Array GenerateDesignSpace(const IRModule& mod) final { + ffi::Array GenerateDesignSpace(const IRModule& mod) final { tir::Schedule sch = tir::Schedule::Traced( /*mod=*/mod, /*rand_state=*/ForkSeed(&this->rand_state_), @@ -56,7 +58,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { return {sch.value()}; } if (const auto* arr = obj.as()) { - Array result; + ffi::Array result; result.reserve(arr->size()); for (Any val : *arr) { if (auto sch = val.as()) { @@ -76,20 +78,19 @@ class ScheduleFnNode : public SpaceGeneratorNode { } SpaceGenerator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); CloneRules(this, n.get()); return SpaceGenerator(n); } - static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; - TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.ScheduleFn", ScheduleFnNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs) { - ObjectPtr n = make_object(); +SpaceGenerator SpaceGenerator::ScheduleFn( + ffi::Function schedule_fn, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs) { + ObjectPtr n = ffi::make_object(); n->sch_rules = std::move(sch_rules); n->postprocs = std::move(postprocs); n->mutator_probs = std::move(mutator_probs); @@ -97,12 +98,11 @@ SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, return SpaceGenerator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ ScheduleFnNode::RegisterReflection(); }); - -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + ScheduleFnNode::RegisterReflection(); refl::GlobalDef().def("meta_schedule.SpaceGeneratorScheduleFn", SpaceGenerator::ScheduleFn); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 709b36417c9e..9e458a3ad7cf 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -24,7 +24,7 @@ namespace tvm { namespace meta_schedule { -String GetRuleKindFromTarget(const Target& target) { +ffi::String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "llvm") { static auto target_has_feature_fn_ptr = tvm::ffi::Function::GetGlobalRequired("target.target_has_feature"); @@ -39,6 +39,10 @@ String GetRuleKindFromTarget(const Target& target) { return "avx512"; } } + bool have_rvv = target_has_feature_fn_ptr("v", target).cast(); + if (have_rvv) { + return "rvv"; + } TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export()); TargetFeatures afeatures = Downcast(target_json.at("features")); @@ -55,7 +59,7 @@ String GetRuleKindFromTarget(const Target& target) { return "hexagon"; } if (target->kind->name == "cuda") { - if (Optional opt_sm = target->GetAttr("arch")) { + if (ffi::Optional opt_sm = target->GetAttr("arch")) { std::string sm = opt_sm.value(); if (support::StartsWith(sm, "sm_")) { sm = sm.substr(3); @@ -88,10 +92,10 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { !(sch_rules.defined() && // postprocs.defined() && // mutator_probs.defined())) { - String kind = GetRuleKindFromTarget(context->target.value()); - Array default_sch_rules; - Array default_postprocs; - Map default_mutator_probs; + ffi::String kind = GetRuleKindFromTarget(context->target.value()); + ffi::Array default_sch_rules; + ffi::Array default_postprocs; + ffi::Map default_mutator_probs; // for target with skylake-avx512 if (kind == "llvm") { default_sch_rules = ScheduleRule::DefaultLLVM(); @@ -117,6 +121,13 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_sch_rules = ScheduleRule::DefaultX86("avx512"); default_postprocs = Postproc::DefaultCPUTensorization(); default_mutator_probs = Mutator::DefaultLLVM(); + } else if (kind == "rvv") { + static auto llvm_get_vector_width = + tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width"); + const int vlen = llvm_get_vector_width(context->target.value()).cast(); + default_sch_rules = ScheduleRule::DefaultRISCV(vlen); + default_postprocs = Postproc::DefaultRISCV(); + default_mutator_probs = Mutator::DefaultLLVM(); } else if (kind == "asimd") { default_sch_rules = ScheduleRule::DefaultARM("neon"); default_postprocs = Postproc::DefaultCPUTensorization(); @@ -163,7 +174,7 @@ void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) f_initialize_with_tune_context(context); } -Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { +ffi::Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { ICHECK(f_generate_design_space != nullptr) << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; return f_generate_design_space(mod); @@ -175,11 +186,12 @@ SpaceGenerator PySpaceGeneratorNode::Clone() const { } SpaceGenerator SpaceGenerator::PySpaceGenerator( - Optional> sch_rules, Optional> postprocs, - Optional> mutator_probs, + ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs, FInitializeWithTuneContext f_initialize_with_tune_context, FGenerateDesignSpace f_generate_design_space, FClone f_clone) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->sch_rules = sch_rules; n->postprocs = postprocs; n->mutator_probs = mutator_probs; @@ -189,12 +201,12 @@ SpaceGenerator SpaceGenerator::PySpaceGenerator( return SpaceGenerator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { SpaceGeneratorNode::RegisterReflection(); PySpaceGeneratorNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("meta_schedule.SpaceGeneratorInitializeWithTuneContext", @@ -203,7 +215,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ &SpaceGeneratorNode::GenerateDesignSpace) .def("meta_schedule.SpaceGeneratorPySpaceGenerator", SpaceGenerator::PySpaceGenerator) .def_method("meta_schedule.SpaceGeneratorClone", &SpaceGeneratorNode::Clone); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index f9a8c2e71c8b..026daa68a762 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -27,7 +27,7 @@ namespace meta_schedule { class SpaceGeneratorUnionNode : public SpaceGeneratorNode { public: /*! \brief The array of design space generators unioned, could be recursive. */ - Array space_generators; + ffi::Array space_generators; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -42,11 +42,11 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { } } - Array GenerateDesignSpace(const IRModule& mod) final { - Array design_spaces; + ffi::Array GenerateDesignSpace(const IRModule& mod) final { + ffi::Array design_spaces; for (const SpaceGenerator& space_generator : space_generators) { // Generate partial design spaces from each design space generator. - Array partial = space_generator->GenerateDesignSpace(mod); + ffi::Array partial = space_generator->GenerateDesignSpace(mod); // Merge the partial design spaces. design_spaces.insert(design_spaces.end(), partial.begin(), partial.end()); } @@ -54,17 +54,16 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { } SpaceGenerator Clone() const final { - ObjectPtr n = make_object(*this); - n->space_generators = Array(); + ObjectPtr n = ffi::make_object(*this); + n->space_generators = ffi::Array(); for (const SpaceGenerator& space_generator : this->space_generators) { n->space_generators.push_back(space_generator->Clone()); } CloneRules(this, n.get()); return SpaceGenerator(n); } - - static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion"; - TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.SpaceGeneratorUnion", SpaceGeneratorUnionNode, + SpaceGeneratorNode); }; /*! @@ -72,11 +71,11 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { * \param space_generators Array of the design space generators to be unioned. * \return The design space generator created. */ -SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_generators, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs) { - ObjectPtr n = make_object(); +SpaceGenerator SpaceGenerator::SpaceGeneratorUnion( + ffi::Array space_generators, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs) { + ObjectPtr n = ffi::make_object(); n->sch_rules = std::move(sch_rules); n->postprocs = std::move(postprocs); n->mutator_probs = std::move(mutator_probs); @@ -84,13 +83,13 @@ SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_g return SpaceGenerator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ SpaceGeneratorUnionNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { SpaceGeneratorUnionNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.SpaceGeneratorSpaceGeneratorUnion", SpaceGenerator::SpaceGeneratorUnion); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index a19754b49ccd..babf521c280c 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -39,15 +39,14 @@ class GradientBasedNode final : public TaskSchedulerNode { .def_ro("alpha", &GradientBasedNode::alpha) .def_ro("window_size", &GradientBasedNode::window_size); } - - static constexpr const char* _type_key = "meta_schedule.GradientBased"; - TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.GradientBased", GradientBasedNode, + TaskSchedulerNode); public: - void Tune(Array tasks, Array task_weights, int max_trials_global, + void Tune(ffi::Array tasks, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, Optional database, - Optional cost_model) final { + ffi::Array measure_callbacks, ffi::Optional database, + ffi::Optional cost_model) final { int n_tasks = tasks.size(); round_robin_rounds_ = 0; best_latency_history_.resize(n_tasks, std::vector()); @@ -122,8 +121,8 @@ class GradientBasedNode final : public TaskSchedulerNode { return task_id; } - Array JoinRunningTask(int task_id) final { - Array results = TaskSchedulerNode::JoinRunningTask(task_id); + ffi::Array JoinRunningTask(int task_id) final { + ffi::Array results = TaskSchedulerNode::JoinRunningTask(task_id); TaskRecordNode* task = this->tasks_[task_id].get(); if (task->latency_ms.size() > 0) { this->best_latency_history_.at(task_id).push_back( @@ -136,7 +135,7 @@ class GradientBasedNode final : public TaskSchedulerNode { TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->logger = logger; n->alpha = alpha; n->window_size = window_size; @@ -144,12 +143,12 @@ TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, i return TaskScheduler(n); } -TVM_FFI_STATIC_INIT_BLOCK({ GradientBasedNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { GradientBasedNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.TaskSchedulerGradientBased", TaskScheduler::GradientBased); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 9bb5a20188ec..c3b95a7cc4c6 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -33,9 +33,7 @@ class RoundRobinNode final : public TaskSchedulerNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("task_id", &RoundRobinNode::task_id); } - - static constexpr const char* _type_key = "meta_schedule.RoundRobin"; - TVM_DECLARE_FINAL_OBJECT_INFO(RoundRobinNode, TaskSchedulerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("meta_schedule.RoundRobin", RoundRobinNode, TaskSchedulerNode); protected: int NextTaskId() final { @@ -58,18 +56,18 @@ class RoundRobinNode final : public TaskSchedulerNode { }; TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->logger = logger; n->task_id = -1; return TaskScheduler(n); } -TVM_FFI_STATIC_INIT_BLOCK({ RoundRobinNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RoundRobinNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.TaskSchedulerRoundRobin", TaskScheduler::RoundRobin); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 21827ba8ad03..85c6d71b4307 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -23,11 +23,11 @@ namespace tvm { namespace meta_schedule { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { TaskRecordNode::RegisterReflection(); TaskSchedulerNode::RegisterReflection(); PyTaskSchedulerNode::RegisterReflection(); -}); +} TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { ObjectPtr n = ffi::make_object(); @@ -48,9 +48,9 @@ TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { void SendToBuilder(TaskRecordNode* self, const Builder& builder) { auto _ = Profiler::TimedScope("SendToBuilder"); - Array candidates = self->measure_candidates.value(); + ffi::Array candidates = self->measure_candidates.value(); Target target = self->ctx->target.value(); - Array inputs; + ffi::Array inputs; inputs.reserve(candidates.size()); for (const MeasureCandidate& candidate : candidates) { inputs.push_back(BuilderInput(candidate->sch->mod(), target)); @@ -60,13 +60,13 @@ void SendToBuilder(TaskRecordNode* self, const Builder& builder) { void SendToRunner(TaskRecordNode* self, const Runner& runner) { auto _ = Profiler::TimedScope("SendToRunner"); - Array candidates = self->measure_candidates.value(); - Array builder_results = self->builder_results.value(); + ffi::Array candidates = self->measure_candidates.value(); + ffi::Array builder_results = self->builder_results.value(); Target target = self->ctx->target.value(); ICHECK_EQ(candidates.size(), builder_results.size()); int n = candidates.size(); int n_build_errors = 0; - Array inputs; + ffi::Array inputs; inputs.reserve(n); for (int i = 0; i < n; ++i) { const MeasureCandidate& candidate = candidates[i]; @@ -79,12 +79,12 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { /*device_type=*/target->kind->name, /*args_info=*/candidate->args_info)); } - Array futures = runner->Run(inputs); + ffi::Array futures = runner->Run(inputs); if (n_build_errors == 0) { self->runner_futures = futures; return; } - Array results; + ffi::Array results; results.reserve(n); for (int i = 0, j = 0; i < n; ++i) { const BuilderResult& builder_result = builder_results[i]; @@ -102,7 +102,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { self->runner_futures = results; } -void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& results) { +void TaskCleanUp(TaskRecordNode* self, int task_id, const ffi::Array& results) { ICHECK_EQ(self->builder_results.value().size(), results.size()); ICHECK_EQ(self->runner_futures.value().size(), results.size()); int n = results.size(); @@ -112,7 +112,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r const BuilderResult& builder_result = self->builder_results.value()[i]; const MeasureCandidate& candidate = self->measure_candidates.value()[i]; const RunnerResult& runner_result = results[i]; - Optional error_msg = std::nullopt; + ffi::Optional error_msg = std::nullopt; int trials = self->latency_ms.size() + 1; double run_ms = 1e9; if ((error_msg = builder_result->error_msg)) { @@ -148,11 +148,12 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r self->runner_futures = std::nullopt; } -void TaskSchedulerNode::Tune(Array ctxs, Array task_weights, +void TaskSchedulerNode::Tune(ffi::Array ctxs, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, Optional database, - Optional cost_model) { + ffi::Array measure_callbacks, + ffi::Optional database, + ffi::Optional cost_model) { CHECK_EQ(ctxs.size(), task_weights.size()) << "ValueError: `task_weights` must have the same " "length as `ctxs`"; int n_tasks = this->remaining_tasks_ = ctxs.size(); @@ -167,7 +168,7 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh TVM_PY_LOG(INFO, this->logger) << "Initializing Task #" << i << ": " << ctx->task_name; TVM_PY_LOG(INFO, ctx->logger) << "Initializing Task #" << i << ": " << ctx->task_name; this->tasks_.push_back(TaskRecord(ctx, weight)); - Array design_spaces = + ffi::Array design_spaces = ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value()); TVM_PY_LOG(INFO, ctx->logger) << "Total " << design_spaces.size() << " design space(s) generated"; @@ -194,7 +195,7 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh TerminateTask(task_id); continue; } - if (Optional> candidates = task->measure_candidates = + if (ffi::Optional> candidates = task->measure_candidates = task->ctx->search_strategy.value()->GenerateMeasureCandidates()) { int num_candidates = candidates.value().size(); num_trials_already += num_candidates; @@ -218,13 +219,13 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh } } -Array TaskSchedulerNode::JoinRunningTask(int task_id) { +ffi::Array TaskSchedulerNode::JoinRunningTask(int task_id) { TaskRecordNode* task = this->tasks_[task_id].get(); ICHECK(task->runner_futures.defined()); - Array results; + ffi::Array results; { auto _ = Profiler::TimedScope("JoinRunnerFutures"); - Array futures = task->runner_futures.value(); + ffi::Array futures = task->runner_futures.value(); results.reserve(futures.size()); for (RunnerFuture future : futures) { results.push_back(future->Result()); @@ -237,7 +238,7 @@ Array TaskSchedulerNode::JoinRunningTask(int task_id) { ICHECK_EQ(results.size(), task->measure_candidates.value().size()); ICHECK_EQ(results.size(), task->builder_results.value().size()); for (const MeasureCallback& callback : this->measure_callbacks_) { - callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), + callback->Apply(ffi::GetRef(this), task_id, task->measure_candidates.value(), task->builder_results.value(), results); } TaskCleanUp(task, task_id, results); @@ -333,7 +334,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler( ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune) { CHECK(f_next_task_id != nullptr) << "ValueError: next_task_id is not defined"; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->logger = logger; n->f_next_task_id = f_next_task_id; n->f_join_running_task = f_join_running_task; @@ -346,7 +347,7 @@ int PyTaskSchedulerNode::NextTaskId() { return f_next_task_id(); } -Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { +ffi::Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { if (f_join_running_task == nullptr) { return TaskSchedulerNode::JoinRunningTask(task_id); } else { @@ -354,11 +355,12 @@ Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { } } -void PyTaskSchedulerNode::Tune(Array tasks, Array task_weights, +void PyTaskSchedulerNode::Tune(ffi::Array tasks, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, - Optional database, Optional cost_model) { + ffi::Array measure_callbacks, + ffi::Optional database, + ffi::Optional cost_model) { if (f_tune == nullptr) { TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, builder, runner, measure_callbacks, database, @@ -369,7 +371,7 @@ void PyTaskSchedulerNode::Tune(Array tasks, Array task_we } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.TaskSchedulerPyTaskScheduler", TaskScheduler::PyTaskScheduler) @@ -380,7 +382,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.TaskSchedulerTouchTask", &TaskSchedulerNode::TouchTask) .def_method("meta_schedule.TaskSchedulerPrintTuningStatistics", &TaskSchedulerNode::PrintTuningStatistics); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 114afc0ad72e..d6300afcf9eb 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -56,7 +56,7 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { std::unordered_set get_block_names; for (const auto& inst : anchor_trace->insts) { if (inst->kind.same_as(kind_get_block)) { - auto block_name = Downcast(inst->attrs[0]); + auto block_name = Downcast(inst->attrs[0]); get_block_names.insert(block_name); } } @@ -140,9 +140,10 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { continue; } - Array inputs = TranslateInputRVs(inst->inputs, rv_map); + ffi::Array inputs = TranslateInputRVs(inst->inputs, rv_map); - if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast(inst->attrs[0]))) { + if (inst->kind.same_as(kind_get_block) && + !HasBlock(sch, Downcast(inst->attrs[0]))) { // The anchor trace does get_block on a block that is not part of the target schedule. auto block = Downcast(inst->outputs[0]); foreign_blocks.insert(block); @@ -174,7 +175,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { } Any decision = anchor_trace->GetDecision(inst); - Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); + ffi::Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); if (inst->kind.same_as(kind_get_child_blocks)) { // We want to allow a trace generated for a single conv2d block to be applied to @@ -184,9 +185,9 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" // outputs matches with the "old" outputs, and truncating the new outputs accordingly. ICHECK(inst->outputs.size() <= outputs.size()); - TranslateAddOutputRVs(inst->outputs, - Array(outputs.begin(), outputs.begin() + inst->outputs.size()), - &rv_map); + TranslateAddOutputRVs( + inst->outputs, ffi::Array(outputs.begin(), outputs.begin() + inst->outputs.size()), + &rv_map); } else { TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); } @@ -248,16 +249,16 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm auto auto_bind_rule = ScheduleRule::AutoBind(/*max_threadblocks=*/256, - /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}, + /*thread_extents*/ ffi::Array{32, 64, 128, 256, 512, 1024}, max_threads_per_block.value()->value); auto_bind_rule->Apply(sch, last_block); } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ScheduleUsingAnchorTrace", ScheduleUsingAnchorTrace); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 1b2cb9d0c140..5a0ca76cc284 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -25,12 +25,13 @@ namespace tvm { namespace meta_schedule { -TuneContext::TuneContext(Optional mod, Optional target, - Optional space_generator, - Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, ffi::Function logger) { +TuneContext::TuneContext(ffi::Optional mod, ffi::Optional target, + ffi::Optional space_generator, + ffi::Optional search_strategy, + ffi::Optional task_name, int num_threads, + TRandState rand_state, ffi::Function logger) { CHECK(rand_state == -1 || rand_state >= 0) << "ValueError: Invalid random state: " << rand_state; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->mod = mod; n->target = target; n->space_generator = space_generator; @@ -43,7 +44,7 @@ TuneContext::TuneContext(Optional mod, Optional target, } TuneContext TuneContextNode::Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); if (this->space_generator.defined()) { n->space_generator = this->space_generator.value()->Clone(); } @@ -57,30 +58,30 @@ TuneContext TuneContextNode::Clone() const { void TuneContextNode::Initialize() { if (this->space_generator.defined()) { - this->space_generator.value()->InitializeWithTuneContext(GetRef(this)); + this->space_generator.value()->InitializeWithTuneContext(ffi::GetRef(this)); } if (this->search_strategy.defined()) { - this->search_strategy.value()->InitializeWithTuneContext(GetRef(this)); + this->search_strategy.value()->InitializeWithTuneContext(ffi::GetRef(this)); } } -TVM_FFI_STATIC_INIT_BLOCK({ TuneContextNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TuneContextNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.TuneContext", - [](Optional mod, Optional target, - Optional space_generator, Optional search_strategy, - Optional task_name, int num_threads, TRandState rand_state, - ffi::Function logger) -> TuneContext { + [](ffi::Optional mod, ffi::Optional target, + ffi::Optional space_generator, + ffi::Optional search_strategy, ffi::Optional task_name, + int num_threads, TRandState rand_state, ffi::Function logger) -> TuneContext { return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads, rand_state, logger); }) .def("meta_schedule._SHash2Hex", SHash2Hex) .def_method("meta_schedule.TuneContextInitialize", &TuneContextNode::Initialize) .def_method("meta_schedule.TuneContextClone", &TuneContextNode::Clone); -}); +} } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 21483d3b98a4..ee94b1d2ab5e 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -136,7 +136,7 @@ inline bool using_ipython() { * \brief Print out the performance table interactively in jupyter notebook. * \param str The serialized performance table. */ -inline void print_interactive_table(const String& data) { +inline void print_interactive_table(const ffi::String& data) { const auto f_print_interactive_table = tvm::ffi::Function::GetGlobal("meta_schedule.print_interactive_table"); ICHECK(f_print_interactive_table.has_value()) @@ -214,14 +214,14 @@ std::string JSONDumps(Any json_obj); * \param hash_code The hash code * \return The string representation of the hash code */ -inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } +inline ffi::String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } /*! * \brief Converts an TVM object to the hex string representation of its structural hash. * \param obj The TVM object. * \return The hex string representation of the hash code. */ -inline String SHash2Hex(const ObjectRef& obj) { +inline ffi::String SHash2Hex(const ObjectRef& obj) { std::ostringstream os; size_t hash_code = 0; if (obj.defined()) { @@ -272,7 +272,7 @@ inline IRModule DeepCopyIRModule(IRModule mod) { return LoadJSON(SaveJSON(mod)). * \param delim The delimiter * \return The concatenated string */ -inline std::string Concat(const Array& strs, const std::string& delim) { +inline std::string Concat(const ffi::Array& strs, const std::string& delim) { if (strs.empty()) { return ""; } @@ -292,7 +292,7 @@ inline std::string Concat(const Array& strs, const std::string& delim) { * \return The BlockRV */ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, - const String& global_var_name) { + const ffi::String& global_var_name) { const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); return sch->GetBlock(block->name_hint, global_var_name); } @@ -303,7 +303,7 @@ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& */ struct ThreadedTraceApply { /*! \brief Constructor */ - explicit ThreadedTraceApply(const Array& postprocs) + explicit ThreadedTraceApply(const ffi::Array& postprocs) : n_(postprocs.size()), items_(new Item[n_]) { for (int i = 0; i < n_; ++i) { items_[i].postproc = postprocs[i]; @@ -321,8 +321,8 @@ struct ThreadedTraceApply { * \param rand_state The random seed * \return The schedule created, or std::nullopt if any postprocessor fails */ - Optional Apply(const IRModule& mod, const tir::Trace& trace, - TRandState* rand_state) { + ffi::Optional Apply(const IRModule& mod, const tir::Trace& trace, + TRandState* rand_state) { tir::Schedule sch = tir::Schedule::Traced(mod, /*rand_state=*/ForkSeed(rand_state), @@ -360,7 +360,7 @@ struct ThreadedTraceApply { /*! \brief A helper data structure that stores the fail count for each postprocessor. */ struct Item { /*! \brief The postprocessor. */ - Postproc postproc{nullptr}; + Postproc postproc{ffi::UnsafeInit()}; /*! \brief The thread-safe postprocessor failure counter. */ std::atomic fail_counter{0}; }; @@ -397,7 +397,7 @@ inline int GetTargetNumCores(const Target& target) { * \return The median of the running time in millisecond */ inline double GetRunMsMedian(const RunnerResult& runner_result) { - Array run_secs = runner_result->run_secs.value(); + ffi::Array run_secs = runner_result->run_secs.value(); ICHECK(!run_secs.empty()); std::vector v; v.reserve(run_secs.size()); @@ -417,10 +417,10 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) { * \param obj The object to be converted * \return The array of floating point numbers */ -inline Array AsFloatArray(const ObjectRef& obj) { +inline ffi::Array AsFloatArray(const ObjectRef& obj) { const ffi::ArrayObj* arr = obj.as(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); - Array results; + ffi::Array results; results.reserve(arr->size()); for (Any val : *arr) { auto float_value = [&]() -> FloatImm { @@ -444,10 +444,10 @@ inline Array AsFloatArray(const ObjectRef& obj) { * \param obj The object to be converted * \return The array of integers */ -inline Array AsIntArray(const ObjectRef& obj) { +inline ffi::Array AsIntArray(const ObjectRef& obj) { const ffi::ArrayObj* arr = obj.as(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); - Array results; + ffi::Array results; results.reserve(arr->size()); for (Any val : *arr) { auto int_value = [&]() -> int64_t { @@ -467,7 +467,7 @@ inline Array AsIntArray(const ObjectRef& obj) { struct SortTuningRecordByMeanRunSecs { static const constexpr double kMaxMeanTime = 1e10; - static double Mean(const Array& a) { + static double Mean(const ffi::Array& a) { if (a.empty()) { return kMaxMeanTime; } @@ -492,8 +492,8 @@ struct SortTuningRecordByMeanRunSecs { */ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { if (src->sch_rules.defined()) { - Array original = src->sch_rules.value(); - Array sch_rules; + ffi::Array original = src->sch_rules.value(); + ffi::Array sch_rules; sch_rules.reserve(original.size()); for (const ScheduleRule& sch_rule : original) { sch_rules.push_back(sch_rule->Clone()); @@ -501,8 +501,8 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { dst->sch_rules = std::move(sch_rules); } if (src->postprocs.defined()) { - Array original = src->postprocs.value(); - Array postprocs; + ffi::Array original = src->postprocs.value(); + ffi::Array postprocs; postprocs.reserve(original.size()); for (const Postproc& postproc : original) { postprocs.push_back(postproc->Clone()); @@ -510,8 +510,8 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { dst->postprocs = std::move(postprocs); } if (src->mutator_probs.defined()) { - Map original = src->mutator_probs.value(); - Map mutator_probs; + ffi::Map original = src->mutator_probs.value(); + ffi::Map mutator_probs; for (const auto& kv : original) { mutator_probs.Set(kv.first->Clone(), kv.second); } @@ -532,7 +532,7 @@ inline bool IsGPUTarget(const std::string& target_name) { * \return The AutoInline schedule rule for the given target. */ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { - Array rules{nullptr}; + ffi::Array rules{nullptr}; if (target_name == "llvm") { rules = ScheduleRule::DefaultLLVM(); } else if (target_name == "hexagon") { @@ -557,7 +557,7 @@ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { * \param arr The array of FloatImm. * \return The summary of the values in the given array. */ -inline double Sum(const Array& arr) { +inline double Sum(const ffi::Array& arr) { double sum = 0; for (const FloatImm& f : arr) { sum += f->value; @@ -568,21 +568,21 @@ inline double Sum(const Array& arr) { /*! \brief Collecting all the blocks */ class BlockCollector : public tir::StmtVisitor { public: - static Array Collect(const tir::Schedule& sch, - const ffi::Function f_block_filter = nullptr) { // + static ffi::Array Collect(const tir::Schedule& sch, + const ffi::Function f_block_filter = nullptr) { // return BlockCollector(sch, f_block_filter).Run(); } private: /*! \brief Entry point */ - Array Run() { + ffi::Array Run() { std::vector results; - auto f_collect = [this, &results](tir::PrimFunc func, String func_name) { + auto f_collect = [this, &results](tir::PrimFunc func, ffi::String func_name) { func_name_ = func_name; block_names_.clear(); blocks_to_collect_.clear(); VisitStmt(func->body); - for (const String& name : blocks_to_collect_) { + for (const ffi::String& name : blocks_to_collect_) { results.push_back(sch_->GetBlock(name, func_name_)); } }; @@ -596,7 +596,7 @@ class BlockCollector : public tir::StmtVisitor { // `gv->name_hint` is the name of the function // `base_func` can be PrimFunc or relax::Function if (const auto* func = base_func.as()) { - f_collect(GetRef(func), gv->name_hint); + f_collect(ffi::GetRef(func), gv->name_hint); } } } @@ -617,7 +617,7 @@ class BlockCollector : public tir::StmtVisitor { // Otherwise collect all blocks. Bool collect_block = Bool(true); if (f_block_filter_ != nullptr) { - collect_block = f_block_filter_(GetRef(block)).cast(); + collect_block = f_block_filter_(ffi::GetRef(block)).cast(); } if (collect_block) { blocks_to_collect_.push_back(block->name_hint); @@ -629,15 +629,15 @@ class BlockCollector : public tir::StmtVisitor { /*! \brief An optional packed func that allows only certain blocks to be collected. */ const ffi::Function f_block_filter_; /*! \brief The set of func name and block name pair */ - std::unordered_set block_names_; + std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ - Array blocks_to_collect_; + ffi::Array blocks_to_collect_; /*! \brief Name of the current PrimFunc */ - String func_name_; + ffi::String func_name_; }; -void JSONFileAppendLine(const String& path, const std::string& line); -std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing); +void JSONFileAppendLine(const ffi::String& path, const std::string& line); +std::vector JSONFileReadLines(const ffi::String& path, int num_threads, bool allow_missing); } // namespace meta_schedule } // namespace tvm diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index 334c15b3be97..fee7eeb26cab 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -50,7 +50,7 @@ class AttrRegistry { * \param name The name of the item. * \return The corresponding entry. */ - const EntryType* Get(const String& name) const { + const EntryType* Get(const ffi::String& name) const { auto it = entry_map_.find(name); if (it != entry_map_.end()) return it->second; return nullptr; @@ -61,7 +61,7 @@ class AttrRegistry { * \param name The name of the item. * \return The corresponding entry. */ - EntryType& RegisterOrGet(const String& name) { + EntryType& RegisterOrGet(const ffi::String& name) { auto it = entry_map_.find(name); if (it != entry_map_.end()) return *it->second; uint32_t registry_index = static_cast(entries_.size()); @@ -77,8 +77,8 @@ class AttrRegistry { * \brief List all the entry names in the registry. * \return The entry names. */ - Array ListAllNames() const { - Array names; + ffi::Array ListAllNames() const { + ffi::Array names; for (const auto& kv : entry_map_) { names.push_back(kv.first); } @@ -92,7 +92,7 @@ class AttrRegistry { * \param value The value to be set. * \param plevel The support level. */ - void UpdateAttr(const String& attr_name, const KeyType& key, Any value, int plevel) { + void UpdateAttr(const ffi::String& attr_name, const KeyType& key, Any value, int plevel) { using ffi::Any; auto& op_map = attrs_[attr_name]; if (op_map == nullptr) { @@ -119,7 +119,7 @@ class AttrRegistry { * \param attr_name The name of the attribute. * \param key The key to the attribute table. */ - void ResetAttr(const String& attr_name, const KeyType& key) { + void ResetAttr(const ffi::String& attr_name, const KeyType& key) { auto& op_map = attrs_[attr_name]; if (op_map == nullptr) { return; @@ -135,7 +135,7 @@ class AttrRegistry { * \param attr_name The name of the attribute. * \return The result attribute map. */ - const AttrRegistryMapContainerMap& GetAttrMap(const String& attr_name) { + const AttrRegistryMapContainerMap& GetAttrMap(const ffi::String& attr_name) { auto it = attrs_.find(attr_name); if (it == attrs_.end()) { LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered"; @@ -148,7 +148,7 @@ class AttrRegistry { * \param attr_name The name of the attribute. * \return The check result. */ - bool HasAttrMap(const String& attr_name) { return attrs_.count(attr_name); } + bool HasAttrMap(const ffi::String& attr_name) { return attrs_.count(attr_name); } /*! * \return a global singleton of the registry. @@ -162,9 +162,9 @@ class AttrRegistry { // entries in the registry std::vector> entries_; // map from name to entries. - std::unordered_map entry_map_; + std::unordered_map entry_map_; // storage of additional attribute table. - std::unordered_map>> attrs_; + std::unordered_map>> attrs_; }; } // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index e666b434f8f5..2565a02b64a5 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -38,12 +38,12 @@ using ffi::PackedArgs; // key1, value1, ..., key_n, value_n void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { // TODO(tvm-team): consider further simplify by removing DictAttrsNode special handling - String type_key = args[0].cast(); + ffi::String type_key = args[0].cast(); int32_t type_index; TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); if (type_index == DictAttrsNode::RuntimeTypeIndex()) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->InitByPackedArgs(args.Slice(1), false); *rv = ObjectRef(attrs); } else { @@ -52,9 +52,9 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("node.MakeNode", MakeNode); -}); +} } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 04a6f7533a19..b60583c6ab85 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -127,12 +127,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << Downcast(node); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.AsRepr", [](ffi::Any obj) { std::ostringstream os; os << obj; return os.str(); }); -}); +} } // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 9b1565d2ab3a..36c61d78b345 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -28,14 +28,15 @@ namespace tvm { using AccessPath = ffi::reflection::AccessPath; -TVM_FFI_STATIC_INIT_BLOCK({ PrinterConfigNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PrinterConfigNode::RegisterReflection(); } TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { static FType inst; return inst; } -std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional& cfg) { +std::string TVMScriptPrinter::Script(const ObjectRef& node, + const ffi::Optional& cfg) { if (!TVMScriptPrinter::vtable().can_dispatch(node)) { std::ostringstream os; ReprPrinter printer(os); @@ -59,34 +60,34 @@ bool IsIdentifier(const std::string& name) { [](char c) { return std::isalnum(c) || c == '_'; }); } -PrinterConfig::PrinterConfig(Map config_dict) { - runtime::ObjectPtr n = make_object(); +PrinterConfig::PrinterConfig(ffi::Map config_dict) { + runtime::ObjectPtr n = ffi::make_object(); if (auto v = config_dict.Get("name")) { - n->binding_names.push_back(Downcast(v.value())); + n->binding_names.push_back(Downcast(v.value())); } if (auto v = config_dict.Get("show_meta")) { n->show_meta = v.value().cast(); } if (auto v = config_dict.Get("ir_prefix")) { - n->ir_prefix = Downcast(v.value()); + n->ir_prefix = Downcast(v.value()); } if (auto v = config_dict.Get("tir_prefix")) { - n->tir_prefix = Downcast(v.value()); + n->tir_prefix = Downcast(v.value()); } if (auto v = config_dict.Get("relax_prefix")) { - n->relax_prefix = Downcast(v.value()); + n->relax_prefix = Downcast(v.value()); } if (auto v = config_dict.Get("module_alias")) { - n->module_alias = Downcast(v.value()); + n->module_alias = Downcast(v.value()); } if (auto v = config_dict.Get("buffer_dtype")) { - n->buffer_dtype = DataType(StringToDLDataType(Downcast(v.value()))); + n->buffer_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("int_dtype")) { - n->int_dtype = DataType(StringToDLDataType(Downcast(v.value()))); + n->int_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("float_dtype")) { - n->float_dtype = DataType(StringToDLDataType(Downcast(v.value()))); + n->float_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("verbose_expr")) { n->verbose_expr = v.value().cast(); @@ -101,18 +102,20 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->num_context_lines = v.value().cast(); } if (auto v = config_dict.Get("path_to_underline")) { - n->path_to_underline = Downcast>>(v).value_or(Array()); + n->path_to_underline = + Downcast>>(v).value_or(ffi::Array()); } if (auto v = config_dict.Get("path_to_annotate")) { - n->path_to_annotate = - Downcast>>(v).value_or(Map()); + n->path_to_annotate = Downcast>>(v).value_or( + ffi::Map()); } if (auto v = config_dict.Get("obj_to_underline")) { - n->obj_to_underline = Downcast>>(v).value_or(Array()); + n->obj_to_underline = + Downcast>>(v).value_or(ffi::Array()); } if (auto v = config_dict.Get("obj_to_annotate")) { - n->obj_to_annotate = - Downcast>>(v).value_or(Map()); + n->obj_to_annotate = Downcast>>(v).value_or( + ffi::Map()); } if (auto v = config_dict.Get("syntax_sugar")) { n->syntax_sugar = v.value().cast(); @@ -134,20 +137,20 @@ PrinterConfig::PrinterConfig(Map config_dict) { this->data_ = std::move(n); } -Array PrinterConfigNode::GetBuiltinKeywords() { - Array result{this->ir_prefix, this->tir_prefix, this->relax_prefix}; +ffi::Array PrinterConfigNode::GetBuiltinKeywords() { + ffi::Array result{this->ir_prefix, this->tir_prefix, this->relax_prefix}; if (!this->module_alias.empty()) { result.push_back(this->module_alias); } return result; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("node.PrinterConfig", - [](Map config_dict) { return PrinterConfig(config_dict); }) + [](ffi::Map config_dict) { return PrinterConfig(config_dict); }) .def("node.TVMScriptPrinterScript", TVMScriptPrinter::Script); -}); +} } // namespace tvm diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 09e364bb8ee4..2faf8d170bd8 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -40,8 +40,8 @@ Any LoadJSON(std::string json_str) { return ffi::FromJSONGraph(jgraph); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.SaveJSON", SaveJSON).def("node.LoadJSON", LoadJSON); -}); +} } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index be009a77c305..e33d7c774687 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -73,12 +73,12 @@ bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool assert_mode } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("node.StructuralEqual", NodeStructuralEqualAdapter) .def("node.GetFirstStructuralMismatch", ffi::StructuralEqual::GetFirstMismatch); -}); +} bool StructuralEqual::operator()(const ffi::Any& lhs, const ffi::Any& rhs, bool map_free_params) const { diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 41a22e4d39d8..aa02d097e966 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -41,7 +41,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.StructuralHash", [](const Any& object, bool map_free_vars) -> int64_t { @@ -50,35 +50,35 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::TypeAttrDef() .def("__data_to_json__", [](const ffi::ModuleObj* node) { - std::string bytes = codegen::SerializeModuleToBytes(GetRef(node), + std::string bytes = codegen::SerializeModuleToBytes(ffi::GetRef(node), /*export_dso*/ false); return ffi::Base64Encode(ffi::Bytes(bytes)); }) - .def("__data_from_json__", [](const String& base64_bytes) { - Bytes bytes = ffi::Base64Decode(base64_bytes); + .def("__data_from_json__", [](const ffi::String& base64_bytes) { + ffi::Bytes bytes = ffi::Base64Decode(base64_bytes); ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string()); return rtmod; }); - refl::TypeAttrDef() + refl::TypeAttrDef() .def("__data_to_json__", - [](const runtime::NDArray::Container* node) { + [](const runtime::Tensor::Container* node) { std::string blob; dmlc::MemoryStringStream mstrm(&blob); support::Base64OutStream b64strm(&mstrm); runtime::SaveDLTensor(&b64strm, node); b64strm.Finish(); - return String(blob); + return ffi::String(blob); }) .def("__data_from_json__", [](const std::string& blob) { dmlc::MemoryStringStream mstrm(const_cast(&blob)); support::Base64InStream b64strm(&mstrm); b64strm.InitPosition(); - runtime::NDArray temp; + runtime::Tensor temp; ICHECK(temp.Load(&b64strm)); return temp; }); -}); +} uint64_t StructuralHash::operator()(const ffi::Any& object) const { return ffi::StructuralHash::Hash(object, false); @@ -100,7 +100,7 @@ struct ReportNodeTrait { } }; -TVM_FFI_STATIC_INIT_BLOCK({ ReportNodeTrait::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ReportNodeTrait::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -116,7 +116,7 @@ struct CountNodeTrait { } }; -TVM_FFI_STATIC_INIT_BLOCK({ CountNodeTrait::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { CountNodeTrait::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -132,7 +132,7 @@ struct DurationNodeTrait { } }; -TVM_FFI_STATIC_INIT_BLOCK({ DurationNodeTrait::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { DurationNodeTrait::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -148,7 +148,7 @@ struct PercentNodeTrait { } }; -TVM_FFI_STATIC_INIT_BLOCK({ PercentNodeTrait::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PercentNodeTrait::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -164,7 +164,7 @@ struct RatioNodeTrait { } }; -TVM_FFI_STATIC_INIT_BLOCK({ RatioNodeTrait::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RatioNodeTrait::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index f1f47910f8b1..a61d548443a3 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -46,9 +46,9 @@ struct InsertionSet { class VarVisitor : protected ExprVisitor { public: - Array Free(const Expr& expr) { + ffi::Array Free(const Expr& expr) { this->VisitExpr(expr); - Array ret; + ffi::Array ret; for (const auto& v : vars_.data) { if (bound_vars_.set.count(v) == 0) { ret.push_back(v); @@ -57,31 +57,31 @@ class VarVisitor : protected ExprVisitor { return ret; } - Array Collect() { - Array ret; + ffi::Array Collect() { + ffi::Array ret; for (const auto& v : bound_vars_.data) { ret.push_back(v); } return ret; } - Array Bound(const Expr& expr) { + ffi::Array Bound(const Expr& expr) { this->VisitExpr(expr); return Collect(); } - Array All(const Expr& expr) { + ffi::Array All(const Expr& expr) { this->VisitExpr(expr); - Array ret; + ffi::Array ret; for (const auto& v : vars_.data) { ret.push_back(v); } return ret; } - Array AllGlobalVars(const Expr& expr) { + ffi::Array AllGlobalVars(const Expr& expr) { this->VisitExpr(expr); - Array ret; + ffi::Array ret; for (const auto& v : global_vars_.data) { ret.push_back(v); } @@ -93,7 +93,7 @@ class VarVisitor : protected ExprVisitor { vars_.Insert(v); } - void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } + void VisitExpr_(const VarNode* var) final { vars_.Insert(ffi::GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { @@ -102,7 +102,9 @@ class VarVisitor : protected ExprVisitor { VisitExpr(op->body); } - void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef(op)); } + void VisitExpr_(const GlobalVarNode* op) final { + global_vars_.Insert(ffi::GetRef(op)); + } void VisitExpr_(const CallNode* call_node) final { VisitSpan(call_node->span); @@ -134,25 +136,27 @@ class VarVisitor : protected ExprVisitor { InsertionSet global_vars_; }; -tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } +tvm::ffi::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } -tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } +tvm::ffi::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } -tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } +tvm::ffi::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } +tvm::ffi::Array AllGlobalVars(const Expr& expr) { + return VarVisitor().AllGlobalVars(expr); +} -Optional FindImpureCall(const Expr& expr, const Optional& own_name) { +ffi::Optional FindImpureCall(const Expr& expr, const ffi::Optional& own_name) { class ImpureCallChecker : public ExprVisitor { public: - static Optional Check(const Expr& expr, const Optional& own_name) { + static ffi::Optional Check(const Expr& expr, const ffi::Optional& own_name) { ImpureCallChecker visitor(own_name); visitor.VisitExpr(expr); return visitor.impure_expr_; } private: - explicit ImpureCallChecker(const Optional& own_name) : own_name_(own_name) {} + explicit ImpureCallChecker(const ffi::Optional& own_name) : own_name_(own_name) {} void VisitExpr(const Expr& expr) override { // Early bail-out if we found an impure expression @@ -169,7 +173,7 @@ Optional FindImpureCall(const Expr& expr, const Optional& own_name) void VisitExpr_(const CallNode* call) override { // ignore recursive calls if we find one bool is_recursive = (own_name_ && own_name_.value().same_as(call->op)); - auto expr = GetRef(call); + auto expr = ffi::GetRef(call); if (!is_recursive && IsImpureCall(expr)) { impure_expr_ = expr; } else { @@ -178,8 +182,8 @@ Optional FindImpureCall(const Expr& expr, const Optional& own_name) } private: - const Optional& own_name_; - Optional impure_expr_ = std::nullopt; + const ffi::Optional& own_name_; + ffi::Optional impure_expr_ = std::nullopt; }; if (own_name) { @@ -194,11 +198,11 @@ Optional FindImpureCall(const Expr& expr, const Optional& own_name) return ImpureCallChecker::Check(to_check, own_name); } -bool ContainsImpureCall(const Expr& expr, const Optional& own_name) { +bool ContainsImpureCall(const Expr& expr, const ffi::Optional& own_name) { return FindImpureCall(expr, own_name).defined(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.analysis.free_vars", FreeVars) @@ -206,7 +210,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("relax.analysis.all_vars", AllVars) .def("relax.analysis.all_global_vars", AllGlobalVars) .def("relax.analysis.contains_impure_call", ContainsImpureCall); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/collect_call_map.cc b/src/relax/analysis/collect_call_map.cc index 3e0170d3444d..85099d88ff57 100644 --- a/src/relax/analysis/collect_call_map.cc +++ b/src/relax/analysis/collect_call_map.cc @@ -38,7 +38,9 @@ using ir::CalleeCollector; struct Visitor : ExprVisitor { explicit Visitor(CalleeCollector* collector) : collector(collector) {} CalleeCollector* collector; - void VisitExpr_(const GlobalVarNode* node) override { collector->Mark(GetRef(node)); } + void VisitExpr_(const GlobalVarNode* node) override { + collector->Mark(ffi::GetRef(node)); + } }; } // namespace diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 8b8665445d98..954240c19189 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -35,10 +35,10 @@ namespace relax { namespace { class CompileTimeCollector : ExprVisitor { public: - static Array Collect(const Function& func) { + static ffi::Array Collect(const Function& func) { CompileTimeCollector visitor; visitor(func); - return Array(visitor.known_relax_vars_.begin(), visitor.known_relax_vars_.end()); + return ffi::Array(visitor.known_relax_vars_.begin(), visitor.known_relax_vars_.end()); } private: @@ -89,14 +89,14 @@ class CompileTimeCollector : ExprVisitor { }; } // namespace -Array ComputableAtCompileTime(const Function& func) { +ffi::Array ComputableAtCompileTime(const Function& func) { return CompileTimeCollector::Collect(func); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.computable_at_compile_time", ComputableAtCompileTime); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/detect_recursion.cc b/src/relax/analysis/detect_recursion.cc index 73ad8a31f8a5..7b2a5f516e92 100644 --- a/src/relax/analysis/detect_recursion.cc +++ b/src/relax/analysis/detect_recursion.cc @@ -87,7 +87,7 @@ class DependencyGatherer : public ExprVisitor { void VisitExpr_(const GlobalVarNode* gv) override { // disregard PrimFuncs - if (!m_->Lookup(GetRef(gv)).as()) { + if (!m_->Lookup(ffi::GetRef(gv)).as()) { return; } deps_.insert(gv->name_hint); @@ -111,7 +111,7 @@ adjacency_map GatherDependencyGraph(const IRModule& m) { continue; } std::string name = gv_func.first->name_hint; - auto deps = DependencyGatherer(m).Track(GetRef(func)); + auto deps = DependencyGatherer(m).Track(ffi::GetRef(func)); ret.insert({name, deps}); } return ret; @@ -369,7 +369,7 @@ std::vector CoalesceCircuits(const std::vector& circuits) { return ret; } -tvm::Array> DetectRecursion(const IRModule& m) { +tvm::ffi::Array> DetectRecursion(const IRModule& m) { auto graph = GatherDependencyGraph(m); // have to decide on some ordering for names @@ -382,9 +382,9 @@ tvm::Array> DetectRecursion(const IRModule& m) { auto groups = CoalesceCircuits(DetectElementaryCircuits(indices)); // convert to expected representation - tvm::Array> ret; + tvm::ffi::Array> ret; for (auto group : groups) { - tvm::Array found; + tvm::ffi::Array found; for (size_t node : group) { found.push_back(m->GetGlobalVar(name_ordering[node])); } @@ -393,10 +393,10 @@ tvm::Array> DetectRecursion(const IRModule& m) { return ret; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.detect_recursion", DetectRecursion); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/graph_partitioner.cc b/src/relax/analysis/graph_partitioner.cc index 00f4da400657..d68626160fe9 100644 --- a/src/relax/analysis/graph_partitioner.cc +++ b/src/relax/analysis/graph_partitioner.cc @@ -252,11 +252,11 @@ size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src, } return 0; }; - if (auto call_node = GetRef(src->ref).as()) { + if (auto call_node = ffi::GetRef(src->ref).as()) { for (auto& it : call_node->args) { sum += calc_args_number(it); } - } else if (auto tuple_node = GetRef(src->ref).as()) { + } else if (auto tuple_node = ffi::GetRef(src->ref).as()) { for (auto& it : tuple_node->fields) { sum += calc_args_number(it); } @@ -288,19 +288,19 @@ size_t GraphPartitioner::CountFusedArgs(const IndexedForwardGraph& graph, void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { auto args_counter = [](const tvm::Object* obj) { size_t args_num = 0; - if (auto call_node = GetRef(obj).as()) { + if (auto call_node = ffi::GetRef(obj).as()) { for (auto& it : call_node->args) { if (it.as() || it.as()) { args_num++; } } - } else if (auto tuple_node = GetRef(obj).as()) { + } else if (auto tuple_node = ffi::GetRef(obj).as()) { for (auto& it : tuple_node->fields) { if (it.as() || it.as()) { args_num++; } } - } else if (GetRef(obj).as()) { + } else if (ffi::GetRef(obj).as()) { args_num++; } return args_num; diff --git a/src/relax/analysis/graph_partitioner.h b/src/relax/analysis/graph_partitioner.h index 3afb9888a162..09bf68734cc8 100644 --- a/src/relax/analysis/graph_partitioner.h +++ b/src/relax/analysis/graph_partitioner.h @@ -83,7 +83,7 @@ class IndexedForwardGraph { std::ostringstream os; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; - os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; + os << "node[" << i << "], " << ffi::GetRef(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; } @@ -194,7 +194,7 @@ class GraphPartitioner { size_t args_num{0}; /*! \brief Optional attributes to annotate the grouped function. */ - Map attrs; + ffi::Map attrs; /*! * \brief Find the group root, perform path compression * \return The root type node. diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index 109af127df2e..5bd5568a93a3 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -40,8 +40,8 @@ using namespace tir; /********** Helper Functions **********/ /*! \brief Checks if a transformation is bijective affine over the given ranges */ -static bool IsBijectiveAffine(const IndexMap& m, const Array& ranges) { - Map input_iters; +static bool IsBijectiveAffine(const IndexMap& m, const ffi::Array& ranges) { + ffi::Map input_iters; ICHECK_EQ(m->initial_indices.size(), ranges.size()); for (size_t i = 0; i < ranges.size(); i++) { input_iters.Set(m->initial_indices[i], ranges[i]); @@ -61,7 +61,7 @@ static bool IsBijectiveAffine(const IndexMap& m, const Array& ranges) { */ class IndexAnalyzer : public ExprVisitor { public: - Array Analyze(const arith::IterSumExpr& expr) { + ffi::Array Analyze(const arith::IterSumExpr& expr) { VisitExpr(expr); return iterators_; } @@ -86,14 +86,14 @@ class IndexAnalyzer : public ExprVisitor { void VisitIterMark(const arith::IterMark& op) { if (const auto* var = op->source.as()) - iterators_.push_back(GetRef(var)); + iterators_.push_back(ffi::GetRef(var)); else VisitExpr(op->source); VisitExpr(op->extent); } private: - Array iterators_; + ffi::Array iterators_; }; /*! @@ -111,13 +111,13 @@ class IndexAnalyzer : public ExprVisitor { * SpatialLayout(A[s0, constant, r0, s1]) = {s0, null, null, s1} * SpatialLayout(A[s0 * c + s1]) = undefined */ -using SpatialLayout = Array>; +using SpatialLayout = ffi::Array>; static SpatialLayout GetSpatialLayout(const arith::IterMapResult& iter_map_result) { ICHECK(!iter_map_result->indices.empty()); SpatialLayout result; for (const arith::IterSumExpr& index : iter_map_result->indices) { IndexAnalyzer index_analyzer; - Array iter_vars = index_analyzer.Analyze(index); + ffi::Array iter_vars = index_analyzer.Analyze(index); if (iter_vars.size() >= 2) { LOG(WARNING) << "[LayoutInference] Unable to get spatial layout of access: " << arith::NormalizeIterMapToExpr(index); @@ -173,7 +173,7 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { if (t0->final_indices.size() != t1->final_indices.size()) return false; // Create a new shape expression. - Array t1_initial_indices = + ffi::Array t1_initial_indices = t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; }); arith::Analyzer analyzer; auto t0_output = t0->MapIndices(t1_initial_indices, &analyzer); @@ -213,9 +213,9 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4) */ using VarSet = std::unordered_set; -static Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, - const IndexMap& src_transformation, - const SpatialLayout& tgt_spatial_layout) { +static ffi::Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, + const IndexMap& src_transformation, + const SpatialLayout& tgt_spatial_layout) { // Copy over the src transformation intial and final indices auto initial_indices = support::AsList(src_transformation->initial_indices); auto final_indices = support::AsList(src_transformation->final_indices); @@ -244,7 +244,7 @@ static Optional InferLayoutTransformation(const SpatialLayout& src_spa auto final_indices_it = final_indices.begin(); while (final_indices_it != final_indices.end()) { // Collect all the vars used in this final index. - Array used_vars = tir::UndefinedVars(*final_indices_it); + ffi::Array used_vars = tir::UndefinedVars(*final_indices_it); ICHECK(!used_vars.empty()) << "IndexMap expression must always contain tir::Var nodes but found none in: " << *final_indices_it; @@ -318,7 +318,7 @@ static Optional InferLayoutTransformation(const SpatialLayout& src_spa */ class BlockAnalyzer : public StmtExprVisitor { public: - explicit BlockAnalyzer(const Block& block, const Map& transformation_cache, + explicit BlockAnalyzer(const Block& block, const ffi::Map& transformation_cache, IndexMap write_transformation) : can_transform_block_(true), write_transformation_(write_transformation), @@ -380,7 +380,7 @@ class BlockAnalyzer : public StmtExprVisitor { } block_transformation_ = maybe_block_transformation.value(); - Array block_ranges = block_->iter_vars.Map([](const IterVar& i) { return i->dom; }); + ffi::Array block_ranges = block_->iter_vars.Map([](const IterVar& i) { return i->dom; }); if (!IsBijectiveAffine(block_transformation_, block_ranges)) { can_transform_block_ = false; LOG(WARNING) << "[LayoutInference] Inferred block transformation is not bijective affine, " @@ -437,7 +437,7 @@ class BlockAnalyzer : public StmtExprVisitor { }; // Helper to break down the indices of buffer access. - SpatialLayout DetectBufferAccessIterMap(Array indices) { + SpatialLayout DetectBufferAccessIterMap(ffi::Array indices) { auto result = arith::DetectIterMap( /*indices=*/indices, /*input_iters*/ spatial_dom_, /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck, &arith_analyzer_); @@ -516,19 +516,19 @@ class BlockAnalyzer : public StmtExprVisitor { public: bool CanBeTransformed() { return can_transform_block_; } IndexMap GetBlockTransformation() { return block_transformation_; } - Map GetReadBufferTransformations() { return read_buffer_transformations_; } + ffi::Map GetReadBufferTransformations() { return read_buffer_transformations_; } private: bool can_transform_block_; IndexMap write_transformation_; - Map spatial_dom_; + ffi::Map spatial_dom_; arith::Analyzer arith_analyzer_; Block block_; IndexMap block_transformation_; - Map read_buffer_transformations_; - const Map& buffer_transformation_cache_; + ffi::Map read_buffer_transformations_; + const ffi::Map& buffer_transformation_cache_; std::unordered_map buffer_access_info_; }; @@ -542,14 +542,14 @@ class BlockAnalyzer : public StmtExprVisitor { */ class PrimFuncAnalyzer : public StmtExprVisitor { public: - explicit PrimFuncAnalyzer(const PrimFunc& func, Array write_transformations) { + explicit PrimFuncAnalyzer(const PrimFunc& func, ffi::Array write_transformations) { ICHECK_LE(write_transformations.size(), func->params.size()) << "Incompatible PrimFunc and write_transformations"; size_t first_write_index = func->params.size() - write_transformations.size(); for (size_t i = 0; i < write_transformations.size(); ++i) { auto param = func->params[first_write_index + i]; - Optional param_buf = func->buffer_map.Get(param); + ffi::Optional param_buf = func->buffer_map.Get(param); ICHECK(param_buf.defined()); ICHECK_EQ(param_buf.value()->shape.size(), write_transformations[i]->initial_indices.size()) << "Mismatch between output buffer shape and index map"; @@ -557,10 +557,10 @@ class PrimFuncAnalyzer : public StmtExprVisitor { } VisitStmt(func->body); } - Map> GetSuggestedTransforms() { - Map> result; + ffi::Map> GetSuggestedTransforms() { + ffi::Map> result; for (const auto& [block, index_map] : block_transformations_) { - Map block_transformations; + ffi::Map block_transformations; block_transformations.Set(block, index_map); for (const auto& buffer : block_to_buffer_[block]) { block_transformations.Set(buffer, buffer_transformation_cache_[buffer]); @@ -578,7 +578,7 @@ class PrimFuncAnalyzer : public StmtExprVisitor { return; } - Block block = GetRef(op); + Block block = ffi::GetRef(op); // Get block write buffer transformation. if (block->writes.size() != 1) return; auto write_buffer = block->writes[0]->buffer; @@ -601,13 +601,13 @@ class PrimFuncAnalyzer : public StmtExprVisitor { } private: - Map buffer_transformation_cache_; - Map block_transformations_; - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; + ffi::Map buffer_transformation_cache_; + ffi::Map block_transformations_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; }; -Map> SuggestLayoutTransforms( - const PrimFunc& prim_func, Array write_buffer_transformations) { +ffi::Map> SuggestLayoutTransforms( + const PrimFunc& prim_func, ffi::Array write_buffer_transformations) { // No changes to the PrimFunc are required if no transformations on output buffers. if (write_buffer_transformations.empty()) return {}; @@ -615,13 +615,13 @@ Map> SuggestLayoutTransforms( return analyzer.GetSuggestedTransforms(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.suggest_layout_transforms", - [](PrimFunc fn, Array write_buffer_transformations) { + [](PrimFunc fn, ffi::Array write_buffer_transformations) { return SuggestLayoutTransforms(fn, write_buffer_transformations); }); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/shape_analysis.cc b/src/relax/analysis/shape_analysis.cc index 70ce5ac06e90..e2f624937773 100644 --- a/src/relax/analysis/shape_analysis.cc +++ b/src/relax/analysis/shape_analysis.cc @@ -29,7 +29,7 @@ namespace tvm { namespace relax { -bool CanProveShapeEqual(const Array& lhs, const Array& rhs, +bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Array& rhs, arith::Analyzer* ana) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 389fb003c6d3..3952b1ce4a6e 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -57,14 +57,14 @@ class StaticTypeDeriver : public StructInfoFunctor { // end-module: distributed Type VisitStructInfo_(const TupleStructInfoNode* op) final { - Array fields = + ffi::Array fields = op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); return TupleType(fields, op->span); } Type VisitStructInfo_(const FuncStructInfoNode* op) final { if (op->IsOpaque()) return PackedFuncType(op->span); - Array params = op->params.value().Map( + ffi::Array params = op->params.value().Map( [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); Type ret = this->VisitStructInfo(op->ret); return FuncType(params, ret, op->span); @@ -73,11 +73,11 @@ class StaticTypeDeriver : public StructInfoFunctor { Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.GetStaticType", [](const StructInfo& info) { return GetStaticType(info); }); -}); +} //-------------------------- // StructInfoFromType @@ -93,13 +93,13 @@ StructInfo StructInfoFromType(const Type& type) { } else if (const TensorTypeNode* tensor_type = type.as()) { return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); } else if (const TupleTypeNode* tuple_type = type.as()) { - Array fields; + ffi::Array fields; for (const Type& field : tuple_type->fields) { fields.push_back(StructInfoFromType(field)); } return TupleStructInfo(fields, type->span); } else if (const FuncTypeNode* func_type = type.as()) { - Array params = + ffi::Array params = func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); StructInfo ret = StructInfoFromType(func_type->ret_type); // TODO(relax-team): Maybe add purity into the type as well @@ -117,13 +117,14 @@ class WellDefinedEraser : public StructInfoMutator, public ExprMutatorBase, public tir::ExprMutator { public: - WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, - std::function(const Var& var)> f_var_map, arith::Analyzer* ana) + WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, + arith::Analyzer* ana) : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} StructInfo VisitStructInfo_(const PrimStructInfoNode* op) final { bool has_undefined = false; - Optional value; + ffi::Optional value; if (op->value.defined()) { std::swap(has_undefined_, has_undefined); @@ -134,7 +135,7 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return PrimStructInfo(value.value(), op->span); } @@ -145,7 +146,7 @@ class WellDefinedEraser : public StructInfoMutator, StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { bool has_undefined = false; - Optional> values; + ffi::Optional> values; if (op->values.defined()) { std::swap(has_undefined_, has_undefined); @@ -155,7 +156,7 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (values.same_as(op->values)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ShapeStructInfo(values.value(), op->span); } @@ -166,7 +167,7 @@ class WellDefinedEraser : public StructInfoMutator, StructInfo VisitStructInfo_(const TensorStructInfoNode* op) final { bool has_undefined = false; - Optional shape; + ffi::Optional shape; if (op->shape.defined()) { std::swap(has_undefined_, has_undefined); @@ -179,7 +180,7 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (shape.same_as(op->shape)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (shape.defined()) { return TensorStructInfo(shape.value(), op->dtype, vdev, op->span); @@ -197,7 +198,7 @@ class WellDefinedEraser : public StructInfoMutator, // // All the occuring symbolic variables are defined in parameters' // struct info annotations. So there is no needed to erase. - return GetRef(op); + return ffi::GetRef(op); } using relax::ExprMutatorBase::VisitExpr_; @@ -215,22 +216,22 @@ class WellDefinedEraser : public StructInfoMutator, } Expr VisitExpr_(const VarNode* var) final { - Optional ret; + ffi::Optional ret; if (f_var_map_ != nullptr) { - ret = f_var_map_(GetRef(var)); + ret = f_var_map_(ffi::GetRef(var)); } has_undefined_ = has_undefined_ || !ret.defined(); if (ret.defined()) { ICHECK(ret.as() || ret.as()) << "Only allow Expr in StructInfo to be ShapeExpr or Var"; } - return ret.value_or(GetRef(var)); + return ret.value_or(ffi::GetRef(var)); } PrimExpr VisitExpr_(const tir::VarNode* var) final { - Optional ret; + ffi::Optional ret; if (f_shape_var_map_ != nullptr) { - ret = f_shape_var_map_(GetRef(var)); + ret = f_shape_var_map_(ffi::GetRef(var)); } has_undefined_ = has_undefined_ || !ret.defined(); @@ -242,20 +243,21 @@ class WellDefinedEraser : public StructInfoMutator, ICHECK(value.dtype() == DataType::Int(64)) << "Can only provide i64 expressions in shape"; return value; } else { - return GetRef(var); + return ffi::GetRef(var); } } private: bool has_undefined_ = false; - std::function(const tir::Var& var)> f_shape_var_map_; - std::function(const Var& var)> f_var_map_; + std::function(const tir::Var& var)> f_shape_var_map_; + std::function(const Var& var)> f_var_map_; arith::Analyzer* ana_; }; StructInfo EraseToWellDefined( - const StructInfo& info, std::function(const tir::Var& var)> f_shape_var_map, - std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { + const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { if (ana == nullptr) { arith::Analyzer inst; return WellDefinedEraser(f_shape_var_map, f_var_map, &inst).VisitStructInfo(info); @@ -264,13 +266,13 @@ StructInfo EraseToWellDefined( } } -StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, - Map var_map, arith::Analyzer* ana) { - std::function(const tir::Var& var)> f_shape_var_map = nullptr; - std::function(const Var& var)> f_var_map = nullptr; +StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, + ffi::Map var_map, arith::Analyzer* ana) { + std::function(const tir::Var& var)> f_shape_var_map = nullptr; + std::function(const Var& var)> f_var_map = nullptr; if (!shape_var_map.empty()) { - f_shape_var_map = [&](const tir::Var& var) -> Optional { + f_shape_var_map = [&](const tir::Var& var) -> ffi::Optional { auto it = shape_var_map.find(var); if (it != shape_var_map.end()) return (*it).second; return std::nullopt; @@ -278,7 +280,7 @@ StructInfo EraseToWellDefined(const StructInfo& info, Map sh } if (!var_map.empty()) { - f_var_map = [&](const Var& var) -> Optional { + f_var_map = [&](const Var& var) -> ffi::Optional { auto it = var_map.find(var); if (it != var_map.end()) return (*it).second; return std::nullopt; @@ -288,14 +290,13 @@ StructInfo EraseToWellDefined(const StructInfo& info, Map sh return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.analysis.EraseToWellDefined", - [](const StructInfo& info, Map shape_var_map, Map var_map) { - return EraseToWellDefined(info, shape_var_map, var_map); - }); -}); + [](const StructInfo& info, ffi::Map shape_var_map, + ffi::Map var_map) { return EraseToWellDefined(info, shape_var_map, var_map); }); +} //-------------------------- // IsBaseOf @@ -472,7 +473,7 @@ class StructInfoBaseChecker // // Given we only do best effort checking in these cases, and such cases // are likely not a primary concern atm, we take this approach here. - if (struct_equal_(GetRef(lhs), other)) return BaseCheckResult::kPass; + if (struct_equal_(ffi::GetRef(lhs), other)) return BaseCheckResult::kPass; auto param_check = FuncParamsCheck(lhs->params.value(), rhs->params.value()); auto ret_check = this->VisitStructInfo(lhs->ret, rhs->ret); @@ -511,7 +512,8 @@ class StructInfoBaseChecker * \param rhs The right hand shape. * \return CheckResult. */ - virtual BaseCheckResult ShapeMatchCheck(const Array& lhs, const Array& rhs) { + virtual BaseCheckResult ShapeMatchCheck(const ffi::Array& lhs, + const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; BaseCheckResult ret = BaseCheckResult::kPass; @@ -546,8 +548,8 @@ class StructInfoBaseChecker * \param rhs The right hand params. * \return Check result. */ - virtual BaseCheckResult FuncParamsCheck(const Array& lhs, - const Array& rhs) { + virtual BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, + const ffi::Array& rhs) { auto res = ArrayCheck(lhs, rhs); // treat L1 failures in params checking as L2. if (res == BaseCheckResult::kFailL1) res = BaseCheckResult::kFailL2; @@ -578,7 +580,7 @@ class StructInfoBaseChecker * \param lhs The left operand. * \param rhs The right operand. */ - BaseCheckResult ArrayCheck(const Array& lhs, const Array& rhs) { + BaseCheckResult ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; BaseCheckResult ret = BaseCheckResult::kPass; @@ -601,24 +603,24 @@ BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& de } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.StructInfoBaseCheck", [](const StructInfo& base, const StructInfo& derived) -> int { return static_cast(StructInfoBaseCheck(base, derived)); }); -}); +} bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer* ana) { return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.StructInfoIsBaseOf", [](const StructInfo& base, const StructInfo& derived) { return IsBaseOf(base, derived); }); -}); +} class StructInfoBasePreconditionCollector : public StructInfoFunctor { @@ -789,7 +791,7 @@ class StructInfoBasePreconditionCollector } private: - PrimExpr ArrayCheck(const Array& lhs, const Array& rhs) { + PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return Bool(false); } @@ -801,7 +803,7 @@ class StructInfoBasePreconditionCollector return all_equal; } - PrimExpr ArrayCheck(const Array& lhs, const Array& rhs) { + PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return Bool(false); } @@ -877,8 +879,8 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { // Whether to populate map in params. bool populate_mapping_{true}; // for simplicity, we make these fields public so the user can access them. - Map shape_var_map_; - Map var_map_; + ffi::Map shape_var_map_; + ffi::Map var_map_; using StructInfoBaseChecker::ShapeMatchCheck; @@ -889,7 +891,7 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { } if (auto* ptr = param.as()) { - auto var = GetRef(ptr); + auto var = ffi::GetRef(ptr); auto it = shape_var_map_.find(var); // not populated if (it == shape_var_map_.end()) { @@ -916,7 +918,7 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { } if (auto* ptr = lhs.as()) { - auto var = GetRef(ptr); + auto var = ffi::GetRef(ptr); auto it = var_map_.find(var); // not populated if (it == var_map_.end()) { @@ -936,8 +938,8 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); } - BaseCheckResult FuncParamsCheck(const Array& lhs, - const Array& rhs) final { + BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, + const ffi::Array& rhs) final { // Set populate mapping to false // so we do not pick up symbolic vars in params with function type. // @@ -966,13 +968,13 @@ StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.DeriveCallRetStructInfo", [](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { return DeriveCallRetStructInfo(finfo, call, ctx); }); -}); +} //-------------------------- // UnifyToLCA @@ -990,7 +992,7 @@ class StructInfoLCAFinder // Object is based of everything, unify to object. StructInfo VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { - return GetRef(lhs); + return ffi::GetRef(lhs); } StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { @@ -1008,13 +1010,13 @@ class StructInfoLCAFinder if (!lhs->value.defined()) { // If the mismatch was due to extra information in the RHS, // prefer to avoid constructing a new object. - return GetRef(lhs); + return ffi::GetRef(lhs); } else { return PrimStructInfo(lhs->dtype, lhs->span); } } - return GetRef(lhs); + return ffi::GetRef(lhs); } StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { @@ -1026,13 +1028,13 @@ class StructInfoLCAFinder !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { // prefers return same when possible if (!lhs->values.defined() && lhs->ndim == ndim) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { return ShapeStructInfo(ndim, lhs->span); } } // equals to each other - return GetRef(lhs); + return ffi::GetRef(lhs); } StructInfo VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { @@ -1054,7 +1056,7 @@ class StructInfoLCAFinder // reuse lhs when possible if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim && (!lhs->vdevice.defined() || vdev.defined())) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { return TensorStructInfo(dtype, ndim, vdev, lhs->span); } @@ -1063,14 +1065,14 @@ class StructInfoLCAFinder if (lhs->dtype != dtype || (lhs->vdevice.defined() && !vdev.defined())) { return TensorStructInfo(lhs->shape.value(), dtype, vdev, lhs->span); } else { - return GetRef(lhs); + return ffi::GetRef(lhs); } } StructInfo VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) return ObjectStructInfo(lhs->span); - Optional> fields = UnifyArray(lhs->fields, rhs->fields); + ffi::Optional> fields = UnifyArray(lhs->fields, rhs->fields); // tuple length not the same. if (!fields.defined()) return ObjectStructInfo(lhs->span); @@ -1078,7 +1080,7 @@ class StructInfoLCAFinder if (!fields.same_as(lhs->fields)) { return TupleStructInfo(fields.value(), lhs->span); } else { - return GetRef(lhs); + return ffi::GetRef(lhs); } } @@ -1093,7 +1095,7 @@ class StructInfoLCAFinder if (lhs->IsOpaque()) { if (lhs->derive_func.defined()) { if (lhs->derive_func.same_as(rhs->derive_func)) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { // Create a new opaque with object return return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), purity, lhs->span); @@ -1101,7 +1103,7 @@ class StructInfoLCAFinder } else { // no derivation function, only depends on ret StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); - if (ret.same_as(lhs->ret)) return GetRef(lhs); + if (ret.same_as(lhs->ret)) return ffi::GetRef(lhs); return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); } } @@ -1128,15 +1130,15 @@ class StructInfoLCAFinder // // Given we only do best effort checking in these cases, and such cases // are likely not a primary concern atm, we take this approach here. - if (struct_equal_(GetRef(lhs), GetRef(rhs))) { - return GetRef(lhs); + if (struct_equal_(ffi::GetRef(lhs), ffi::GetRef(rhs))) { + return ffi::GetRef(lhs); } auto params = UnifyArray(lhs->params.value(), rhs->params.value()); auto ret = this->VisitStructInfo(lhs->ret, rhs->ret); if (params.same_as(lhs->params) && ret.same_as(lhs->ret)) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { // fail to unify the params if (!params.defined()) { @@ -1154,8 +1156,8 @@ class StructInfoLCAFinder StructuralEqual struct_equal_; // check arrays - Optional> UnifyArray(const Array& lhs, - const Array& rhs) { + ffi::Optional> UnifyArray(const ffi::Array& lhs, + const ffi::Array& rhs) { if (lhs.same_as(rhs)) return lhs; if (lhs.size() != rhs.size()) return std::nullopt; size_t index = 0; @@ -1172,12 +1174,12 @@ StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::An } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.analysis.StructInfoLCA", [](const StructInfo& lhs, const StructInfo& rhs) { return StructInfoLCA(lhs, rhs); }); -}); +} //-------------------------- // TIRVarsInStructInfo @@ -1189,9 +1191,9 @@ class TIRVarsDetector : public StructInfoVisitor { Definition, Usage, }; - TIRVarsDetector(VarType collection_type) : collection_type(collection_type) {} + explicit TIRVarsDetector(VarType collection_type) : collection_type(collection_type) {} - Array GetTIRVars() const { return tir_vars_; } + ffi::Array GetTIRVars() const { return tir_vars_; } private: void VisitPrimExpr(PrimExpr expr) { @@ -1208,7 +1210,7 @@ class TIRVarsDetector : public StructInfoVisitor { } } - void VisitShape(Array shape) { + void VisitShape(ffi::Array shape) { for (const PrimExpr& expr : shape) { VisitPrimExpr(expr); } @@ -1239,34 +1241,34 @@ class TIRVarsDetector : public StructInfoVisitor { } } - Array tir_vars_; + ffi::Array tir_vars_; std::unordered_set used_tir_vars_dedup_; VarType collection_type; }; -Array TIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo) { TIRVarsDetector detector(TIRVarsDetector::VarType::Usage); detector(sinfo); return detector.GetTIRVars(); } -Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { TIRVarsDetector detector(TIRVarsDetector::VarType::Definition); detector(sinfo); return detector.GetTIRVars(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.analysis.TIRVarsInStructInfo", TIRVarsInStructInfo) .def("relax.analysis.DefinableTIRVarsInStructInfo", DefinableTIRVarsInStructInfo); -}); +} class NonNegativeExpressionCollector : relax::StructInfoVisitor { public: - static Array Collect(const StructInfo& sinfo) { + static ffi::Array Collect(const StructInfo& sinfo) { NonNegativeExpressionCollector visitor; visitor(sinfo); return visitor.expressions_; @@ -1298,36 +1300,37 @@ class NonNegativeExpressionCollector : relax::StructInfoVisitor { } } - Array expressions_; + ffi::Array expressions_; std::unordered_set dedup_lookup_; }; -Array CollectNonNegativeExpressions(const StructInfo& sinfo) { +ffi::Array CollectNonNegativeExpressions(const StructInfo& sinfo) { return NonNegativeExpressionCollector::Collect(sinfo); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.CollectNonNegativeExpressions", CollectNonNegativeExpressions); -}); +} class SymbolicVarCollector : public relax::ExprVisitor, public relax::StructInfoVisitor, public tir::ExprVisitor { public: - static Array Free(const Expr& expr) { + static ffi::Array Free(const Expr& expr) { SymbolicVarCollector collector; collector.VisitExpr(expr); - Array ret{collector.free_symbolic_var_.begin(), collector.free_symbolic_var_.end()}; + ffi::Array ret{collector.free_symbolic_var_.begin(), + collector.free_symbolic_var_.end()}; return ret; } - static Array Defined(const Expr& expr) { + static ffi::Array Defined(const Expr& expr) { SymbolicVarCollector collector; collector.VisitExpr(expr); - Array ret{collector.defined_symbolic_var_.begin(), - collector.defined_symbolic_var_.end()}; + ffi::Array ret{collector.defined_symbolic_var_.begin(), + collector.defined_symbolic_var_.end()}; return ret; } @@ -1429,7 +1432,7 @@ class SymbolicVarCollector : public relax::ExprVisitor, } void VisitExpr_(const tir::VarNode* op) final { - tir::Var var = GetRef(op); + tir::Var var = ffi::GetRef(op); // default mode, check defined. if (defined_symbolic_var_.count(var) == 0) { free_symbolic_var_.insert(var); @@ -1452,17 +1455,17 @@ class SymbolicVarCollector : public relax::ExprVisitor, std::unordered_set free_symbolic_var_; }; -Array DefinedSymbolicVars(const Expr& expr) { +ffi::Array DefinedSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Defined(expr); } -Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } +ffi::Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.analysis.DefinedSymbolicVars", DefinedSymbolicVars) .def("relax.analysis.FreeSymbolicVars", FreeSymbolicVars); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index b6809c0f35bb..58c47529a103 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -35,7 +35,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { public: explicit PatternKindAnalyzer(const tir::PrimFunc& func) { for (const tir::Var& param : func->params) { - Optional param_buf = func->buffer_map.Get(param); + ffi::Optional param_buf = func->buffer_map.Get(param); if (param_buf.defined()) { param_buffers_.insert(param_buf.value()); } @@ -59,12 +59,12 @@ class PatternKindAnalyzer : public StmtExprVisitor { kind_ = kOpaque; return; } - store_ = GetRef(op); + store_ = ffi::GetRef(op); StmtVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode* op) final { - loads_.push_back(GetRef(op)); + loads_.push_back(ffi::GetRef(op)); ExprVisitor::VisitExpr_(op); } @@ -130,7 +130,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { // Step 4. Checking if the block contains reduce axis by looking into block iterators. bool has_reduction = false; - Array reduce_vars; + ffi::Array reduce_vars; for (const IterVar& it : op->iter_vars) { if (it->iter_type == tir::IterVarType::kCommReduce) { has_reduction = true; @@ -162,7 +162,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { /********** Helper Functions **********/ /*! \brief Checking if two arrays contains same elements. */ - static bool IsSameArray(const Array& lhs, const Array& rhs) { + static bool IsSameArray(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -293,8 +293,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { if (!lhs || !rhs) { return false; } - return IsAllowReusePattern(GetRef(store), GetRef(lhs)) && - IsAllowReusePattern(GetRef(store), GetRef(rhs)); + return IsAllowReusePattern(ffi::GetRef(store), + ffi::GetRef(lhs)) && + IsAllowReusePattern(ffi::GetRef(store), ffi::GetRef(rhs)); } } } @@ -308,7 +309,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { * A[i] = sum(B[i, j + k]) is not pure reduce * pooling is not pure reduce */ - static bool IsPureReducePattern(Array reduce_loops, Array indices) { + static bool IsPureReducePattern(ffi::Array reduce_loops, ffi::Array indices) { for (const PrimExpr& e : indices) { int id = -1; if (UsesVar(e, [&](const tir::VarNode* var) { @@ -333,9 +334,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { * \brief The BufferStore node in the current block. * \note We only support one BufferStore node in a block (usually generated by TE compute) */ - Optional store_; + ffi::Optional store_; /*! \brief The BufferLoad nodes in the current block. */ - Array loads_; + ffi::Array loads_; /*! \brief The result of op pattern. */ OpPatternKind kind_ = kElemWise; /*! \brief The buffers from function params. I.e. the input and output buffers. */ @@ -379,8 +380,8 @@ bool HasReshapePattern(const PrimFunc& func) { // binding values. The mapping will be used in the substitution of // the flattened buffer access index. const Block& block = block_realize->block; - const Array& block_iter = block->iter_vars; - const Array& iter_values = block_realize->iter_values; + const ffi::Array& block_iter = block->iter_vars; + const ffi::Array& iter_values = block_realize->iter_values; ICHECK_EQ(block_iter.size(), iter_values.size()); int n_iter = block_iter.size(); for (int i = 0; i < n_iter; ++i) { @@ -401,7 +402,7 @@ bool HasReshapePattern(const PrimFunc& func) { return; } - Map var_range; + ffi::Map var_range; for (const IterVar& v : block->iter_vars) { ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); var_range.Set(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); @@ -429,7 +430,7 @@ bool HasReshapePattern(const PrimFunc& func) { // This check requires at least one of the src/dst side is a trivial buffer // access (e.g., buf[ax0, ax1, ax2]). - auto f_calc_flattened_idx = [&](const Buffer& buffer, const Array& indices) { + auto f_calc_flattened_idx = [&](const Buffer& buffer, const ffi::Array& indices) { ICHECK_EQ(indices.size(), buffer->shape.size()); int ndim = indices.size(); PrimExpr idx = 0; @@ -447,7 +448,7 @@ bool HasReshapePattern(const PrimFunc& func) { }; auto f_is_trivial_indices = [block, this](const Buffer& buffer, - const Array& indices) { + const ffi::Array& indices) { if (indices.size() != block->iter_vars.size()) { return false; } @@ -462,7 +463,7 @@ bool HasReshapePattern(const PrimFunc& func) { return true; }; - Array nontrivial_indices{nullptr}; + ffi::Array nontrivial_indices{nullptr}; Buffer nontrivial_buffer{nullptr}; if (f_is_trivial_indices(dst_buffer_, buffer_store->indices)) { nontrivial_indices = buffer_load->indices; @@ -476,7 +477,7 @@ bool HasReshapePattern(const PrimFunc& func) { DataType dtype = !block->iter_vars.empty() ? block->iter_vars[0]->var->dtype : DataType::Int(64); tir::Var fused_var("fused", dtype); - Map inverse_indices_map; + ffi::Map inverse_indices_map; PrimExpr stride = IntImm(dtype, /*value=*/1); for (int i = static_cast(block->iter_vars.size()) - 1; i >= 0; --i) { inverse_indices_map.Set( @@ -487,7 +488,7 @@ bool HasReshapePattern(const PrimFunc& func) { PrimExpr flattened_idx = f_calc_flattened_idx(nontrivial_buffer, nontrivial_indices); flattened_idx = Substitute(std::move(flattened_idx), inverse_indices_map); - Array simplify_res = arith::IterMapSimplify( + ffi::Array simplify_res = arith::IterMapSimplify( /*indices=*/{flattened_idx}, /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, /*input_pred=*/Bool(true), @@ -519,7 +520,7 @@ bool HasReshapePattern(const PrimFunc& func) { arith::Analyzer ana_; }; - Array buffer_args; + ffi::Array buffer_args; for (const auto& param : func->params) { if (auto buffer = func->buffer_map.Get(param)) { buffer_args.push_back(buffer.value()); @@ -538,10 +539,10 @@ bool HasReshapePattern(const PrimFunc& func) { return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.has_reshape_pattern", HasReshapePattern); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 6ec8dcfb5769..bbdbb7b644ef 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -44,23 +44,23 @@ class UDChain : relax::ExprVisitor { UDChain visitor; visitor.VisitExpr(expr); - Array output(visitor.outputs.begin(), visitor.outputs.end()); + ffi::Array output(visitor.outputs.begin(), visitor.outputs.end()); - Map> use_def; + ffi::Map> use_def; for (const auto& [var, usage] : visitor.usage_map) { - use_def.Set(var, Array(usage.begin(), usage.end())); + use_def.Set(var, ffi::Array(usage.begin(), usage.end())); } return VarUsageInfo{visitor.bound_values, use_def, output}; } private: - Map bound_values; + ffi::Map bound_values; std::unordered_set forward_declarations; std::unordered_map> usage_map; support::OrderedSet outputs; - Optional cur_user_; + ffi::Optional cur_user_; void VisitBinding_(const VarBindingNode* binding) override { CHECK(!bound_values.count(binding->var)) @@ -89,7 +89,7 @@ class UDChain : relax::ExprVisitor { } } void VisitExpr_(const VarNode* op) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (cur_user_) { usage_map[var].insert(cur_user_.value()); @@ -109,20 +109,20 @@ class UDChain : relax::ExprVisitor { } }; -std::pair>, Array> FunctionUseDef(const Expr& fn) { +std::pair>, ffi::Array> FunctionUseDef(const Expr& fn) { auto usage = UDChain::Collect(fn); return {usage.downstream_usage, usage.outputs}; } -Map> DataflowBlockUseDef(const DataflowBlock& dfb) { - auto usage = UDChain::Collect(SeqExpr({dfb}, Tuple(Array()))); +ffi::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { + auto usage = UDChain::Collect(SeqExpr({dfb}, Tuple(ffi::Array()))); return usage.downstream_usage; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.udchain", DataflowBlockUseDef); -}); +} VarUsageInfo CollectVarUsage(const Expr& expr) { return UDChain::Collect(expr); } diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index 1f28ba9edbf7..17a439b408ff 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -26,7 +26,7 @@ namespace tvm { namespace relax { class Var2ValAnalysis : public relax::ExprVisitor { public: - Map var2value_; + ffi::Map var2value_; void VisitBinding_(const VarBindingNode* binding) override { var2value_.Set(binding->var, binding->value); // Recursively visit the value to handle local functions. @@ -34,64 +34,65 @@ class Var2ValAnalysis : public relax::ExprVisitor { } }; -Map AnalyzeVar2Value(const Expr& expr) { +ffi::Map AnalyzeVar2Value(const Expr& expr) { Var2ValAnalysis var2val_analysis; var2val_analysis.VisitExpr(expr); return std::move(var2val_analysis.var2value_); } -Map AnalyzeVar2Value(const DataflowBlock& dfb) { +ffi::Map AnalyzeVar2Value(const DataflowBlock& dfb) { Var2ValAnalysis var2val_analysis; var2val_analysis.VisitBindingBlock_(dfb.get()); return std::move(var2val_analysis.var2value_); } -Map AnalyzeVar2Value(const IRModule& m) { +ffi::Map AnalyzeVar2Value(const IRModule& m) { Var2ValAnalysis var2val_analysis; for (const auto& it : m->functions) { // visit relax.Function if (auto* n = it.second.as()) { - var2val_analysis.VisitExpr(GetRef(n)); + var2val_analysis.VisitExpr(ffi::GetRef(n)); } } return std::move(var2val_analysis.var2value_); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.get_var2val", [](const Function& f) { return AnalyzeVar2Value(f); }); -}); +} class Name2BindingAnalysis : public relax::ExprVisitor { public: // Map is not suitable for doing in-place update. // so we use standard container for internal usage. - std::map> name2bindings_; + std::map> name2bindings_; void VisitBinding_(const VarBindingNode* binding) override { const auto& vname = binding->var->name_hint(); - name2bindings_[vname].push_back(GetRef(binding)); + name2bindings_[vname].push_back(ffi::GetRef(binding)); } void VisitBinding_(const MatchCastNode* binding) override { const auto& vname = binding->var->name_hint(); - name2bindings_[vname].push_back(GetRef(binding)); + name2bindings_[vname].push_back(ffi::GetRef(binding)); } }; -Map> NameToBinding(const Function& fn) { +ffi::Map> NameToBinding(const Function& fn) { Name2BindingAnalysis analysis{}; analysis.VisitExpr_(fn.get()); - return Map>(std::make_move_iterator(analysis.name2bindings_.begin()), - std::make_move_iterator(analysis.name2bindings_.end())); + return ffi::Map>( + std::make_move_iterator(analysis.name2bindings_.begin()), + std::make_move_iterator(analysis.name2bindings_.end())); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.name_to_binding", NameToBinding); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index a1bc99ee75bf..0cfc9efad835 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -86,7 +86,7 @@ class WellFormedChecker : public relax::ExprVisitor, public relax::StructInfoVisitor, public tir::ExprVisitor { public: - static bool Check(Variant obj, bool check_struct_info) { + static bool Check(ffi::Variant obj, bool check_struct_info) { WellFormedChecker well_formed_checker = WellFormedChecker(obj.as(), check_struct_info); @@ -94,13 +94,13 @@ class WellFormedChecker : public relax::ExprVisitor, for (const auto& it : mod->functions) { // visit relax.Function if (auto* n = it.second.as()) { - Function func = GetRef(n); + Function func = ffi::GetRef(n); well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); well_formed_checker.VisitExpr(func); } } } else if (const auto* func = obj.as()) { - well_formed_checker.VisitExpr(GetRef(func)); + well_formed_checker.VisitExpr(ffi::GetRef(func)); } else { LOG(FATAL) << "Unreachable, " << "variant did not contain any of the allowed types"; @@ -109,7 +109,7 @@ class WellFormedChecker : public relax::ExprVisitor, } private: - WellFormedChecker(Optional mod, bool check_struct_info) + WellFormedChecker(ffi::Optional mod, bool check_struct_info) : mod_(std::move(mod)), check_struct_info_(check_struct_info), cur_visited_func_(nullptr) {} using relax::ExprVisitor::VisitExpr_; @@ -139,7 +139,7 @@ class WellFormedChecker : public relax::ExprVisitor, // to check again // check name in global var and gsymbol - Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (gsymbol.has_value() && gsymbol != var->name_hint) { Malformed(Diagnostic::Error(func->span) << "Name in GlobalVar is not equal to name in gsymbol: " << var @@ -155,18 +155,20 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr_(const GlobalVarNode* op) final { - GlobalVar var = GetRef(op); + GlobalVar var = ffi::GetRef(op); if (mod_.defined()) { if (!(mod_.value()->ContainGlobalVar(var->name_hint) && mod_.value()->GetGlobalVar(var->name_hint).same_as(var))) { - Malformed(Diagnostic::Error(var) << "GlobalVar " << GetRef(op) << " is not defined."); + Malformed(Diagnostic::Error(var) + << "GlobalVar " << ffi::GetRef(op) << " is not defined."); } } if (op->struct_info_.defined()) { if (!op->struct_info_->IsInstance()) { - Malformed(Diagnostic::Error(var) << "The struct_info_ of GlobalVar " << GetRef(op) - << " must be either FuncStructInfo."); + Malformed(Diagnostic::Error(var) + << "The struct_info_ of GlobalVar " << ffi::GetRef(op) + << " must be either FuncStructInfo."); } } @@ -198,21 +200,22 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) { - Malformed(Diagnostic::Error(var) << "Var " << GetRef(op) << " is not defined."); + Malformed(Diagnostic::Error(var) << "Var " << ffi::GetRef(op) << " is not defined."); } CheckStructInfo(op); } void VisitExpr_(const DataflowVarNode* op) final { - DataflowVar var = GetRef(op); + DataflowVar var = ffi::GetRef(op); if (!is_dataflow_) { Malformed(Diagnostic::Error(var) - << "DataflowVar " << GetRef(op) << " is used outside DataflowBlock."); + << "DataflowVar " << ffi::GetRef(op) << " is used outside DataflowBlock."); } if (dataflow_var_set_.count(var) == 0) { - Malformed(Diagnostic::Error(var) << "DataflowVar " << GetRef(op) << " is not defined."); + Malformed(Diagnostic::Error(var) + << "DataflowVar " << ffi::GetRef(op) << " is not defined."); } CheckStructInfo(op); } @@ -244,8 +247,8 @@ class WellFormedChecker : public relax::ExprVisitor, // ensure the purity attributes are valid if (op->GetAttr(relax::attr::kForcePure).value_or(false) && !op->is_pure) { Malformed(Diagnostic::Error(op->span) - << "Function " << GetRef(op) << " has true for " << relax::attr::kForcePure - << " but false for is_pure; " << relax::attr::kForcePure + << "Function " << ffi::GetRef(op) << " has true for " + << relax::attr::kForcePure << " but false for is_pure; " << relax::attr::kForcePure << " should be true only if is_pure is also true."); } @@ -318,7 +321,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(call); if (is_dataflow_ && check_struct_info_) { - if (auto impure = FindImpureCall(GetRef(call))) { + if (auto impure = FindImpureCall(ffi::GetRef(call))) { Malformed(Diagnostic::Error(call) << "Impure function call " << impure << " occurs within a dataflow block."); } @@ -331,8 +334,8 @@ class WellFormedChecker : public relax::ExprVisitor, if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) { auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); - Call before_normalize = GetRef(call); - Optional after_normalize = std::nullopt; + Call before_normalize = ffi::GetRef(call); + ffi::Optional after_normalize = std::nullopt; try { after_normalize = func_normalize(dummy_builder, before_normalize); } catch (std::exception& err) { @@ -355,7 +358,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (auto func_validate = op_map_validate_.get(call->op, nullptr); func_validate != nullptr) { try { - func_validate(GetRef(call)); + func_validate(ffi::GetRef(call)); } catch (std::exception& err) { Malformed(Diagnostic::Error(call) << "Operator-specific validation (FValidate) for " << call->op << " identified error: \n" @@ -369,13 +372,13 @@ class WellFormedChecker : public relax::ExprVisitor, // an expression that does not yet have `StructInfo`. auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); Call copied(call->op, call->args, call->attrs, call->sinfo_args); - Optional normalized = std::nullopt; + ffi::Optional normalized = std::nullopt; try { normalized = dummy_builder->Normalize(copied); } catch (std::exception& err) { Malformed(Diagnostic::Error(call) << "Each Relax expression must be able to have its StructInfo inferred. " - << "However, inferring the struct info of expression " << GetRef(call) + << "However, inferring the struct info of expression " << ffi::GetRef(call) << " resulted in the error: \n" << err.what()); } @@ -400,8 +403,9 @@ class WellFormedChecker : public relax::ExprVisitor, BaseCheckResult::kFailL1) { Malformed(Diagnostic::Error(call) << "All information in StructInfo annotations must be correct. " - << "However, while the expression " << GetRef(call) << " is annotated as " - << current_struct_info << ", the expression outputs " << inferred_struct_info); + << "However, while the expression " << ffi::GetRef(call) + << " is annotated as " << current_struct_info << ", the expression outputs " + << inferred_struct_info); } } } @@ -513,7 +517,7 @@ class WellFormedChecker : public relax::ExprVisitor, Malformed(Diagnostic::Error(var) << "DataflowVar " << var << " is defined outside DataflowBlock."); } - DataflowVar lv = GetRef(var); + DataflowVar lv = ffi::GetRef(var); if (dataflow_var_set_.count(lv) == 1) { Malformed(Diagnostic::Error(var) << "DataflowVar " << lv << " is defined more than once."); } @@ -523,7 +527,7 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitVarDef_(const VarNode* var) final { - Var gv = GetRef(var); + Var gv = ffi::GetRef(var); if (var_set_.count(gv) == 1) { Malformed(Diagnostic::Error(var) << "Var " << gv << " is defined more than once."); } @@ -533,7 +537,7 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr_(const tir::VarNode* op) final { - tir::Var var = GetRef(op); + tir::Var var = ffi::GetRef(op); // default mode, check defined. if (symbolic_var_set_.count(var) == 0) { this->Malformed(Diagnostic::Error(var) << "Symbolic Var " << var << " is not defined."); @@ -571,7 +575,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence if (auto* op = expr.as()) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (var_set_.count(var) == 0) { var_set_.insert(var); } @@ -590,7 +594,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence if (auto* op = expr.as()) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (symbolic_var_set_.count(var) == 0) { symbolic_var_set_.insert(var); } @@ -607,7 +611,7 @@ class WellFormedChecker : public relax::ExprVisitor, auto* sinfo = op->struct_info_.as(); if (sinfo != nullptr) { - this->VisitStructInfo(GetRef(sinfo)); + this->VisitStructInfo(ffi::GetRef(sinfo)); } else { Malformed(Diagnostic::Error(op) << "Expr must have struct_info populated. " << " Expr.type_key=" << op->GetTypeKey()); @@ -622,7 +626,7 @@ class WellFormedChecker : public relax::ExprVisitor, std::swap(mode_, mode); } - Optional mod_; + ffi::Optional mod_; const bool check_struct_info_; bool well_formed_ = true; bool is_dataflow_; @@ -642,14 +646,14 @@ class WellFormedChecker : public relax::ExprVisitor, tvm::OpAttrMap op_map_validate_ = Op::GetAttrMap("FValidate"); }; -bool WellFormed(Variant obj, bool check_struct_info) { +bool WellFormed(ffi::Variant obj, bool check_struct_info) { return WellFormedChecker::Check(obj, check_struct_info); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.well_formed", WellFormed); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc new file mode 100644 index 000000000000..887b81872940 --- /dev/null +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -0,0 +1,755 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/adreno/annotate_texture_storage.cc + * \brief Texture Storage Annotation Pass for Adreno GPU targets. + * + * Texture realization for Adreno GPU targets requires fundamentally follows + * Stage 1: Transforming the shapes with inner most dimension being 4 + * Stage 2: Annotate appropriate memory_scope hint in VDevice of StructInfo + * Stage 3: TIR lowering does injects texture load/store builtins looking at this scope + * Stage 4: Finally codegen handles appropriate code looking at buffer types and load/store + * builtins. + * + * Stage 1 is generic and straight forward by using convert_layout pass that transforms the + * shapes as well as injecting layout_transform ops as needed. + * + * Stage 2 This pass is responsible for injeting appropriate VDevice into StructInfo and + * adding any copies if there is a conflict between producer and consuner scopes. + * + * After convert_layout the mod looks like below + * @I.ir_module + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv: R.Tensor((2, 16, 56, 56, 4), dtype="float32") = R.layout_transform( + * x, + * index_map=T.index_map( + * lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4))) + * lv1: R.Tensor((8, 64, 3, 3, 4), dtype="float32") = R.layout_transform( + * w, + * index_map=T.index_map( + * lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4))) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d( + * lv, + * lv1, + * data_layout="NCHW4c", + * kernel_layout="OIHW4o", + * out_layout="NCHW4c", + * out_dtype="float32" + * ) + * gv: R.Tensor((2, 32, 54, 54), dtype="float32") = R.layout_transform( + * lv2, + * index_map=T.index_map( + * lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3))) + * R.output(gv) + * return gv + * + * Here, the param layout transforms are injected properly and the conv2d op is operating + * in 5D shapes. + * + * Now, the scope annotation decisions are done by + * - For op_pattern < kCommReduce we just look for shape being 5D and inner dimsion = 4 + * - For op_pattern > kCommReduce we make decisions selectively. Currently we do enable texture + * scope for Conv2D, PoolOps. + * The trick here is whiel this pass is in action we need op_pattern information for ops that are + * below kCommReduce as well op attrbuted for seletive ops like Conv2D and PoolOps. + * op_pattern is available after legalization and TIROpPattern pass does an analysis. However, + * op specific attributes doesn't exist after legalization. + * + * To solve this issue, we go legalization in parts. + * At first, we call legalization by skipping the list of ops we wanted not to legalize. + * LigalizeOps is enhanced to accept skip_ops for this purpose. + * After legalization and AnnotateTIROpPattern this way the mod liiks like + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv = R.call_tir(cls.te_layout_transform, (x,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32") + * ) + * lv1 = R.call_tir(cls.te_layout_transform1, (w,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32") + * ) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d( + * lv, + * lv1, + * data_layout="NCHW4c", + * kernel_layout="OIHW4o", + * out_layout="NCHW4c", + * out_dtype="float32" + * ) + * gv = R.call_tir(cls.te_layout_transform2, (lv2,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32") + * ) + * R.output(gv) + * return gv + * + * Here, the legalized prim functions does have op_pattern attribute. + * We now have what we wanted to run this pass. + * + * This pass in principle does scope annotation based on sonsumer priotiry. i.e. + * For any tensor object we tries to assign scope based on the sonsuner requirement. + * The conflicts and multiple consumers for same tensor are handled by injecting + * appropriate copies. + * 1: CollectConsumerScopeInfo: Visitor collects all consumer demand for each input + * 2: CollectProducerScopeInfo: Visitor does finalizes the scope for each input and output based + * on consumer scope information. It does evaluating mutiple consumer cases and conflicts. + * 3: DefineVDevice: Pass does injects hint_on_device for each argument. It also tries to update + * out StructInfo containing VDevice information. This update for tir calls is straight forward + * as sinfo_args in CallNode is meant for this purpose. This sinfo_args for other calls by + * design is invalid as we do this by "FInferStructInfo". + * Another issue we have with "FInferStructInfo" per op is they can't decide this + * memory scope information which is done by this pass based on consumer demand. + * Hence, we are going to use the sinfo_args to indicate this information. + * So, this pass attributes sinfo_args for regumar calls too and FInferStructInfo implmentation + * do take VDevice information fro this hint. This also solves the issue of mixed VDevice + * for arguments of an op. + * After these steps the mod looks like + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv: R.Tensor((2, 64, 56, 56), dtype="float32") = R.hint_on_device( + * x, R.device(dev_type=4, dev_id=0), "global" + * ) + * lv_1 = R.call_tir(cls.te_layout_transform, (lv,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) + * ) + * lv1: R.Tensor((32, 64, 3, 3), dtype="float32") = R.hint_on_device( + * w, R.device(dev_type=4, dev_id=0), "global" + * ) + * lv1_1 = R.call_tir(cls.te_layout_transform1, (lv1,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) + * ) + * lv2: R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) = R.hint_on_device(lv_1, R.device(dev_type=4, dev_id=0), "global.texture-nhwc") + * lv3: R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) = R.hint_on_device(lv1_1, R.device(dev_type=4, dev_id=0), "global.texture-weight") + * lv2_1: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + & ) = R.nn.conv2d( + * lv2, lv3, + * data_layout="NCHW4c", kernel_layout="OIHW4o", + * out_layout="NCHW4c", out_dtype="float32", + * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global"), + * ) + * ) + * lv4: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + * ) = R.hint_on_device(lv2_1, R.device(dev_type=4, dev_id=0), "global") + * gv = R.call_tir(cls.te_layout_transform2, (lv4,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * ) + * R.output(gv) + * return gv + * + * What we have above is hint_on_device injections and out_sinfo for all calls. + * Now, we apply RealizeVDevice to formalize the hints. Follwed by we also call + * CanonicalizeBindings that removes redundant assignments like + * + * lv: R.Tensor((2, 64, 56, 56), dtype="float32", vdevice="opencl:1:global") = x + * lv1: R.Tensor((32, 64, 3, 3), dtype="float32", vdevice="opencl:1:global") = w + * + * These assignments are result of hint_on_device not realizing any copy while consumer and + * producer has same memory scope or vdevice. These assignments do impact operator fusion. + * + * Now the mod looks like, + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv = R.call_tir(cls.te_layout_transform, (x,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) + * ) + * lv1 = R.call_tir(cls.te_layout_transform1, (w,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) + * ) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + * ) = R.nn.conv2d( + * lv2, lv3, + * data_layout="NCHW4c", kernel_layout="OIHW4o", + * out_layout="NCHW4c", out_dtype="float32", + * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global"), + * ) + * ) + * gv = R.call_tir(cls.te_layout_transform2, (lv4,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * ) + * R.output(gv) + * return gv + * + * Followed by, the compilation pipeline calls + * - legalization of the remainng ops: This legalization do forwards the annotated out_sinfo + * VDevice information to tir_calls + * - AnnotateTIROpPattern : TIROp Patterns for newly legalizes ops + * - Fusion + * - FoldVDeviceScopeChange: There existed some ToVDevice copies from texture to buffer + * This pass removes the copes and updates producer scope to global. + * - SpecializePrimFuncBasedOnCallSite: Finally we updates the Buffer Var maps according to + * VDevice scopes. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../op/tensor/manipulate.h" +#include "../../transform/infer_layout_utils.h" +#include "../../transform/utils.h" + +namespace tvm { +namespace relax { +namespace backend { +namespace adreno { + +using tvm::tir::Buffer; + +static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +/* + * \brief generates consumer information for each var + * \return scope_info is a map which contain for each var the corresponding call nodes that + * consume it and corresponding scope it expects this input to be. + * \return call_scope_info is a map of each call_node and array holding scope infor for each input. + */ +class CollectConsumerScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + std::pair>, + ffi::Map>>> + Collect(const IRModule& mod, Function func, const Target& target) { + mod_ = mod; + target_ = target; + VisitExpr(func->body); + // Extend the scope for tuple items + for (const auto& val : arg_to_binding) { + if (scope_info.find(val.first) != scope_info.end()) { + if (scope_info.find(val.second) == scope_info.end()) { + scope_info.Set(val.second, scope_info[val.first]); + } else { + auto ent = scope_info[val.second]; + for (auto ent_val : scope_info[val.first]) { + ent.Set(ent_val.first, ent_val.second); + } + scope_info.Set(val.second, ent); + } + } + } + + return std::make_pair(call_scope_info, scope_info); + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (arg_to_binding.find(ffi::GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(ffi::GetRef(binding->var.get()), + ffi::GetRef(tuple_get_item_node->tuple.get())); + } + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + GlobalVar gv; + ffi::Array op_attrs; + ffi::Optional op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); + Tuple func_args; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + op_attrs = ExtractAttrs(pfunc); + op_pattern = ExtractPattern(pfunc); + func_args = Downcast(call->args[1]); + } else { + op_attrs = {call->attrs}; + op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); + func_args = Tuple(call->args); + } + + bool is_texture_supported = SupportsTexture(op_attrs, op_pattern.value()); + + ffi::Array arg_scope; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + auto scope = is_texture_supported + ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) + : "global"; + ffi::Map> ent_call; + const VarNode* arg_var = arg.as(); + if (scope_info.find(ffi::GetRef(arg_var)) != scope_info.end()) { + ent_call = scope_info[ffi::GetRef(arg_var)]; + } + ent_call.Set(ffi::GetRef(call), {scope}); + scope_info.Set(ffi::GetRef(arg_var), ent_call); + arg_scope.push_back(scope); + } + } + call_scope_info.Set(ffi::GetRef(call), arg_scope); + } + + private: + template + ffi::Array ExtractAttrs(const T& func) { + ffi::Array op_attrs; + ffi::Optional attrs = func->template GetAttr("op_attrs"); + if (attrs) { + if (auto val = attrs.value().as()) { + op_attrs.push_back(val.value()); + } else if (auto val = attrs.value().as>()) { + op_attrs = val.value(); + } + } + return op_attrs; + } + + template + ffi::Optional ExtractPattern(const T& func) { + ffi::Optional op_pat = func->template GetAttr("op_pattern"); + return op_pat; + } + + bool SupportsTexture(const ffi::Array& op_attrs, Integer op_pattern) { + if (op_pattern.IntValue() < OpPatternKind::kCommReduce) return true; + + for (auto attr : op_attrs) { + if (auto conv_attr = attr.as()) { + if (conv_attr->data_layout == "NCHW4c" && conv_attr->kernel_layout == "OIHW4o") { + return true; + } + } else if (auto pool_attrs = attr.as()) { + if (pool_attrs->layout == "NCHW4c") { + return true; + } + } else if (auto avg_attrs = attr.as()) { + if (avg_attrs->layout == "NCHW4c") { + return true; + } + } else if (attr.as()) { + return true; + } + } + + return false; + } + + std::string Scope(ffi::Array shape) { + // currently we support only textures been made from 5d tensors + // 5d requirement is not limitation of textures in general, it is limitation how + // we are representing memory scopes/layout and flattening of textures in tir + if (shape.size() == 5 && shape[4].as()->value == 4) { + for (auto ind : shape) { + if (!ind.as()) { + // Dynamic tensors + return "global.texture-nchw"; + } + } + std::map diffs; + int spatial_limit = + target_->GetAttr("texture_spatial_limit").value_or(Integer(16384))->value; + int depth_limit = + target_->GetAttr("texture_depth_limit").value_or(Integer(2048))->value; + int a0 = shape[0].as()->value; + int a1 = shape[1].as()->value; + int a2 = shape[2].as()->value; + int a3 = shape[3].as()->value; + + int d1r = a0 * a1; + int d2r = a2 * a3; + int d3r = a1 * a2 * a3; + std::string scope = "global"; + if (a0 < spatial_limit && d3r < spatial_limit) + scope += ".texture-weight"; + else if (a0 < depth_limit && a1 < spatial_limit && d2r < spatial_limit) + scope += ".texture-nhwc"; + else if (d1r < depth_limit && a2 < spatial_limit && a3 < spatial_limit) + scope += ".texture"; + return scope; + } + return "global"; + } + + /* Map of each Var consumption by a call node and its scope */ + ffi::Map>> scope_info; + /* A map of call node and scope info for each argument it consunes */ + ffi::Map> call_scope_info; + ffi::Map arg_to_binding; + IRModule mod_; + Target target_; +}; + +/* + * \brief producer scope information consolidated based on consumer demands. + * \return producer_info which is a map of each call node and corresponding out StructInfo + * This pass considers all consumers and their scope demand. + * Any mismatches here introduces copies as needed. + */ +class CollectProducerScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + ffi::Map Collect( + const IRModule& mod, Function func, + const ffi::Map>>& scope_info, + const Target& target, const BlockBuilder& builder) { + mod_ = mod; + scope_info_ = scope_info; + target_ = target; + builder_ = builder; + VisitExpr(func->body); + + return producer_sinfo; + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + ExprVisitor::VisitBinding_(binding, call); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + out_sinfo = call->sinfo_args[0]; + } else { + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); + + auto* op_ptr = call->op.as(); + Op op = ffi::GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + out_sinfo = op_map_infer_struct_info_[op](ffi::GetRef(call), builder_); + } + + std::unordered_map scope_count; + + // Decide the final scope based on the max consumer demand. Rest will use to_device. + auto arg_var = binding->var.as(); + if (scope_info_.find(ffi::GetRef(arg_var)) != scope_info_.end()) { + for (const auto& val : scope_info_[ffi::GetRef(arg_var)]) { + auto call_node = Downcast(val.first); + if (scope_count.find(val.second[0]) == scope_count.end()) { + scope_count.insert({val.second[0], 1}); + } else { + auto curr_count = scope_count[val.second[0]]; + scope_count.emplace(val.second[0], curr_count + 1); + } + } + } + ffi::String final_scope = "global"; + int count = 0; + for (const auto& sval : scope_count) { + if (sval.second > count) { + final_scope = sval.first; + count = sval.second; + } + } + // Applying same scope for outputs + StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); + producer_sinfo.Set(ffi::GetRef(call), updated_ret_sinfo); + } + + private: + StructInfo UpdateStructInfo(const StructInfo& out_sinfo, ffi::Array scope) { + if (out_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(out_sinfo); + auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); + return TensorStructInfo(ShapeExpr(shape_arr), tensor_sinfo->dtype, + VDevice(target_, 0, scope[0])); + } + + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + ffi::Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + sinfo_fields.push_back( + TensorStructInfo(ShapeExpr(shape_arr), sinfo->dtype, VDevice(target_, 0, scope[0]))); + } + return TupleStructInfo(sinfo_fields); + } + + ffi::Map>> scope_info_; + ffi::Map producer_sinfo; + IRModule mod_; + Target target_; + BlockBuilder builder_; +}; + +/* + * \brief main pass that injects hint_on_device for each argument based on producer, + * consumer indormations. This also attributes ret StructInfo for each call node. + * This pass also calls the ReliaseVdevice that formalizes the hints by appropriately injecting + * Vdevice copies as needed. + */ + +class DefineVDevice : ExprMutator { + public: + explicit DefineVDevice(const Target& target) : target_(target) {} + + IRModule Run(IRModule& mod) { + mod_ = mod; + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + const auto& base_func = mod_->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); + call_scope_info_ = info.first; + scope_info_ = info.second; + producer_sinfo_ = CollectProducerScopeInfo().Collect(mod_, Downcast(func), + scope_info_, target_, builder_); + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + + ffi::Array global_vdevices_; + for (auto vdev : vdevices_) { + global_vdevices_.push_back(vdev.as().value()); + } + mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); + + mod_ = relax::transform::DeadCodeElimination()(mod_); + mod_ = relax::transform::RealizeVDevice()(mod_); + mod_ = relax::transform::CanonicalizeBindings()(mod_); + + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + GlobalVar gv; + Tuple func_args; + + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + // tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + // out_sinfo = call->sinfo_args[0]; + func_args = Downcast(call->args[1]); + } else { + func_args = Tuple(call->args); + // return call; + } + + ffi::Array new_args; + StructInfo updated_ret_sinfo = producer_sinfo_[ffi::GetRef(call_node)]; + + if (updated_ret_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(updated_ret_sinfo); + auto shape = tensor_sinfo->shape.value(); + auto dtype = tensor_sinfo->dtype; + if (tensor_sinfo->vdevice.defined()) { + auto vdev = tensor_sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + updated_ret_sinfo = TensorStructInfo(shape, dtype, vdev_global); + } + } else { + ICHECK(updated_ret_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << updated_ret_sinfo; + + const auto& tuple_sinfo = Downcast(updated_ret_sinfo); + ffi::Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + + auto shape = sinfo->shape.value(); + auto dtype = sinfo->dtype; + if (sinfo->vdevice.defined()) { + auto vdev = sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + sinfo_fields.push_back(TensorStructInfo(shape, dtype, vdev_global)); + } else { + sinfo_fields.push_back(sinfo); + } + } + updated_ret_sinfo = TupleStructInfo(sinfo_fields); + } + + int arg_idx = 0; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + ffi::String scope = "global"; + if (call_scope_info_.find(ffi::GetRef(call_node)) != call_scope_info_.end()) { + scope = call_scope_info_[ffi::GetRef(call_node)][arg_idx]; + } + new_args.push_back(HintArg(arg, scope)); + arg_idx++; + } else { + new_args.push_back(arg); + } + } + + if (call->op == call_tir_op) { + return builder_->Normalize( + Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo})); + } else { + return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_sinfo})); + } + } + + private: + VDevice MakeGlobalVDevice(VDevice vdev) { + int device_type = vdev->target->GetTargetDeviceType(); + for (size_t i = 0; i < vdevices_.size(); ++i) { + int dev_type = vdevices_[i]->target->GetTargetDeviceType(); + if (dev_type == device_type && vdevices_[i]->vdevice_id == vdev->vdevice_id && + vdevices_[i]->memory_scope == vdev->memory_scope) { + return vdevices_[i]; + } + } + vdevices_.push_back(vdev); + return (vdevices_.back()); + } + + Expr HintArg(const Expr& arg, ffi::String scope) { + if (arg->IsInstance()) { + if (auto tsinfo = arg->struct_info_.as()) { + if (!tsinfo->vdevice.defined()) { + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); + CHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; + arg->struct_info_ = + TensorStructInfo(tsinfo->shape.value(), tsinfo->dtype, vdev, tsinfo->span); + return arg; + } + } + } + ObjectPtr attrs = ffi::make_object(); + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); + attrs->device_type = vdev->target->GetTargetDeviceType(); + attrs->index = vdev->vdevice_id; + attrs->memory_scope = vdev->memory_scope; + + Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); + + return new_arg; + } + + ffi::Optional GetTarget(const StructInfo& sinfo) { + auto tinfo = sinfo.as(); + if (tinfo->vdevice.defined()) { + auto vdevice = tinfo->vdevice.value(); + if (vdevice->target.defined()) { + return vdevice->target; + } + } + return std::nullopt; + } + + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); + IRModule mod_; + IRModule updates_; + Target target_; + ffi::Array vdevices_; + ffi::Map>> scope_info_; + ffi::Map producer_sinfo_; + ffi::Map> call_scope_info_; +}; + +namespace transform { + +Pass AnnotateCustomMemoryScope(Target target) { + auto pass_func = [=](IRModule mod, PassContext pc) { + return tvm::relax::backend::adreno::DefineVDevice(target).Run(mod); + }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"AnnotateCustomMemoryScope", + /*required=*/{}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.backend.adreno.transform.AnnotateCustomMemoryScope", + AnnotateCustomMemoryScope); +} +} // namespace transform +} // namespace adreno +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc new file mode 100644 index 000000000000..c59beae78e96 --- /dev/null +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/adreno/fold_vdevice_scope_change.cc + * \brief This is a texture specific pass that can optimize unnecessary to_device copies. + * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + * store into global scope avoiding unnecessary device copy. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../op/tensor/manipulate.h" +#include "../../transform/infer_layout_utils.h" +#include "../../transform/utils.h" + +namespace tvm { +namespace relax { +namespace backend { +namespace adreno { + +namespace { +std::tuple)>> CreatePatterns( + ffi::Map> consumers) { + auto pat_gv = WildcardPattern(); + + auto pat_inp = WildcardPattern(); + auto pat_call_tir = IsOp("relax.call_tir")(pat_gv, pat_inp); + auto pattern_out = IsOp("relax.to_vdevice")(pat_call_tir); + + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { + const auto* call_tir = matches[pat_call_tir].as(); + ICHECK(call_tir) << "InternalError: " + << "Match of relax.call_tir operator should produce Call, " + << "but instead produces " << matches[pat_call_tir] << " with type " + << matches[pat_call_tir]->GetTypeKey(); + + const auto* out = matches[pattern_out].as(); + ICHECK(out) << "InternalError: " + << "Match of relax.to_vdevice operator should produce Call, " + << "but instead produces " << matches[pattern_out] << " with type " + << matches[pattern_out]->GetTypeKey(); + + const auto* vdev_attrs = out->attrs.as(); + ICHECK(vdev_attrs) << "InternalError: " + << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " + << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); + + const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); + if (!tir_out_sinfo) return expr; + + if (!tir_out_sinfo->vdevice.defined()) return expr; + + const VarNode* arg_var = out->args[0].as(); + if (consumers.find(ffi::GetRef(arg_var)) != consumers.end()) { + if (consumers[ffi::GetRef(arg_var)].size() > 1) { + /* Don't do to_device optimization as we are not the only consumer */ + return expr; + } + } + + if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != + std::string::npos) && + (vdev_attrs->dst_vdevice->memory_scope == "global")) { + auto shape_arr = tir_out_sinfo->GetShape().value(); + auto new_sinfo = + TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, vdev_attrs->dst_vdevice); + + return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_sinfo}); + } + return expr; + }; + + return {pattern_out, rewriter}; +} + +} // namespace + +class CollectConsumerDetails : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + ffi::Map> Collect(const IRModule& mod, Function func, + const Target& target) { + mod_ = mod; + target_ = target; + VisitExpr(func->body); + // Extend the consumer details for tuple items + for (const auto& val : arg_to_binding) { + if (consumers.find(val.first) != consumers.end()) { + if (consumers.find(val.second) == consumers.end()) { + consumers.Set(val.second, consumers[val.first]); + } else { + auto ent = consumers[val.second]; + for (auto ent_val : consumers[val.first]) { + ent.push_back(ent_val); + } + consumers.Set(val.second, ent); + } + } + } + return consumers; + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (arg_to_binding.find(ffi::GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(ffi::GetRef(binding->var.get()), + ffi::GetRef(tuple_get_item_node->tuple.get())); + } + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + Tuple func_args; + + if (call->op == call_tir_op) { + func_args = Downcast(call->args[1]); + } else { + func_args = Tuple(call->args); + } + + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + ffi::Array call_list; + + const VarNode* arg_var = arg.as(); + + if (consumers.find(ffi::GetRef(arg_var)) != consumers.end()) { + call_list = consumers[ffi::GetRef(arg_var)]; + } + call_list.push_back(ffi::GetRef(call)); + consumers.Set(ffi::GetRef(arg_var), call_list); + } + } + } + + private: + /* Map of each Var consumption by a call node */ + ffi::Map> consumers; + ffi::Map arg_to_binding; + IRModule mod_; + Target target_; +}; + +namespace transform { + +Pass FoldVDeviceScopeChange() { + auto pass_func = [=](Function func, IRModule mod, PassContext pc) { + /* here Target doesn't matter as the consumers we use only to find multiple consumers */ + auto consumers = + CollectConsumerDetails().Collect(mod, Downcast(func), Target("opencl")); + auto [pattern, rewriter] = CreatePatterns(consumers); + return RewriteCall(pattern, rewriter, func); + }; + return CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.backend.adreno.transform.FoldVDeviceScopeChange", + FoldVDeviceScopeChange); +} +} // namespace transform +} // namespace adreno +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index b25bfbdb22a7..362621f4238e 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -48,18 +48,17 @@ struct OpenCLMLCompilerConfigNode : public AttrsNodeReflAdapter constant_names, Map bindings) + explicit OpenCLMLJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} /*! @@ -135,9 +135,9 @@ class OpenCLMLJSONSerializer : public JSONSerializer { // The call must be to an inline "Composite" function const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); @@ -177,7 +177,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer { VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; } - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } /*! @@ -191,8 +191,8 @@ class OpenCLMLJSONSerializer : public JSONSerializer { const auto* fn_var = cn->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); nodes.pad = backend::TryGetOpInFunction(fn, "relax.nn.pad"); @@ -220,8 +220,8 @@ class OpenCLMLJSONSerializer : public JSONSerializer { const auto* fn_var = cn->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); @@ -292,11 +292,11 @@ class OpenCLMLJSONSerializer : public JSONSerializer { private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { - for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + for (const auto& entry : serializer_->VisitExpr(ffi::GetRef(constant_node))) { args_.emplace_back(entry); } } @@ -311,9 +311,10 @@ void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) * \param functions The extern functions to be compiled via OpenCLML * \return Runtime modules. */ -Array OpenCLMLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array OpenCLMLCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "OpenCLML partition:" << std::endl << func; OpenCLMLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -328,10 +329,10 @@ Array OpenCLMLCompiler(Array functions, Map return compiled_functions; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.ext.openclml", OpenCLMLCompiler); -}); +} /*! * \brief Check whether OpenCLML graph executor is enabled. @@ -357,12 +358,12 @@ Integer GetOpenCLMLVersion() { #endif // TVM_GRAPH_EXECUTOR_CLML } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.is_openclml_runtime_enabled", IsOpenCLMLRuntimeEnabled) .def("relax.get_openclml_version", GetOpenCLMLVersion); -}); +} } // namespace contrib } // namespace relax diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h b/src/relax/backend/contrib/codegen_c/codegen_c.h index 7f04091fc178..3c6469423890 100644 --- a/src/relax/backend/contrib/codegen_c/codegen_c.h +++ b/src/relax/backend/contrib/codegen_c/codegen_c.h @@ -47,7 +47,7 @@ struct GenerateBodyOutput { std::string decl; std::vector buffers; std::vector outputs; - Array headers; + ffi::Array headers; }; // The base class to generate the declaration functions in C. @@ -115,7 +115,7 @@ class CodegenCBase { * * \code * - * Array foo_consts; + * ffi::Array foo_consts; * * // An example code for the generated C function. * int foo_wrapper_(DLTensor* arg0, @@ -129,7 +129,7 @@ class CodegenCBase { * * TVM_FFI_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); * - * int foo_init_wrapper_(Array arr) { + * int foo_init_wrapper_(ffi::Array arr) { * foo_consts = arr; * return 0; * } @@ -220,7 +220,7 @@ class CodegenCBase { // codegen. Moreover, in microTVM we dont expect this part to be generated. code_stream_ << "#ifdef __cplusplus\n"; code_stream_ << "int " << func_name - << "_init_wrapper_(tvm::Array arr) {\n"; + << "_init_wrapper_(tvm::ffi::Array arr) {\n"; EnterScope(); PrintIndents(); code_stream_ << func_name << "_consts = arr;\n"; @@ -233,7 +233,7 @@ class CodegenCBase { } } - void GenerateBackendCFunc(const std::string& func_name, const Array& args, + void GenerateBackendCFunc(const std::string& func_name, const ffi::Array& args, const std::string& const_arr_name, const std::vector& outs, bool pass_dl_tensor = false) { std::vector arg_types; @@ -266,7 +266,7 @@ class CodegenCBase { * * \return The emitted code string. */ - std::string JitImpl(const std::string& ext_func_id, const Array& args, + std::string JitImpl(const std::string& ext_func_id, const ffi::Array& args, const std::vector& buf_decl, const std::vector& body, const std::string& const_arr_name, const std::vector& outs) { @@ -369,7 +369,7 @@ class CodegenCBase { } /*! - * \brief Creates a checker to check if the NDArray pool is initialized + * \brief Creates a checker to check if the Tensor pool is initialized * * \param symobl The Symbol of the current function * @@ -389,8 +389,8 @@ class CodegenCBase { * * \return The created declaration */ - std::string CreateNDArrayPool(const std::string& symbol) const { - return "tvm::Array " + symbol + "_consts;"; + std::string CreateTensorPool(const std::string& symbol) const { + return "tvm::ffi::Array " + symbol + "_consts;"; } /*! diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 3e0b6ea5e8c6..505696254209 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -87,7 +87,7 @@ class OpAttrExtractor { void Visit(const char* key, std::string* value) { SetNodeAttr(key, {*value}); } - void Visit(const char* key, Optional* value) { + void Visit(const char* key, ffi::Optional* value) { if (value->has_value()) { SetNodeAttr(key, {Fp2String(value->value())}); } else { @@ -95,7 +95,7 @@ class OpAttrExtractor { } } - void Visit(const char* key, Optional* value) { + void Visit(const char* key, ffi::Optional* value) { if (value->has_value()) { SetNodeAttr(key, {std::to_string(value->value())}); } else { @@ -119,7 +119,7 @@ class OpAttrExtractor { attr.push_back(std::to_string(im->value)); } else if (const auto* fm = (*an)[i].as()) { attr.push_back(Fp2String(fm->value)); - } else if (auto opt_str = (*an)[i].as()) { + } else if (auto opt_str = (*an)[i].as()) { attr.push_back(*opt_str); } else { LOG(FATAL) << "Not supported type: " << (*an)[i].GetTypeKey(); @@ -174,7 +174,7 @@ class OpAttrExtractor { this->Visit(field_info->name.data, &value); break; } - case ffi::TypeIndex::kTVMFFINDArray: { + case ffi::TypeIndex::kTVMFFITensor: { this->Visit(field_info->name.data, &field_value); break; } @@ -201,7 +201,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { * \brief Constructor * \param constant_names The names of all constants in the original module. */ - explicit JSONSerializer(const Map& constant_names) + explicit JSONSerializer(const ffi::Map& constant_names) : constant_names_(constant_names) {} void serialize(Function func) { @@ -214,7 +214,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } /*!\brief Return the required constants. */ - Array GetConstantNames() const { return constants_used_; } + ffi::Array GetConstantNames() const { return constants_used_; } /*!\brief Return the generated json. */ std::string GetJSON() { @@ -284,7 +284,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { extractor.Extract(const_cast(call_attr)); } else if (const auto* fn = cn->op.as()) { ICHECK(false); - auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); + auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); ICHECK(pattern.has_value()); std::vector values; values.push_back(pattern.value()); @@ -361,12 +361,12 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const ConstantNode* cn) { - auto name = constant_names_.find(GetRef(cn)); + auto name = constant_names_.find(ffi::GetRef(cn)); ICHECK(name != constant_names_.end()) - << "Cannot find the name of the constant: " << GetRef(cn); + << "Cannot find the name of the constant: " << ffi::GetRef(cn); constants_used_.push_back((*name).second); auto node = std::make_shared((*name).second, "const" /* op_type_ */); - return AddNode(node, GetRef(cn)); + return AddNode(node, ffi::GetRef(cn)); } NodeEntries VisitExpr_(const TupleNode* tn) { @@ -379,12 +379,12 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const CallNode* cn) { - Expr expr = GetRef(cn); + Expr expr = ffi::GetRef(cn); std::string name; if (const auto* op_node = cn->op.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->GetAttr(attr::kComposite); + auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.has_value()) << "JSON runtime only supports composite functions."; name = comp.value(); } else { @@ -404,7 +404,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); SetCallNodeAttribute(node, cn); - return AddNode(node, GetRef(cn)); + return AddNode(node, ffi::GetRef(cn)); } NodeEntries VisitExpr_(const TupleGetItemNode* gtn) { @@ -413,7 +413,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const FunctionNode* fn) { - ICHECK(fn->GetAttr(attr::kComposite).has_value()) + ICHECK(fn->GetAttr(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. @@ -453,9 +453,9 @@ class JSONSerializer : public relax::MemoizedExprTranslator { /*! \brief Output of the JSON graph. */ NodeEntries heads_; /*! \brief The list of required constants, ordered. */ - Array constants_used_; + ffi::Array constants_used_; /*! \brief The names of all constants in the original module. */ - const Map& constant_names_; + const ffi::Map& constant_names_; }; } // namespace contrib diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 0cd0150970e6..ab8336bfd5b2 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -41,7 +41,7 @@ using backend::contrib::NodeEntries; class CublasJSONSerializer : public JSONSerializer { public: - CublasJSONSerializer(Map constant_names, Map bindings) + CublasJSONSerializer(ffi::Map constant_names, ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -49,10 +49,10 @@ class CublasJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -101,17 +101,18 @@ class CublasJSONSerializer : public JSONSerializer { const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array CublasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array CublasCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -126,10 +127,10 @@ Array CublasCompiler(Array functions, Map constant_names, Map bindings) + cuDNNJSONSerializer(ffi::Map constant_names, ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -48,10 +48,10 @@ class cuDNNJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -89,7 +89,7 @@ class cuDNNJSONSerializer : public JSONSerializer { const CallNode* root_call = backend::GetOpInFunction(fn, "relax.nn.conv2d"); SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } NodeEntries HandleAttention(const CallNode* call_node, const Function& fn, @@ -125,17 +125,18 @@ class cuDNNJSONSerializer : public JSONSerializer { node->SetAttr("head_size", to_str_array(head_size)); node->SetAttr("head_size_v", to_str_array(head_size_v)); node->SetAttr("layout", std::vector{std::vector{layout}}); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array cuDNNCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array cuDNNCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { cuDNNJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -150,10 +151,10 @@ Array cuDNNCompiler(Array functions, Map& out, const std::string& fun return code_stream_.str(); } -ffi::Module Finalize(const std::string& code, const Array& func_names) { +ffi::Module Finalize(const std::string& code, const ffi::Array& func_names) { ICHECK(!func_names.empty()) << "Should only create CUTLASS CSourceModule if there is at least one CUTLASS partition"; @@ -71,14 +71,14 @@ ffi::Module Finalize(const std::string& code, const Array& func_names) { const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.CSourceModuleCreate"); VLOG(1) << "Generated CUTLASS code:" << std::endl << code; return pf(default_headers.str() + code, "cu", func_names, - /*const_vars=*/Array()) + /*const_vars=*/ffi::Array()) .cast(); } class CodegenResultNode : public Object { public: - String code; - Array headers; + ffi::String code; + ffi::Array headers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -86,36 +86,35 @@ class CodegenResultNode : public Object { .def_ro("code", &CodegenResultNode::code) .def_ro("headers", &CodegenResultNode::headers); } - - static constexpr const char* _type_key = "contrib.cutlass.CodegenResult"; - TVM_DECLARE_FINAL_OBJECT_INFO(CodegenResultNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("contrib.cutlass.CodegenResult", CodegenResultNode, Object); }; class CodegenResult : public ObjectRef { public: - CodegenResult(String code, Array headers) { - auto n = make_object(); + CodegenResult(ffi::String code, ffi::Array headers) { + auto n = ffi::make_object(); n->code = std::move(code); n->headers = std::move(headers); data_ = std::move(n); } - TVM_DEFINE_OBJECT_REF_METHODS(CodegenResult, ObjectRef, CodegenResultNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CodegenResult, ObjectRef, CodegenResultNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ CodegenResultNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { CodegenResultNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("contrib.cutlass.CodegenResult", [](String code, Array headers) { - return CodegenResult(code, headers); - }); -}); + refl::GlobalDef().def("contrib.cutlass.CodegenResult", + [](ffi::String code, ffi::Array headers) { + return CodegenResult(code, headers); + }); +} GenerateBodyOutput GenerateBody(const std::string& func_name, const std::string& ext_func_id, const std::vector& output_types, - const Array& func_args, const Map& attrs, - int* buf_idx) { + const ffi::Array& func_args, + const ffi::Map& attrs, int* buf_idx) { // Make function call with input buffers when visiting arguements ICHECK_GT(func_args.size(), 0); std::ostringstream decl_stream; @@ -150,7 +149,7 @@ using OutputType = std::vector; class CodegenCutlass : public relax::MemoizedExprTranslator, public relax::contrib::CodegenCBase { public: - CodegenCutlass(const std::string& id, const Map& bindings) + CodegenCutlass(const std::string& id, const ffi::Map& bindings) : ext_func_id_(id), bindings_(bindings) {} void AddParm(Var param) { @@ -195,7 +194,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, return code_stream_.str(); } - Array GetHeaders() { return headers_; } + ffi::Array GetHeaders() { return headers_; } protected: OutputType VisitExpr_(const VarNode* node) final { @@ -209,8 +208,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, OutputType VisitExpr_(const CallNode* call) final { const auto* fn_var = call->op.as(); ICHECK(fn_var); - const auto func = Downcast(bindings_[GetRef(fn_var)]); - const auto pattern_name_opt = func->GetAttr(attr::kComposite); + const auto func = Downcast(bindings_[ffi::GetRef(fn_var)]); + const auto pattern_name_opt = func->GetAttr(attr::kComposite); ICHECK(pattern_name_opt) << "Only composite function is supported for CUTLASS."; auto ret = GenerateBody(call, pattern_name_opt.value(), func->attrs->dict); ext_func_body_.push_back(ret.decl); @@ -219,7 +218,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } OutputType VisitExpr_(const FunctionNode* fn) final { - ICHECK(fn->GetAttr(attr::kComposite).has_value()) + ICHECK(fn->GetAttr(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. return {}; @@ -282,8 +281,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } private: - Array GetArgumentNames(const CallNode* call) { - Array arg_names; + ffi::Array GetArgumentNames(const CallNode* call) { + ffi::Array arg_names; for (size_t i = 0; i < call->args.size(); ++i) { auto res = VisitExpr(call->args[i]); for (const auto& out : res) { @@ -294,9 +293,9 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } GenerateBodyOutput GenerateBody(const CallNode* call, const std::string& func_name, - const Map& attrs) { + const ffi::Map& attrs) { auto func_args = GetArgumentNames(call); - auto struct_info = GetStructInfo(GetRef(call)); + auto struct_info = GetStructInfo(ffi::GetRef(call)); std::vector out_types; if (const auto* tensor_sinfo = struct_info.as()) { @@ -316,15 +315,15 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, */ int buf_idx_{0}; /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ - Array ext_func_args_; + ffi::Array ext_func_args_; /*! \brief The statements of the function that will be compiled using CUTLASS kernels. */ std::vector ext_func_body_; /*! \brief The declaration of intermediate buffers. */ std::vector buf_decl_; /*! \brief The binding to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; /*! \brief Required header-file names. */ - Array headers_; + ffi::Array headers_; /*! * \brief A mapping from a variable to its unique name. * We use this since sometimes different parameters to the same function end up having the same @@ -337,7 +336,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, class CutlassModuleCodegen { public: - ffi::Module CreateCSourceModule(Array functions, const Map& options) { + ffi::Module CreateCSourceModule(ffi::Array functions, + const ffi::Map& options) { std::string headers = ""; std::string code = ""; for (const auto& f : functions) { @@ -351,8 +351,8 @@ class CutlassModuleCodegen { } private: - std::pair> GenCutlassFunc(const Function& function, - const Map& options) { + std::pair> GenCutlassFunc( + const Function& function, const ffi::Map& options) { ICHECK(function.defined()) << "Input error: expect a Relax function."; auto sid = GetExtSymbol(function); @@ -369,17 +369,18 @@ class CutlassModuleCodegen { } /*! \brief The accumulated function names. */ - Array func_names_; + ffi::Array func_names_; }; -Array CUTLASSCompiler(Array functions, Map options, - Map /*unused*/) { +ffi::Array CUTLASSCompiler(ffi::Array functions, + ffi::Map options, + ffi::Map /*unused*/) { const auto tune_func = tvm::ffi::Function::GetGlobal("contrib.cutlass.tune_relax_function"); ICHECK(tune_func.has_value()) << "The packed function contrib.cutlass.tune_relax_function not found, " "please import tvm.contrib.cutlass.build"; - auto annotated_functions = (*tune_func)(functions, options).cast>(); + auto annotated_functions = (*tune_func)(functions, options).cast>(); auto source_mod = CutlassModuleCodegen().CreateCSourceModule(annotated_functions, options); const auto pf = tvm::ffi::Function::GetGlobal("contrib.cutlass.compile"); @@ -390,10 +391,10 @@ Array CUTLASSCompiler(Array functions, Map constant_names, Map bindings) + DNNLJSONSerializer(ffi::Map constant_names, ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -48,10 +48,10 @@ class DNNLJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -73,17 +73,18 @@ class DNNLJSONSerializer : public JSONSerializer { } SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array DNNLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array DNNLCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { DNNLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -98,10 +99,10 @@ Array DNNLCompiler(Array functions, Map return compiled_functions; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.ext.dnnl", DNNLCompiler); -}); +} } // namespace contrib } // namespace relax diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc index e1104ac3d6c7..09a0f0026789 100644 --- a/src/relax/backend/contrib/hipblas/codegen.cc +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -40,7 +40,8 @@ using backend::contrib::NodeEntries; class HipblasJSONSerializer : public JSONSerializer { public: - HipblasJSONSerializer(Map constant_names, Map bindings) + HipblasJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -48,10 +49,10 @@ class HipblasJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -78,17 +79,18 @@ class HipblasJSONSerializer : public JSONSerializer { const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array HipblasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array HipblasCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { HipblasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -103,10 +105,10 @@ Array HipblasCompiler(Array functions, Map constant_names, Map bindings) + explicit NNAPIJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; std::vector VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -221,11 +222,11 @@ class NNAPIJSONSerializer : public JSONSerializer { VLOG(1) << "Adding node " << composite_name << " with " << node->GetInputs().size() << " inputs"; - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: - Map bindings_; + ffi::Map bindings_; }; void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { @@ -247,11 +248,12 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { ExprVisitor::VisitExpr_(call_node); } -Array NNAPICompiler(Array functions, Map /*unused*/, - Map constant_names) { +ffi::Array NNAPICompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { VLOG(1) << "NNAPI Compiler"; - Array compiled_functions; + ffi::Array compiled_functions; for (const auto& func : functions) { NNAPIJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); serializer.serialize(func); @@ -267,10 +269,10 @@ Array NNAPICompiler(Array functions, Map { - Array tensorrt_version; + ffi::Array tensorrt_version; bool use_implicit_batch; size_t max_workspace_size; bool remove_no_mac_subgraphs; @@ -58,7 +58,7 @@ struct TensorRTCompilerConfigNode : public AttrsNodeReflAdapter() .def_ro("tensorrt_version", &TensorRTCompilerConfigNode::tensorrt_version, "TensorRT version as (major, minor, patch).", - refl::DefaultValue(Array({6, 0, 1}))) + refl::DefaultValue(ffi::Array({6, 0, 1}))) .def_ro("use_implicit_batch", &TensorRTCompilerConfigNode::use_implicit_batch, "Use implicit batch", refl::DefaultValue(true)) .def_ro("max_workspace_size", &TensorRTCompilerConfigNode::max_workspace_size, @@ -70,18 +70,17 @@ struct TensorRTCompilerConfigNode : public AttrsNodeReflAdapter constant_names, Map bindings) + explicit TensorRTJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -137,9 +137,9 @@ class TensorRTJSONSerializer : public JSONSerializer { // The call must be to an inline "Composite" function const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); @@ -172,7 +172,7 @@ class TensorRTJSONSerializer : public JSONSerializer { VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } static void SaveGlobalAttributes(std::shared_ptr node) { @@ -206,11 +206,11 @@ class TensorRTJSONSerializer : public JSONSerializer { private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { - for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + for (const auto& entry : serializer_->VisitExpr(ffi::GetRef(constant_node))) { args_.emplace_back(entry); } } @@ -225,9 +225,10 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array TensorRTCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array TensorRTCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "TensorRT partition:" << std::endl << func; TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -243,10 +244,10 @@ Array TensorRTCompiler(Array functions, Map GetTensorRTVersion() { +ffi::Array GetTensorRTVersion() { #if TVM_GRAPH_EXECUTOR_TENSORRT return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; #else @@ -273,12 +274,12 @@ Array GetTensorRTVersion() { #endif // TVM_GRAPH_EXECUTOR_TENSORRT } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.is_tensorrt_runtime_enabled", IsTensorRTRuntimeEnabled) .def("relax.get_tensorrt_version", GetTensorRTVersion); -}); +} } // namespace contrib } // namespace relax diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index b555d1fc0f74..1840986c019d 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -31,8 +31,8 @@ namespace tvm { namespace relax { namespace backend { -Map ExtractArgIdx(String pattern_name, Function f) { - Map arg_idx; +ffi::Map ExtractArgIdx(ffi::String pattern_name, Function f) { + ffi::Map arg_idx; auto pattern = backend::GetPattern(pattern_name); ICHECK(pattern) << "Unsupported op_type " << pattern_name; @@ -44,7 +44,7 @@ Map ExtractArgIdx(String pattern_name, Function f) { << "\", expected to find a match for " << pattern.value()->pattern << ". However, the function did not include this pattern " << f; - auto find_index = [](const Array& params, Var v) -> std::optional { + auto find_index = [](const ffi::Array& params, Var v) -> std::optional { for (size_t i = 0; i < params.size(); ++i) { if (params[i] == v) { return i; @@ -56,7 +56,7 @@ Map ExtractArgIdx(String pattern_name, Function f) { for (const auto& [name, pat] : pattern.value()->annotation_patterns) { auto exp = matched_expr.value()[pat]; if (auto arg_var = exp.as()) { - if (auto idx = find_index(f->params, GetRef(arg_var))) { + if (auto idx = find_index(f->params, ffi::GetRef(arg_var))) { arg_idx.Set(name, IntImm(DataType::Int(64), *idx)); } } @@ -76,10 +76,10 @@ bool EndsWithPattern(const std::string& str, const std::string& pattern) { return str.compare(str.length() - pattern.length(), pattern.length(), pattern) == 0; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.contrib.extract_arg_idx", ExtractArgIdx); -}); +} } // namespace backend } // namespace relax diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index bbff798b8623..e1bcfd0aee1e 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -43,7 +43,7 @@ namespace backend { * \return The converted shape in std::vector */ -inline std::vector GetIntShape(const Array& shape) { +inline std::vector GetIntShape(const ffi::Array& shape) { std::vector ret; for (const auto& dim : shape) { const int64_t* pval = tir::as_const_int(dim); @@ -71,7 +71,7 @@ inline std::string DType2String(const tvm::DataType dtype) { inline bool IsOp(const CallNode* call, const std::string& op_name) { const auto* op_node = call->op.as(); if (!op_node) return false; - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); return op == Op::Get(op_name); } @@ -116,12 +116,12 @@ inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { * \return A mapping between variable pattern names and their positions in the partitioned * function parameter list. */ -Map ExtractArgIdx(String pattern_name, Function f); +ffi::Map ExtractArgIdx(ffi::String pattern_name, Function f); /*! * \brief Converts a numeric value to std::string. * \param value A numeric value to convert. - * \return String representation of a numeric value. + * \return ffi::String representation of a numeric value. */ template std::string to_str(const Type& value) { diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc index 6689aca2f9f4..c11ef6a35e07 100644 --- a/src/relax/backend/pattern_registry.cc +++ b/src/relax/backend/pattern_registry.cc @@ -31,15 +31,15 @@ static std::vector* GetRegistryTable() { return &table; } -void RegisterPatterns(Array entries) { +void RegisterPatterns(ffi::Array entries) { auto* table = GetRegistryTable(); for (const auto& entry : entries) { table->push_back(entry); } } -void RemovePatterns(Array names) { - std::unordered_set name_set{names.begin(), names.end()}; +void RemovePatterns(ffi::Array names) { + std::unordered_set name_set{names.begin(), names.end()}; auto* table = GetRegistryTable(); table->erase( @@ -48,9 +48,9 @@ void RemovePatterns(Array names) { table->end()); } -Array GetPatternsWithPrefix(const String& prefix) { +ffi::Array GetPatternsWithPrefix(const ffi::String& prefix) { auto* table = GetRegistryTable(); - Array result; + ffi::Array result; for (auto it = table->rbegin(); it != table->rend(); ++it) { if (support::StartsWith((*it)->name, prefix.data())) { result.push_back(*it); @@ -59,7 +59,7 @@ Array GetPatternsWithPrefix(const String& prefix) { return result; } -Optional GetPattern(const String& pattern_name) { +ffi::Optional GetPattern(const ffi::String& pattern_name) { auto* table = GetRegistryTable(); for (auto it = table->rbegin(); it != table->rend(); ++it) { if ((*it)->name == pattern_name) { @@ -69,14 +69,14 @@ Optional GetPattern(const String& pattern_name) { return std::nullopt; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.backend.RegisterPatterns", RegisterPatterns) .def("relax.backend.RemovePatterns", RemovePatterns) .def("relax.backend.GetPatternsWithPrefix", GetPatternsWithPrefix) .def("relax.backend.GetPattern", GetPattern); -}); +} } // namespace backend } // namespace relax diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h index 2c1f385a2dda..72956c33d625 100644 --- a/src/relax/backend/pattern_registry.h +++ b/src/relax/backend/pattern_registry.h @@ -44,27 +44,27 @@ using transform::FusionPattern; * \param patterns Patterns to be registered. Patterns that appear later in the list have * higher priority when partitioning DataflowBlock. */ -void RegisterPatterns(Array patterns); +void RegisterPatterns(ffi::Array patterns); /*! * \brief Remove patterns from the registry by their name. * \param names The name of patterns to be removed */ -void RemovePatterns(Array names); +void RemovePatterns(ffi::Array names); /*! * \brief Find patterns whose name starts with a particular prefix. * \param prefx The pattern name prefix. * \return Matched patterns, ordered by priority from high to low. */ -Array GetPatternsWithPrefix(const String& prefix); +ffi::Array GetPatternsWithPrefix(const ffi::String& prefix); /*! * \brief Find the pattern with a particular name. * \param name The pattern name. * \return The matched pattern. std::nullopt if not found. */ -Optional GetPattern(const String& name); +ffi::Optional GetPattern(const ffi::String& name); } // namespace backend } // namespace relax diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index b0571913049c..71c024b9d7a0 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -67,15 +67,16 @@ class BlockCounter : public tir::StmtVisitor { class TaskExtractor : public ExprVisitor { public: - static Array ExtractTask(IRModule mod, Target target, String mod_eq_name) { + static ffi::Array ExtractTask(IRModule mod, Target target, + ffi::String mod_eq_name) { TaskExtractor extractor(mod, target, mod_eq_name); // We go through each Relax function in the module. for (const auto& kv : mod->functions) { if (const auto* func = kv.second.as()) { - extractor(GetRef(func)); + extractor(ffi::GetRef(func)); } } - Array tasks; + ffi::Array tasks; for (const auto& it : extractor.func2task_) { tasks.push_back(it.second); } @@ -83,7 +84,7 @@ class TaskExtractor : public ExprVisitor { } private: - explicit TaskExtractor(IRModule mod, Target target, String mod_eq_name) + explicit TaskExtractor(IRModule mod, Target target, ffi::String mod_eq_name) : mod_(std::move(mod)), target_(std::move(target)), mod_eq_(ModuleEquality::Create(mod_eq_name)), @@ -140,13 +141,13 @@ class TaskExtractor : public ExprVisitor { std::optional normalize_mod_func_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.backend.MetaScheduleExtractTask", [](IRModule mod, Target target, - String mod_eq_name) { + ffi::String mod_eq_name) { return TaskExtractor::ExtractTask(std::move(mod), std::move(target), std::move(mod_eq_name)); }); -}); +} } // namespace backend } // namespace relax diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 1f9e8c0378a7..e2d9b5b068b7 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -60,7 +60,7 @@ class CodeGenVM : public ExprFunctor { // Remove relax function and turn into TIR func. for (const auto& [gvar, f] : mod->functions) { if (auto* func = f.as()) { - codegen.Codegen(GetRef(func)); + codegen.Codegen(ffi::GetRef(func)); res_mod->Remove(gvar); } } @@ -82,11 +82,11 @@ class CodeGenVM : public ExprFunctor { } void Codegen(const Function& func) { - Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; - Array param_names; + ffi::Array param_names; for (Var param : func->params) { param_names.push_back(param->name_hint()); } @@ -132,7 +132,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const CallNode* call_node) final { - Call call = GetRef(call_node); + Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { return Instruction::Arg::Register(Instruction::kVoidRegister); @@ -163,7 +163,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const IfNode* op) final { - const If& ife = GetRef(op); + const If& ife = ffi::GetRef(op); Instruction::Arg cond_value = this->VisitExpr(ife->cond); // Reserve a register for cond @@ -207,7 +207,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = this->var_arg_map_.find(var); ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined"; return it->second; @@ -236,7 +236,8 @@ class CodeGenVM : public ExprFunctor { return builder_->ConvertConstant(float_imm->value); } else { LOG(FATAL) << "PrimValue should only contain constant after VMShapeLower, " - << "but received " << GetRef(op) << " with type " << op->value->GetTypeKey(); + << "but received " << ffi::GetRef(op) << " with type " + << op->value->GetTypeKey(); } } @@ -249,7 +250,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const TupleNode* op) final { - Tuple tuple = GetRef(op); + Tuple tuple = ffi::GetRef(op); std::vector args; for (Expr arg : tuple->fields) { args.push_back(this->VisitExpr(arg)); @@ -261,7 +262,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final { - TupleGetItem expr = GetRef(op); + TupleGetItem expr = ffi::GetRef(op); std::vector args = {this->VisitExpr(expr->tuple)}; args.push_back(builder_->ConvertConstant(expr->index)); @@ -273,8 +274,8 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const GlobalVarNode* op) final { - GlobalVar gvar = GetRef(op); - Optional symbol; + GlobalVar gvar = ffi::GetRef(op); + ffi::Optional symbol; VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc; // Run a look up in the env to see if it maps to an extern func. @@ -306,10 +307,10 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const ExternFuncNode* op) final { static const constexpr char* kCSource = "c_source"; static const constexpr char* kCSourceFmt = "c_source_fmt"; - if (Optional opt_code = op->attrs.GetAttr(kCSource)) { - String sym = op->global_symbol; - String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); - String code = opt_code.value(); + if (ffi::Optional opt_code = op->attrs.GetAttr(kCSource)) { + ffi::String sym = op->global_symbol; + ffi::String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); + ffi::String code = opt_code.value(); ffi::Module c_source_module = codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, /*func_names=*/{sym}, /*const_vars=*/{}); @@ -367,7 +368,6 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall(func, args, dst_reg); } - void EmitNormalCall(const Call& call_node, RegName dst_reg) { Instruction::Arg func = VisitExpr(call_node->op); std::vector args = VisitArray(call_node->args); @@ -388,7 +388,7 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall(name, args, dst_reg); } - std::vector VisitArray(const Array& arr) { + std::vector VisitArray(const ffi::Array& arr) { std::vector ret; for (size_t i = 0; i < arr.size(); ++i) { ret.push_back(this->VisitExpr(arr[i])); @@ -426,10 +426,10 @@ IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) { return CodeGenVM::Run(exec_builder, mod); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.VMCodeGen", VMCodeGen); -}); +} /*! * \brief Link the modules together, possibly create a constant module. @@ -440,8 +440,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ * module(s). * \return The created module. */ -void LinkModules(ObjectPtr exec, const Map& params, - const tvm::ffi::Module& lib, const Array& ext_libs) { +void LinkModules(ObjectPtr exec, const ffi::Map& params, + const tvm::ffi::Module& lib, const ffi::Array& ext_libs) { // query if we need const loader for ext_modules // Wrap all submodules in the initialization wrapper. std::unordered_map> const_vars_by_symbol; @@ -450,8 +450,8 @@ void LinkModules(ObjectPtr exec, const MapGetFunction("get_const_vars"); std::vector symbol_const_vars; if (pf_sym.has_value() && pf_var.has_value()) { - String symbol = (*pf_sym)().cast(); - Array variables = (*pf_var)().cast>(); + ffi::String symbol = (*pf_sym)().cast(); + ffi::Array variables = (*pf_var)().cast>(); for (size_t i = 0; i < variables.size(); i++) { symbol_const_vars.push_back(variables[i].operator std::string()); } @@ -461,12 +461,12 @@ void LinkModules(ObjectPtr exec, const Map const_var_ndarray; + std::unordered_map const_var_tensor; for (const auto& [name, param] : params) { - const_var_ndarray[name] = param; + const_var_tensor[name] = param; } ffi::Module const_loader_mod = - runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); + runtime::ConstLoaderModuleCreate(const_var_tensor, const_vars_by_symbol); const_loader_mod->ImportModule(lib); for (const auto& it : ext_libs) { const_loader_mod->ImportModule(it); @@ -484,20 +484,21 @@ void LinkModules(ObjectPtr exec, const Map lib, - Array ext_libs, Map params) { +ffi::Module VMLink(ExecBuilder builder, Target target, ffi::Optional lib, + ffi::Array ext_libs, + ffi::Map params) { ObjectPtr executable = builder->Get(); if (!lib.defined()) { - lib = codegen::CSourceModuleCreate(";", "c", Array{}); + lib = codegen::CSourceModuleCreate(";", "c", ffi::Array{}); } LinkModules(executable, params, lib.value(), ext_libs); return ffi::Module(executable); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.VMLink", VMLink); -}); +} } // namespace codegen_vm } // namespace relax diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index c7cf06ea9d7f..a5bb83d406a5 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -50,11 +50,11 @@ using vm::VMFuncInfo; * \note Skip CallPacked with special attrs for now, as they can be * further simplified with PrimValue. */ -class CodeGenVMTIR : public ExprFunctor(const Expr&)> { +class CodeGenVMTIR : public ExprFunctor(const Expr&)> { public: explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod) : builder_(builder), ctx_mod_(ctx_mod) { - system_lib_prefix_ = ctx_mod_->GetAttr(tvm::attr::kSystemLibPrefix); + system_lib_prefix_ = ctx_mod_->GetAttr(tvm::attr::kSystemLibPrefix); } static IRModule Run(relax::ExecBuilder builder, IRModule mod) { @@ -66,8 +66,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { // Remove relax function and turn into TIR func. for (auto& p : mod->functions) { if (auto* func = p.second.as()) { - auto tir_func = codegen.Codegen(GetRef(func)); - auto gsymbol = tir_func->GetAttr(tvm::attr::kGlobalSymbol); + auto tir_func = codegen.Codegen(ffi::GetRef(func)); + auto gsymbol = tir_func->GetAttr(tvm::attr::kGlobalSymbol); res_mod->Add(GlobalVar(gsymbol.value()), tir_func); res_mod->Remove(p.first); } @@ -105,8 +105,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { stmt_stack_.back().emplace_back(stmt); } - void EmitCallPacked(String name, const Array& args, int64_t dst_anylist_slot = -1) { - Array all_args; + void EmitCallPacked(ffi::String name, const ffi::Array& args, + int64_t dst_anylist_slot = -1) { + ffi::Array all_args; // negative index indicate return value can be discarded, emit call_packed if (dst_anylist_slot >= 0) { all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; @@ -124,11 +125,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } } - void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array& args, + void EmitCallCPacked(const tir::PrimFunc& prim_func, const ffi::Array& args, int64_t dst_anylist_slot = -1) { - Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase"; - Array all_args; + ffi::Array all_args; // negative index indicate return value can be discarded, emit call_packed if (dst_anylist_slot >= 0) { all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; @@ -147,7 +148,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } tir::PrimFunc Codegen(const Function& func) { - Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; // initialize the state @@ -159,7 +160,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { func_anylist_handle_ = tir::Var("f", DataType::Handle()); const_anylist_handle_ = tir::Var("c", DataType::Handle()); - Array param_names; + ffi::Array param_names; for (Var param : func->params) { param_names.push_back(param->name_hint()); } @@ -174,7 +175,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t ret_reg = NewRegister(); tir::Stmt body = WithNewScope([&]() { - Optional ret = ExprFunctor::VisitExpr(func->body); + ffi::Optional ret = ExprFunctor::VisitExpr(func->body); if (ret.defined()) { this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg); } @@ -186,9 +187,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { builder_->EndFunction(gsymbol.value()); Type ret_type = VoidType(); - Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, - func_anylist_handle_}; - String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value(); + ffi::Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, + func_anylist_handle_}; + ffi::String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value(); tir::PrimFunc tir_func(tir_params, body, ret_type, {}); tir_func = WithAttr(tir_func, "global_symbol", tir_func_name); registers_num_ = 0; @@ -197,11 +198,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return tir_func; } - Optional VisitExpr_(const SeqExprNode* op) final { + ffi::Optional VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { Expr expr = GetBoundValue(binding); - Optional value = VisitExpr(expr); + ffi::Optional value = VisitExpr(expr); if (expr.as() && value.defined()) { // For a normalized relax module, there should be one @@ -220,8 +221,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return this->VisitExpr(op->body); } - Optional VisitExpr_(const CallNode* call_node) final { - Call call = GetRef(call_node); + ffi::Optional VisitExpr_(const CallNode* call_node) final { + Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { return tir::Call(DataType::Handle(), tir::builtin::reinterpret(), @@ -252,7 +253,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } } - Optional VisitExpr_(const IfNode* op) final { + ffi::Optional VisitExpr_(const IfNode* op) final { // Reserve a register for return size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); @@ -272,18 +273,18 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return RegListGet(merge_register); } - Optional VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + ffi::Optional VisitExpr_(const VarNode* op) final { + Var var = ffi::GetRef(op); auto it = this->var_map_.find(var); ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; return it->second; } - Optional VisitExpr_(const ConstantNode* op) final { + ffi::Optional VisitExpr_(const ConstantNode* op) final { return ConstListGet(builder_->ConvertConstant(op->data).value()); } - Optional VisitExpr_(const ShapeExprNode* op) final { + ffi::Optional VisitExpr_(const ShapeExprNode* op) final { std::vector shape; for (PrimExpr e : op->values) { if (auto* int_value = e.as()) { @@ -295,19 +296,19 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return ConstListGet(builder_->ConvertConstant(ffi::Shape(shape)).value()); } - Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } + ffi::Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } - Optional VisitExpr_(const StringImmNode* op) final { + ffi::Optional VisitExpr_(const StringImmNode* op) final { return ConstListGet(builder_->ConvertConstant(op->value).value()); } - Optional VisitExpr_(const DataTypeImmNode* op) final { + ffi::Optional VisitExpr_(const DataTypeImmNode* op) final { return ConstListGet(builder_->ConvertConstant(op->value).value()); } - Optional VisitExpr_(const TupleNode* op) final { - Tuple tuple = GetRef(op); - Array args; + ffi::Optional VisitExpr_(const TupleNode* op) final { + Tuple tuple = ffi::GetRef(op); + ffi::Array args; for (auto arg : tuple->fields) { args.push_back(this->VisitExpr(arg).value()); } @@ -316,9 +317,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return RegListGet(dst_register); } - Optional VisitExpr_(const TupleGetItemNode* op) final { - TupleGetItem expr = GetRef(op); - Array args = {this->VisitExpr(expr->tuple).value()}; + ffi::Optional VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = ffi::GetRef(op); + ffi::Array args = {this->VisitExpr(expr->tuple).value()}; args.push_back(ConstInt64(expr->index)); @@ -328,12 +329,12 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } // Lookup the function and see if it matches - Optional LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { + ffi::Optional LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { if (auto* ext_func = expr.as()) { *kind = VMFuncInfo::FuncKind::kPackedFunc; return ext_func->global_symbol; } else if (auto* gvar_ptr = expr.as()) { - GlobalVar gvar = GetRef(gvar_ptr); + GlobalVar gvar = ffi::GetRef(gvar_ptr); // Run a look up in the env to see if it maps to an extern func. auto it = ctx_mod_->functions.find(gvar); if (it != ctx_mod_->functions.end()) { @@ -362,7 +363,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } // Lookup PrimFunc in the same module // We can do direct PrimFunc call in such cases - Optional LookupPrimFunc(const String& name) { + ffi::Optional LookupPrimFunc(const ffi::String& name) { if (!ctx_mod_->ContainGlobalVar(name)) return std::nullopt; GlobalVar gvar = ctx_mod_->GetGlobalVar(name); @@ -370,28 +371,28 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (it != ctx_mod_->functions.end()) { BaseFunc func = (*it).second; if (auto* prim_func = func.as()) { - return GetRef(prim_func); + return ffi::GetRef(prim_func); } } return std::nullopt; } - Optional VisitExpr_(const GlobalVarNode* op) final { + ffi::Optional VisitExpr_(const GlobalVarNode* op) final { VMFuncInfo::FuncKind kind; - auto symbol = LookupFunction(GetRef(op), &kind); + auto symbol = LookupFunction(ffi::GetRef(op), &kind); ICHECK(symbol.has_value()); builder_->DeclareFunction(symbol.value(), kind); return FuncListGet(builder_->GetFunction(symbol.value()).value()); } - Optional VisitExpr_(const ExternFuncNode* op) final { + ffi::Optional VisitExpr_(const ExternFuncNode* op) final { builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); return FuncListGet(builder_->GetFunction(op->global_symbol).value()); } void EmitAllocStorage(const Call& call_node, int64_t dst_reg) { // Handle args of the call - Array args; + ffi::Array args; args.push_back(ctx_ptr_); for (Expr arg : call_node->args) { args.push_back(this->VisitExpr(arg).value()); @@ -401,7 +402,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { void EmitAllocTensor(const Call& call_node, int64_t dst_reg) { ICHECK_EQ(call_node->args.size(), 4); - Array args; + ffi::Array args; args.reserve(4); for (Expr arg : call_node->args) { args.push_back(this->VisitExpr(arg).value()); @@ -429,7 +430,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { - Array args; + ffi::Array args; // if context is required, pass as first argument. args.push_back(ctx_ptr_); auto* func = call_node->args[0].as(); @@ -446,7 +447,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } void EmitNormalCall(const Call& call_node, int64_t dst_reg) { - Array args = VisitArray(call_node->args); + ffi::Array args = VisitArray(call_node->args); // A function can be a closure that comes from parent // Do call closure to be safe. VMFuncInfo::FuncKind kind; @@ -455,14 +456,14 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (symbol.has_value() && kind == VMFuncInfo::FuncKind::kPackedFunc) { // primfunc in the same module. // use cpacked to directly invoke without named based lookup - if (Optional prim_func = LookupPrimFunc(symbol.value())) { + if (ffi::Optional prim_func = LookupPrimFunc(symbol.value())) { this->EmitCallCPacked(prim_func.value(), args, dst_reg); } else { this->EmitCallPacked(symbol.value(), args, dst_reg); } } else { // Default path, leverage function table and invoke as closure - Array all_args; + ffi::Array all_args; all_args.push_back(ctx_ptr_); all_args.push_back(this->VisitExpr(call_node->op).value()); for (auto arg : args) { @@ -481,8 +482,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return stmt; } - Array VisitArray(const Array& arr) { - Array ret; + ffi::Array VisitArray(const ffi::Array& arr) { + ffi::Array ret; for (size_t i = 0; i < arr.size(); ++i) { ret.push_back(this->VisitExpr(arr[i]).value()); } @@ -506,11 +507,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { /*! \brief Stack to build up statements */ std::vector> stmt_stack_; /*! \brief Map from var to Expr. */ - std::unordered_map> var_map_; + std::unordered_map> var_map_; /*! \brief the context module. */ IRModule ctx_mod_; /*! \brief system lib prefix */ - Optional system_lib_prefix_; + ffi::Optional system_lib_prefix_; /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); @@ -531,10 +532,10 @@ IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) { return CodeGenVMTIR::Run(exec_builder, mod); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.VMTIRCodeGen", VMTIRCodeGen); -}); +} } // namespace codegen_vm } // namespace relax diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index 8e229c4fe641..b893b48830ce 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -30,11 +30,11 @@ namespace relax { using namespace vm; -TVM_FFI_STATIC_INIT_BLOCK({ ExecBuilderNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ExecBuilderNode::RegisterReflection(); } ExecBuilder ExecBuilderNode::Create() { - ExecBuilder ret(make_object()); - ret->exec_ = make_object(); + ExecBuilder ret(ffi::make_object()); + ret->exec_ = ffi::make_object(); return ret; } @@ -90,7 +90,7 @@ vm::Instruction::Arg ExecBuilderNode::GetFunction(const std::string& func_name) } void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inputs, - Optional> param_names, + ffi::Optional> param_names, vm::VMFuncInfo::FuncKind kind, int64_t init_register_size) { auto it = exec_->func_map.find(func_name); if (it == exec_->func_map.end()) { @@ -319,7 +319,7 @@ void ExecBuilderNode::Formalize() { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.ExecBuilderCreate", ExecBuilderNode::Create) @@ -331,17 +331,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ *ret = builder->ConvertConstant(rt).data(); }) .def("relax.ExecBuilderEmitFunction", - [](ExecBuilder builder, String func, int64_t num_inputs, - Optional> param_names) { + [](ExecBuilder builder, ffi::String func, int64_t num_inputs, + ffi::Optional> param_names) { builder->EmitFunction(func, num_inputs, param_names); }) .def_method("relax.ExecBuilderEndFunction", &ExecBuilderNode::EndFunction) .def("relax.ExecBuilderDeclareFunction", - [](ExecBuilder builder, String name, int32_t kind) { + [](ExecBuilder builder, ffi::String name, int32_t kind) { builder->DeclareFunction(name, static_cast(kind)); }) .def("relax.ExecBuilderEmitCall", - [](ExecBuilder builder, String name, Array args, int64_t dst) { + [](ExecBuilder builder, ffi::String name, ffi::Array args, int64_t dst) { std::vector args_; for (size_t i = 0; i < args.size(); ++i) { args_.push_back(Instruction::Arg::FromData(args[i]->value)); @@ -370,13 +370,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ExecBuilder builder, int64_t value) { return Instruction::Arg::ConstIdx(value).data(); }) - .def("relax.ExecBuilderF", - [](ExecBuilder builder, String value) { return builder->GetFunction(value).data(); }) + .def( + "relax.ExecBuilderF", + [](ExecBuilder builder, ffi::String value) { return builder->GetFunction(value).data(); }) .def("relax.ExecBuilderGet", [](ExecBuilder builder) { ObjectPtr p_exec = builder->Get(); return ffi::Module(p_exec); }); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 06adc3daba4c..71b8413e9889 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return ShapeOf(call); } else if (call->op == tensor_to_shape_op_) { return TensorToShape(call); + } else if (call->op == call_py_func_op_) { + return CallPyFunc(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -60,7 +63,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return InvokeClosure(call); } else if (call->op == alloc_tensor_op_) { LOG(FATAL) << "VMBuiltinLower encountered " << call->op << " in expression " - << GetRef(call_node) << ". " + << ffi::GetRef(call_node) << ". " << "This operation should have been lowered earlier " << "using the 'relax.transform.LowerAllocTensor' pass."; } else if (call->op == mem_alloc_storage_op_) { @@ -70,7 +73,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) { return MakeMemKillObject(call); } else if (const auto* op_node = call->op.as()) { - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); if (lower_builtin_fmap.count(op)) { return lower_builtin_fmap[op](builder_, call); } @@ -101,7 +104,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args.size() == 2); ICHECK(call_node->args[0]->IsInstance()); ICHECK(call_node->args[1]->IsInstance()); - Array args; + ffi::Array args; auto tir_args = Downcast(call_node->args[1]); args.push_back(call_node->args[0]); @@ -139,12 +142,27 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } + Expr CallPyFunc(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->struct_info_.defined()); + + // Create tuple with function name and arguments tuple + ffi::Array tuple_fields; + tuple_fields.push_back(call_node->args[0]); // function name + tuple_fields.push_back(call_node->args[1]); // arguments tuple + auto combined_tuple = Tuple(tuple_fields); + + // Direct call to vm.builtin.call_py_func + return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args, + call_node->span); + } + Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); auto attrs = call_node->attrs.as(); - Array args; + ffi::Array args; args.push_back(call_node->args[0]); // Get the DLDeviceType and device_id from VDevice VDevice vdev = attrs->dst_vdevice; @@ -160,7 +178,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args[0]->IsInstance()); ICHECK(call_node->args[1]->IsInstance()); - Array args; + ffi::Array args; auto func = call_node->args[0]; auto closure_args = Downcast(call_node->args[1]); @@ -177,7 +195,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args[0]->IsInstance()); ICHECK(call_node->args[1]->IsInstance()); - Array args; + ffi::Array args; args.push_back(call_node->args[0]); @@ -192,12 +210,13 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); const StructInfo object_sinfo_ = ObjectStructInfo(); - const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + const StructInfo void_sinfo_ = TupleStructInfo(ffi::Array({})); // object to pattern match. const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); + const Op& call_py_func_op_ = Op::Get("relax.call_py_func"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -216,6 +235,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"}; + const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"}; const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; @@ -232,10 +252,10 @@ Pass LowerRuntimeBuiltin() { return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.LowerRuntimeBuiltin", LowerRuntimeBuiltin); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 397490023cbe..bbc227d1d559 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -63,8 +63,8 @@ struct PrimExprSlot { */ struct MatchShapeTodoItem { Expr input; - Array pattern; - String err_ctx; + ffi::Array pattern; + ffi::String err_ctx; }; /*! \brief Slot map used for shape lowering. */ @@ -200,7 +200,7 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { */ class VMShapeLowerMutator : public ExprMutator, - public StructInfoFunctor*)> { public: static IRModule Lower(IRModule mod, bool emit_err_ctx) { @@ -208,7 +208,7 @@ class VMShapeLowerMutator for (auto& kv : mod->functions) { if (auto* func = kv.second.as()) { - Function updated_func = mutator.Rewrite(kv.first, GetRef(func)); + Function updated_func = mutator.Rewrite(kv.first, ffi::GetRef(func)); mutator.builder_->UpdateFunction(kv.first, updated_func); } } @@ -235,7 +235,7 @@ class VMShapeLowerMutator // prepare slot information this->PopulateSlotInfo(); - Array blocks; + ffi::Array blocks; builder_->BeginScope(func->params); @@ -305,7 +305,7 @@ class VMShapeLowerMutator for (auto& kv : slot_map_) { auto* slot = kv.second; if (!slot->expr.as()) { - Array dep_vars = tir::UndefinedVars(slot->expr); + ffi::Array dep_vars = tir::UndefinedVars(slot->expr); for (auto var : dep_vars) { auto it = slot_map_.find(var); ICHECK(it != slot_map_.end()) @@ -323,7 +323,7 @@ class VMShapeLowerMutator //------------------------------------------------------- // Helper functions //------------------------------------------------------- - StringImm GetErrContext(String err_ctx) const { + StringImm GetErrContext(ffi::String err_ctx) const { return emit_err_ctx_ ? StringImm(err_ctx) : StringImm(""); } @@ -350,7 +350,7 @@ class VMShapeLowerMutator Expr VisitExpr_(const FunctionNode* op) final { LOG(FATAL) << "VMShapeLower do not work for local functions, make sure " << " to run it after LambdaLift"; - return GetRef(op); + return ffi::GetRef(op); } std::pair MakeSymbolicShapeArg(const PrimExpr& expr) { @@ -376,10 +376,10 @@ class VMShapeLowerMutator bool is_const_value = op->value->IsInstance() || op->value->IsInstance(); if (is_const_value) { - return GetRef(op); + return ffi::GetRef(op); } - Array args = {shape_heap_}; + ffi::Array args = {shape_heap_}; auto [code, value_or_index] = MakeSymbolicShapeArg(op->value); args.push_back(code); args.push_back(value_or_index); @@ -396,10 +396,11 @@ class VMShapeLowerMutator return e->IsInstance(); }); if (is_const_shape) { - return GetRef(op); + return ffi::GetRef(op); } - Array args = {shape_heap_, PrimValue::Int64(static_cast(op->values.size()))}; + ffi::Array args = {shape_heap_, + PrimValue::Int64(static_cast(op->values.size()))}; for (PrimExpr expr : op->values) { auto [code, value_or_index] = MakeSymbolicShapeArg(expr); args.push_back(code); @@ -502,7 +503,7 @@ class VMShapeLowerMutator bool all_nop = true; bool any_nop = false; - Array args = {item.input, shape_heap_}; + ffi::Array args = {item.input, shape_heap_}; Expr match_op; if (item.input->struct_info_.as()) { @@ -567,18 +568,18 @@ class VMShapeLowerMutator ICHECK_GT(heap_size_->value, 0); // construct a PrimFunc that compute the shape. tir::Var heap("heap", DataType::Handle()); - Array buffer_shape{heap_size_}; + ffi::Array buffer_shape{heap_size_}; tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); - Map buffer_map; + ffi::Map buffer_map; buffer_map.Set(heap, buffer); - auto var_map = [&](const tir::Var& var) -> Optional { + auto var_map = [&](const tir::Var& var) -> ffi::Optional { auto it = slot_map_.find(var); ICHECK(it != slot_map_.end()); return tir::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); }; - Array seq; + ffi::Array seq; for (PrimExprSlot* slot : to_compute) { ICHECK(!slot->value_computed); slot->value_computed = true; @@ -587,7 +588,7 @@ class VMShapeLowerMutator } tir::Stmt body = tir::SeqStmt::Flatten(seq); - Array params{heap}; + ffi::Array params{heap}; Type ret_type = VoidType(); // TODO(relax-team): Consider attach the target attribute to @@ -623,14 +624,14 @@ class VMShapeLowerMutator * visit the match cast. */ void CheckMatchCast(const StructInfo& struct_info, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) { return this->VisitStructInfo(struct_info, value, always_check, dynamic_only, err_ctx, match_todos); } void VisitStructInfo(const StructInfo& struct_info, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // short-cut, if the struct info already satisfies the // constraint during match cast, we can skip matching @@ -640,11 +641,11 @@ class VMShapeLowerMutator } void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final {} void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape if (always_check || !IsBaseOf(PrimStructInfo(op->dtype), GetStructInfo(value))) { @@ -663,7 +664,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), GetStructInfo(value))) { @@ -683,7 +684,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape auto* shape_expr = op->shape.as(); @@ -734,7 +735,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { auto* value_tinfo = GetStructInfoAs(value); if (value_tinfo) { @@ -757,7 +758,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // we only check function is callable. if (!always_check && MatchStructInfo(value)) return; @@ -779,7 +780,7 @@ class VMShapeLowerMutator std::vector> slot_vec_; /*! \brief Expr => slot. */ PrimExprSlotMap slot_map_; - Optional current_gvar_ = std::nullopt; + ffi::Optional current_gvar_ = std::nullopt; /*! * \brief List of vars that are being defined but * have not go through outstanding shape compute check. @@ -790,7 +791,7 @@ class VMShapeLowerMutator const Op& null_value_op_ = Op::Get("relax.null_value"); // common struct info const StructInfo object_sinfo_ = ObjectStructInfo(); - const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + const StructInfo void_sinfo_ = TupleStructInfo(ffi::Array({})); // check function const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"}; const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"}; @@ -814,11 +815,11 @@ Pass VMShapeLower(bool emit_err_ctx) { return CreateModulePass(pass_func, 0, "VMShapeLower", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.VMShapeLower", [](bool emit_err_ctx) { return VMShapeLower(emit_err_ctx); }); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/distributed/axis_group_graph.cc b/src/relax/distributed/axis_group_graph.cc index 491ffc12fa57..12feeacc8b0b 100644 --- a/src/relax/distributed/axis_group_graph.cc +++ b/src/relax/distributed/axis_group_graph.cc @@ -29,7 +29,8 @@ namespace tvm { namespace tir { -Var GetShardingVarFromIndex(PrimExpr index, Map var_range, arith::Analyzer* analyzer) { +Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, + arith::Analyzer* analyzer) { if (index.as()) { return Downcast(index); } @@ -47,12 +48,12 @@ Var GetShardingVarFromIndex(PrimExpr index, Map var_range, arith::An return Var(); } // the floormod must take no effect - if (!analyzer->CanProve( - floordiv(var_range[GetRef(source_var)]->extent, highest_iter_split->lower_factor) <= - highest_iter_split->extent)) { + if (!analyzer->CanProve(floordiv(var_range[ffi::GetRef(source_var)]->extent, + highest_iter_split->lower_factor) <= + highest_iter_split->extent)) { return Var(); } - return GetRef(source_var); + return ffi::GetRef(source_var); } } // namespace tir } // namespace tvm @@ -75,7 +76,7 @@ const TensorStructInfoNode* GetTensorStructInfo(Expr tensor) { throw; } -void UnaryOpHelper(Array tensor_list, distributed::AxisGroupGraph* axis_group_graph) { +void UnaryOpHelper(ffi::Array tensor_list, distributed::AxisGroupGraph* axis_group_graph) { int n_dim = GetTensorStructInfo(tensor_list[0])->ndim; for (const auto& tensor : tensor_list) { ICHECK(GetTensorStructInfo(tensor)->ndim == n_dim); @@ -91,7 +92,7 @@ void UnaryOpHelper(Array tensor_list, distributed::AxisGroupGraph* axis_gr void BuildAxisGraphUnary(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { - Array tensor_list; // vars in param and output + ffi::Array tensor_list; // vars in param and output if (call->args[0]->IsInstance()) { tensor_list.push_back(call->args[0]); } @@ -101,7 +102,7 @@ void BuildAxisGraphUnary(const Var& output_var, const Call& call, void BuildAxisGraphBinary(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { - Array tensor_list; // vars in param and output + ffi::Array tensor_list; // vars in param and output if (call->args[0]->struct_info_.as() || call->args[0]->struct_info_.as()) { tensor_list.push_back(call->args[0]); @@ -162,7 +163,7 @@ void BuildAxisGraphBinary(const Var& output_var, const Call& call, void BuildAxisGraphReduce(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { Expr input_tensor = call->args[0]; - Array axes; + ffi::Array axes; bool keepdims; if (const auto* attrs = call->attrs.as()) { if (attrs->axis.defined()) { @@ -228,10 +229,10 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, const auto* x1_shape = x1_sinfo->shape.as(); const auto* x2_shape = x2_sinfo->shape.as(); ICHECK(x1_shape && x2_shape); - Array x1_shape_prefix{x1_shape->values.begin(), - x1_shape->values.end() - 2 + x1_prepended}; - Array x2_shape_prefix{x2_shape->values.begin(), - x2_shape->values.end() - 2 + x2_appended}; + ffi::Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + ffi::Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; int x1_prefix_ndim = x1_shape_prefix.size(); int x2_prefix_ndim = x2_shape_prefix.size(); @@ -311,8 +312,8 @@ void BuildAxisGraphReshape(const Var& output_var, const Call& call, const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); const auto* old_shape_sinfo = GetStructInfoAs(tensor_sinfo->shape.value()); ICHECK_NOTNULL(old_shape_sinfo); - Array old_shape_values = old_shape_sinfo->values.value(); - Array new_shape_values = new_shape_sinfo->values.value(); + ffi::Array old_shape_values = old_shape_sinfo->values.value(); + ffi::Array new_shape_values = new_shape_sinfo->values.value(); int i = old_shape_values.size(); int j = new_shape_values.size(); PrimExpr old_shape_product = 1, new_shape_product = 1; @@ -349,8 +350,8 @@ inline int GetNumOutput(Call call) { void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tir::PrimFunc& func, distributed::AxisGroupGraph* axis_group_graph) { auto tir_var_axis_group_list = tir::BufferAxisGraphExtractor::GetTIRVarAxisGraph(func); - Map input_var_to_relax_expr; - Array input_list = Downcast(call->args[1])->fields; + ffi::Map input_var_to_relax_expr; + ffi::Array input_list = Downcast(call->args[1])->fields; input_list.push_back(output_var); for (int i = 0; i < static_cast(input_list.size()); i++) { if (func->buffer_map.count(func->params[i])) { diff --git a/src/relax/distributed/global_info.cc b/src/relax/distributed/global_info.cc index b4f435569330..408d31680c79 100644 --- a/src/relax/distributed/global_info.cc +++ b/src/relax/distributed/global_info.cc @@ -24,14 +24,14 @@ namespace tvm { namespace relax { namespace distributed { -TVM_FFI_STATIC_INIT_BLOCK({ DeviceMeshNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { DeviceMeshNode::RegisterReflection(); } -DeviceMesh::DeviceMesh(ffi::Shape shape, Array device_ids) { +DeviceMesh::DeviceMesh(ffi::Shape shape, ffi::Array device_ids) { int prod = 1; for (int i = 0; i < static_cast(shape.size()); i++) { prod *= shape[i]; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); CHECK_EQ(prod, static_cast(device_ids.size())) << "The number of device ids must match the product of the shape"; n->shape = std::move(shape); @@ -40,8 +40,8 @@ DeviceMesh::DeviceMesh(ffi::Shape shape, Array device_ids) { } DeviceMesh::DeviceMesh(ffi::Shape shape, Range device_range) { - ObjectPtr n = make_object(); - Array device_ids; + ObjectPtr n = ffi::make_object(); + ffi::Array device_ids; int range_start = device_range->min.as()->value; int range_extent = device_range->extent.as()->value; for (int i = range_start; i < range_start + range_extent; i++) { @@ -59,17 +59,17 @@ DeviceMesh::DeviceMesh(ffi::Shape shape, Range device_range) { data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.distributed.DeviceMesh", - [](ffi::Shape shape, Array device_ids, Optional device_range) { + [](ffi::Shape shape, ffi::Array device_ids, ffi::Optional device_range) { if (device_range.defined()) return DeviceMesh(shape, device_range.value()); else return DeviceMesh(shape, device_ids); }); -}); +} } // namespace distributed } // namespace relax diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 0b6f3624cc10..5c51920fa7e6 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -28,34 +28,34 @@ namespace tvm { namespace relax { namespace distributed { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { DTensorStructInfoNode::RegisterReflection(); PlacementNode::RegisterReflection(); PlacementSpecNode::RegisterReflection(); -}); +} PlacementSpec PlacementSpec::Sharding(int axis) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->axis = axis; n->kind = PlacementSpecKind::kSharding; return PlacementSpec(n); } PlacementSpec PlacementSpec::Replica() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->axis = -1; n->kind = PlacementSpecKind::kReplica; return PlacementSpec(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.distributed.Sharding", [](int axis) { return PlacementSpec::Sharding(axis); }) .def("relax.distributed.Replica", []() { return PlacementSpec::Replica(); }); -}); +} -String PlacementNode::ToString() const { +ffi::String PlacementNode::ToString() const { std::stringstream ss; for (size_t i = 0; i < dim_specs.size(); ++i) { if (i != 0) { @@ -70,14 +70,14 @@ String PlacementNode::ToString() const { return ss.str(); } -Placement::Placement(Array dim_specs) { - ObjectPtr n = make_object(); +Placement::Placement(ffi::Array dim_specs) { + ObjectPtr n = ffi::make_object(); n->dim_specs = std::move(dim_specs); data_ = std::move(n); } -Placement Placement::FromText(String text_repr) { - Array dim_specs; +Placement Placement::FromText(ffi::String text_repr) { + ffi::Array dim_specs; std::stringstream ss(text_repr); while (true) { char indicator = 0; @@ -109,13 +109,13 @@ Placement Placement::FromText(String text_repr) { return Placement(dim_specs); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.distributed.PlacementFromText", Placement::FromText) .def("relax.distributed.Placement", - [](Array dim_specs) { return Placement(dim_specs); }); -}); + [](ffi::Array dim_specs) { return Placement(dim_specs); }); +} // DTensor DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, @@ -127,7 +127,7 @@ DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh d CHECK_LT(spec->axis, tensor_sinfo->ndim) << "ValueError: Sharding dimension should be smaller than tensor ndim"; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->device_mesh = std::move(device_mesh); n->placement = std::move(placement); n->tensor_sinfo = std::move(tensor_sinfo); @@ -135,14 +135,14 @@ DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh d data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.distributed.DTensorStructInfo", [](TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span) { return DTensorStructInfo(tensor_sinfo, device_mesh, placement, span); }); -}); +} } // namespace distributed } // namespace relax diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 47f28252ff51..aaac39c61b20 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -55,7 +55,7 @@ class RedistributeLegalizer : public ExprMutator { continue; } Expr new_func_body = VisitExpr(func_->body); - auto new_func = make_object(*func_); + auto new_func = ffi::make_object(*func_); new_func->body = new_func_body; builder_->UpdateFunction(gv, Function(new_func)); } @@ -116,10 +116,10 @@ Pass LegalizeRedistribute() { }; return CreateModulePass(pass_func, 1, "LegalizeRedistribute", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.distributed.transform.LegalizeRedistribute", LegalizeRedistribute); -}); +} } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 036867043f71..7930e2dfe7fc 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -52,10 +52,10 @@ class DistIRSharder : public ExprMutator { auto mod = builder_->GetContextIRModule(); for (const auto& [gv, base_func] : mod->functions) { const auto* func_ = base_func.as(); - if (func_ == nullptr || !IsDistIRFunc(GetRef(func_))) { + if (func_ == nullptr || !IsDistIRFunc(ffi::GetRef(func_))) { continue; } - Function func = RewriteFunction(GetRef(func_)); + Function func = RewriteFunction(ffi::GetRef(func_)); builder_->UpdateFunction(gv, func); } return builder_->GetContextIRModule(); @@ -63,7 +63,7 @@ class DistIRSharder : public ExprMutator { ShapeExpr ShardShape(ShapeExpr orig_shape, DeviceMesh device_mesh, Placement placement) { ffi::Shape device_mesh_shape = device_mesh->shape; - Array new_tensor_shape_value = orig_shape->values; + ffi::Array new_tensor_shape_value = orig_shape->values; for (int i = 0; i < static_cast(device_mesh_shape.size()); i++) { if (placement->dim_specs[i]->kind == PlacementSpecKind::kSharding) { int shard_size = device_mesh_shape[i]; @@ -78,25 +78,25 @@ class DistIRSharder : public ExprMutator { TensorStructInfo tensor_sinfo = orig_sinfo->tensor_sinfo; ICHECK(tensor_sinfo->shape); const auto* orig_shape = tensor_sinfo->shape.as(); - auto new_tensor_sinfo = make_object(*tensor_sinfo.get()); - new_tensor_sinfo->shape = - ShardShape(GetRef(orig_shape), orig_sinfo->device_mesh, orig_sinfo->placement); + auto new_tensor_sinfo = ffi::make_object(*tensor_sinfo.get()); + new_tensor_sinfo->shape = ShardShape(ffi::GetRef(orig_shape), + orig_sinfo->device_mesh, orig_sinfo->placement); return TensorStructInfo(new_tensor_sinfo); } StructInfo ConvertSinfo(StructInfo orig_sinfo, bool shard_shape) { if (const auto* dtensor_sinfo = orig_sinfo.as()) { if (shard_shape) { - return ShardDTensorSinfo(GetRef(dtensor_sinfo)); + return ShardDTensorSinfo(ffi::GetRef(dtensor_sinfo)); } else { return dtensor_sinfo->tensor_sinfo; } } else if (const auto* tuple_sinfo = orig_sinfo.as()) { - Array new_fields; + ffi::Array new_fields; for (const auto& field_sinfo : tuple_sinfo->fields) { if (const auto* dtensor_sinfo = field_sinfo.as()) { if (shard_shape) { - new_fields.push_back(ShardDTensorSinfo(GetRef(dtensor_sinfo))); + new_fields.push_back(ShardDTensorSinfo(ffi::GetRef(dtensor_sinfo))); } else { new_fields.push_back(dtensor_sinfo->tensor_sinfo); } @@ -157,12 +157,13 @@ class DistIRSharder : public ExprMutator { for (int i = 0; i < static_cast(func_->params.size()); i++) { Var param = func_->params[i]; if (const auto* dtensor_sinfo = GetStructInfoAs(param)) { - EmitBroadcastOrScatter(param, new_params_[i], GetRef(dtensor_sinfo)); + EmitBroadcastOrScatter(param, new_params_[i], + ffi::GetRef(dtensor_sinfo)); } else if (const auto* tuple_sinfo = GetStructInfoAs(param)) { for (int j = 0; j < static_cast(tuple_sinfo->fields.size()); j++) { if (const auto* dtensor_sinfo = tuple_sinfo->fields[j].as()) { EmitBroadcastOrScatter(TupleGetItem(param, j), TupleGetItem(new_params_[i], j), - GetRef(dtensor_sinfo)); + ffi::GetRef(dtensor_sinfo)); } } } @@ -170,7 +171,7 @@ class DistIRSharder : public ExprMutator { } Function RewriteFunction(Function func) { - Array new_params; + ffi::Array new_params; for (const Var& var : func->params) { Var new_param = Downcast(ShardInputParamTensorAndConstant(var)); var_remap_[var->vid] = new_param; @@ -184,8 +185,8 @@ class DistIRSharder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { - if (tuple_getitem_remap_.count(GetRef(val))) { - var_remap_[binding->var->vid] = tuple_getitem_remap_[GetRef(val)]; + if (tuple_getitem_remap_.count(ffi::GetRef(val))) { + var_remap_[binding->var->vid] = tuple_getitem_remap_[ffi::GetRef(val)]; } else { ExprMutator::VisitBinding_(binding, val); } @@ -217,19 +218,19 @@ class DistIRSharder : public ExprMutator { ICHECK(call->args[1].as()); const auto* out_sinfo = GetStructInfoAs(binding_var); ICHECK(out_sinfo); - auto new_call_node = make_object(*call); + auto new_call_node = ffi::make_object(*call); new_call_node->args.Set(1, ShardShape(Downcast(call->args[1]), out_sinfo->device_mesh, out_sinfo->placement)); return Call(new_call_node); } else if (call->op.same_as(call_tir_local_view_op)) { - auto new_call_node = make_object(*call); + auto new_call_node = ffi::make_object(*call); new_call_node->op = call_tir_op; new_call_node->sinfo_args = {ConvertSinfo(GetStructInfo(binding_var), true)}; return Call(new_call_node); } else if (call->op.same_as(call_tir_op)) { LOG(FATAL) << "call_tir should be lowered to call_tir_local_view before lowering to relax"; } else if (const auto* extern_func = call->op.as()) { - auto new_call_node = make_object(*call); + auto new_call_node = ffi::make_object(*call); if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_append") { new_call_node->op = ExternFunc("vm.builtin.attention_kv_cache_append"); } else if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_view") { @@ -243,7 +244,7 @@ class DistIRSharder : public ExprMutator { } return Call(new_call_node); } - return GetRef(call); + return ffi::GetRef(call); } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { @@ -253,7 +254,7 @@ class DistIRSharder : public ExprMutator { } Function func_; - Array new_params_; + ffi::Array new_params_; std::unordered_map tuple_getitem_remap_; }; @@ -263,10 +264,10 @@ Pass LowerDistIR() { auto pass_func = [=](IRModule m, PassContext pc) { return DistIRSharder::LowerDistIR(m); }; return CreateModulePass(pass_func, 1, "LowerDistIR", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.distributed.transform.LowerDistIR", LowerDistIR); -}); +} } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 7baf49508d58..837f2f0a5dcb 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -36,18 +36,18 @@ using namespace tvm::relax::distributed; class DistBufferReplacer : public StmtExprMutator { public: - static Stmt BufferReplace(Stmt stmt, Map buffer_map) { + static Stmt BufferReplace(Stmt stmt, ffi::Map buffer_map) { DistBufferReplacer replacer(buffer_map); return replacer(stmt); } private: - explicit DistBufferReplacer(Map buffer_map) : buffer_map_(buffer_map) {} + explicit DistBufferReplacer(ffi::Map buffer_map) : buffer_map_(buffer_map) {} Stmt VisitStmt_(const BufferStoreNode* _store) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); if (buffer_map_.count(store->buffer)) { - ObjectPtr new_store = make_object(*store.get()); + ObjectPtr new_store = ffi::make_object(*store.get()); new_store->buffer = buffer_map_[store->buffer]; return BufferStore(new_store); } @@ -57,7 +57,7 @@ class DistBufferReplacer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* _load) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); if (buffer_map_.count(load->buffer)) { - ObjectPtr new_load = make_object(*load.get()); + ObjectPtr new_load = ffi::make_object(*load.get()); new_load->buffer = buffer_map_[load->buffer]; return BufferLoad(new_load); } @@ -65,15 +65,15 @@ class DistBufferReplacer : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* _block) final { - Block old_block = GetRef(_block); + Block old_block = ffi::GetRef(_block); Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); - ObjectPtr new_block = make_object(*block.get()); + ObjectPtr new_block = ffi::make_object(*block.get()); new_block->reads = ReplaceBuffer(new_block->reads, buffer_map_); new_block->writes = ReplaceBuffer(new_block->writes, buffer_map_); return Block(new_block); } - Map buffer_map_; + ffi::Map buffer_map_; }; class DistBlockInfoCollector : public StmtExprVisitor { @@ -136,7 +136,7 @@ class DistBlockInfoCollector : public StmtExprVisitor { Buffer reduce_buffer_; public: - std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> + std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> buffer_access_indices; std::string reduce_kind; }; @@ -151,8 +151,8 @@ class DistributedBufferCompactor : StmtExprMutator { const std::vector& sharding_specs, PrimFunc prim_func) { prim_func = RenewDefs(prim_func); DistributedBufferCompactor compactor(sharding_specs, prim_func); - Map new_func_buffer_map; - Map replace_buffer_map; + ffi::Map new_func_buffer_map; + ffi::Map replace_buffer_map; for (const auto& pr : prim_func->buffer_map) { Buffer shard_buffer = compactor.ShardBuffer(pr.second); new_func_buffer_map.Set(pr.first, shard_buffer); @@ -162,7 +162,7 @@ class DistributedBufferCompactor : StmtExprMutator { } Stmt new_body = compactor(prim_func->body); new_body = DistBufferReplacer::BufferReplace(new_body, replace_buffer_map); - ObjectPtr new_func = make_object(*prim_func.get()); + ObjectPtr new_func = ffi::make_object(*prim_func.get()); new_func->buffer_map = new_func_buffer_map; new_func->body = new_body; return std::make_tuple(PrimFunc(new_func), compactor.add_allreduce_kind_); @@ -200,10 +200,9 @@ class DistributedBufferCompactor : StmtExprMutator { } } - Array ShardIterVar( - Block block, - const std::unordered_map>, ObjectPtrHash, ObjectPtrEqual>& - buffer_access_indices) { + ffi::Array ShardIterVar( + Block block, const std::unordered_map>, ObjectPtrHash, + ObjectPtrEqual>& buffer_access_indices) { std::vector buffers; for (const auto& read : block->reads) { buffers.push_back(read->buffer); @@ -211,7 +210,7 @@ class DistributedBufferCompactor : StmtExprMutator { for (const auto& write : block->writes) { buffers.push_back(write->buffer); } - Map iter_var_range; + ffi::Map iter_var_range; for (const auto& iter_var : block->iter_vars) { iter_var_range.Set(iter_var->var, iter_var->dom); } @@ -220,7 +219,7 @@ class DistributedBufferCompactor : StmtExprMutator { if (buffer_access_indices.count(buffer) == 0 || buffer_shards_.count(buffer) == 0) { continue; } - Array> access_indices = buffer_access_indices.at(buffer); + ffi::Array> access_indices = buffer_access_indices.at(buffer); DimShard dim_shards = buffer_shards_[buffer]; for (const auto& access_index : access_indices) { for (const auto& pr : dim_shards) { @@ -234,7 +233,7 @@ class DistributedBufferCompactor : StmtExprMutator { } } - Array new_iter_vars; + ffi::Array new_iter_vars; for (const auto& iter_var : block->iter_vars) { if (iter_var_shards_.count(iter_var->var)) { int shard = iter_var_shards_[iter_var->var]; @@ -259,7 +258,7 @@ class DistributedBufferCompactor : StmtExprMutator { return buffer; } DimShard dim_shards = buffer_shards_[buffer]; - Array shape; + ffi::Array shape; for (int i = 0; i < static_cast(buffer->shape.size()); i++) { if (dim_shards.count(i)) { shape.push_back(floordiv(buffer->shape[i], dim_shards[i])); @@ -267,7 +266,7 @@ class DistributedBufferCompactor : StmtExprMutator { shape.push_back(buffer->shape[i]); } } - ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); new_buffer->shape = shape; return Buffer(new_buffer); } @@ -276,9 +275,9 @@ class DistributedBufferCompactor : StmtExprMutator { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); DistBlockInfoCollector collector; collector(block); - Array new_iter_vars = ShardIterVar(block, collector.buffer_access_indices); - Array new_alloc_buffers; - Map buffer_map; + ffi::Array new_iter_vars = ShardIterVar(block, collector.buffer_access_indices); + ffi::Array new_alloc_buffers; + ffi::Map buffer_map; for (const Buffer& buffer : block->alloc_buffers) { Buffer sharded_buffer = ShardBuffer(buffer); if (!sharded_buffer.same_as(buffer)) { @@ -295,7 +294,7 @@ class DistributedBufferCompactor : StmtExprMutator { break; } } - ObjectPtr new_block = make_object(*block.operator->()); + ObjectPtr new_block = ffi::make_object(*block.operator->()); new_block->iter_vars = new_iter_vars; new_block->alloc_buffers = new_alloc_buffers; if (new_block->name_hint == "root") { @@ -331,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator { if (shard > 1) { arith::Analyzer analyzer; ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0)); - return For(new_loop->loop_var, new_loop->min, floordiv(new_loop->extent, shard), - new_loop->kind, new_loop->body, new_loop->thread_binding, new_loop->annotations); + new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard); + return new_loop; } } return new_loop; @@ -340,7 +339,7 @@ class DistributedBufferCompactor : StmtExprMutator { std::unordered_map iter_var_shards_; std::unordered_map loop_var_shards_; - Array allocated_buffer_under_root; + ffi::Array allocated_buffer_under_root; BufferAxisGraphExtractor extractor_; std::vector sharding_specs_; std::unordered_map buffer_shards_; @@ -362,11 +361,11 @@ class LowerTIRToLocalView : public ExprMutator { auto mod = builder_->GetContextIRModule(); for (const auto& [gv, base_func] : mod->functions) { const auto* func_ = base_func.as(); - if (func_ == nullptr || !IsDistIRFunc(GetRef(func_))) { + if (func_ == nullptr || !IsDistIRFunc(ffi::GetRef(func_))) { continue; } Expr new_func_body = this->VisitExpr(func_->body); - ObjectPtr new_func = make_object(*func_); + ObjectPtr new_func = ffi::make_object(*func_); new_func->body = new_func_body; builder_->UpdateFunction(gv, Function(new_func)); } @@ -374,11 +373,11 @@ class LowerTIRToLocalView : public ExprMutator { } private: - inline Array ExtractDTensorStructInfo(Var var) { + inline ffi::Array ExtractDTensorStructInfo(Var var) { if (const auto* dtensor_sinfo = GetStructInfoAs(var)) { - return {GetRef(dtensor_sinfo)}; + return {ffi::GetRef(dtensor_sinfo)}; } else if (const auto* tuple_sinfo = GetStructInfoAs(var)) { - Array ret; + ffi::Array ret; for (const auto& field : tuple_sinfo->fields) { ret.push_back(Downcast(field)); } @@ -395,14 +394,14 @@ class LowerTIRToLocalView : public ExprMutator { return; } std::vector sharding_specs; - Array args = Downcast(val->args[1])->fields; + ffi::Array args = Downcast(val->args[1])->fields; for (const auto& arg : args) { const auto* sinfo = GetStructInfoAs(arg); ICHECK(sinfo); sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); } Var output_var = binding->var; - Array output_sinfos = ExtractDTensorStructInfo(output_var); + ffi::Array output_sinfos = ExtractDTensorStructInfo(output_var); for (const auto& sinfo : output_sinfos) { sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); } @@ -414,12 +413,12 @@ class LowerTIRToLocalView : public ExprMutator { tir::DistributedBufferCompactor::DistBufferCompact(sharding_specs, prim_func); auto new_gvar = builder_->AddFunction(new_prim_func, gvar->name_hint); Call call = Downcast(this->VisitExpr(binding->value)); - ObjectPtr new_call_node = make_object(*call.get()); + ObjectPtr new_call_node = ffi::make_object(*call.get()); new_call_node->op = Op::Get("relax.dist.call_tir_local_view"); new_call_node->args.Set(0, new_gvar); Call new_call(new_call_node); if (allreduce_kind != "") { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->op_type = allreduce_kind; new_call = Call(Op::Get("relax.ccl.allreduce"), {new_call}, Attrs(attrs), {}); } @@ -433,11 +432,11 @@ Pass LowerGlobalViewToLocalView() { auto pass_func = [=](IRModule m, PassContext pc) { return LowerTIRToLocalView(m).Lower(); }; return CreateModulePass(pass_func, 1, "LowerGlobalViewToLocalView", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.distributed.transform.LowerGlobalViewToLocalView", LowerGlobalViewToLocalView); -}); +} } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 1f46b54cfe50..1ff614c019c8 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -48,7 +48,7 @@ void CollectAxisGraphBinary(const VarBindingNode* binding, const CallNode* call, for (const auto& op_name : binary_op_names) { const Op& binary_op = Op::Get("relax." + op_name); if (call->op.same_as(binary_op)) { - BuildAxisGraphBinary(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphBinary(binding->var, ffi::GetRef(call), axis_group_graph); break; } } @@ -71,7 +71,7 @@ void CollectAxisGraphUnary(const VarBindingNode* binding, const CallNode* call, for (const auto& op_name : unary_op_names) { const Op& unary_op = Op::Get("relax." + op_name); if (call->op.same_as(unary_op)) { - BuildAxisGraphUnary(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphUnary(binding->var, ffi::GetRef(call), axis_group_graph); } } } @@ -83,7 +83,7 @@ void CollectAxisGraphReduce(const VarBindingNode* binding, const CallNode* call, for (const auto& op_name : reduction_op_names) { const Op& reduction_op = Op::Get("relax." + op_name); if (call->op.same_as(reduction_op)) { - BuildAxisGraphReduce(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphReduce(binding->var, ffi::GetRef(call), axis_group_graph); break; } } @@ -93,7 +93,7 @@ void CollectAxisGraphMatmul(const VarBindingNode* binding, const CallNode* call, AxisGroupGraph* axis_group_graph) { static const Op& matmul_op = Op::Get("relax.matmul"); if (call->op.same_as(matmul_op)) { - BuildAxisGraphMatmul(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphMatmul(binding->var, ffi::GetRef(call), axis_group_graph); } } @@ -101,7 +101,7 @@ void CollectAxisGraphPermuteDims(const VarBindingNode* binding, const CallNode* AxisGroupGraph* axis_group_graph) { static const Op& permute_dims_op = Op::Get("relax.permute_dims"); if (call->op.same_as(permute_dims_op)) { - BuildAxisGraphPermuteDims(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphPermuteDims(binding->var, ffi::GetRef(call), axis_group_graph); } } @@ -109,15 +109,15 @@ void CollectAxisGraphReshape(const VarBindingNode* binding, const CallNode* call AxisGroupGraph* axis_group_graph) { static const Op& reshape_op = Op::Get("relax.reshape"); if (call->op.same_as(reshape_op)) { - BuildAxisGraphReshape(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphReshape(binding->var, ffi::GetRef(call), axis_group_graph); } } void CollectAxisGraphForDeviceMesh(const VarBindingNode* binding, const CallNode* call, AxisGroupGraph* axis_group_graph) { - Array tensor_list; + ffi::Array tensor_list; static const Op& call_tir_op = Op::Get("relax.call_tir"); - Array args; + ffi::Array args; if (call->op.same_as(call_tir_op)) { args = Downcast(call->args[1])->fields; } else { @@ -158,8 +158,9 @@ class AxisGroupGraphBuilder : public ExprVisitor { CollectAxisGraphReshape(binding, val, axis_group_graph_); static const Op& call_tir_op = Op::Get("relax.call_tir"); if (val->op.same_as(call_tir_op)) { - if (Optional func = MatchPrimFunc(mod_, val->args[0])) { - BuildAxisGraphCallTIR(binding->var, GetRef(val), func.value(), axis_group_graph_); + if (ffi::Optional func = MatchPrimFunc(mod_, val->args[0])) { + BuildAxisGraphCallTIR(binding->var, ffi::GetRef(val), func.value(), + axis_group_graph_); } } CollectAxisGraphForDeviceMesh(binding, val, axis_group_graph_); @@ -183,9 +184,9 @@ class AxisGroupGraphBuilder : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding, const VarNode* val) { - Array tensor_sinfos; + ffi::Array tensor_sinfos; if (const auto* tensor_sinfo = binding->var->struct_info_.as()) { - tensor_sinfos.push_back(GetRef(tensor_sinfo)); + tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else if (const auto* tuple_sinfo = binding->var->struct_info_.as()) { ICHECK(tuple_sinfo); for (const auto& sinfo : tuple_sinfo->fields) { @@ -271,7 +272,7 @@ class ShardingConflictHandler : public ExprVisitor { ICHECK(shape); int ndim = sinfo->ndim; std::unordered_set sharded_mesh_dim; - Optional device_mesh; + ffi::Optional device_mesh; for (int i = -1; i < ndim; i++) { AxisShardingSpec sharding_spec; int has_sharding_spec; @@ -318,7 +319,7 @@ class ShardingConflictHandler : public ExprVisitor { } void VisitExpr_(const CallNode* op) final { - Array args = GetCallArgs(GetRef(op)); + ffi::Array args = GetCallArgs(ffi::GetRef(op)); for (const auto& arg : args) { if (arg.as()) { CheckConstantNoSharding(Downcast(arg)); @@ -348,10 +349,10 @@ class DistributedIRBuilder : public ExprMutator { auto mod = builder_->GetContextIRModule(); for (const auto& [gv, base_func] : mod->functions) { const auto* func_ = base_func.as(); - if (func_ == nullptr || !IsShardingAnnotatedFunc(GetRef(func_))) { + if (func_ == nullptr || !IsShardingAnnotatedFunc(ffi::GetRef(func_))) { continue; } - Function func = RewriteFunction(GetRef(func_), mod); + Function func = RewriteFunction(ffi::GetRef(func_), mod); builder_->UpdateFunction(gv, func); } return builder_->GetContextIRModule(); @@ -366,7 +367,7 @@ class DistributedIRBuilder : public ExprMutator { DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({expr.get(), -1, tuple_idx})).first; ICHECK(device_mesh.defined()) << expr << "[" << tuple_idx << "] is not assigned device mesh"; - Array placement_specs( + ffi::Array placement_specs( std::vector(device_mesh->shape.size(), PlacementSpec::Replica())); for (int i = 0; i < ndim; i++) { AxisShardingSpec sharding_spec; @@ -387,7 +388,7 @@ class DistributedIRBuilder : public ExprMutator { new_sinfo = ConvertToDTensorStructInfo(Downcast(tensor->struct_info_), tensor); } else if (const auto* tuple = tensor->struct_info_.as()) { - Array tuple_sinfo_fields; + ffi::Array tuple_sinfo_fields; for (int i = 0; i < static_cast(tuple->fields.size()); i++) { if (tuple->fields[i].as()) { tuple_sinfo_fields.push_back( @@ -419,7 +420,7 @@ class DistributedIRBuilder : public ExprMutator { // Step 3. Handle Sharding Conflict ShardingConflictHandler::HandleShardingConflict(&axis_group_graph_, func); // Step 4. Rewrite Function - Array new_params; + ffi::Array new_params; for (const Var& var : func->params) { if (GetStructInfoAs(var) || GetStructInfoAs(var)) { Var new_param = Downcast(RewriteInputTensorAndConstant(var)); @@ -437,20 +438,20 @@ class DistributedIRBuilder : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); FBuildAxisGraph f = [&](const Var& var, const Call& call, AxisGroupGraph* axis_group_graph) { - Optional prim_func = + ffi::Optional prim_func = MatchPrimFunc(this->builder_->GetContextIRModule(), call->args[0]); ICHECK(prim_func); return BuildAxisGraphCallTIR(var, call, prim_func.value(), axis_group_graph); }; Call new_call = Downcast(ExprMutator::VisitExpr_(call)); - Array args = GetCallArgs(new_call); + ffi::Array args = GetCallArgs(new_call); for (int i = 0; i < static_cast(args.size()); i++) { if (args[i].as()) { args.Set(i, RewriteInputTensorAndConstant(args[i])); } } - ObjectPtr n = make_object(*new_call.get()); + ObjectPtr n = ffi::make_object(*new_call.get()); if (new_call->op.same_as(call_tir_op)) { // do not infer output sinfo when arg size is 0 if (!args.empty()) { @@ -484,13 +485,13 @@ class DistributedIRBuilder : public ExprMutator { return redistribute(expr, device_mesh, placement); } - Call RewriteOutSinfo(Call call, DeviceMesh device_mesh, Array placements) { + Call RewriteOutSinfo(Call call, DeviceMesh device_mesh, ffi::Array placements) { // in cases when infer fails (like arg size is 0), we use propagated sinfo for output Call new_call = call; static Op call_tir_op = Op::Get("relax.call_tir"); if (const auto* extern_func = call->op.as()) { if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_view") { - ObjectPtr new_call_node = make_object(*call.get()); + ObjectPtr new_call_node = ffi::make_object(*call.get()); StructInfo new_dtensor_sinfo = DTensorStructInfo( Downcast(call->sinfo_args[0]), device_mesh, placements[0]); new_call_node->sinfo_args = {new_dtensor_sinfo}; @@ -500,14 +501,14 @@ class DistributedIRBuilder : public ExprMutator { } else if (call->op.same_as(call_tir_op)) { ICHECK(call->sinfo_args.size() == 1); if (!SinfoCompatibleWithDistIR(call->sinfo_args)) { - ObjectPtr new_call_node = make_object(*call.get()); + ObjectPtr new_call_node = ffi::make_object(*call.get()); if (placements.size() == 1) { new_call_node->sinfo_args = {DTensorStructInfo( Downcast(call->sinfo_args[0]), device_mesh, placements[0])}; } else { const auto* tuple_sinfo = call->sinfo_args[0].as(); ICHECK(placements.size() == tuple_sinfo->fields.size()); - Array new_tuple_sinfo_fields; + ffi::Array new_tuple_sinfo_fields; for (int i = 0; i < static_cast(placements.size()); i++) { new_tuple_sinfo_fields.push_back(DTensorStructInfo( Downcast(tuple_sinfo->fields[i]), device_mesh, placements[i])); @@ -522,9 +523,9 @@ class DistributedIRBuilder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { - Array orig_output_tensor_sinfos; + ffi::Array orig_output_tensor_sinfos; if (const auto* tensor_sinfo = GetStructInfoAs(binding->var)) { - orig_output_tensor_sinfos.push_back(GetRef(tensor_sinfo)); + orig_output_tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else if (const auto* tuple_sinfo = GetStructInfoAs(binding->var)) { for (const auto& sinfo : tuple_sinfo->fields) { orig_output_tensor_sinfos.push_back(Downcast(sinfo)); @@ -537,9 +538,9 @@ class DistributedIRBuilder : public ExprMutator { DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({binding->var.get(), -1})).first; ICHECK(device_mesh.defined()); - Array placements; // every tuple element has a placement + ffi::Array placements; // every tuple element has a placement for (int idx = 0; idx < static_cast(orig_output_tensor_sinfos.size()); idx++) { - Array placement_specs( + ffi::Array placement_specs( std::vector(device_mesh->shape.size(), PlacementSpec::Replica())); for (int i = 0; i < orig_output_tensor_sinfos[idx]->ndim; i++) { AxisShardingSpec sharding_spec; @@ -565,7 +566,7 @@ class DistributedIRBuilder : public ExprMutator { new_value = InsertRedistribute(new_value, device_mesh, placements[0]); } if (const auto* var = new_value.as()) { - var_remap_[binding->var->vid] = GetRef(var); + var_remap_[binding->var->vid] = ffi::GetRef(var); } else { ReEmitBinding(binding, builder_->Normalize(new_value)); } @@ -589,22 +590,22 @@ class DistributedIRBuilder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { - if (tuple_getitem_remap_.count(GetRef(val))) { - var_remap_[binding->var->vid] = tuple_getitem_remap_[GetRef(val)]; + if (tuple_getitem_remap_.count(ffi::GetRef(val))) { + var_remap_[binding->var->vid] = tuple_getitem_remap_[ffi::GetRef(val)]; } else { ExprMutator::VisitBinding_(binding, val); } } Expr VisitExpr_(const VarNode* var) final { - auto it = input_tensor_remap_.find(GetRef(var)); + auto it = input_tensor_remap_.find(ffi::GetRef(var)); if (it != input_tensor_remap_.end()) { var_remap_[var->vid] = (*it).second; } return ExprMutator::VisitExpr_(var); } - Map input_tensor_remap_; + ffi::Map input_tensor_remap_; std::unordered_map tuple_getitem_remap_; AxisGroupGraph axis_group_graph_; }; @@ -616,10 +617,10 @@ Pass PropagateSharding() { }; return CreateModulePass(pass_func, 1, "PropagateSharding", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.distributed.transform.PropagateSharding", PropagateSharding); -}); +} } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/utils.cc b/src/relax/distributed/transform/utils.cc index 42b914617e73..0bcd730d42c8 100644 --- a/src/relax/distributed/transform/utils.cc +++ b/src/relax/distributed/transform/utils.cc @@ -22,7 +22,7 @@ namespace tvm { namespace relax { namespace distributed { -bool SinfoCompatibleWithDistIR(Array sinfos) { +bool SinfoCompatibleWithDistIR(ffi::Array sinfos) { bool compatible = true; for (const auto& sinfo : sinfos) { if (const auto* tuple_sinfo = sinfo.as()) { @@ -34,7 +34,7 @@ bool SinfoCompatibleWithDistIR(Array sinfos) { return compatible; } -bool SinfoCompatibleWithRelax(Array sinfos) { +bool SinfoCompatibleWithRelax(ffi::Array sinfos) { bool compatible = true; for (const auto& sinfo : sinfos) { if (const auto* tuple_sinfo = sinfo.as()) { @@ -46,7 +46,7 @@ bool SinfoCompatibleWithRelax(Array sinfos) { return compatible; } bool IsDistIRFunc(Function func) { - Array param_sinfos; + ffi::Array param_sinfos; for (const auto& param : func->params) { ICHECK(param->struct_info_); param_sinfos.push_back(Downcast(param->struct_info_.value())); diff --git a/src/relax/distributed/transform/utils.h b/src/relax/distributed/transform/utils.h index 2680c892695c..963efc15f6a0 100644 --- a/src/relax/distributed/transform/utils.h +++ b/src/relax/distributed/transform/utils.h @@ -33,12 +33,12 @@ namespace distributed { * \brief Pattern match op to a TIR function and look it up. * \return The TIR function, or nullopt if pattern match fails. */ -inline Optional MatchPrimFunc(const IRModule& mod_, const Expr& op) { +inline ffi::Optional MatchPrimFunc(const IRModule& mod_, const Expr& op) { const GlobalVar& global_var = Downcast(op); // NOTE: as check works for nullptr(returns null) - Optional base_func = mod_->functions.Get(global_var); + ffi::Optional base_func = mod_->functions.Get(global_var); if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); + return ffi::GetRef(pfunc); } return std::nullopt; } @@ -46,7 +46,7 @@ inline Optional MatchPrimFunc(const IRModule& mod_, const Expr& o * \brief Check whether the given struct infos can appear in DistIR * \return Whether the given struct infos can appear in DistIR */ -bool SinfoCompatibleWithDistIR(Array sinfos); +bool SinfoCompatibleWithDistIR(ffi::Array sinfos); /*! * \brief Check whether the given function is a DistIR function diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 9dae9175ef27..0bbfef31b83a 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -36,10 +36,10 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ DataflowBlockRewriteNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { DataflowBlockRewriteNode::RegisterReflection(); } DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { - auto n = make_object(); + auto n = ffi::make_object(); n->dfb_ = dfb; n->root_fn_ = root_fn; n->original_fn_ptr_ = root_fn.get(); @@ -52,12 +52,12 @@ DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.DataflowBlockRewrite", [](DataflowBlock dfb, Function root_fn) { return DataflowBlockRewrite(dfb, root_fn); }); -}); +} void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { class ReplaceAllUsePass : public ExprMutator { @@ -73,7 +73,7 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { using ExprMutator::VisitExpr_; Expr VisitExpr_(const VarNode* op) override { - return (op == old_var.get()) ? new_var : GetRef(op); + return (op == old_var.get()) ? new_var : ffi::GetRef(op); } BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { @@ -113,13 +113,13 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dfb_rewrite_replace_all_uses", [](DataflowBlockRewrite rwt, Var old_var, Var new_var) { rwt->ReplaceAllUses(old_var, new_var); }); -}); +} class UpdateDFB : public ExprMutator { private: @@ -177,29 +177,29 @@ void DataflowBlockRewriteNode::Add(Binding binding) { } for (const VarNode* v : used_vars) { - auto var = GetRef(v); + auto var = ffi::GetRef(v); if (auto users = to_users_.Get(var)) { users.value().push_back(var); } } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.dfb_rewrite_add_binding", [](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }) .def("relax.dfb_rewrite_add", - [](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { + [](DataflowBlockRewrite rwt, Expr expr, ffi::Optional name, bool is_dfvar) { if (name.has_value()) { rwt->Add(name.value(), expr, is_dfvar); } else { rwt->Add(expr, is_dfvar); } }); -}); +} -std::set GetUnusedVars(Map> users_map, Array fn_outputs) { +std::set GetUnusedVars(ffi::Map> users_map, ffi::Array fn_outputs) { std::vector unused; // iterative dataflow algorithm. @@ -227,7 +227,7 @@ std::set GetUnusedVars(Map> users_map, Array fn_output // remove def site. for (const auto& used_var : used) { ICHECK(users_map.count(used_var)); - Array var_users = users_map[used_var]; + ffi::Array var_users = users_map[used_var]; // remove the unused var from the use site. if (auto it = std::find(var_users.begin(), var_users.end(), unused[i]); it != var_users.end()) { @@ -244,11 +244,11 @@ std::set GetUnusedVars(Map> users_map, Array fn_output class RemoveUnusedVars : public ExprMutator { public: std::set unused_vars; - Optional caught_rewrite = std::nullopt; + ffi::Optional caught_rewrite = std::nullopt; - RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} + explicit RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} - RemoveUnusedVars(Map> users, Array fn_outputs) + RemoveUnusedVars(ffi::Map> users, ffi::Array fn_outputs) : RemoveUnusedVars(GetUnusedVars(users, fn_outputs)) {} void VisitBinding_(const VarBindingNode* binding) override { @@ -301,13 +301,13 @@ void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { to_users_.erase(unused); // update use-def chain. } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dfb_rewrite_remove_unused", [](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { rwt->RemoveUnused(unused, allow_undef); }); -}); +} void DataflowBlockRewriteNode::RemoveAllUnused() { RemoveUnusedVars remover(to_users_, fn_outputs_); @@ -326,11 +326,11 @@ void DataflowBlockRewriteNode::RemoveAllUnused() { for (const auto& unused : remover.unused_vars) to_users_.erase(unused); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dfb_rewrite_remove_all_unused", [](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); -}); +} Expr RemoveAllUnused(Expr expr) { auto var_usage = CollectVarUsage(expr); @@ -345,14 +345,14 @@ Expr RemoveAllUnused(Expr expr) { } RemoveUnusedVars remover(var_usage.downstream_usage, - Array(externally_exposed.begin(), externally_exposed.end())); + ffi::Array(externally_exposed.begin(), externally_exposed.end())); return remover.VisitExpr(std::move(expr)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.remove_all_unused", RemoveAllUnused); -}); +} IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { BlockBuilder builder = BlockBuilder::Create(irmod); @@ -367,12 +367,12 @@ IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { return builder->GetContextIRModule(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.dfb_rewrite_mutate_irmodule", [](DataflowBlockRewrite rwt, IRModule irmod) { return rwt->MutateIRModule(irmod); }); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 3cf24d8a8c1a..09f404d29cbd 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -74,13 +74,13 @@ class BlockBuilderImpl : public BlockBuilderNode { IRModule Finalize() final { return transform::NormalizeGlobalVar()(context_mod_); } - GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) final { + GlobalVar AddFunction(const BaseFunc& func, ffi::String func_name_hint) final { LazyInitCtxFuncDedupMap(); auto it = ctx_func_dedup_map_->find(func); if (it == ctx_func_dedup_map_->end()) { context_mod_.CopyOnWrite(); - String func_name = GetUniqueName(func_name_hint); + ffi::String func_name = GetUniqueName(func_name_hint); while (context_mod_->ContainGlobalVar(func_name)) { func_name = GetUniqueName(func_name_hint); } @@ -160,7 +160,7 @@ class BlockBuilderImpl : public BlockBuilderNode { //------------------------------- // Scope management //------------------------------- - Optional LookupBinding(const Var& var) final { + ffi::Optional LookupBinding(const Var& var) final { auto it = binding_table_.find(var->vid); if (it == binding_table_.end()) return std::nullopt; return it->second; @@ -170,7 +170,7 @@ class BlockBuilderImpl : public BlockBuilderNode { void BeginBindingBlock() final { block_stack_.emplace_back(BlockFrame{{}, false}); } - void BeginScope(Optional> params) final { + void BeginScope(ffi::Optional> params) final { // The current implementation handles the collection of shape var // defined in parameter struct info annotations. The implementation // is correct (since we will simply erase all relax Vars in EraseToWellDefined), @@ -205,7 +205,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // defined in parameter struct info annotations. The implementation // is correct (since we will simply erase all relax Vars in EraseToWellDefined), // but can be further improved. - Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + ffi::Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); for (const auto& kv : var_map) { const tir::Var& shape_var = kv.first; const PrimExpr& shape_expr = kv.second; @@ -239,11 +239,11 @@ class BlockBuilderImpl : public BlockBuilderNode { bool CurrentBlockIsDataFlow() final { return CurrentBlockFrame()->is_dataflow; } - Var Emit(Expr expr, String name_hint) final { + Var Emit(Expr expr, ffi::String name_hint) final { return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); } - Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint) final { + Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint) final { value = this->Normalize(value); CHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != BaseCheckResult::kFailL0) @@ -265,7 +265,7 @@ class BlockBuilderImpl : public BlockBuilderNode { return var; } - Var EmitOutput(Expr output, String name_hint) final { + Var EmitOutput(Expr output, ffi::String name_hint) final { BlockFrame* cur_frame = CurrentBlockFrame(); ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; @@ -317,7 +317,7 @@ class BlockBuilderImpl : public BlockBuilderNode { /*! * \brief List of bindings */ - Array bindings; + ffi::Array bindings; /*! \brief Whether current block is dataflow block. */ bool is_dataflow; /*! @@ -341,7 +341,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // // TODO(relax-team) tracks the var defined also through match-cast. /*! \brief set of defined symbolic vars, value as themself. */ - Map shape_var_map; + ffi::Map shape_var_map; }; /*! \brief A stack to store block frames. */ @@ -391,7 +391,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * and performs shape/type deductions by calling Normalize. * \return The new variable that \p expr is bound to. */ - Var Emit(Expr expr, bool is_dataflow, String name_hint) { + Var Emit(Expr expr, bool is_dataflow, ffi::String name_hint) { expr = this->Normalize(expr); Var var = CreateVar(is_dataflow, name_hint); @@ -413,7 +413,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * \param name_hint Name hint for the bound variable. * \return The created var. */ - Var CreateVar(bool is_dataflow, String name_hint) { + Var CreateVar(bool is_dataflow, ffi::String name_hint) { if (name_hint.empty()) { name_hint = is_dataflow ? "lv" : "gv"; } @@ -427,12 +427,12 @@ class BlockBuilderImpl : public BlockBuilderNode { return name_supply_->FreshName(prefix, /*add_prefix*/ false, /*add_underscore*/ false); } - /*! \brief A custom structural hashing that ignores NDArray raw data. */ + /*! \brief A custom structural hashing that ignores Tensor raw data. */ class StructuralHashIgnoreNDarray { public: uint64_t operator()(const ObjectRef& key) const { return ffi::StructuralHash::Hash(key, /*map_free_vars=*/false, - /*skip_ndarray_content=*/true); + /*skip_tensor_content=*/true); } }; @@ -466,7 +466,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // shape vars as defined when calling BeginScope(params) class StructInfoVarCollector : public StructInfoVisitor { public: - static Map Collect(const StructInfo& struct_info) { + static ffi::Map Collect(const StructInfo& struct_info) { StructInfoVarCollector collector; collector(struct_info); return collector.shape_var_map_; @@ -478,17 +478,17 @@ class BlockBuilderImpl : public BlockBuilderNode { for (const PrimExpr& s : shape_expr->values) { // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) if (const auto* var = s.as()) { - shape_var_map_.Set(GetRef(var), s); + shape_var_map_.Set(ffi::GetRef(var), s); } } } } void VisitStructInfo_(const ShapeStructInfoNode* op) final { - for (const PrimExpr& s : op->values.value_or(Array())) { + for (const PrimExpr& s : op->values.value_or(ffi::Array())) { // Only collect single var defined shape. Ignore something like `R.Shape((m + 1, n + 1)) if (const auto* var = s.as()) { - shape_var_map_.Set(GetRef(var), s); + shape_var_map_.Set(ffi::GetRef(var), s); } } } @@ -503,7 +503,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } private: - Map shape_var_map_; + ffi::Map shape_var_map_; }; }; @@ -511,7 +511,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // Normalization //--------------------------------------- #define RELAX_EXPR_NORMALIZER_LEAF(OP) \ - Expr VisitExpr_(const OP* op) final { return GetRef(op); } + Expr VisitExpr_(const OP* op) final { return ffi::GetRef(op); } // TODO(relax-team): Check normalize logic after struct info. @@ -589,13 +589,13 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorstruct_info_.defined()) << "Var " << var->name_hint() << " does not have struct info."; - return GetRef(var); + return ffi::GetRef(var); } Expr VisitExpr_(const VarNode* var_ptr) final { auto var = VisitVar_(var_ptr); if (HasVoidStructInfo(var)) { - return VisitExpr(Tuple(Array{})); + return VisitExpr(Tuple(ffi::Array{})); } else { return var; } @@ -617,7 +617,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor new_fields; + ffi::Array new_fields; for (const Expr& field : op->fields) { Expr new_field = this->NormalizeArgument(field); @@ -625,10 +625,10 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(op) : Tuple(new_fields, op->span); + Tuple tuple = unchanged ? ffi::GetRef(op) : Tuple(new_fields, op->span); // Update tuple fields. if (!tuple->struct_info_.defined()) { - Array tuple_sinfo; + ffi::Array tuple_sinfo; for (Expr field : tuple->fields) { tuple_sinfo.push_back(GetStructInfo(field)); } @@ -641,7 +641,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitWithNewScope(op->body, op->params); if (new_body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Function(op->params, new_body, op->ret_struct_info, op->is_pure, op->attrs); } @@ -650,11 +650,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->op); - Array new_args = op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); }); + ffi::Array new_args = + op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); }); Call call; if (new_op.same_as(op->op) && new_args.same_as(op->args)) { - call = GetRef(op); + call = ffi::GetRef(op); } else { call = Call(new_op, new_args, op->attrs, op->sinfo_args); } @@ -670,7 +671,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop, nullptr); func_normalize != nullptr) { - Expr normalized = func_normalize(GetRef(this), call); + Expr normalized = func_normalize(ffi::GetRef(this), call); if (!normalized.same_as(call)) { return VisitExpr(normalized); } @@ -682,7 +683,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor new_blocks; + ffi::Array new_blocks; for (BindingBlock block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); new_blocks.push_back(new_block); @@ -711,12 +712,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor normalized_blocks = NormalizeBlocks(new_blocks); + ffi::Array normalized_blocks = NormalizeBlocks(new_blocks); unchanged &= normalized_blocks.same_as(new_blocks); SeqExpr seq_expr; if (unchanged) { - seq_expr = GetRef(op); + seq_expr = ffi::GetRef(op); } else { seq_expr = SeqExpr(normalized_blocks, new_body, op->span); } @@ -736,7 +737,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorcond) && new_true.same_as(op->true_branch) && new_false.same_as(op->false_branch)) { - if_node = GetRef(op); + if_node = ffi::GetRef(op); } else { if_node = If(new_cond, new_true, new_false, op->span); } @@ -751,7 +752,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->tuple); - TupleGetItem node = new_tuple.same_as(op->tuple) ? GetRef(op) + TupleGetItem node = new_tuple.same_as(op->tuple) ? ffi::GetRef(op) : TupleGetItem(new_tuple, op->index); if (!node->struct_info_.defined()) { @@ -767,11 +768,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { - return this->VisitVarBinding(GetRef(var_binding)); + return this->VisitVarBinding(ffi::GetRef(var_binding)); } else { auto* match_cast = binding.as(); ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); - return this->VisitMatchCast(GetRef(match_cast)); + return this->VisitMatchCast(ffi::GetRef(match_cast)); } } @@ -824,7 +825,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop.as()) { // Case 1: the op field is a primitive op, look up FInferStructInfo attribute - Op op = GetRef(op_ptr); + Op op = ffi::GetRef(op_ptr); bool is_dist_op = false; for (const auto& arg : call->args) { if (arg->struct_info_.as()) { @@ -839,18 +840,18 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorname; - return op_map_dist_infer_struct_info_[op](call, GetRef(this)); + return op_map_dist_infer_struct_info_[op](call, ffi::GetRef(this)); } ICHECK(op_map_infer_struct_info_.count(op)) << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; - return op_map_infer_struct_info_[op](call, GetRef(this)); + return op_map_infer_struct_info_[op](call, ffi::GetRef(this)); } else { // derive using function parameters ICHECK(call->op->struct_info_.defined()); auto opt = MatchStructInfo(call->op); ICHECK(opt) << "Call->op must contains a function struct info"; FuncStructInfo finfo = opt.value(); - return DeriveCallRetStructInfo(finfo, call, GetRef(this), &analyzer_); + return DeriveCallRetStructInfo(finfo, call, ffi::GetRef(this), &analyzer_); } } @@ -862,7 +863,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Optional { + auto f_shape_var_map = [curr_scope](tir::Var var) -> ffi::Optional { auto it = curr_scope->shape_var_map.find(var); if (it != curr_scope->shape_var_map.end()) return (*it).second; return std::nullopt; @@ -870,7 +871,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor> params = std::nullopt) { + Expr VisitWithNewScope(const Expr& expr, ffi::Optional> params = std::nullopt) { if (params.defined()) { this->BeginScope(params.value()); } else { @@ -891,7 +892,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor() && prologue->bindings.empty()) { return post; } - Array bindings; + ffi::Array bindings; if (!prologue->bindings.empty()) { bindings.push_back(prologue); } @@ -906,15 +907,15 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor FlattenBlocks(const Array& blocks) { + ffi::Array FlattenBlocks(const ffi::Array& blocks) { // If there is a binding that is a seq expr, split the current block, // add the nested blocks prior to the seq expr, and bind the seq expr body // to the var - Array ret; + ffi::Array ret; bool changed = false; for (const BindingBlock& block : blocks) { bool is_dataflow = block->IsInstance(); - Array current; + ffi::Array current; for (const Binding& binding : block->bindings) { Expr value; if (const auto* var_binding = binding.as()) { @@ -950,8 +951,8 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor{}))); - Array free_dataflow_vars; + auto free_vars = FreeVars(SeqExpr({block}, Tuple(ffi::Array{}))); + ffi::Array free_dataflow_vars; for (const auto& var : free_vars) { if (auto opt = var.as()) { free_dataflow_vars.push_back(opt.value()); @@ -987,9 +988,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor NormalizeBlocks(const Array& blocks) { + ffi::Array NormalizeBlocks(const ffi::Array& blocks) { bool changed = false; - Array ret; + ffi::Array ret; auto flattened = FlattenBlocks(blocks); if (!flattened.same_as(blocks)) { changed = true; @@ -1003,11 +1004,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { - auto n = make_object(*dataflow_block); + auto n = ffi::make_object(*dataflow_block); n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); merged = DataflowBlock(n); } else if (const auto* binding_block = ret.back().as()) { - auto n = make_object(*binding_block); + auto n = ffi::make_object(*binding_block); n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); merged = BindingBlock(n); } else { @@ -1036,14 +1037,14 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor mod) { - ObjectPtr n = make_object(mod.value_or(IRModule())); +BlockBuilder BlockBuilder::Create(ffi::Optional mod) { + ObjectPtr n = ffi::make_object(mod.value_or(IRModule())); return BlockBuilder(n); } -BlockBuilder BlockBuilder::Create(Optional mod, +BlockBuilder BlockBuilder::Create(ffi::Optional mod, BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript) { - ObjectPtr n = make_object( + ObjectPtr n = ffi::make_object( mod.value_or(IRModule()), BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript()); return BlockBuilder(n); } @@ -1052,31 +1053,32 @@ BlockBuilder BlockBuilder::Create(Optional mod, // User facing function registration. //--------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); refl::GlobalDef() .def("relax.BlockBuilderCreate", - [](Optional mod) { return BlockBuilder::Create(mod); }) + [](ffi::Optional mod) { return BlockBuilder::Create(mod); }) .def_method("relax.BlockBuilderBeginDataflowBlock", &BlockBuilderNode::BeginDataflowBlock) .def_method("relax.BlockBuilderBeginBindingBlock", &BlockBuilderNode::BeginBindingBlock) .def_method("relax.BlockBuilderEndBlock", &BlockBuilderNode::EndBlock) .def_method("relax.BlockBuilderNormalize", &BlockBuilderNode::Normalize) .def("relax.BlockBuilderEmit", - [](BlockBuilder builder, Expr expr, String name_hint) { + [](BlockBuilder builder, Expr expr, ffi::String name_hint) { return builder->Emit(expr, name_hint); }) .def("relax.BlockBuilderEmitMatchCast", - [](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) { + [](BlockBuilder builder, Expr value, StructInfo struct_info, ffi::String name_hint) { return builder->EmitMatchCast(value, struct_info, name_hint); }) .def("relax.BlockBuilderEmitOutput", - [](BlockBuilder builder, const Expr& output, String name_hint) { + [](BlockBuilder builder, const Expr& output, ffi::String name_hint) { return builder->EmitOutput(output, name_hint); }) .def("relax.BlockBuilderEmitNormalized", [](BlockBuilder builder, Binding binding) { return builder->EmitNormalized(binding); }) .def("relax.BlockBuilderGetUniqueName", - [](BlockBuilder builder, String name_hint) { + [](BlockBuilder builder, ffi::String name_hint) { return builder->name_supply()->FreshName(name_hint, /*add_prefix*/ false, /*add_underscore*/ false); }) @@ -1089,6 +1091,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("relax.BlockBuilderLookupBinding", &BlockBuilderNode::LookupBinding) .def_method("relax.BlockBuilderBeginScope", &BlockBuilderNode::BeginScope) .def_method("relax.BlockBuilderEndScope", &BlockBuilderNode::EndScope); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index def0e61c986c..249ec14f89dd 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -135,7 +135,7 @@ struct MatchState { static std::optional TryMatch(const PNode& p, const RNode& r, const MatchState& current_match, DFPatternMatcher* m, const MatcherUseDefAnalysis& ud_analysis) { - if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; + if (!m->Match(ffi::GetRef(p.ptr), ffi::GetRef(r.ptr))) return std::nullopt; MatchState new_match; @@ -192,15 +192,15 @@ static std::optional TryValidate( const std::vector& validation_constraints, arith::Analyzer* analyzer) { MatchState new_match; - std::function(const DFPatternNode*)> query_match_state = - [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { + std::function(const DFPatternNode*)> query_match_state = + [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> ffi::Optional { auto it = pattern2node.find(pattern); ICHECK(it != pattern2node.end()) - << "DFConstraint attempted to access DFPattern " << GetRef(pattern) + << "DFConstraint attempted to access DFPattern " << ffi::GetRef(pattern) << ", which does not appear in the PatternContext"; const auto& p_node = it->second; if (auto ptr = current_match.matched(p_node)) { - return GetRef(ptr); + return ffi::GetRef(ptr); } else { return std::nullopt; } @@ -289,9 +289,9 @@ static std::optional MatchTree( return std::nullopt; } -Optional> MatchGraph(const PatternContext& ctx, - const Array& binding_arr, - const Map& bindings) { +ffi::Optional> MatchGraph(const PatternContext& ctx, + const ffi::Array& binding_arr, + const ffi::Map& bindings) { // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; DFPatternMatcher matcher(bindings); @@ -351,31 +351,33 @@ Optional> MatchGraph(const PatternContext& ctx, return std::nullopt; } - Map ret; + ffi::Map ret; for (const auto& [pat, p_node] : pattern2node) { ICHECK(match->matched(p_node)); - ret.Set(GetRef(pat), GetRef(match->matched(p_node))); + ret.Set(ffi::GetRef(pat), ffi::GetRef(match->matched(p_node))); } return ret; } -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { +ffi::Optional> MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb) { return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.dpl.match_dfb", [](const PatternContext& ctx, const DataflowBlock& dfb) { return MatchGraph(ctx, dfb); }); -}); +} class PatternContextRewriterNode : public PatternMatchingRewriterNode { public: PatternContext pattern; - ffi::TypedFunction(Map, Map)> rewriter_func; + ffi::TypedFunction(ffi::Map, ffi::Map)> + rewriter_func; - RewriteSpec RewriteBindings(const Array& bindings) const override; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -383,19 +385,18 @@ class PatternContextRewriterNode : public PatternMatchingRewriterNode { .def_ro("pattern", &PatternContextRewriterNode::pattern) .def_ro("rewriter_func", &PatternContextRewriterNode::rewriter_func); } - - static constexpr const char* _type_key = "relax.dpl.PatternContextRewriter"; - TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, PatternMatchingRewriterNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.dpl.PatternContextRewriter", PatternContextRewriterNode, + PatternMatchingRewriterNode); private: - Optional> MatchBindings(const Array& bindings) const { - Map var_lookup; + ffi::Optional> MatchBindings(const ffi::Array& bindings) const { + ffi::Map var_lookup; for (const auto& binding : bindings) { var_lookup.Set(binding->var, GetBoundValue(binding)); } if (auto matches = MatchGraph(pattern, bindings, var_lookup)) { - Map replacements = rewriter_func(matches.value(), var_lookup); + ffi::Map replacements = rewriter_func(matches.value(), var_lookup); if (replacements.size()) { return replacements; } @@ -409,16 +410,17 @@ class PatternContextRewriter : public PatternMatchingRewriter { public: PatternContextRewriter( PatternContext pattern, - ffi::TypedFunction(Map, Map)> rewriter_func); + ffi::TypedFunction(ffi::Map, ffi::Map)> + rewriter_func); - TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, PatternMatchingRewriter, - PatternContextRewriterNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PatternContextRewriter, PatternMatchingRewriter, + PatternContextRewriterNode); }; -RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bindings) const { +RewriteSpec PatternContextRewriterNode::RewriteBindings(const ffi::Array& bindings) const { std::vector remaining_bindings{bindings.begin(), bindings.end()}; - Map variable_rewrites; + ffi::Map variable_rewrites; while (auto opt = MatchBindings(remaining_bindings)) { auto new_rewrites = opt.value(); remaining_bindings.erase(std::remove_if(remaining_bindings.begin(), remaining_bindings.end(), @@ -436,8 +438,9 @@ RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bi PatternContextRewriter::PatternContextRewriter( PatternContext pattern, - ffi::TypedFunction(Map, Map)> rewriter_func) { - auto node = make_object(); + ffi::TypedFunction(ffi::Map, ffi::Map)> + rewriter_func) { + auto node = ffi::make_object(); node->pattern = std::move(pattern); node->rewriter_func = std::move(rewriter_func); data_ = std::move(node); @@ -445,18 +448,18 @@ PatternContextRewriter::PatternContextRewriter( Function RewriteBindings( const PatternContext& ctx, - ffi::TypedFunction(Map, Map)> rewriter, + ffi::TypedFunction(ffi::Map, ffi::Map)> rewriter, Function func) { // return BlockPatternRewriter::Run(ctx, rewriter, func); return Downcast(PatternContextRewriter(ctx, rewriter)(func)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.rewrite_bindings", RewriteBindings); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ PatternContextRewriterNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PatternContextRewriterNode::RegisterReflection(); } } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 21000fec0cb8..4aca923a4b80 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -45,11 +45,11 @@ namespace relax { namespace { class GlobalVarReplacer : public ExprMutator { public: - explicit GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} + explicit GlobalVarReplacer(ffi::Map gvar_map) : gvar_map_(gvar_map) {} using ExprMutator::VisitExpr_; Expr VisitExpr_(const GlobalVarNode* op) override { - auto gvar = GetRef(op); + auto gvar = ffi::GetRef(op); if (auto opt = gvar_map_.Get(gvar)) { gvar = opt.value(); } @@ -57,10 +57,10 @@ class GlobalVarReplacer : public ExprMutator { } private: - Map gvar_map_; + ffi::Map gvar_map_; }; -Array TopologicalSort(const Array& bindings) { +ffi::Array TopologicalSort(const ffi::Array& bindings) { std::unordered_set remaining_bindings; for (const auto& binding : bindings) { remaining_bindings.insert(binding->var); @@ -74,7 +74,7 @@ Array TopologicalSort(const Array& bindings) { bool emitted; }; std::vector delayed_bindings; - Array sorted_bindings; + ffi::Array sorted_bindings; // Utility function to append the auto push_sorted_binding = [&](Binding binding) { @@ -159,7 +159,7 @@ void RewriteSpec::Append(RewriteSpec other) { gvar_name_supply->ReserveName(gvar->name_hint); } - Map gvar_rewrites; + ffi::Map gvar_rewrites; for (auto [gvar, func] : other.new_subroutines) { if (auto it = new_subroutines.find(gvar); it != new_subroutines.end()) { // The two rewrites provide the same GlobalVar. @@ -192,19 +192,19 @@ void RewriteSpec::Append(RewriteSpec other) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.dpl.PatternMatchingRewriterFromPattern", [](DFPattern pattern, - ffi::TypedFunction(Expr, Map)> func) { + ffi::TypedFunction(Expr, ffi::Map)> func) { return PatternMatchingRewriter::FromPattern(pattern, func); }) .def("relax.dpl.PatternMatchingRewriterFromModule", [](IRModule mod) { return PatternMatchingRewriter::FromModule(mod); }) .def("relax.dpl.PatternMatchingRewriterApply", [](PatternMatchingRewriter rewriter, - Variant obj) -> Variant { + ffi::Variant obj) -> ffi::Variant { if (auto expr = obj.as()) { return rewriter(expr.value()); } else if (auto mod = obj.as()) { @@ -213,11 +213,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ LOG(FATAL) << "Unreachable: object does not contain either variant type"; } }); -}); +} -RewriteSpec ExprPatternRewriterNode::RewriteBindings(const Array& bindings) const { - Map variable_rewrites; - Map binding_lookup; +RewriteSpec ExprPatternRewriterNode::RewriteBindings(const ffi::Array& bindings) const { + ffi::Map variable_rewrites; + ffi::Map binding_lookup; for (const auto& binding : bindings) { auto bound_value = GetBoundValue(binding); if (auto new_expr = RewriteExpr(bound_value, binding_lookup)) { @@ -233,8 +233,8 @@ RewriteSpec ExprPatternRewriterNode::RewriteBindings(const Array& bindi } } -Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, - const Map& bindings) const { +ffi::Optional ExprPatternRewriterNode::RewriteExpr( + const Expr& expr, const ffi::Map& bindings) const { if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings)) { auto matches = opt_matches.value(); if (additional_bindings) { @@ -249,7 +249,7 @@ Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, } } - Optional rewritten_expr = func(expr, matches); + ffi::Optional rewritten_expr = func(expr, matches); if (rewritten_expr.defined() && !rewritten_expr.same_as(expr)) { return rewritten_expr.value(); } @@ -257,19 +257,22 @@ Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, return std::nullopt; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.dpl.PatternRewriter", - [](DFPattern pattern, ffi::TypedFunction(Expr, Map)> func) { + [](DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func) { return ExprPatternRewriter(pattern, func); }); -}); +} ExprPatternRewriter::ExprPatternRewriter( - DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings, Map new_subroutines) { - auto node = make_object(); + DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings, + ffi::Map new_subroutines) { + auto node = ffi::make_object(); node->pattern = std::move(pattern); node->func = std::move(func); node->additional_bindings = std::move(additional_bindings); @@ -277,7 +280,7 @@ ExprPatternRewriter::ExprPatternRewriter( data_ = std::move(node); } -RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) const { +RewriteSpec OrRewriterNode::RewriteBindings(const ffi::Array& bindings) const { auto lhs_match = lhs->RewriteBindings(bindings); if (!lhs_match) { // If no rewrites found on LHS, RHS is allowed to modify any @@ -291,7 +294,7 @@ RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) cons // the LHS. Variable replacements from the RHS may still occur, // but will need to wait for the next round of // iterate-until-converged. - Array remaining_bindings; + ffi::Array remaining_bindings; for (const auto& binding : bindings) { if (!lhs_match.variable_rewrites.count(binding->var)) { remaining_bindings.push_back(binding); @@ -307,26 +310,26 @@ RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) cons return lhs_match; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.OrRewriter", [](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { return OrRewriter(lhs, rhs); }); -}); +} OrRewriter::OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { - auto node = make_object(); + auto node = ffi::make_object(); node->lhs = std::move(lhs); node->rhs = std::move(rhs); data_ = std::move(node); } -RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) const { +RewriteSpec TupleRewriterNode::RewriteBindings(const ffi::Array& bindings) const { CHECK_LE(patterns.size(), 3) << "For performance reasons, " << "matching of implicit tuple patterns is currently limited" << " to tuples with 3 elements or fewer."; - Map variable_rewrites = GenerateVariableRewrites(bindings); + ffi::Map variable_rewrites = GenerateVariableRewrites(bindings); if (variable_rewrites.size()) { return RewriteSpec{variable_rewrites, new_subroutines}; @@ -335,10 +338,11 @@ RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) c } } -Map TupleRewriterNode::GenerateVariableRewrites(const Array& bindings) const { - Map rewrites; +ffi::Map TupleRewriterNode::GenerateVariableRewrites( + const ffi::Array& bindings) const { + ffi::Map rewrites; - Map binding_lookup; + ffi::Map binding_lookup; std::vector info_vec; @@ -534,7 +538,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( } } - Map merged_matches = info_vec[indices[0]].matches[0].value(); + ffi::Map merged_matches = info_vec[indices[0]].matches[0].value(); for (size_t i = 1; i < indices.size(); i++) { for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) { if (auto it = merged_matches.find(pat); it != merged_matches.end()) { @@ -572,7 +576,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( } auto full_tuple = [&]() -> relax::Expr { - Array fields; + ffi::Array fields; for (size_t index : indices) { fields.push_back(info_vec[index].expr); } @@ -604,20 +608,22 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( return rewrites; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.TupleRewriter", - [](Array patterns, - ffi::TypedFunction(Expr, Map)> func) { - return TupleRewriter(patterns, func); - }); -}); + refl::GlobalDef().def( + "relax.dpl.TupleRewriter", + [](ffi::Array patterns, + ffi::TypedFunction(Expr, ffi::Map)> func) { + return TupleRewriter(patterns, func); + }); +} -TupleRewriter::TupleRewriter(Array patterns, - ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings, - Map new_subroutines) { - auto node = make_object(); +TupleRewriter::TupleRewriter( + ffi::Array patterns, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings, + ffi::Map new_subroutines) { + auto node = ffi::make_object(); node->patterns = std::move(patterns); node->func = std::move(func); node->additional_bindings = std::move(additional_bindings); @@ -626,8 +632,10 @@ TupleRewriter::TupleRewriter(Array patterns, } PatternMatchingRewriter PatternMatchingRewriter::FromPattern( - DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings, Map new_subroutines) { + DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings, + ffi::Map new_subroutines) { if (auto or_pattern = pattern.as()) { auto new_additional_bindings = additional_bindings.value_or({}); new_additional_bindings.push_back(pattern); @@ -678,10 +686,10 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { return Downcast(base_func); }(); - Map new_subroutines; + ffi::Map new_subroutines; for (const auto& [gvar, func] : mod->functions) { if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { - bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); CHECK(!is_public) << "ValueError: " << "Expected module to have no publicly-exposed functions " << "other than 'pattern' and 'replacement'. " @@ -699,8 +707,8 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { << "but the pattern has struct info " << sinfo_pattern << ", while the replacement has struct info " << sinfo_replacement; - Array param_wildcards; - Map pattern_lookup; + ffi::Array param_wildcards; + ffi::Map pattern_lookup; for (const auto& param : func_pattern->params) { WildcardPattern wildcard; param_wildcards.push_back(wildcard); @@ -752,15 +760,15 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { DFPattern top_pattern = make_pattern(func_pattern->body->body); - ffi::TypedFunction(Expr, Map)> rewriter_func = + ffi::TypedFunction(Expr, ffi::Map)> rewriter_func = [param_wildcards = std::move(param_wildcards), orig_func_replacement = std::move(func_replacement)]( - Expr expr, Map matches) -> Optional { + Expr expr, ffi::Map matches) -> ffi::Optional { auto func_replacement = CopyWithNewVars(orig_func_replacement); - Array new_blocks; + ffi::Array new_blocks; - Array wildcard_bindings; + ffi::Array wildcard_bindings; ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); for (size_t i = 0; i < param_wildcards.size(); i++) { Expr matched_expr = matches[param_wildcards[i]]; @@ -787,8 +795,8 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { new_subroutines); } -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings_opt) { +ffi::Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, ffi::Optional> bindings_opt) { auto bindings = bindings_opt.value_or({}); DFPatternMatcher matcher(bindings); @@ -799,19 +807,19 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, return matcher.GetMemo(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.extract_matched_expr", ExtractMatchedExpr); -}); +} -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { +bool MatchExpr(DFPattern pattern, Expr expr, ffi::Optional> bindings_opt) { return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.match_expr", MatchExpr); -}); +} /*! * \brief Apply pattern matching to each expression, replacing @@ -821,9 +829,10 @@ class PatternMatchingMutator : public ExprMutator { public: using ExprMutator::VisitExpr_; - PatternMatchingMutator(const PatternMatchingRewriterNode* rewriter) : rewriter_(rewriter) {} + explicit PatternMatchingMutator(const PatternMatchingRewriterNode* rewriter) + : rewriter_(rewriter) {} - Map GetNewSubroutines() const { return new_subroutines_; } + ffi::Map GetNewSubroutines() const { return new_subroutines_; } Expr VisitExpr_(const SeqExprNode* seq) override { SeqExpr prev = Downcast(ExprMutator::VisitExpr_(seq)); @@ -861,13 +870,13 @@ class PatternMatchingMutator : public ExprMutator { return prev; } - Optional TryRewriteSeqExpr(const SeqExpr& seq) { - Array old_blocks = seq->blocks; + ffi::Optional TryRewriteSeqExpr(const SeqExpr& seq) { + ffi::Array old_blocks = seq->blocks; // If the SeqExpr's output is not a variable, treat it as if it // were the last variable binding of the last block. This // simplifies the special handling of the SeqExpr's body. - Optional dummy_output_var = std::nullopt; + ffi::Optional dummy_output_var = std::nullopt; if (!seq->body->IsInstance()) { dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); VarBinding dummy_binding(dummy_output_var.value(), seq->body); @@ -878,7 +887,7 @@ class PatternMatchingMutator : public ExprMutator { old_blocks.pop_back(); return last_block; } else { - return BindingBlock(Array{}); + return BindingBlock(ffi::Array{}); } }(); @@ -886,7 +895,7 @@ class PatternMatchingMutator : public ExprMutator { old_blocks.push_back(last_block); } - auto rewrite_block = [&](Array orig_bindings) -> Array { + auto rewrite_block = [&](ffi::Array orig_bindings) -> ffi::Array { auto rewrites = rewriter_->RewriteBindings(orig_bindings); if (!rewrites) return orig_bindings; @@ -921,7 +930,7 @@ class PatternMatchingMutator : public ExprMutator { // Utility function to return the rewrites that should be applied // to a given block. - auto get_rewrites = [&](BindingBlock block) -> Array { + auto get_rewrites = [&](BindingBlock block) -> ffi::Array { if (block.as()) { // Early return for DataflowBlock. Since neither control flow // nor impure functions are allowed within the dataflow block, @@ -931,8 +940,8 @@ class PatternMatchingMutator : public ExprMutator { RewriteSpec rewrites; - Array collected_bindings; - Array finalized_bindings; + ffi::Array collected_bindings; + ffi::Array finalized_bindings; auto handle_collected_rewrites = [&]() { if (collected_bindings.size()) { @@ -1029,17 +1038,17 @@ class PatternMatchingMutator : public ExprMutator { private: const PatternMatchingRewriterNode* rewriter_; - Map new_subroutines_; + ffi::Map new_subroutines_; }; Expr PatternMatchingRewriter::operator()(Expr expr) { PatternMatchingMutator mutator(get()); auto new_expr = mutator(expr); auto new_subroutines = mutator.GetNewSubroutines(); - CHECK_EQ(new_subroutines.size(), 0) - << "If PatternMatchingRewriter provides subroutines, " - << "then it must be applied to an entire IRModule. " - << "However, PatternMatchingRewriter produced subroutines " << [&]() -> Array { + CHECK_EQ(new_subroutines.size(), 0) << "If PatternMatchingRewriter provides subroutines, " + << "then it must be applied to an entire IRModule. " + << "However, PatternMatchingRewriter produced subroutines " + << [&]() -> ffi::Array { std::vector vec; for (const auto& [gvar, func] : new_subroutines) { vec.push_back(gvar); @@ -1079,21 +1088,22 @@ tvm::transform::PassInfo PatternMatchingRewriterNode::Info() const { } Function RewriteCall(const DFPattern& pat, - ffi::TypedFunction)> rewriter, Function func) { + ffi::TypedFunction)> rewriter, + Function func) { return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.rewrite_call", RewriteCall); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PatternMatchingRewriterNode::RegisterReflection(); ExprPatternRewriterNode::RegisterReflection(); OrRewriterNode::RegisterReflection(); TupleRewriterNode::RegisterReflection(); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index b70c97cc3d13..5c0fd6d8f554 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -60,7 +60,7 @@ using tvm::arith::Analyzer; * \param attributes The attributes to match. * \return True if the attributes match, false otherwise. */ -bool MatchAttrs(const Any& attrs, const Map& attributes) { +bool MatchAttrs(const Any& attrs, const ffi::Map& attributes) { // TODO(tqchen): consider lift to common utils if (auto* dict_attrs = attrs.as()) { for (auto kv : attributes) { @@ -85,7 +85,7 @@ bool MatchAttrs(const Any& attrs, const Map& attributes) { const Object* obj = attrs.cast(); ffi::reflection::ForEachFieldInfoWithEarlyStop( type_info, [&](const TVMFFIFieldInfo* field_info) { - String field_name(field_info->name); + ffi::String field_name(field_info->name); if (attributes.count(field_name)) { ffi::reflection::FieldGetter field_getter(field_info); ffi::Any field_value = field_getter(obj); @@ -108,12 +108,12 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } -Expr DFPatternMatcher::UnwrapBindings(Expr expr, const Map& var2val) { - auto unwrap = [&](Expr expr) -> Optional { +Expr DFPatternMatcher::UnwrapBindings(Expr expr, const ffi::Map& var2val) { + auto unwrap = [&](Expr expr) -> ffi::Optional { // Unwrap variables into the value to which they are bound. if (var2val.size()) { if (const VarNode* var = expr.as()) { - if (auto may = var2val.Get(GetRef(var))) { + if (auto may = var2val.Get(ffi::GetRef(var))) { return may.value(); } } @@ -187,7 +187,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons VLOG(1) << "considering AttrPatternNode at:\n" << expr; auto attributes = attr_pattern->attrs.as()->dict; if (const auto* op_node = expr.as()) { - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); for (auto kv : attributes) { auto attr_name = kv.first; auto attr_value = kv.second; @@ -257,8 +257,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex if (matches_op) { auto watermark2 = matched_nodes_.size(); - auto match_args = [this, &watermark2](const Array& pattern_args, auto expr_begin, - auto expr_end) { + auto match_args = [this, &watermark2](const ffi::Array& pattern_args, + auto expr_begin, auto expr_end) { bool matches = true; auto pattern_it = pattern_args.begin(); auto expr_it = expr_begin; @@ -385,8 +385,8 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e return matches; } -bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array patterns, - const tvm::Array fields, +bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::ffi::Array patterns, + const tvm::ffi::Array fields, std::vector& match_cache, std::vector& matched) { if (idx >= patterns.size()) return true; @@ -456,7 +456,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { return condition; } - auto sort_key = [](PrimExpr expr) -> String { + auto sort_key = [](PrimExpr expr) -> ffi::String { if (const auto* equal = expr.as()) { if (const auto* var = equal->a.as()) { return var->name_hint; @@ -476,7 +476,8 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { return analyzer_.Simplify(sorted_condition); } -static bool ShapeEqual(Analyzer* analyzer, const Array& lhs, const Array& rhs) { +static bool ShapeEqual(Analyzer* analyzer, const ffi::Array& lhs, + const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) if (!tir::is_one(analyzer->Simplify(lhs[i] == rhs[i]))) return false; @@ -495,8 +496,8 @@ bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& e } std::tuple SameShapeConstraintNode::AsPrimExpr( - std::function(const DFPatternNode*)> match_state) const { - Optional> expected_shape; + std::function(const DFPatternNode*)> match_state) const { + ffi::Optional> expected_shape; bool all_shapes_defined = true; // The expression that must be true in order @@ -505,7 +506,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( for (const auto& arg : args) { if (auto opt_var = match_state(arg.get())) { auto var = opt_var.value(); - auto opt_var_shape = [&]() -> Optional> { + auto opt_var_shape = [&]() -> ffi::Optional> { auto sinfo = GetStructInfo(var); if (auto tensor = sinfo.as()) { return tensor->GetShape(); diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index 71fa4a4c35c1..bece0af12070 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -38,15 +38,15 @@ namespace relax { class DFPatternMatcher : public DFPatternFunctor { public: - using var2val_t = Map; + using var2val_t = ffi::Map; explicit DFPatternMatcher() {} explicit DFPatternMatcher(var2val_t var2val) : var2val_(std::move(var2val)) {} bool Match(const DFPattern& pattern, const Expr& expr); - Map GetMemo() { return memo_; } + ffi::Map GetMemo() { return memo_; } /* \brief Unwrap trivial expressions/bindings */ - static Expr UnwrapBindings(Expr expr, const Map& bindings); + static Expr UnwrapBindings(Expr expr, const ffi::Map& bindings); protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -73,8 +73,8 @@ class DFPatternMatcher : public DFPatternFunctor patterns, - const tvm::Array fields, std::vector& match_cache, + bool TryUnorderedMatch(size_t idx, const tvm::ffi::Array patterns, + const tvm::ffi::Array fields, std::vector& match_cache, std::vector& matched); /* \brief Simplify a boolean condition using the analyzer diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index f0f40e4df1a1..99e7dc6dfe05 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -32,7 +32,7 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PatternSeqNode::RegisterReflection(); ExprPatternNode::RegisterReflection(); VarPatternNode::RegisterReflection(); @@ -54,7 +54,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ AttrPatternNode::RegisterReflection(); ExternFuncPatternNode::RegisterReflection(); ConstantPatternNode::RegisterReflection(); -}); +} #define RELAX_PATTERN_PRINTER_DEF(NODE_TYPE, REPR_LAMBDA) \ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ @@ -63,41 +63,41 @@ TVM_FFI_STATIC_INIT_BLOCK({ REPR_LAMBDA(p, node); \ }) -ExternFuncPattern::ExternFuncPattern(String global_symbol) { - ObjectPtr n = make_object(); +ExternFuncPattern::ExternFuncPattern(ffi::String global_symbol) { + ObjectPtr n = ffi::make_object(); n->global_symbol_ = std::move(global_symbol); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.ExternFuncPattern", - [](String global_symbol) { return ExternFuncPattern(global_symbol); }); -}); + [](ffi::String global_symbol) { return ExternFuncPattern(global_symbol); }); +} RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { p->stream << "ExternFuncPattern(" << node->global_symbol() << ")"; }); -VarPattern::VarPattern(String name_hint) { - ObjectPtr n = make_object(); +VarPattern::VarPattern(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name_hint); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.VarPattern", - [](String name_hint) { return VarPattern(name_hint); }); -}); + [](ffi::String name_hint) { return VarPattern(name_hint); }); +} RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { p->stream << "VarPattern(" << node->name_hint() << ")"; }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.DataflowVarPattern", - [](String name_hint) { return DataflowVarPattern(name_hint); }); -}); -DataflowVarPattern::DataflowVarPattern(String name_hint) { - ObjectPtr n = make_object(); + [](ffi::String name_hint) { return DataflowVarPattern(name_hint); }); +} +DataflowVarPattern::DataflowVarPattern(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name_hint); data_ = std::move(n); } @@ -105,55 +105,55 @@ RELAX_PATTERN_PRINTER_DEF(DataflowVarPatternNode, [](auto p, auto node) { p->stream << "DataflowVarPattern(" << node->name_hint() << ")"; }); -GlobalVarPattern::GlobalVarPattern(String name_hint) { - ObjectPtr n = make_object(); +GlobalVarPattern::GlobalVarPattern(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name_hint); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.GlobalVarPattern", - [](String name_hint) { return GlobalVarPattern(name_hint); }); -}); + [](ffi::String name_hint) { return GlobalVarPattern(name_hint); }); +} RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { p->stream << "GlobalVarPattern(" << node->name_hint() << ")"; }); ExprPattern::ExprPattern(Expr expr) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->expr = std::move(expr); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.ExprPattern", [](Expr e) { return ExprPattern(e); }); -}); +} RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node->expr); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.ConstantPattern", []() { - auto c = ConstantPattern(make_object()); + auto c = ConstantPattern(ffi::make_object()); return c; }); -}); +} RELAX_PATTERN_PRINTER_DEF(ConstantPatternNode, [](auto p, auto node) { p->stream << "ConstantPattern()"; }); -CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_wildcard) { - ObjectPtr n = make_object(); +CallPattern::CallPattern(DFPattern op, ffi::Array args, bool varg_default_wildcard) { + ObjectPtr n = ffi::make_object(); n->op = std::move(op); n->args = std::move(args); n->varg_default_wildcard = varg_default_wildcard; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.CallPattern", - [](DFPattern op, Array args, bool varg_default_wildcard) { + [](DFPattern op, ffi::Array args, bool varg_default_wildcard) { return CallPattern(op, args, varg_default_wildcard); }); -}); +} RELAX_PATTERN_PRINTER_DEF(CallPatternNode, [](auto p, auto node) { p->stream << node->op << "("; for (size_t i = 0; i < node->args.size(); ++i) { @@ -167,166 +167,167 @@ RELAX_PATTERN_PRINTER_DEF(CallPatternNode, [](auto p, auto node) { p->stream << ")"; }); -PrimArrPattern::PrimArrPattern(Array arr) { - ObjectPtr n = make_object(); +PrimArrPattern::PrimArrPattern(ffi::Array arr) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(arr); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.PrimArrPattern", - [](Array arr) { return PrimArrPattern(std::move(arr)); }); -}); + [](ffi::Array arr) { return PrimArrPattern(std::move(arr)); }); +} RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { p->stream << "PrimArrPattern(" << node->fields << ")"; }); -FunctionPattern::FunctionPattern(Array params, DFPattern body) { - ObjectPtr n = make_object(); +FunctionPattern::FunctionPattern(ffi::Array params, DFPattern body) { + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.FunctionPattern", [](Array params, DFPattern body) { - return FunctionPattern(params, body); - }); -}); + refl::GlobalDef().def( + "relax.dpl.FunctionPattern", + [](ffi::Array params, DFPattern body) { return FunctionPattern(params, body); }); +} RELAX_PATTERN_PRINTER_DEF(FunctionPatternNode, [](auto p, auto node) { p->stream << "FunctionPattern(" << node->params << ", " << node->body << ")"; }); -TuplePattern::TuplePattern(tvm::Array fields) { - ObjectPtr n = make_object(); +TuplePattern::TuplePattern(tvm::ffi::Array fields) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.TuplePattern", - [](tvm::Array fields) { return TuplePattern(fields); }); -}); + [](tvm::ffi::Array fields) { return TuplePattern(fields); }); +} RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { p->stream << "TuplePattern(" << node->fields << ")"; }); -UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { - ObjectPtr n = make_object(); +UnorderedTuplePattern::UnorderedTuplePattern(tvm::ffi::Array fields) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.UnorderedTuplePattern", - [](tvm::Array fields) { return UnorderedTuplePattern(fields); }); -}); + refl::GlobalDef().def("relax.dpl.UnorderedTuplePattern", [](tvm::ffi::Array fields) { + return UnorderedTuplePattern(fields); + }); +} RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { p->stream << "UnorderedTuplePattern(" << node->fields << ")"; }); TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->tuple = std::move(tuple); n->index = index; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.TupleGetItemPattern", [](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); -}); +} RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { p->stream << "TupleGetItemPattern(" << node->tuple << ", " << node->index << ")"; }); AndPattern::AndPattern(DFPattern left, DFPattern right) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->left = std::move(left); n->right = std::move(right); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.AndPattern", [](DFPattern left, DFPattern right) { return AndPattern(left, right); }); -}); +} RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { p->stream << "AndPattern(" << node->left << " & " << node->right << ")"; }); OrPattern::OrPattern(DFPattern left, DFPattern right) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->left = std::move(left); n->right = std::move(right); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.OrPattern", [](DFPattern left, DFPattern right) { return OrPattern(left, right); }); -}); +} RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { p->stream << "OrPattern(" << node->left << " | " << node->right << ")"; }); NotPattern::NotPattern(DFPattern reject) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->reject = std::move(reject); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.NotPattern", [](DFPattern reject) { return NotPattern(reject); }); -}); +} RELAX_PATTERN_PRINTER_DEF(NotPatternNode, [](auto p, auto node) { p->stream << "!(" << node->reject << ")"; }); -WildcardPattern::WildcardPattern() { data_ = make_object(); } -TVM_FFI_STATIC_INIT_BLOCK({ +WildcardPattern::WildcardPattern() { data_ = ffi::make_object(); } +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.WildcardPattern", []() { return WildcardPattern(); }); -}); +} RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo struct_info) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->struct_info = std::move(struct_info); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.StructInfoPattern", [](DFPattern pattern, StructInfo struct_info) { return StructInfoPattern(pattern, struct_info); }); -}); +} RELAX_PATTERN_PRINTER_DEF(StructInfoPatternNode, [](auto p, auto node) { p->stream << "StructInfoPattern(" << node->pattern << " has relax StructInfo " << node->struct_info << ")"; }); -ShapePattern::ShapePattern(DFPattern pattern, Array shape) { - ObjectPtr n = make_object(); +ShapePattern::ShapePattern(DFPattern pattern, ffi::Array shape) { + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->shape = std::move(shape); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.ShapePattern", [](DFPattern pattern, Array shape) { - return ShapePattern(pattern, shape); - }); -}); + refl::GlobalDef().def( + "relax.dpl.ShapePattern", + [](DFPattern pattern, ffi::Array shape) { return ShapePattern(pattern, shape); }); +} RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p, auto node) { p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; }); -SameShapeConstraint::SameShapeConstraint(Array args) { - ObjectPtr n = make_object(); +SameShapeConstraint::SameShapeConstraint(ffi::Array args) { + ObjectPtr n = ffi::make_object(); n->args = std::move(args); data_ = std::move(n); @@ -334,11 +335,11 @@ SameShapeConstraint::SameShapeConstraint(Array args) { ctx.value().add_constraint(*this); } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.SameShapeConstraint", - [](Array args) { return SameShapeConstraint(args); }); -}); + [](ffi::Array args) { return SameShapeConstraint(args); }); +} RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { p->stream << "SameShapeConstraint("; for (size_t i = 0; i < node->args.size(); i++) { @@ -351,33 +352,33 @@ RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { }); DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->dtype = std::move(dtype); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.DataTypePattern", [](DFPattern pattern, DataType dtype) { return DataTypePattern(pattern, dtype); }); -}); +} RELAX_PATTERN_PRINTER_DEF(DataTypePatternNode, [](auto p, auto node) { p->stream << "DataTypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; }); AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->attrs = std::move(attrs); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.AttrPattern", [](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); }); -}); +} RELAX_PATTERN_PRINTER_DEF(AttrPatternNode, [](auto p, auto node) { p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; }); @@ -396,10 +397,10 @@ class DFPatternDuplicator : public DFPatternFunctor DFPattern VisitDFPattern_(const NotPatternNode* op) override { return NotPattern(op->reject); } DFPattern VisitDFPattern_(const VarPatternNode* op) override { return VarPattern(op->name); } DFPattern VisitDFPattern_(const ConstantPatternNode* op) override { - return ConstantPattern(make_object()); + return ConstantPattern(ffi::make_object()); } DFPattern VisitDFPattern_(const WildcardPatternNode* op) override { - return WildcardPattern(make_object()); + return WildcardPattern(ffi::make_object()); } DFPattern VisitDFPattern_(const ExprPatternNode* op) override { return ExprPattern(op->expr); } DFPattern VisitDFPattern_(const GlobalVarPatternNode* op) override { @@ -443,7 +444,7 @@ class DFPatternDuplicator : public DFPatternFunctor // Syntatic Sugar CallPattern DFPattern::operator()(const std::vector& args) const { - return CallPattern(*this, Array(args)); + return CallPattern(*this, ffi::Array(args)); } OrPattern DFPattern::operator|(const DFPattern& other) const { return OrPattern(*this, other); } @@ -451,7 +452,7 @@ AndPattern DFPattern::operator&(const DFPattern& other) const { return AndPatter NotPattern DFPattern::operator~() const { return NotPattern(*this); } -AttrPattern DFPattern::HasAttr(const Map& attrs) const { +AttrPattern DFPattern::HasAttr(const ffi::Map& attrs) const { return AttrPattern(*this, DictAttrs(attrs)); } StructInfoPattern DFPattern::HasStructInfo(const StructInfo& struct_info) const { @@ -463,7 +464,7 @@ DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { return HasDtype(DataType(ffi::StringToDLDataType(dtype))); } -ShapePattern DFPattern::HasShape(const Array& shape) const { +ShapePattern DFPattern::HasShape(const ffi::Array& shape) const { return ShapePattern(*this, shape); } @@ -474,13 +475,13 @@ std::stack& pattern_ctx_stack() { return graph_pattern_managers; } -Optional PatternContext::Current() { +ffi::Optional PatternContext::Current() { if (pattern_ctx_stack().empty()) return std::nullopt; return pattern_ctx_stack().top(); } PatternContext::PatternContext(bool incremental) { - auto n = make_object(); + auto n = ffi::make_object(); if (incremental) { ICHECK(!pattern_ctx_stack().empty()) << "Incremental context needs to be built inside a existing context."; @@ -506,16 +507,16 @@ static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, P } PatternSeq::PatternSeq(DFPattern init_pattern) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = {init_pattern}; n->pair_constraints = {}; data_ = std::move(n); } -PatternSeq::PatternSeq(tvm::Array patterns, bool only_used_by) { +PatternSeq::PatternSeq(tvm::ffi::Array patterns, bool only_used_by) { ICHECK_GE(patterns.size(), 1) << "PatternSeq must have at least one pattern"; const auto cons = PairCons(only_used_by ? PairCons::kOnlyUsedBy : PairCons::kUsedBy); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = std::move(patterns); n->pair_constraints = std::vector(n->patterns.size() - 1, cons); data_ = std::move(n); @@ -532,8 +533,8 @@ PatternSeq PatternSeq::OnlyUsedBy(PatternSeq other, int index) const { PatternSeq PatternSeq::dup() const { PatternSeq ret; - ObjectPtr n = make_object(); - n->patterns = Array{}; + ObjectPtr n = ffi::make_object(); + n->patterns = ffi::Array{}; n->patterns.reserve(get()->patterns.size()); n->pair_constraints = this->get()->pair_constraints; @@ -547,12 +548,13 @@ PatternSeq PatternSeq::dup() const { return ret; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.PatternSeq", [](Array patterns, bool only_used_by) { - return PatternSeq(std::move(patterns), only_used_by); - }); -}); + refl::GlobalDef().def("relax.dpl.PatternSeq", + [](ffi::Array patterns, bool only_used_by) { + return PatternSeq(std::move(patterns), only_used_by); + }); +} RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { p->stream << "["; for (size_t i = 0; i < node->patterns.size(); ++i) { @@ -563,14 +565,14 @@ RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { p->stream << "]"; }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.dpl.used_by", [](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.UsedBy(rhs, index); }) .def("relax.dpl.only_used_by", [](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.OnlyUsedBy(rhs, index); }); -}); +} PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { PatternSeq ret; @@ -580,7 +582,7 @@ PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), PairCons{PairCons::kUsedBy, index}); - Array patterns; + ffi::Array patterns; patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); @@ -591,7 +593,7 @@ PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), rhs->pair_constraints.end()); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = std::move(patterns); n->pair_constraints = std::move(pair_constraints); ret.data_ = std::move(n); @@ -607,7 +609,7 @@ PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), constraint); - Array patterns; + ffi::Array patterns; patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); @@ -618,7 +620,7 @@ PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), rhs->pair_constraints.end()); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = std::move(patterns); n->pair_constraints = std::move(pair_constraints); ret.data_ = std::move(n); @@ -627,13 +629,13 @@ PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { } PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.OnlyUsedBy(rhs); } -VarPattern IsVar(const String& name) { return VarPattern(name); } -ConstantPattern IsConst() { return ConstantPattern(make_object()); } -WildcardPattern Wildcard() { return WildcardPattern(make_object()); } +VarPattern IsVar(const ffi::String& name) { return VarPattern(name); } +ConstantPattern IsConst() { return ConstantPattern(ffi::make_object()); } +WildcardPattern Wildcard() { return WildcardPattern(ffi::make_object()); } ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } -ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } -CallPattern IsCallTIR(const String& name, Optional var_args, - Optional tir_vars) { +ExprPattern IsOp(const ffi::String& op_name) { return IsExpr(Op::Get(op_name)); } +CallPattern IsCallTIR(const ffi::String& name, ffi::Optional var_args, + ffi::Optional tir_vars) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard(); @@ -647,10 +649,10 @@ CallPattern IsCallTIR(const String& name, Optional var_args, return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern); } -CallPattern IsCallTIR(const String& name, TuplePattern var_args) { +CallPattern IsCallTIR(const ffi::String& name, TuplePattern var_args) { return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args); } -CallPattern IsCallDPSPacked(const String& name, Optional var_args) { +CallPattern IsCallDPSPacked(const ffi::String& name, ffi::Optional var_args) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard(); @@ -661,11 +663,11 @@ CallPattern IsCallDPSPacked(const String& name, Optional var_args) return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern); } -CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) { +CallPattern IsCallDPSPacked(const ffi::String& name, TuplePattern var_args) { return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args); } -DFPattern IsTuple(const Array& fields, bool unordered) { +DFPattern IsTuple(const ffi::Array& fields, bool unordered) { if (unordered) return UnorderedTuplePattern(fields); else @@ -680,7 +682,7 @@ DFPattern DFPattern::dup() const { return pattern; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.dpl.dup_pattern", [](DFPattern pattern) { return pattern.dup(); }) @@ -689,7 +691,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("relax.dpl.current_context", [] { return PatternContext::Current(); }) .def("relax.dpl.enter_context", [](const PatternContext& ctx) { ctx.EnterWithScope(); }) .def("relax.dpl.exit_context", [](const PatternContext& ctx) { ctx.ExitWithScope(); }); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h index 6b64226d77b7..85f892e3815b 100644 --- a/src/relax/ir/dataflow_rewriter.h +++ b/src/relax/ir/dataflow_rewriter.h @@ -40,8 +40,8 @@ namespace tvm { namespace relax { struct RewriteSpec { - Map variable_rewrites; - Map new_subroutines; + ffi::Map variable_rewrites; + ffi::Map new_subroutines; explicit operator bool() const { return variable_rewrites.size(); } @@ -50,7 +50,7 @@ struct RewriteSpec { class PatternMatchingRewriterNode : public tvm::transform::PassNode { public: - virtual RewriteSpec RewriteBindings(const Array& bindings) const { + virtual RewriteSpec RewriteBindings(const ffi::Array& bindings) const { return RewriteSpec(); } @@ -60,36 +60,37 @@ class PatternMatchingRewriterNode : public tvm::transform::PassNode { IRModule operator()(IRModule mod, const tvm::transform::PassContext& pass_ctx) const override; tvm::transform::PassInfo Info() const override; - - static constexpr const char* _type_key = "relax.dpl.PatternMatchingRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(PatternMatchingRewriterNode, PassNode); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.PatternMatchingRewriter", PatternMatchingRewriterNode, + PassNode); }; class PatternMatchingRewriter : public tvm::transform::Pass { public: static PatternMatchingRewriter FromPattern( - DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = std::nullopt, - Map new_subroutines = {}); + DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings = std::nullopt, + ffi::Map new_subroutines = {}); static PatternMatchingRewriter FromModule(IRModule mod); Expr operator()(Expr expr); using Pass::operator(); - TVM_DEFINE_OBJECT_REF_METHODS(PatternMatchingRewriter, Pass, PatternMatchingRewriterNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PatternMatchingRewriter, Pass, + PatternMatchingRewriterNode); }; class ExprPatternRewriterNode : public PatternMatchingRewriterNode { public: DFPattern pattern; - ffi::TypedFunction(Expr, Map)> func; - Optional> additional_bindings; - Map new_subroutines; + ffi::TypedFunction(Expr, ffi::Map)> func; + ffi::Optional> additional_bindings; + ffi::Map new_subroutines; - RewriteSpec RewriteBindings(const Array& bindings) const final; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const final; - Optional RewriteExpr(const Expr& expr, const Map& bindings) const; + ffi::Optional RewriteExpr(const Expr& expr, const ffi::Map& bindings) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -97,20 +98,19 @@ class ExprPatternRewriterNode : public PatternMatchingRewriterNode { .def_ro("pattern", &ExprPatternRewriterNode::pattern) .def_ro("func", &ExprPatternRewriterNode::func); } - - static constexpr const char* _type_key = "relax.dpl.ExprPatternRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(ExprPatternRewriterNode, PatternMatchingRewriterNode); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.ExprPatternRewriter", ExprPatternRewriterNode, + PatternMatchingRewriterNode); }; class ExprPatternRewriter : public PatternMatchingRewriter { public: ExprPatternRewriter(DFPattern pattern, - ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = std::nullopt, - Map new_subroutines = {}); + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings = std::nullopt, + ffi::Map new_subroutines = {}); - TVM_DEFINE_OBJECT_REF_METHODS(ExprPatternRewriter, PatternMatchingRewriter, - ExprPatternRewriterNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExprPatternRewriter, PatternMatchingRewriter, + ExprPatternRewriterNode); }; class OrRewriterNode : public PatternMatchingRewriterNode { @@ -118,7 +118,7 @@ class OrRewriterNode : public PatternMatchingRewriterNode { PatternMatchingRewriter lhs; PatternMatchingRewriter rhs; - RewriteSpec RewriteBindings(const Array& bindings) const override; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -126,26 +126,24 @@ class OrRewriterNode : public PatternMatchingRewriterNode { .def_ro("lhs", &OrRewriterNode::lhs) .def_ro("rhs", &OrRewriterNode::rhs); } - - static constexpr const char* _type_key = "relax.dpl.OrRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, PatternMatchingRewriterNode); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.OrRewriter", OrRewriterNode, PatternMatchingRewriterNode); }; class OrRewriter : public PatternMatchingRewriter { public: OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs); - TVM_DEFINE_OBJECT_REF_METHODS(OrRewriter, PatternMatchingRewriter, OrRewriterNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(OrRewriter, PatternMatchingRewriter, OrRewriterNode); }; class TupleRewriterNode : public PatternMatchingRewriterNode { public: - Array patterns; - ffi::TypedFunction(Expr, Map)> func; - Optional> additional_bindings; - Map new_subroutines; + ffi::Array patterns; + ffi::TypedFunction(Expr, ffi::Map)> func; + ffi::Optional> additional_bindings; + ffi::Map new_subroutines; - RewriteSpec RewriteBindings(const Array& bindings) const override; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -153,20 +151,19 @@ class TupleRewriterNode : public PatternMatchingRewriterNode { .def_ro("patterns", &TupleRewriterNode::patterns) .def_ro("func", &TupleRewriterNode::func); } - - static constexpr const char* _type_key = "relax.dpl.TupleRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, PatternMatchingRewriterNode); + TVM_FFI_DECLARE_OBJECT_INFO("relax.dpl.TupleRewriter", TupleRewriterNode, + PatternMatchingRewriterNode); private: struct VarInfo { Var var; Expr expr; - Array>> matches; + ffi::Array>> matches; std::unordered_set downstream_usage; bool used = false; }; - Map GenerateVariableRewrites(const Array& bindings) const; + ffi::Map GenerateVariableRewrites(const ffi::Array& bindings) const; std::optional> TryMatchByBindingIndex(const std::vector& info_vec, const std::vector& indices) const; @@ -174,12 +171,13 @@ class TupleRewriterNode : public PatternMatchingRewriterNode { class TupleRewriter : public PatternMatchingRewriter { public: - TupleRewriter(Array patterns, - ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = std::nullopt, - Map new_subroutines = {}); + TupleRewriter(ffi::Array patterns, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings = std::nullopt, + ffi::Map new_subroutines = {}); - TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, PatternMatchingRewriter, TupleRewriterNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TupleRewriter, PatternMatchingRewriter, + TupleRewriterNode); }; } // namespace relax diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index d46b634ca7c9..ee10a97aa0e7 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -36,10 +36,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "rxplaceholder(" << op->name << ", " << op << ")"; }); -TVM_FFI_STATIC_INIT_BLOCK({ RXPlaceholderOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RXPlaceholderOpNode::RegisterReflection(); } -te::Tensor TETensor(Expr value, Map tir_var_map, std::string name) { - auto n = make_object(); +te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name) { + auto n = ffi::make_object(); n->name = name; n->value = value; @@ -51,7 +51,7 @@ te::Tensor TETensor(Expr value, Map tir_var_map, std::string int ndim = constant->data->ndim; ffi::Shape shape_tuple = constant->data.Shape(); - Array shape; + ffi::Array shape; shape.reserve(ndim); for (int i = 0; i < ndim; ++i) { shape.push_back(IntImm(DataType::Int(64), shape_tuple[i])); @@ -73,10 +73,10 @@ te::Tensor TETensor(Expr value, Map tir_var_map, std::string return te::PlaceholderOp(n).output(0); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.TETensor", TETensor); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index aa7cb9db538e..f09dcb7f8230 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -52,8 +52,11 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { .def_ro("dtype", &RXPlaceholderOpNode::dtype); } - static constexpr const char* _type_key = "relax.TEPlaceholderOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); + // FFI system configuration for structural equality and hashing + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TEPlaceholderOp", RXPlaceholderOpNode, + te::PlaceholderOpNode); }; /*! @@ -64,7 +67,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { * shape of the input Expr. * \param name The name of the created tensor. */ -te::Tensor TETensor(Expr value, Map tir_var_map, std::string name); +te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 8fbe05e891ee..2c681b00bc22 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -29,7 +29,7 @@ namespace relax { using tvm::ReprPrinter; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { IdNode::RegisterReflection(); CallNode::RegisterReflection(); TupleNode::RegisterReflection(); @@ -50,22 +50,23 @@ TVM_FFI_STATIC_INIT_BLOCK({ IfNode::RegisterReflection(); FunctionNode::RegisterReflection(); ExternFuncNode::RegisterReflection(); -}); +} -Id::Id(String name_hint) { - ObjectPtr n = make_object(); +Id::Id(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name_hint = std::move(name_hint); data_ = std::move(n); } -Call::Call(Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { +Call::Call(Expr op, ffi::Array args, Attrs attrs, ffi::Array sinfo_args, + Span span) { CHECK(!op->struct_info_.defined() || op->struct_info_->IsInstance()) << "ValueError: " << "Call expects its operator to have FuncStructInfo, " << "but operator " << op << ", which was called with arguments " << args << ", has struct info " << op->struct_info_; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); @@ -74,14 +75,15 @@ Call::Call(Expr op, Array args, Attrs attrs, Array sinfo_args, data_ = std::move(n); } -Call WithFields(Call call, Optional opt_op, Optional> opt_args, - Optional opt_attrs, Optional> opt_sinfo_args, - Optional opt_span) { +Call WithFields(Call call, ffi::Optional opt_op, ffi::Optional> opt_args, + ffi::Optional opt_attrs, + ffi::Optional> opt_sinfo_args, + ffi::Optional opt_span) { // Collect new values for fields. Expr op = opt_op.value_or(call->op); - Array args = opt_args.value_or(call->args); + ffi::Array args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); - Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); + ffi::Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); Span span = opt_span.value_or(call->span); // Check if anything changed. @@ -117,15 +119,16 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args return call; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Call", - [](Expr op, Array args, Attrs attrs, Array sinfo_args, - Span span) { return Call(op, args, attrs, sinfo_args, span); }); -}); + refl::GlobalDef().def("relax.Call", [](Expr op, ffi::Array args, Attrs attrs, + ffi::Array sinfo_args, Span span) { + return Call(op, args, attrs, sinfo_args, span); + }); +} If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); @@ -133,8 +136,8 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { data_ = std::move(n); } -If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, - Optional opt_false_branch, Optional opt_span) { +If WithFields(If if_expr, ffi::Optional opt_cond, ffi::Optional opt_true_branch, + ffi::Optional opt_false_branch, ffi::Optional opt_span) { Expr cond = opt_cond.value_or(if_expr->cond); Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); @@ -153,16 +156,16 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc return if_expr; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.If", [](Expr cond, Expr true_branch, Expr false_branch, Span span) { return If(cond, true_branch, false_branch, span); }); -}); +} -Tuple::Tuple(tvm::Array fields, Span span) { - Optional tuple_sinfo = [&]() -> Optional { - Array field_sinfo; +Tuple::Tuple(tvm::ffi::Array fields, Span span) { + ffi::Optional tuple_sinfo = [&]() -> ffi::Optional { + ffi::Array field_sinfo; for (const auto& field : fields) { if (field->struct_info_.defined()) { field_sinfo.push_back(GetStructInfo(field)); @@ -173,21 +176,22 @@ Tuple::Tuple(tvm::Array fields, Span span) { return TupleStructInfo(field_sinfo); }(); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = std::move(span); n->struct_info_ = tuple_sinfo; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Tuple", - [](tvm::Array fields, Span span) { return Tuple(fields, span); }); -}); + refl::GlobalDef().def( + "relax.Tuple", [](tvm::ffi::Array fields, Span span) { return Tuple(fields, span); }); +} -Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { - Array fields = opt_fields.value_or(tuple->fields); +Tuple WithFields(Tuple tuple, ffi::Optional> opt_fields, + ffi::Optional opt_span) { + ffi::Array fields = opt_fields.value_or(tuple->fields); Span span = opt_span.value_or(tuple->span); bool all_fields_unchanged = true; @@ -211,7 +215,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple << " cannot be accessed with negative index " << index; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); if (auto* tuple_info = tuple->struct_info_.as()) { CHECK_LT(index, tuple_info->fields.size()) @@ -226,8 +230,8 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { data_ = std::move(n); } -TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, - Optional opt_index, Optional opt_span) { +TupleGetItem WithFields(TupleGetItem tuple_get_item, ffi::Optional opt_tuple, + ffi::Optional opt_index, ffi::Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); Integer index = opt_index.value_or(tuple_get_item->index); Span span = opt_span.value_or(tuple_get_item->span); @@ -243,15 +247,15 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, return tuple_get_item; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.TupleGetItem", [](Expr tuple, int index, Span span) { return TupleGetItem(tuple, index, span); }); -}); +} -ShapeExpr::ShapeExpr(Array values, Span span) { - ObjectPtr n = make_object(); +ShapeExpr::ShapeExpr(ffi::Array values, Span span) { + ObjectPtr n = ffi::make_object(); n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { @@ -266,14 +270,15 @@ ShapeExpr::ShapeExpr(Array values, Span span) { data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ShapeExpr", - [](Array values, Span span) { return ShapeExpr(values, span); }); -}); + refl::GlobalDef().def("relax.ShapeExpr", [](ffi::Array values, Span span) { + return ShapeExpr(values, span); + }); +} -Var::Var(Id vid, Optional struct_info_annotation, Span span) { - ObjectPtr n = make_object(); +Var::Var(Id vid, ffi::Optional struct_info_annotation, Span span) { + ObjectPtr n = ffi::make_object(); n->vid = std::move(vid); n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); @@ -290,27 +295,26 @@ VarNode* Var::CopyOnWrite() { if (!data_.unique()) { ObjectPtr node; if (auto dataflow_var = as()) { - node = make_object(*dataflow_var); + node = ffi::make_object(*dataflow_var); } else { - node = make_object(*(operator->())); + node = ffi::make_object(*(operator->())); } ObjectPtr(std::move(node)).swap(data_); } return static_cast(data_.get()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("relax.Var", [](String name_hint, Optional struct_info_annotation, + .def("relax.Var", [](ffi::String name_hint, ffi::Optional struct_info_annotation, Span span) { return Var(name_hint, struct_info_annotation, span); }) - .def("relax.VarFromId", [](Id vid, Optional struct_info_annotation, Span span) { - return Var(vid, struct_info_annotation, span); - }); -}); + .def("relax.VarFromId", [](Id vid, ffi::Optional struct_info_annotation, + Span span) { return Var(vid, struct_info_annotation, span); }); +} -DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Span span) { - ObjectPtr n = make_object(); +DataflowVar::DataflowVar(Id vid, ffi::Optional struct_info_annotation, Span span) { + ObjectPtr n = ffi::make_object(); n->vid = std::move(vid); n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); @@ -318,26 +322,27 @@ DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Sp data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.DataflowVar", - [](String name_hint, Optional struct_info_annotation, Span span) { + [](ffi::String name_hint, ffi::Optional struct_info_annotation, Span span) { return DataflowVar(name_hint, struct_info_annotation, span); }) .def("relax.DataflowVarFromId", - [](Id vid, Optional struct_info_annotation, Span span) { + [](Id vid, ffi::Optional struct_info_annotation, Span span) { return DataflowVar(vid, struct_info_annotation, span); }); -}); +} -Constant::Constant(runtime::NDArray data, Optional struct_info_annotation, Span span) { - ObjectPtr n = make_object(); +Constant::Constant(runtime::Tensor data, ffi::Optional struct_info_annotation, + Span span) { + ObjectPtr n = ffi::make_object(); n->data = std::move(data); n->span = std::move(span); // set struct info. - Array values; + ffi::Array values; auto shape_tuple = n->data.Shape(); for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); @@ -352,16 +357,16 @@ Constant::Constant(runtime::NDArray data, Optional struct_info_annot data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.Constant", - [](runtime::NDArray data, Optional struct_info_annotation = std::nullopt, + [](runtime::Tensor data, ffi::Optional struct_info_annotation = std::nullopt, Span span = Span()) { return Constant(data, struct_info_annotation, span); }); -}); +} PrimValue::PrimValue(PrimExpr value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->struct_info_ = PrimStructInfo(value); n->value = std::move(value); n->span = std::move(span); @@ -372,42 +377,42 @@ PrimValue PrimValue::Int64(int64_t value, Span span) { return PrimValue(IntImm(DataType::Int(64), value), span); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.PrimValue", [](PrimExpr value, Span span) { return PrimValue(value, span); }); -}); +} -StringImm::StringImm(String value, Span span) { - ObjectPtr n = make_object(); +StringImm::StringImm(ffi::String value, Span span) { + ObjectPtr n = ffi::make_object(); n->value = std::move(value); n->span = std::move(span); n->struct_info_ = ObjectStructInfo(); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.StringImm", - [](String value, Span span) { return StringImm(value, span); }); -}); + [](ffi::String value, Span span) { return StringImm(value, span); }); +} DataTypeImm::DataTypeImm(DataType value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->value = std::move(value); n->span = std::move(span); n->struct_info_ = ObjectStructInfo(); data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.DataTypeImm", [](DataType value, Span span) { return DataTypeImm(value, span); }); -}); +} MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); ICHECK(var.defined()) << "MatchCast requires var to be defined"; n->var = std::move(var); n->value = std::move(value); @@ -416,28 +421,28 @@ MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.MatchCast", [](Var var, Expr value, StructInfo struct_info, Span span) { return MatchCast(var, value, struct_info, span); }); -}); +} VarBinding::VarBinding(Var var, Expr value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->var = std::move(var); n->value = std::move(value); n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.VarBinding", [](Var var, Expr value, Span span) { return VarBinding(var, value, span); }); -}); +} bool VarBindingNode::SEqual(const VarBindingNode* other, ffi::TypedFunction equal) const { @@ -467,8 +472,8 @@ uint64_t VarBindingNode::SHash(uint64_t init_hash, return hash_value; } -BindingBlock::BindingBlock(Array bindings, Span span) { - ObjectPtr n = make_object(); +BindingBlock::BindingBlock(ffi::Array bindings, Span span) { + ObjectPtr n = ffi::make_object(); n->bindings = std::move(bindings); n->span = span; data_ = std::move(n); @@ -484,61 +489,61 @@ BindingBlockNode* BindingBlock::CopyOnWrite() { if (!data_.unique()) { ObjectPtr node; if (auto dataflow_block = as()) { - node = make_object(*dataflow_block); + node = ffi::make_object(*dataflow_block); } else { - node = make_object(*(operator->())); + node = ffi::make_object(*(operator->())); } ObjectPtr(std::move(node)).swap(data_); } return static_cast(data_.get()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.BindingBlock", [](Array bindings, Span span) { + refl::GlobalDef().def("relax.BindingBlock", [](ffi::Array bindings, Span span) { return BindingBlock(bindings, span); }); -}); +} -DataflowBlock::DataflowBlock(Array bindings, Span span) { - ObjectPtr n = make_object(); +DataflowBlock::DataflowBlock(ffi::Array bindings, Span span) { + ObjectPtr n = ffi::make_object(); n->bindings = std::move(bindings); n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.DataflowBlock", [](Array bindings, Span span) { + refl::GlobalDef().def("relax.DataflowBlock", [](ffi::Array bindings, Span span) { return DataflowBlock(bindings, span); }); -}); +} SeqExpr::SeqExpr(Expr body) { if (auto seq = body.as()) { *this = seq.value(); } else { - *this = SeqExpr(Array{}, body); + *this = SeqExpr(ffi::Array{}, body); } } -SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { - ObjectPtr n = make_object(); +SeqExpr::SeqExpr(ffi::Array blocks, Expr body, Span span) { + ObjectPtr n = ffi::make_object(); n->blocks = std::move(blocks); n->body = std::move(body); n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.SeqExpr", [](Array blocks, Expr body, Span span) { + refl::GlobalDef().def("relax.SeqExpr", [](ffi::Array blocks, Expr body, Span span) { return SeqExpr(blocks, body, span); }); -}); +} -Function::Function(Array params, Expr body, Optional ret_struct_info, bool is_pure, - DictAttrs attrs, Span span) { +Function::Function(ffi::Array params, Expr body, ffi::Optional ret_struct_info, + bool is_pure, DictAttrs attrs, Span span) { if (!attrs.defined()) { attrs = DictAttrs(); } @@ -546,7 +551,7 @@ Function::Function(Array params, Expr body, Optional ret_struct // Set the function type. // For function, we take a conservative approach and require the function type // to be known at construction time. - Array param_sinfo; + ffi::Array param_sinfo; for (const Var& param : params) { CHECK(param->struct_info_.defined()) @@ -554,7 +559,7 @@ Function::Function(Array params, Expr body, Optional ret_struct param_sinfo.push_back(GetStructInfo(param)); } - Optional body_sinfo; + ffi::Optional body_sinfo; if (body->struct_info_.defined()) { body_sinfo = GetStructInfo(body); @@ -580,7 +585,7 @@ Function::Function(Array params, Expr body, Optional ret_struct auto f_shape_var_map = [&] { auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); - return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { + return [lookup = std::move(lookup)](const tir::Var& var) -> ffi::Optional { if (lookup.count(var)) { return var; } else { @@ -594,7 +599,7 @@ Function::Function(Array params, Expr body, Optional ret_struct FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); // set the fields - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); n->ret_struct_info = ret_struct_info.value(); @@ -605,18 +610,18 @@ Function::Function(Array params, Expr body, Optional ret_struct data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Function", - [](Array params, Expr body, Optional ret_struct_info, - bool is_pure, DictAttrs attrs, Span span) { - return Function(params, body, ret_struct_info, is_pure, attrs, span); - }); -}); + refl::GlobalDef().def("relax.Function", [](ffi::Array params, Expr body, + ffi::Optional ret_struct_info, + bool is_pure, DictAttrs attrs, Span span) { + return Function(params, body, ret_struct_info, is_pure, attrs, span); + }); +} -Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bool is_pure, +Function Function::CreateEmpty(ffi::Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { - Array param_sinfo; + ffi::Array param_sinfo; for (const Var& param : params) { ICHECK(param->struct_info_.defined()) << "relax.Function requires params to contain struct_info_."; @@ -634,7 +639,7 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo }(); // set the fields - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); n->is_pure = is_pure; @@ -645,18 +650,18 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo return Function(std::move(n)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.FunctionCreateEmpty", - [](Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { + "relax.FunctionCreateEmpty", [](ffi::Array params, StructInfo ret_struct_info, + bool is_pure, DictAttrs attrs, Span span) { return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); }); -}); +} // Special opaque derivation function for ExternFunc // Take look at sinfo_args to figure out the return StructInfo. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tvm.relax.struct_info.infer_by_sinfo_args", [](const Call& call, const BlockBuilder& ctx) -> StructInfo { @@ -670,7 +675,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return TupleStructInfo(call->sinfo_args); } }); -}); +} // Get the derive function. FuncStructInfo GetExternFuncStructInfo() { @@ -680,32 +685,32 @@ FuncStructInfo GetExternFuncStructInfo() { return FuncStructInfo::OpaqueFunc(derive); } -ExternFunc::ExternFunc(String global_symbol, Span span) +ExternFunc::ExternFunc(ffi::String global_symbol, Span span) : ExternFunc(global_symbol, GetExternFuncStructInfo(), span) {} -ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) { +ExternFunc::ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span) { CHECK(struct_info.as()) << "ExternFunc must have FuncStructInfo, " << "but declaration of '" << global_symbol << "' received " << struct_info; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->global_symbol = std::move(global_symbol); n->span = span; n->struct_info_ = struct_info; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ExternFunc", - [](String global_symbol, Optional struct_info, Span span) { - if (struct_info.defined()) { - return ExternFunc(global_symbol, struct_info.value(), span); - } else { - return ExternFunc(global_symbol, span); - } - }); -}); + refl::GlobalDef().def("relax.ExternFunc", [](ffi::String global_symbol, + ffi::Optional struct_info, Span span) { + if (struct_info.defined()) { + return ExternFunc(global_symbol, struct_info.value(), span); + } else { + return ExternFunc(global_symbol, span); + } + }); +} Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. @@ -722,31 +727,31 @@ Expr GetShapeOf(const Expr& expr) { return call_shape_of; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.GetShapeOf", [](const Expr& expr) { return GetShapeOf(expr); }) .def("relax.FuncWithAttr", - [](BaseFunc func, String key, ObjectRef value) -> Optional { + [](BaseFunc func, ffi::String key, ObjectRef value) -> ffi::Optional { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); } return std::nullopt; }) .def("relax.FuncWithAttrs", - [](BaseFunc func, Map attr_map) -> Optional { + [](BaseFunc func, ffi::Map attr_map) -> ffi::Optional { if (func->IsInstance()) { return WithAttrs(Downcast(std::move(func)), attr_map); } return std::nullopt; }) - .def("relax.FuncWithoutAttr", [](BaseFunc func, String key) -> Optional { + .def("relax.FuncWithoutAttr", [](BaseFunc func, ffi::String key) -> ffi::Optional { if (func->IsInstance()) { return WithoutAttr(Downcast(std::move(func)), key); } return std::nullopt; }); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index d772613b5d04..6ebc56feebe2 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -127,7 +127,7 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) { this->VisitExpr(field); } if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -135,7 +135,7 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) { void ExprVisitor::VisitExpr_(const VarNode* op) { this->VisitSpan(op->span); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -167,7 +167,7 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { } if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -178,7 +178,7 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitExpr(op->false_branch); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -189,7 +189,7 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -200,7 +200,7 @@ void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { this->VisitSpan(op->span); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -217,14 +217,14 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) { this->VisitExpr(op->body); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } void ExprVisitor::VisitExpr_(const PrimValueNode* op) { this->VisitPrimExpr(op->value); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } this->VisitSpan(op->span); } @@ -327,12 +327,12 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.post_order_visit", [](Expr expr, ffi::Function f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); }); }); -}); +} // ================== // ExprMutatorBase @@ -360,24 +360,24 @@ StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( const FuncStructInfoNode* op) { // Do not recurse into function struct info // as they won't contain ref to values in current scope. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { // Constant' struct info won't be affected by Expr/PrimExpr change. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { // FuncStructInfo won't be affected by Expr/PrimExpr change. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { bool unchanged = true; - tvm::Array fields; + tvm::ffi::Array fields; for (Expr field : op->fields) { Expr new_field = this->VisitExpr(field); fields.push_back(new_field); @@ -388,7 +388,7 @@ Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { // If tuple's struct info change it means that // one of its fields' struct info will change // so un-changed already implies that struct info won't change - return GetRef(op); + return ffi::GetRef(op); } else { // when there is a change return a new tuple node return Tuple(fields, op->span); @@ -399,7 +399,7 @@ Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { // struct info of var-use should remain stable // or the var itself will get replaced - return GetRef(op); + return ffi::GetRef(op); } // Visit the use-site of a defined DataflowVar @@ -413,7 +413,7 @@ Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { Expr body = this->VisitExpr(op->body); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); } @@ -423,14 +423,14 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { Expr new_op = this->VisitExpr(call_node->op); bool unchanged = call_node->op.same_as(new_op); - Array sinfo_args; + ffi::Array sinfo_args; for (StructInfo sinfo_arg : call_node->sinfo_args) { StructInfo new_sinfo_arg = this->VisitExprDepStructInfoField(sinfo_arg); sinfo_args.push_back(new_sinfo_arg); unchanged &= new_sinfo_arg.same_as(sinfo_arg); } - tvm::Array call_args; + tvm::ffi::Array call_args; for (Expr arg : call_node->args) { Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); @@ -438,7 +438,7 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { } if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { - return GetRef(call_node); + return ffi::GetRef(call_node); } else { return Call(new_op, call_args, call_node->attrs, sinfo_args, call_node->span); } @@ -451,20 +451,20 @@ Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); } } -Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { auto t = this->VisitExpr(op->tuple); if (op->tuple.same_as(t)) { // struct info can be deterministically derived by tuple and index // if t does not change, then struct info won't change. - return GetRef(op); + return ffi::GetRef(op); } else { return TupleGetItem(t, op->index, op->span); } @@ -475,21 +475,21 @@ Expr ExprMutatorBase::VisitExpr_(const PrimValueNode* op) { if (op->value.same_as(value)) { // struct info can be deterministically derived by value // if value does not change, then struct info won't change. - return GetRef(op); + return ffi::GetRef(op); } return PrimValue(value, op->span); } -Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return ffi::GetRef(op); } -Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); if (values.same_as(op->values)) { // If values does not change, struct info won't change. - return GetRef(op); + return ffi::GetRef(op); } else { return ShapeExpr(values, op->span); } @@ -497,12 +497,12 @@ Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { // StructInfo of function remains value independent. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; - Array blocks; + ffi::Array blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { @@ -515,13 +515,13 @@ Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } return SeqExpr(blocks, body); } BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { - Array bindings; + ffi::Array bindings; if (const auto* node = block.as()) { for (auto binding : node->bindings) { if (auto var_binding = binding.as()) { @@ -562,7 +562,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { } // default case return self. - return GetRef(op); + return ffi::GetRef(op); } // Visit the use-site of a defined DataflowVar @@ -571,7 +571,7 @@ Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { } Expr ExprMutator::VisitExpr_(const FunctionNode* op) { - tvm::Array params; + tvm::ffi::Array params; bool all_params_unchanged = true; for (Var param : op->params) { Var new_param = this->VisitVarDef(param); @@ -586,7 +586,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { if (all_params_unchanged && body.same_as(op->body)) { // No changes to the function, return the original object - return GetRef(op); + return ffi::GetRef(op); } else if (IsBaseOf(GetStructInfo(body), op->ret_struct_info)) { // If the function was mutated into a form that can no longer // propagate shape information all the way to the return value, we @@ -615,7 +615,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); } @@ -623,7 +623,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; - Array blocks; + ffi::Array blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { @@ -642,7 +642,7 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return SeqExpr(blocks, body); } @@ -671,7 +671,7 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { // fast path: re-emit binding if nothing changes if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); return; } @@ -704,7 +704,7 @@ void ExprMutator::VisitBinding_(const MatchCastNode* binding) { if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && new_struct_info.same_as(binding->struct_info)) { // re-emit old binding if nothing changes - return GetRef(binding); + return ffi::GetRef(binding); } else { new_value = builder_->NormalizeArgument(new_value); new_var = WithStructInfo(new_var, new_struct_info); @@ -749,14 +749,14 @@ Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { Var ExprMutator::VisitVarDef_(const VarNode* var) { if (auto* sinfo = var->struct_info_.as()) { - StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + StructInfo struct_info = this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); if (struct_info.same_as(var->struct_info_)) { - return GetRef(var); + return ffi::GetRef(var); } else { return Var(var->vid, struct_info, var->span); } } else { - return GetRef(var); + return ffi::GetRef(var); } } @@ -794,7 +794,7 @@ Var ExprMutator::VisitVarDef(const Var& var) { return ret; } -Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { +Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::Optional> params) { ICHECK(expr->IsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; @@ -838,7 +838,9 @@ Expr ExprMutator::VisitWithInnerScope(const Expr& expr) { return ret; } -Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } +ffi::Optional ExprMutator::LookupBinding(const Var& var) { + return builder_->LookupBinding(var); +} Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { ICHECK(struct_info.defined()); diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 299839d31f4b..b7d61bfda8ec 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -110,28 +110,29 @@ class PyExprVisitorNode : public Object, public ExprVisitor { PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); void VisitBinding_(const VarBindingNode* binding) - PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(binding), f_visit_var_binding_, ExprVisitor::VisitBinding_(binding)); void VisitBinding_(const MatchCastNode* binding) - PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_cast_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(binding), f_visit_match_cast_, ExprVisitor::VisitBinding_(binding)); void VisitBindingBlock(const BindingBlock& block) PY_EXPR_VISITOR_DEFAULT(block, f_visit_binding_block, ExprVisitor::VisitBindingBlock(block)); void VisitBindingBlock_(const BindingBlockNode* block) - PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_binding_block_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(block), f_visit_binding_block_, ExprVisitor::VisitBindingBlock_(block)); void VisitBindingBlock_(const DataflowBlockNode* block) - PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(block), f_visit_dataflow_block_, ExprVisitor::VisitBindingBlock_(block)); void VisitVarDef(const Var& var) PY_EXPR_VISITOR_DEFAULT(var, f_visit_var_def, ExprVisitor::VisitVarDef(var)); void VisitVarDef_(const VarNode* var) - PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_var_def_, ExprVisitor::VisitVarDef_(var)); + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(var), f_visit_var_def_, + ExprVisitor::VisitVarDef_(var)); void VisitVarDef_(const DataflowVarNode* var) - PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(var), f_visit_dataflow_var_def_, ExprVisitor::VisitVarDef_(var)); void VisitSpan(const Span& span) @@ -141,8 +142,8 @@ class PyExprVisitorNode : public Object, public ExprVisitor { // PyExprVisitorNode has no fields to register } - static constexpr const char* _type_key = "expr_functor.PyExprVisitor"; - TVM_DECLARE_BASE_OBJECT_INFO(PyExprVisitorNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("expr_functor.PyExprVisitor", PyExprVisitorNode, Object); private: // initialize the vtable. @@ -176,6 +177,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor { */ class PyExprVisitor : public ObjectRef { public: + explicit PyExprVisitor(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyExprVisitor with customized methods on the python-side. * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. @@ -227,7 +231,7 @@ class PyExprVisitor : public ObjectRef { ffi::Function f_visit_dataflow_block_, ffi::Function f_visit_var_def, ffi::Function f_visit_var_def_, ffi::Function f_visit_dataflow_var_def_, ffi::Function f_visit_span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_visit_expr = f_visit_expr; n->f_visit_binding = f_visit_binding; n->f_visit_binding_block = f_visit_binding_block; @@ -258,7 +262,7 @@ class PyExprVisitor : public ObjectRef { return PyExprVisitor(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprVisitor, ObjectRef, PyExprVisitorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyExprVisitor, ObjectRef, PyExprVisitorNode); }; /*! @@ -348,14 +352,14 @@ class PyExprMutatorNode : public Object, public ExprMutator { void VisitBinding_(const VarBindingNode* binding) { if (f_visit_var_binding_ != nullptr) - f_visit_var_binding_(GetRef(binding)); + f_visit_var_binding_(ffi::GetRef(binding)); else ExprMutator::VisitBinding_(binding); } void VisitBinding_(const MatchCastNode* binding) { if (f_visit_match_cast_ != nullptr) - f_visit_match_cast_(GetRef(binding)); + f_visit_match_cast_(ffi::GetRef(binding)); else ExprMutator::VisitBinding_(binding); } @@ -365,18 +369,19 @@ class PyExprMutatorNode : public Object, public ExprMutator { BindingBlock); BindingBlock VisitBindingBlock_(const BindingBlockNode* block) - PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_binding_block_, + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(block), f_visit_binding_block_, ExprMutator::VisitBindingBlock_(block), BindingBlock); BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) - PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(block), f_visit_dataflow_block_, ExprMutator::VisitBindingBlock_(block), BindingBlock); Var VisitVarDef(const Var& var) PY_EXPR_MUTATOR_DEFAULT(var, f_visit_var_def, ExprMutator::VisitVarDef(var), Var); - Var VisitVarDef_(const VarNode* var) PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_var_def_, - ExprMutator::VisitVarDef_(var), Var); + Var VisitVarDef_(const VarNode* var) + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(var), f_visit_var_def_, + ExprMutator::VisitVarDef_(var), Var); Var VisitVarDef_(const DataflowVarNode* var) - PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(var), f_visit_dataflow_var_def_, ExprMutator::VisitVarDef_(var), Var); /*! @@ -400,8 +405,8 @@ class PyExprMutatorNode : public Object, public ExprMutator { refl::ObjectDef().def_ro("builder_", &PyExprMutatorNode::builder_); } - static constexpr const char* _type_key = "expr_functor.PyExprMutator"; - TVM_DECLARE_BASE_OBJECT_INFO(PyExprMutatorNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("expr_functor.PyExprMutator", PyExprMutatorNode, Object); private: // initialize the vtable. @@ -459,6 +464,9 @@ class PyExprMutatorNode : public Object, public ExprMutator { */ class PyExprMutator : public ObjectRef { public: + explicit PyExprMutator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyExprMutator with customized methods on the python-side. * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. @@ -510,7 +518,7 @@ class PyExprMutator : public ObjectRef { ffi::Function f_visit_dataflow_block_, ffi::Function f_visit_var_def, ffi::Function f_visit_var_def_, ffi::Function f_visit_dataflow_var_def_, ffi::Function f_visit_span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->builder_ = builder_; n->f_visit_expr = f_visit_expr; n->f_visit_constant_ = f_visit_constant_; @@ -541,11 +549,12 @@ class PyExprMutator : public ObjectRef { n->f_visit_span = f_visit_span; return PyExprMutator(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyExprMutator, ObjectRef, PyExprMutatorNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); refl::GlobalDef() .def("relax.MakePyExprVisitor", PyExprVisitor::MakePyExprVisitor) .def("relax.PyExprVisitorVisitExpr", @@ -652,12 +661,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](PyExprMutator mutator, Id id, Var var) { return mutator->var_remap_[id] = var; }) .def("relax.PyExprMutatorGetVarRemap", [](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PyExprVisitorNode::RegisterReflection(); PyExprMutatorNode::RegisterReflection(); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index d2460a42ce75..22ed4e9ea382 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -30,7 +30,7 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { StructInfoNode::RegisterReflection(); ObjectStructInfoNode::RegisterReflection(); PrimStructInfoNode::RegisterReflection(); @@ -38,22 +38,22 @@ TVM_FFI_STATIC_INIT_BLOCK({ TensorStructInfoNode::RegisterReflection(); TupleStructInfoNode::RegisterReflection(); FuncStructInfoNode::RegisterReflection(); -}); +} ObjectStructInfo::ObjectStructInfo(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.ObjectStructInfo", [](Span span) { return ObjectStructInfo(span); }); -}); +} // Prim PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = value->dtype; n->value = std::move(value); n->span = span; @@ -61,25 +61,25 @@ PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) { } PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->value = std::nullopt; n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.PrimStructInfoFromDtype", [](DataType dtype, Span span) { return PrimStructInfo(dtype, span); }) .def("relax.PrimStructInfoFromValue", [](PrimExpr value, Span span) { return PrimStructInfo(value, span); }); -}); +} // Shape -ShapeStructInfo::ShapeStructInfo(Array values, Span span) { - ObjectPtr n = make_object(); +ShapeStructInfo::ShapeStructInfo(ffi::Array values, Span span) { + ObjectPtr n = ffi::make_object(); n->ndim = static_cast(values.size()); n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { @@ -94,17 +94,17 @@ ShapeStructInfo::ShapeStructInfo(Array values, Span span) { } ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.ShapeStructInfo", [](Optional> values, int ndim, Span span) { + "relax.ShapeStructInfo", [](ffi::Optional> values, int ndim, Span span) { if (values.defined()) { CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; return ShapeStructInfo(values.value(), span); @@ -112,14 +112,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ return ShapeStructInfo(ndim, span); } }); -}); +} // Tensor -TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Optional vdevice, +TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, ffi::Optional vdevice, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); // assign ndim before move - Optional sinfo = MatchStructInfo(shape); + ffi::Optional sinfo = MatchStructInfo(shape); ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; ICHECK(shape.defined()) << "Must provide a shape in this constructor"; ICHECK(shape->IsInstance() || shape->IsInstance()) @@ -133,8 +133,9 @@ TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Optional data_ = std::move(n); } -TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Optional vdevice, Span span) { - ObjectPtr n = make_object(); +TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, ffi::Optional vdevice, + Span span) { + ObjectPtr n = ffi::make_object(); CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->dtype = dtype; @@ -143,37 +144,39 @@ TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Optional v data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.TensorStructInfo", [](Optional shape, Optional dtype, - int ndim, VDevice vdevice, Span span) { - if (shape.defined()) { - CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; - return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); - } else { - return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); - } - }); -}); + refl::GlobalDef().def( + "relax.TensorStructInfo", [](ffi::Optional shape, ffi::Optional dtype, + int ndim, VDevice vdevice, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); + } else { + return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); + } + }); +} // Tuple -TupleStructInfo::TupleStructInfo(Array fields, Span span) { - ObjectPtr n = make_object(); +TupleStructInfo::TupleStructInfo(ffi::Array fields, Span span) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.TupleStructInfo", [](Array fields, Span span) { + refl::GlobalDef().def("relax.TupleStructInfo", [](ffi::Array fields, Span span) { return TupleStructInfo(fields, span); }); -}); +} // Func -FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, bool purity, Span span) { - ObjectPtr n = make_object(); +FuncStructInfo::FuncStructInfo(ffi::Array params, StructInfo ret, bool purity, + Span span) { + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->ret = std::move(ret); n->purity = std::move(purity); @@ -183,7 +186,7 @@ FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, bool pu FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->derive_func = std::move(derive_func); n->ret = ObjectStructInfo(); n->purity = std::move(purity); @@ -192,23 +195,23 @@ FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, bool } FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, bool purity, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ret = std::move(ret); n->purity = std::move(purity); n->span = span; return FuncStructInfo(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.FuncStructInfo", - [](Array params, StructInfo ret, bool purity, Span span) { + [](ffi::Array params, StructInfo ret, bool purity, Span span) { return FuncStructInfo(params, ret, purity, span); }) .def("relax.FuncStructInfoOpaqueFunc", - [](Optional ret, Optional derive_func, bool purity, - Span span) { + [](ffi::Optional ret, ffi::Optional derive_func, + bool purity, Span span) { if (derive_func.defined()) { ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); @@ -216,7 +219,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), purity, span); } }); -}); +} // Helper functions void UpdateStructInfo(Expr expr, StructInfo struct_info) { @@ -229,13 +232,13 @@ void UpdateStructInfo(Expr expr, StructInfo struct_info) { expr->struct_info_ = struct_info; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.UpdateStructInfo", [](Expr expr, StructInfo struct_info) { UpdateStructInfo(expr, struct_info); }) .def("ir.ExprStructInfo", [](Expr expr) { return GetStructInfo(expr); }); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc index ea8f1da8f04b..58df3c24ff8e 100644 --- a/src/relax/ir/struct_info_functor.cc +++ b/src/relax/ir/struct_info_functor.cc @@ -68,24 +68,24 @@ void StructInfoVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { } StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { - return GetRef(op); + return ffi::GetRef(op); } StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { if (!op->value.defined()) { - return GetRef(op); + return ffi::GetRef(op); } auto new_expr = VisitStructInfoExprField(op->value.value()); if (new_expr.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return PrimStructInfo(new_expr); } } StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { - Optional> values; + ffi::Optional> values; if (op->values.defined()) { // if no changes are made the original array will be returned. @@ -94,14 +94,14 @@ StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { } if (values.same_as(op->values)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ShapeStructInfo(values.value(), op->span); } } StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { - Optional shape; + ffi::Optional shape; if (op->shape.defined()) { shape = this->VisitStructInfoExprField(op->shape.value()); @@ -110,7 +110,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { VDevice vdev = op->vdevice.value_or(VDevice()); if (shape.same_as(op->shape)) { - return GetRef(op); + return ffi::GetRef(op); } else { return TensorStructInfo(shape.value(), op->dtype, vdev, op->span); } @@ -123,18 +123,18 @@ StructInfo StructInfoMutator::VisitStructInfo_(const distributed::DTensorStructI } StructInfo StructInfoMutator::VisitStructInfo_(const TupleStructInfoNode* op) { - Array fields = + ffi::Array fields = op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); if (fields.same_as(op->fields)) { - return GetRef(op); + return ffi::GetRef(op); } else { return TupleStructInfo(fields, op->span); } } StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { - Optional> params; + ffi::Optional> params; if (op->params.defined()) { params = op->params.value().Map( @@ -144,7 +144,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { StructInfo ret = this->VisitStructInfo(op->ret); if (params.same_as(op->params) && ret.same_as(op->ret)) { - return GetRef(op); + return ffi::GetRef(op); } else { ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; return FuncStructInfo(params.value(), ret, op->purity, op->span); diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc index ab2d91abcc86..d579aea632bc 100644 --- a/src/relax/ir/tir_pattern.cc +++ b/src/relax/ir/tir_pattern.cc @@ -22,11 +22,11 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ MatchResultNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MatchResultNode::RegisterReflection(); } -MatchResult::MatchResult(TIRPattern pattern, Array symbol_values, - Array matched_buffers) { - auto n = make_object(); +MatchResult::MatchResult(TIRPattern pattern, ffi::Array symbol_values, + ffi::Array matched_buffers) { + auto n = ffi::make_object(); n->pattern = std::move(pattern); n->symbol_values = std::move(symbol_values); n->matched_buffers = std::move(matched_buffers); diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index fb106e2092db..39c754361360 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -81,9 +81,7 @@ class FunctionPassNode : public tvm::transform::PassNode { * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } - - static constexpr const char* _type_key = "relax.FunctionPass"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.FunctionPass", FunctionPassNode, PassNode); private: }; @@ -98,12 +96,12 @@ class FunctionPass : public Pass { TVM_DLL FunctionPass(std::function pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FunctionPass, Pass, FunctionPassNode); }; FunctionPass::FunctionPass(std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -138,7 +136,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) for (const auto& it : updated_mod->functions) { // only picks up relax::Function if (auto* n = it.second.as()) { - Function func = GetRef(n); + Function func = ffi::GetRef(n); auto updated_func = pass_func(func, updated_mod, pass_ctx); updates.push_back({it.first, updated_func}); } @@ -160,12 +158,13 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) } Pass CreateFunctionPass(std::function pass_func, - int opt_level, String name, tvm::Array required, bool traceable) { + int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(std::move(pass_func), pass_info); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.transform.MakeFunctionPass", @@ -176,7 +175,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }; return FunctionPass(wrapped_pass_func, pass_info); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -218,9 +217,7 @@ class DataflowBlockPassNode : public tvm::transform::PassNode { IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; PassInfo Info() const override { return pass_info; } - - static constexpr const char* _type_key = "relax.DataflowBlockPass"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockPassNode, PassNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.DataflowBlockPass", DataflowBlockPassNode, PassNode); }; /*! \brief Helper to apply the passed function to dataflow blocks.*/ @@ -238,14 +235,14 @@ class DataflowBlockMutator : public ExprMutator { */ BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final { // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock - Map global_scope_vars; - Map symbolic_vars; + ffi::Map global_scope_vars; + ffi::Map symbolic_vars; for (const Binding& binding : n->bindings) { Var var = binding->var; if (const auto* match_cast = binding.as()) { auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); for (const tir::VarNode* var : collected_vars) { - symbolic_vars.Set(var->name_hint, GetRef(var)); + symbolic_vars.Set(var->name_hint, ffi::GetRef(var)); } } if (!var.as()) { @@ -254,7 +251,7 @@ class DataflowBlockMutator : public ExprMutator { } // apply pass_func_ to the DataflowBlock - DataflowBlock block = GetRef(n); + DataflowBlock block = ffi::GetRef(n); DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_); // raise error if there are updates of recorded Global Scope Vars and Symbolic Vars @@ -319,13 +316,13 @@ class DataflowBlockPass : public Pass { std::function pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockPass, Pass, DataflowBlockPassNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataflowBlockPass, Pass, DataflowBlockPassNode); }; DataflowBlockPass::DataflowBlockPass( std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -361,7 +358,7 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass for (const auto& it : updated_mod->functions) { // only picks up relax::Function if (auto* n = it.second.as()) { - Function func = GetRef(n); + Function func = ffi::GetRef(n); Function updated_func = Downcast(dataflow_block_mutator.VisitExpr(func)); updates.push_back({it.first, updated_func}); } @@ -384,12 +381,12 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass Pass CreateDataflowBlockPass( std::function pass_func, int opt_level, - String name, tvm::Array required, bool traceable) { + ffi::String name, tvm::ffi::Array required, bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return DataflowBlockPass(std::move(pass_func), pass_info); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.transform.MakeDataflowBlockPass", @@ -401,7 +398,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }; return DataflowBlockPass(wrapped_pass_func, pass_info); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -411,10 +408,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << info->opt_level; }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { FunctionPassNode::RegisterReflection(); DataflowBlockPassNode::RegisterReflection(); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 1f0de47f1f83..faa0814f4c9d 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -28,39 +28,39 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { ShapeTypeNode::RegisterReflection(); TensorTypeNode::RegisterReflection(); ObjectTypeNode::RegisterReflection(); PackedFuncTypeNode::RegisterReflection(); -}); +} ShapeType::ShapeType(int ndim, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ndim = ndim; n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.ShapeType", [](int ndim, Span span) { return ShapeType(ndim, span); }); -}); +} ObjectType::ObjectType(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.ObjectType", [](Span span) { return ObjectType(span); }); -}); +} TensorType::TensorType(int ndim, DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ndim = std::move(ndim); n->dtype = std::move(dtype); n->span = span; @@ -68,30 +68,30 @@ TensorType::TensorType(int ndim, DataType dtype, Span span) { } TensorType TensorType::CreateUnknownNDim(DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ndim = -1; n->dtype = std::move(dtype); n->span = std::move(span); return TensorType(std::move(n)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.TensorType", [](int ndim, DataType dtype, Span span) { return TensorType(ndim, dtype, span); }); -}); +} PackedFuncType::PackedFuncType(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.PackedFuncType", [](Span span) { return PackedFuncType(span); }); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index f46150654f0e..29036f42f846 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -28,14 +28,14 @@ namespace relax { /* relax.ccl.allreduce */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { AllReduceAttrs::RegisterReflection(); AllGatherAttrs::RegisterReflection(); ScatterCollectiveAttrs::RegisterReflection(); -}); +} -Expr allreduce(Expr x, String op_type, bool in_group) { - ObjectPtr attrs = make_object(); +Expr allreduce(Expr x, ffi::String op_type, bool in_group) { + ObjectPtr attrs = ffi::make_object(); attrs->op_type = std::move(op_type); attrs->in_group = std::move(in_group); @@ -43,10 +43,10 @@ Expr allreduce(Expr x, String op_type, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.ccl.allreduce", allreduce); -}); +} StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -64,7 +64,7 @@ TVM_REGISTER_OP("relax.ccl.allreduce") /* relax.ccl.allgather */ Expr allgather(Expr x, int num_workers, bool in_group) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->num_workers = std::move(num_workers); attrs->in_group = std::move(in_group); @@ -72,10 +72,10 @@ Expr allgather(Expr x, int num_workers, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.ccl.allgather", allgather); -}); +} StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -88,7 +88,7 @@ StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { if (!input_shape.defined()) { return input_sinfo; } - Array output_shape = input_shape.value(); + ffi::Array output_shape = input_shape.value(); output_shape.Set(0, floor(output_shape[0] * num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } @@ -106,10 +106,10 @@ Expr broadcast_from_worker0(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.ccl.broadcast_from_worker0", broadcast_from_worker0); -}); +} StructInfo InferStructInfoBroadcastFromZero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -126,7 +126,7 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0") /* relax.ccl.scatter_from_worker0 */ Expr scatter_from_worker0(Expr data, int num_workers, int axis) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->num_workers = std::move(num_workers); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.ccl.scatter_from_worker0"); @@ -134,10 +134,10 @@ Expr scatter_from_worker0(Expr data, int num_workers, int axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.ccl.scatter_from_worker0", scatter_from_worker0); -}); +} StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -158,7 +158,7 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { << " while num_workers is " << num_workers); } - Array output_shape = input_shape.value(); + ffi::Array output_shape = input_shape.value(); output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h index 82ea3935675d..1d049382d0ae 100644 --- a/src/relax/op/ccl/ccl.h +++ b/src/relax/op/ccl/ccl.h @@ -33,7 +33,7 @@ namespace tvm { namespace relax { /*! \brief AllReduce. */ -Expr allreduce(Expr data, String op_type, bool in_group); +Expr allreduce(Expr data, ffi::String op_type, bool in_group); /*! \brief AllGather. */ Expr allgather(Expr data, int num_workers, bool in_group); diff --git a/src/relax/op/distributed/binary.h b/src/relax/op/distributed/binary.h index 7e89c6497dcc..127dec433afa 100644 --- a/src/relax/op/distributed/binary.h +++ b/src/relax/op/distributed/binary.h @@ -36,7 +36,8 @@ namespace distributed { template StructInfo InferDistStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo x1_sinfo, x2_sinfo; x1_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; x2_sinfo = input_dtensor_sinfos[1]->tensor_sinfo; @@ -55,7 +56,7 @@ StructInfo InferDistStructInfoBroadcast(const Call& call, const BlockBuilder& ct // Shapes and ndims if (x1_shape && x2_shape) { // If all inputs have shapes, directly infer shapes - Optional> output_shape = + ffi::Optional> output_shape = InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); if (!output_shape.defined()) { output_tensor_sinfo = TensorStructInfo(output_dtype, /*ndim=*/output_ndim); diff --git a/src/relax/op/distributed/ccl.cc b/src/relax/op/distributed/ccl.cc index 885b084856a1..6ba63986980e 100644 --- a/src/relax/op/distributed/ccl.cc +++ b/src/relax/op/distributed/ccl.cc @@ -25,7 +25,7 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index f9651d8225a4..636891366194 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -37,13 +37,13 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ DistributionAttrs::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { DistributionAttrs::RegisterReflection(); } /* relax.dist.annotate_sharding */ Expr annotate_sharding(Expr input, distributed::DeviceMesh device_mesh, distributed::Placement placement) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->device_mesh = device_mesh; attrs->placement = placement; @@ -51,10 +51,10 @@ Expr annotate_sharding(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.dist.annotate_sharding", annotate_sharding); -}); +} StructInfo InferStructInfoAnnotateSharding(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -71,7 +71,7 @@ TVM_REGISTER_OP("relax.dist.annotate_sharding") Expr redistribute(Expr input, distributed::DeviceMesh device_mesh, distributed::Placement placement) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->device_mesh = device_mesh; attrs->placement = placement; @@ -79,10 +79,10 @@ Expr redistribute(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.dist.redistribute", redistribute); -}); +} StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -120,8 +120,8 @@ TVM_REGISTER_OP("relax.dist.call_tir_local_view") .set_attr("FPurity", Bool(true)); Expr MakeCallTIRLocalView(Expr func, Tuple args, - Array out_sinfo_list, - Optional packed_ints) { + ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->tensor_sinfo->shape.as(); CHECK(shape != nullptr) @@ -148,10 +148,10 @@ Expr MakeCallTIRLocalView(Expr func, Tuple args, return call; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.dist.call_tir_local_view", MakeCallTIRLocalView); -}); +} StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -175,14 +175,14 @@ StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { << " while num_workers is " << num_workers); } - Array output_shape = input_shape.value(); + ffi::Array output_shape = input_shape.value(); output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { using namespace distributed; - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; @@ -212,7 +212,7 @@ StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { } Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->num_workers = std::move(num_workers); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.dist.redistribute_replica_to_shard"); @@ -220,11 +220,11 @@ Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) { return Call(op, {std::move(input)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.dist.redistribute_replica_to_shard", redistribute_replica_to_shard); -}); +} TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard") .set_num_inputs(1) diff --git a/src/relax/op/distributed/linear_algebra.cc b/src/relax/op/distributed/linear_algebra.cc index 727b52c462ec..8fc9cd58d1fc 100644 --- a/src/relax/op/distributed/linear_algebra.cc +++ b/src/relax/op/distributed/linear_algebra.cc @@ -25,7 +25,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo x1_sinfo, x2_sinfo; x1_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; x2_sinfo = input_dtensor_sinfos[1]->tensor_sinfo; @@ -67,11 +68,11 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) ctx->ReportFatal(Diagnostic::Error(call) << "input of distributed operator must have shape"); } - Array x1_shape_prefix{x1_shape->values.begin(), - x1_shape->values.end() - 2 + x1_prepended}; - Array x2_shape_prefix{x2_shape->values.begin(), - x2_shape->values.end() - 2 + x2_appended}; - Optional> output_shape_prefix = + ffi::Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + ffi::Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; + ffi::Optional> output_shape_prefix = InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); ICHECK(output_shape_prefix.defined()) << "Failed to infer output shape of Matmul"; arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -84,7 +85,7 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) << x1_reduction_length << " and " << x2_reduction_length << " respectively."); } - Array output_shape = output_shape_prefix.value(); + ffi::Array output_shape = output_shape_prefix.value(); if (!x1_prepended) { output_shape.push_back(x1_shape->values[x1_ndim - 2]); } diff --git a/src/relax/op/distributed/manipulate.cc b/src/relax/op/distributed/manipulate.cc index 8b18b9578eda..edd5fa7ee7f9 100644 --- a/src/relax/op/distributed/manipulate.cc +++ b/src/relax/op/distributed/manipulate.cc @@ -29,7 +29,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; const auto* attrs = call->attrs.as(); @@ -84,7 +85,8 @@ StructInfo InferDistStructInfoReshape(const Call& call, const BlockBuilder& ctx) if (call->args.size() != 2) { ctx->ReportFatal(Diagnostic::Error(call) << "Reshape op should take 2 arguments"); } - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); @@ -100,7 +102,7 @@ StructInfo InferDistStructInfoReshape(const Call& call, const BlockBuilder& ctx) << call->args[1]->struct_info_->GetTypeKey()); } - Optional> old_shape_values; + ffi::Optional> old_shape_values; if (data_sinfo->shape.defined()) { const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); ICHECK_NOTNULL(old_shape_sinfo); diff --git a/src/relax/op/distributed/nn.cc b/src/relax/op/distributed/nn.cc index ec0bdaeb3242..b020d7902f9b 100644 --- a/src/relax/op/distributed/nn.cc +++ b/src/relax/op/distributed/nn.cc @@ -24,7 +24,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); TensorStructInfo input_tensor_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; diff --git a/src/relax/op/distributed/statistical.cc b/src/relax/op/distributed/statistical.cc index 3bd0f0651718..44ee90e78976 100644 --- a/src/relax/op/distributed/statistical.cc +++ b/src/relax/op/distributed/statistical.cc @@ -25,7 +25,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; const auto* attrs = call->attrs.as(); @@ -60,7 +61,7 @@ StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& ctx->ReportFatal(Diagnostic::Error(call) << "Input of distributed operator must be known shape"); } - Array out_shape; + ffi::Array out_shape; out_shape.reserve(out_ndim); for (int i = 0; i < data_sinfo->ndim; ++i) { if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { diff --git a/src/relax/op/distributed/unary.h b/src/relax/op/distributed/unary.h index cfde689421f7..727707a98525 100644 --- a/src/relax/op/distributed/unary.h +++ b/src/relax/op/distributed/unary.h @@ -34,7 +34,8 @@ namespace distributed { template StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); distributed::DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo input_tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; @@ -47,7 +48,7 @@ StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx, << " requires the input tensor to have float dtype. However, the given input dtype is " << input_tensor_sinfo->dtype); } - auto output_sinfo = make_object(*input_tensor_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_tensor_sinfo.get()); output_sinfo->dtype = f_compute_out_dtype(input_tensor_sinfo); TensorStructInfo out_tensor_sinfo(output_sinfo); return distributed::DTensorStructInfo(out_tensor_sinfo, input_dtensor_sinfo->device_mesh, diff --git a/src/relax/op/distributed/utils.cc b/src/relax/op/distributed/utils.cc index 39bdeea037c5..ffa7dbfa3085 100644 --- a/src/relax/op/distributed/utils.cc +++ b/src/relax/op/distributed/utils.cc @@ -24,16 +24,16 @@ namespace tvm { namespace relax { namespace distributed { -Array GetInputDTensorStructInfo(const Call& call, - const BlockBuilder& ctx) { +ffi::Array GetInputDTensorStructInfo(const Call& call, + const BlockBuilder& ctx) { Op op = Downcast(call->op); - Array args = GetCallArgs(call); - Array input_tensor_sinfo; + ffi::Array args = GetCallArgs(call); + ffi::Array input_tensor_sinfo; input_tensor_sinfo.reserve(args.size()); for (const Expr& arg : args) { const auto* sinfo = GetStructInfoAs(arg); if (sinfo != nullptr) { - input_tensor_sinfo.push_back(GetRef(sinfo)); + input_tensor_sinfo.push_back(ffi::GetRef(sinfo)); } } return input_tensor_sinfo; @@ -42,7 +42,8 @@ Array GetInputDTensorStructInfo(const Call& call StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, const StructInfo& orig_output_sinfo, distributed::FBuildAxisGraph f_build_graph) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); for (int i = 1; i < static_cast(input_dtensor_sinfos.size()); i++) { ICHECK(StructuralEqual()(input_dtensor_sinfos[0]->device_mesh, input_dtensor_sinfos[i]->device_mesh)); @@ -51,7 +52,7 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, Var output_var("output", orig_output_sinfo); distributed::AxisGroupGraph axis_group_graph; f_build_graph(output_var, call, &axis_group_graph); - Array args = GetCallArgs(call); + ffi::Array args = GetCallArgs(call); int n_input_var = input_dtensor_sinfos.size(); for (int i = 0; i < n_input_var; i++) { distributed::DTensorStructInfo dtensor_sinfo = input_dtensor_sinfos[i]; @@ -66,9 +67,9 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, } } axis_group_graph.PropagateShardingSpec(); - Array orig_output_tensor_sinfos; + ffi::Array orig_output_tensor_sinfos; if (const auto* tensor_sinfo = orig_output_sinfo.as()) { - orig_output_tensor_sinfos.push_back(GetRef(tensor_sinfo)); + orig_output_tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else { const auto* tuple_sinfo = orig_output_sinfo.as(); ICHECK(tuple_sinfo); @@ -76,9 +77,9 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, orig_output_tensor_sinfos.push_back(Downcast(sinfo)); } } - Array new_output_dtensor_sinfos; + ffi::Array new_output_dtensor_sinfos; for (int idx = 0; idx < static_cast(orig_output_tensor_sinfos.size()); idx++) { - Array output_placement_specs( + ffi::Array output_placement_specs( std::vector(device_mesh->shape.size(), distributed::PlacementSpec::Replica())); for (int i = 0; i < orig_output_tensor_sinfos[idx]->ndim; i++) { diff --git a/src/relax/op/distributed/utils.h b/src/relax/op/distributed/utils.h index 1656df286784..125a2d242ba5 100644 --- a/src/relax/op/distributed/utils.h +++ b/src/relax/op/distributed/utils.h @@ -42,8 +42,8 @@ namespace distributed { * \return The dtensor struct info of each input. * \note This function require every input tensor to be DTensor. */ -Array GetInputDTensorStructInfo(const Call& call, - const BlockBuilder& ctx); +ffi::Array GetInputDTensorStructInfo(const Call& call, + const BlockBuilder& ctx); /*! * \brief Perform a local sharding spec propagation to infer the output dtensor diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index f6923ecb3ab4..59d845d867f6 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -31,14 +31,15 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ Resize2DAttrs::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { Resize2DAttrs::RegisterReflection(); } /* relax.resize2d */ -Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double cubic_alpha, - int cubic_exclude, double extrapolation_value, Optional out_dtype) { - ObjectPtr attrs = make_object(); +Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, + ffi::String method, ffi::String coordinate_transformation_mode, + ffi::String rounding_method, double cubic_alpha, int cubic_exclude, + double extrapolation_value, ffi::Optional out_dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->roi = std::move(roi); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -53,10 +54,10 @@ Expr resize2d(Expr data, Expr size, Array roi, String layout, String m return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.image.resize2d", resize2d); -}); +} StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1 && call->args.size() != 2) { @@ -93,30 +94,30 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; - Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, GetRef(data_sinfo), data_layout); + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape( + call, ctx, ffi::GetRef(data_sinfo), data_layout); if (!data_shape.defined() || size_value == nullptr) { return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array out_NCHW_shape(data_NCHW_shape); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCHW_shape(data_NCHW_shape); out_NCHW_shape.Set(2, size_value->values[0]); out_NCHW_shape.Set(3, size_value->values[1]); - Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutResize2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutResize2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.image.resize2d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout; - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for resize2d. @@ -147,5 +148,83 @@ TVM_REGISTER_OP("relax.image.resize2d") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.grid_sample */ + +TVM_FFI_STATIC_INIT_BLOCK() { GridSampleAttrs::RegisterReflection(); } + +Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout, + ffi::String padding_mode, bool align_corners) { + ObjectPtr attrs = ffi::make_object(); + attrs->method = std::move(method); + attrs->layout = std::move(layout); + attrs->padding_mode = std::move(padding_mode); + attrs->align_corners = align_corners; + + static const Op& op = Op::Get("relax.image.grid_sample"); + return Call(op, {std::move(data), std::move(grid)}, Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.image.grid_sample", grid_sample); +} + +StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GridSample expects two arguments, while the given number of arguments is " + << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* grid_sinfo = GetStructInfoAs(call->args[1]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GridSample expects the input data to be a Tensor, while the given data is " + << call->args[0]->GetTypeKey()); + } + if (grid_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GridSample expects the grid to be a Tensor, while the given grid is " + << call->args[1]->GetTypeKey()); + } + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, + /*tgt_layout=*/"NCHW", + /*tensor_name=*/"data"); + + DataType out_dtype = data_sinfo->dtype; + + // Output shape: [N, C, grid_H, grid_W] + // grid shape for NCHW layout input is [N, H_out, W_out, 2] + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape( + call, ctx, ffi::GetRef(data_sinfo), data_layout); + const auto* grid_shape = grid_sinfo->shape.as(); + + if (!data_shape.defined() || grid_shape == nullptr) { + return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice); + } + + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + // grid is [N, H_out, W_out, 2], output is [N, C, H_out, W_out] + ffi::Array out_NCHW_shape(data_NCHW_shape); + out_NCHW_shape.Set(2, grid_shape->values[1]); // H_out + out_NCHW_shape.Set(3, grid_shape->values[2]); // W_out + + ffi::Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.image.grid_sample") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("grid", "Tensor", "The grid tensor for sampling.") + .set_attr("FInferStructInfo", InferStructInfoGridSample) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h index 3af171c7bfff..a208aae0921d 100644 --- a/src/relax/op/image/resize.h +++ b/src/relax/op/image/resize.h @@ -33,9 +33,14 @@ namespace tvm { namespace relax { /*! \brief Image resize2d operator. */ -Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double cubic_alpha, - int cubic_exclude, double extrapolation_value, Optional out_dtype); +Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, + ffi::String method, ffi::String coordinate_transformation_mode, + ffi::String rounding_method, double cubic_alpha, int cubic_exclude, + double extrapolation_value, ffi::Optional out_dtype); + +/*! \brief Image grid_sample operator. */ +Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout, + ffi::String padding_mode, bool align_corners); } // namespace relax } // namespace tvm diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 1af12b475136..04a845bd816d 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -30,8 +30,9 @@ namespace tvm { namespace relax { /* relax.op.memory.view */ -Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset) { - Tuple void_expr(Array{}); +Expr view(Expr x, ffi::Optional shape, ffi::Optional dtype, + ffi::Optional relative_byte_offset) { + Tuple void_expr(ffi::Array{}); static const Op& op = Op::Get("relax.memory.view"); return Call(op, { @@ -42,10 +43,10 @@ Expr view(Expr x, Optional shape, Optional dtype, Optional rel }); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.memory.view", view); -}); +} StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 4) { @@ -123,7 +124,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } }(); - auto view_relative_byte_offset = [&]() -> Optional { + auto view_relative_byte_offset = [&]() -> ffi::Optional { StructInfo sinfo = GetStructInfo(arg_relative_byte_offset); if (HasVoidStructInfo(arg_relative_byte_offset)) { @@ -152,9 +153,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } }(); - Optional> input_shape = data_sinfo->GetShape(); + ffi::Optional> input_shape = data_sinfo->GetShape(); - Optional> output_shape = std::nullopt; + ffi::Optional> output_shape = std::nullopt; int output_ndim = kUnknownNDim; if (view_shape_sinfo && view_shape_sinfo->values.defined()) { output_shape = view_shape_sinfo->values.value(); @@ -171,7 +172,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // Helper function, returns the number of bytes per vectorized // element. Cannot use `DataType::bytes`, as it returns the // number of bytes per scalar element. - auto get_size_bytes = [](const DataType& dtype) -> Optional { + auto get_size_bytes = [](const DataType& dtype) -> ffi::Optional { if (dtype.is_void()) { return std::nullopt; } else { @@ -182,7 +183,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // Helper function, returns the number of elements in an array, // given the shape of that array. - auto get_num_elements = [&ctx](const Optional>& shape) -> Optional { + auto get_num_elements = + [&ctx](const ffi::Optional>& shape) -> ffi::Optional { if (!shape.defined()) { return std::nullopt; } @@ -194,11 +196,11 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return ctx->GetAnalyzer()->Simplify(num_elements); }; - Optional input_nelements = get_num_elements(input_shape); - Optional output_nelements = get_num_elements(output_shape); + ffi::Optional input_nelements = get_num_elements(input_shape); + ffi::Optional output_nelements = get_num_elements(output_shape); - Optional input_element_size = get_size_bytes(data_sinfo->dtype); - Optional output_element_size = get_size_bytes(output_dtype); + ffi::Optional input_element_size = get_size_bytes(data_sinfo->dtype); + ffi::Optional output_element_size = get_size_bytes(output_dtype); if (input_nelements && output_nelements && input_element_size && output_element_size && view_relative_byte_offset) { @@ -294,10 +296,10 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tvm.relax.struct_info.infer_view_sinfo", InferStructInfoView); -}); +} Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; @@ -346,7 +348,7 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); - ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + ExternFunc runtime_view_func("runtime.TVMTensorCreateView", runtime_view_sinfo); return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); } @@ -368,10 +370,10 @@ Expr ensure_zero_offset(const Expr& x) { return Call(op, {x}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.memory.ensure_zero_offset", ensure_zero_offset); -}); +} StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h index 77ec7e9833cc..6c23ef7b27a2 100644 --- a/src/relax/op/memory/view.h +++ b/src/relax/op/memory/view.h @@ -30,7 +30,8 @@ namespace tvm { namespace relax { /*! \brief View a tensor with different properties. */ -Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset); +Expr view(Expr x, ffi::Optional shape, ffi::Optional dtype, + ffi::Optional relative_byte_offset); /*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */ Expr ensure_aligned(const Expr& x); diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 916fa2f39f33..bf384e863443 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -28,9 +28,10 @@ namespace relax { /* relax.nn.attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, - Optional causal_mask, Optional window_size) { - ObjectPtr attrs = make_object(); +Expr attention(Expr query, Expr key, Expr value, ffi::Optional bias, + ffi::Optional scale, ffi::Optional causal_mask, + ffi::Optional window_size) { + ObjectPtr attrs = ffi::make_object(); attrs->scale = scale; attrs->causal_mask = causal_mask; attrs->window_size = window_size; @@ -45,9 +46,9 @@ Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, - Optional causal_mask, Optional window_size) { - ObjectPtr attrs = make_object(); + Expr max_seqlen_q, Expr max_seqlen_k, ffi::Optional scale, + ffi::Optional causal_mask, ffi::Optional window_size) { + ObjectPtr attrs = ffi::make_object(); attrs->scale = scale; attrs->causal_mask = causal_mask; attrs->window_size = window_size; @@ -57,19 +58,19 @@ Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr s {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.op.nn.attention", attention) .def("relax.op.nn.attention_var_len", attention_var_len); -}); +} StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo q_sinfo = input_sinfo[0]; TensorStructInfo k_sinfo = input_sinfo[1]; TensorStructInfo v_sinfo = input_sinfo[2]; - auto diag_dim = [&](TensorStructInfo sinfo, String name) { + auto diag_dim = [&](TensorStructInfo sinfo, ffi::String name) { if (sinfo->ndim != 4) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << name << " should have 4 dimension, namely " @@ -89,7 +90,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { PrimExpr num_keys = k_shape->values[1]; PrimExpr head_dim_value = v_shape->values[3]; arith::Analyzer* analyzer = ctx->GetAnalyzer(); - auto diag_equal = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + auto diag_equal = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, ffi::String dim) { if (analyzer->CanProve(v1 != v2)) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << m1 << " " << dim << " and the " << m2 << " " << dim @@ -97,7 +98,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { << v1 << " while the " << dim << " of " << m2 << " is " << v2); } }; - auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + auto multiple_of = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, + ffi::String dim) { if (analyzer->CanProve(indexmod(v1, v2) != 0)) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << m1 << " " << dim << " should be a multiple of " << m2 << " " @@ -121,7 +123,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { << "The bias should have 4 dimensions." << "However, the bias input has " << bias_sinfo->ndim << " dimensions."); } - auto diag_equal_or_broadcast = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + auto diag_equal_or_broadcast = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, + ffi::String dim) { if (analyzer->CanProve(v1 != v2) && !tir::is_one(v2)) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << m1 << " " << dim << " and the " << m2 << " " << dim @@ -136,7 +139,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); } - Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; + ffi::Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype, q_sinfo->vdevice); } @@ -183,7 +186,7 @@ TVM_REGISTER_OP("relax.nn.attention_var_len") .set_attr("FInferStructInfo", InferStructInfoAttention) .set_attr("FPurity", Bool(true)); -TVM_FFI_STATIC_INIT_BLOCK({ AttentionAttrs::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { AttentionAttrs::RegisterReflection(); } } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h index 346907f8e938..f4fe8ad88fd4 100644 --- a/src/relax/op/nn/attention.h +++ b/src/relax/op/nn/attention.h @@ -33,8 +33,9 @@ namespace tvm { namespace relax { /*! \brief fused multi head attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, - Optional causal_mask, Optional window_size); +Expr attention(Expr query, Expr key, Expr value, ffi::Optional bias, + ffi::Optional scale, ffi::Optional causal_mask, + ffi::Optional window_size); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 7346af3b1c98..49e92719ba15 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -31,19 +31,20 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { Conv1DAttrs::RegisterReflection(); Conv2DAttrs::RegisterReflection(); Conv3DAttrs::RegisterReflection(); Conv1DTransposeAttrs::RegisterReflection(); Conv2DTransposeAttrs::RegisterReflection(); -}); +} /* relax.nn.conv1d */ -Expr conv1d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype) { +Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " @@ -60,13 +61,13 @@ Expr conv1d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv1d"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.conv1d", conv1d); -}); +} StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -81,21 +82,22 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); - Array weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCW_shape[1]; @@ -133,19 +135,19 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -InferLayoutOutput InferLayoutConv1d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConv1d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv1d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for conv1d. @@ -200,9 +202,10 @@ TVM_REGISTER_OP("relax.nn.conv1d") /* relax.nn.conv2d */ -Expr conv2d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype) { +Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -225,13 +228,13 @@ Expr conv2d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv2d"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.conv2d", conv2d); -}); +} StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -246,21 +249,22 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCHW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCHW_shape[1]; @@ -303,19 +307,21 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -InferLayoutOutput InferLayoutConv2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConv2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv2d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; - ObjectPtr new_attrs = make_object(*attrs); + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for conv2d. @@ -343,8 +349,10 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, auto kernel_si = GetStructInfo(call->args[1]); TensorStructInfo data_sinfo = data_si.as().value(); TensorStructInfo kernel_sinfo = kernel_si.as().value(); - Optional data_shape = GetRef(data_sinfo->shape.as()); - Optional kernel_shape = GetRef(kernel_sinfo->shape.as()); + ffi::Optional data_shape = + ffi::GetRef(data_sinfo->shape.as()); + ffi::Optional kernel_shape = + ffi::GetRef(kernel_sinfo->shape.as()); bool can_data_proved = CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); @@ -360,14 +368,16 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, new_attrs->kernel_layout = (*it).second[1]; new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + data_layout = LayoutDecision(InitialLayout(4)); + weight_layout = LayoutDecision(InitialLayout(4)); } } } // We don't have a desired layout for conv2d or desired layouts not compatible. // We can just propagate the layout from the input. - data_layout = GetLayoutDecision(var_layout_map, call->args[0]); - weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; new_attrs->data_layout = TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); @@ -399,9 +409,10 @@ TVM_REGISTER_OP("relax.nn.conv2d") /* relax.nn.conv3d */ -Expr conv3d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype) { +Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding3D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -426,13 +437,13 @@ Expr conv3d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv3d"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.conv3d", conv3d); -}); +} StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -447,21 +458,22 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCDHW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); - Array weight_OIDHW_shape = weight2OIDHW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); + ffi::Array weight_OIDHW_shape = weight2OIDHW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCDHW_shape[1]; @@ -510,19 +522,19 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[1]) + 1); out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[2]) + 1); - Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); + ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -InferLayoutOutput InferLayoutConv3d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConv3d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv3d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for conv3d. @@ -575,10 +587,11 @@ TVM_REGISTER_OP("relax.nn.conv3d") .set_attr("FInferMixedPrecision", InferMixedPrecisionConv3d) .set_attr("FPurity", Bool(true)); -Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype) { +Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " @@ -593,7 +606,7 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array(); + auto attrs = ffi::make_object(); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); attrs->output_padding = ConvertIntImmToInt64(output_padding); @@ -607,13 +620,13 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -627,21 +640,22 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout, // /*tgt_layout=*/"NCW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); - Array weight_IOW_shape = weight2IOW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array weight_IOW_shape = weight2IOW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCW_shape[1]; @@ -689,7 +703,7 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& attrs->dilation[0] * (kernel_w - 1) + attrs->output_padding[0] + 1; out_NCW_shape[2] = analyzer->Simplify(out_w); - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } @@ -705,10 +719,11 @@ TVM_REGISTER_OP("relax.nn.conv1d_transpose") /* relax.nn.conv2d_transpose */ -Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype) { +Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (output_padding.size() == 1) { output_padding.push_back(output_padding[0]); @@ -732,7 +747,7 @@ Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array(); + auto attrs = ffi::make_object(); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); attrs->output_padding = ConvertIntImmToInt64(output_padding); @@ -746,13 +761,13 @@ Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -767,21 +782,22 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& /*tgt_layout=*/"NCHW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array weight_IOHW_shape = weight2IOHW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array weight_IOHW_shape = weight2IOHW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCHW_shape[1]; @@ -837,7 +853,7 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& out_NCHW_shape[2] = analyzer->Simplify(out_h); out_NCHW_shape[3] = analyzer->Simplify(out_w); - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index c99f03388e19..4fc175b5aa07 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -36,10 +36,11 @@ namespace tvm { namespace relax { template -inline Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - String out_layout, DataType out_dtype, std::string op_name) { - auto attrs = make_object(); +inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::String out_layout, DataType out_dtype, + std::string op_name) { + auto attrs = ffi::make_object(); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); attrs->dilation = ConvertIntImmToInt64(dilation); @@ -53,19 +54,22 @@ inline Expr MakeConv(Expr data, Expr weight, Array strides, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype); +Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! \brief 2D convolution */ -Expr conv2d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype); +Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! \brief 3D convolution */ -Expr conv3d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype); +Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! * \brief One dimensional transposed convolution operator. @@ -73,10 +77,11 @@ Expr conv3d(Expr data, Expr weight, Array strides, Array padding * This operator is intended to be the backward operator of conv1d. It can be used to calculate the * gradient of the result of conv1d w.r.t. the input of conv1d. */ -Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype); +Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! * \brief Two dimensional transposed convolution operator. @@ -84,10 +89,11 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype); +Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 3597b16a5bcc..f4b9fe400bee 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -27,7 +27,7 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { SoftmaxAttrs::RegisterReflection(); LeakyReluAttrs::RegisterReflection(); SoftplusAttrs::RegisterReflection(); @@ -41,7 +41,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ DropoutAttrs::RegisterReflection(); PadAttrs::RegisterReflection(); PixelShuffleAttrs::RegisterReflection(); -}); +} /* relax.nn.relu */ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(relu, "nn.relu", /*require_float_dtype=*/false); @@ -61,16 +61,16 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/tru /* relax.nn.leakyrelu */ Expr leakyrelu(Expr data, double alpha) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("relax.nn.leakyrelu"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.leakyrelu", leakyrelu); -}); +} TVM_REGISTER_OP("relax.nn.leakyrelu") .set_num_inputs(1) @@ -83,17 +83,17 @@ TVM_REGISTER_OP("relax.nn.leakyrelu") /* relax.nn.softplus */ Expr softplus(Expr data, double beta, double threshold) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->beta = beta; attrs->threshold = threshold; static const Op& op = Op::Get("relax.nn.softplus"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.softplus", softplus); -}); +} TVM_REGISTER_OP("relax.nn.softplus") .set_num_inputs(1) @@ -106,16 +106,16 @@ TVM_REGISTER_OP("relax.nn.softplus") /* relax.nn.prelu */ Expr prelu(Expr data, Expr alpha, int axis = 1) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.prelu"); return Call(op, {data, alpha}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.prelu", prelu); -}); +} StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -133,9 +133,9 @@ StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) { return data_sinfo; } -InferLayoutOutput InferLayoutPRelu(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPRelu( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; @@ -151,7 +151,7 @@ InferLayoutOutput InferLayoutPRelu(const Call& call, layout = LayoutDecision(InitialLayout(ndim)); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis); LayoutDecision alpha_layout = GetLayoutDecision(var_layout_map, call->args[1]); @@ -170,16 +170,16 @@ TVM_REGISTER_OP("relax.nn.prelu") /* relax.nn.softmax */ Expr softmax(Expr data, int axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.softmax"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.softmax", softmax); -}); +} StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -198,9 +198,9 @@ StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { return data_sinfo; } -InferLayoutOutput InferLayoutSoftmax(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutSoftmax( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; @@ -216,7 +216,7 @@ InferLayoutOutput InferLayoutSoftmax(const Call& call, layout = LayoutDecision(InitialLayout(ndim)); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); } @@ -231,16 +231,16 @@ TVM_REGISTER_OP("relax.nn.softmax") /* relax.nn.log_softmax */ Expr log_softmax(Expr data, int axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.log_softmax"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.log_softmax", log_softmax); -}); +} TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) @@ -251,8 +251,8 @@ TVM_REGISTER_OP("relax.nn.log_softmax") /* relax.nn.pad */ -Expr pad(Expr data, Array pad_width, String pad_mode, double pad_value) { - auto attrs = make_object(); +Expr pad(Expr data, ffi::Array pad_width, ffi::String pad_mode, double pad_value) { + auto attrs = ffi::make_object(); attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); attrs->pad_value = pad_value; @@ -260,19 +260,19 @@ Expr pad(Expr data, Array pad_width, String pad_mode, double pad_value) return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.pad", pad); -}); +} StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); int ndim = input_sinfo[0]->ndim; - Array pad_width = attrs->pad_width; + ffi::Array pad_width = attrs->pad_width; ICHECK(static_cast(pad_width.size()) == 2 * ndim) << "Illegal pad_width"; - Array out_shape; + ffi::Array out_shape; if (input_sinfo[0]->shape.defined()) { // Compute output shape by adding corresponding pad width to each axis. const auto* data_shape = input_sinfo[0]->shape.as(); @@ -299,19 +299,19 @@ TVM_REGISTER_OP("relax.nn.pad") /* relax.nn.pixel_shuffle */ Expr pixel_shuffle(Expr data, int upscale_factor) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->upscale_factor = upscale_factor; static const Op& op = Op::Get("relax.nn.pixel_shuffle"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.pixel_shuffle", pixel_shuffle); -}); +} StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); int r = attrs->upscale_factor; ICHECK_GT(r, 0) << "Upscale factor must be positive"; @@ -325,7 +325,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx } const auto* shape = input->shape.as(); - Array in_shape = shape->values; + ffi::Array in_shape = shape->values; int channel_idx = ndim - 3; int h_idx = ndim - 2; @@ -345,7 +345,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx << "Number of input channels must be divisible by the square of the upscale factor"; // Output shape: - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < ndim; ++i) { if (i == channel_idx) { out_shape.push_back(c_in / r_squared); @@ -370,7 +370,8 @@ TVM_REGISTER_OP("relax.nn.pixel_shuffle") /* relax.nn.batchnorm */ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, - const Array& input_sinfo, Array axes) { + const ffi::Array& input_sinfo, + ffi::Array axes) { Op op = Downcast(call->op); int n_input = op->arguments.size(); @@ -405,7 +406,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, } } - std::vector> axis_lengths; + std::vector> axis_lengths; axis_lengths.reserve(n_input); if (const auto* data_shape = data_sinfo->shape.as()) { std::vector lengths; @@ -442,7 +443,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale, double momentum, bool training) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -456,13 +457,13 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ std::move(moving_var)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.batch_norm", batch_norm); -}); +} StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, {attrs->axis}); @@ -478,9 +479,9 @@ StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutBatchNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutBatchNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 5; ++i) { @@ -502,7 +503,7 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call, layout = LayoutDecision(InitialLayout(ndim)); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, (attrs->axis + ndim) % ndim); return InferLayoutOutput( {layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], initial_layouts[4]}, @@ -523,9 +524,9 @@ TVM_REGISTER_OP("relax.nn.batch_norm") /* relax.nn.layer_norm */ -Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, - bool scale) { - ObjectPtr attrs = make_object(); +Expr layer_norm(Expr data, Expr gamma, Expr beta, ffi::Array axes, double epsilon, + bool center, bool scale) { + ObjectPtr attrs = ffi::make_object(); attrs->axes = std::move(axes); attrs->epsilon = epsilon; attrs->center = center; @@ -535,13 +536,13 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.layer_norm", layer_norm); -}); +} StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); @@ -551,9 +552,9 @@ StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { : input_sinfo[0]; } -InferLayoutOutput InferLayoutLayerNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutLayerNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { @@ -566,7 +567,7 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); const auto* input_sinfo = GetStructInfoAs(call->args[0]); int ndim = input_sinfo->ndim; std::vector new_axis; @@ -592,8 +593,8 @@ TVM_REGISTER_OP("relax.nn.layer_norm") /* relax.nn.group_norm */ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, - Array axes, double epsilon, bool center, bool scale) { - ObjectPtr attrs = make_object(); + ffi::Array axes, double epsilon, bool center, bool scale) { + ObjectPtr attrs = ffi::make_object(); attrs->num_groups = num_groups; attrs->channel_axis = channel_axis; attrs->axes = std::move(axes); @@ -605,14 +606,14 @@ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_ax return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.group_norm", group_norm); -}); +} StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); TensorStructInfo data_sinfo = input_sinfo[0]; @@ -666,9 +667,9 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { return data_sinfo; } -InferLayoutOutput InferLayoutGroupNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutGroupNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { @@ -681,7 +682,7 @@ InferLayoutOutput InferLayoutGroupNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); std::vector new_axes; for (const auto& axis : attrs->axes) { new_axes.push_back(FindAxis(layout->layout, axis->value)); @@ -705,9 +706,9 @@ TVM_REGISTER_OP("relax.nn.group_norm") /* relax.nn.instance_norm */ -Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, ffi::Array axes, double epsilon, bool center, bool scale) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->channel_axis = std::move(channel_axis); attrs->axes = std::move(axes); attrs->epsilon = epsilon; @@ -718,14 +719,14 @@ Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array(call->op); - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; TensorStructInfo data_sinfo = input_sinfo[0]; @@ -769,9 +770,9 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx return data_sinfo; } -InferLayoutOutput InferLayoutInstanceNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutInstanceNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { @@ -784,7 +785,7 @@ InferLayoutOutput InferLayoutInstanceNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); std::vector new_axes; for (const auto& axis : attrs->axes) { new_axes.push_back(FindAxis(layout->layout, (axis->value))); @@ -807,8 +808,8 @@ TVM_REGISTER_OP("relax.nn.instance_norm") .set_attr("FPurity", Bool(true)); /* relax.nn.rms_norm */ -Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon) { - ObjectPtr attrs = make_object(); +Expr rms_norm(Expr data, Expr weight, ffi::Array axes, double epsilon) { + ObjectPtr attrs = ffi::make_object(); attrs->axes = std::move(axes); attrs->epsilon = epsilon; @@ -816,13 +817,13 @@ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon) { return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.rms_norm", rms_norm); -}); +} StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); @@ -832,9 +833,9 @@ StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { : input_sinfo[0]; } -InferLayoutOutput InferLayoutRMSNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutRMSNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 2; ++i) { @@ -847,7 +848,7 @@ InferLayoutOutput InferLayoutRMSNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); std::vector new_axes; for (const auto& axis : attrs->axes) { new_axes.push_back(FindAxis(layout->layout, axis->value)); @@ -869,17 +870,17 @@ TVM_REGISTER_OP("relax.nn.rms_norm") /* relax.nn.dropout */ Expr dropout(Expr data, double rate) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->rate = rate; static const Op& op = Op::Get("relax.nn.dropout"); return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.dropout", dropout); -}); +} StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -897,7 +898,7 @@ TVM_REGISTER_OP("relax.nn.dropout") /* relax.nn.cross_entropy_with_logits */ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo pred_sinfo = input_sinfo[0]; TensorStructInfo label_sinfo = input_sinfo[1]; @@ -905,7 +906,7 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo); // infer vdevice - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_sinfo, label_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_sinfo, label_sinfo); // infer ndim if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() && @@ -916,12 +917,12 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx << pred_sinfo->ndim << " while the ndim of labels is " << label_sinfo->ndim); } - Optional> pred_shape_value; + ffi::Optional> pred_shape_value; if (pred_sinfo->shape.defined()) { pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; } - Optional> label_shape_value; + ffi::Optional> label_shape_value; if (label_sinfo->shape.defined()) { label_shape_value = GetStructInfoAs(label_sinfo->shape.value())->values; } @@ -939,7 +940,7 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx } } } - return TensorStructInfo(ShapeExpr(Array()), dtype, vdevice); + return TensorStructInfo(ShapeExpr(ffi::Array()), dtype, vdevice); } Expr cross_entropy_with_logits(Expr predictions, Expr labels) { @@ -947,10 +948,10 @@ Expr cross_entropy_with_logits(Expr predictions, Expr labels) { return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.cross_entropy_with_logits", cross_entropy_with_logits); -}); +} TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") .set_num_inputs(2) @@ -961,9 +962,9 @@ TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") /* relax.nn.nll_loss */ -Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, +Expr nll_loss(Expr predictions, Expr targets, ffi::Optional weights, ffi::String reduction, int ignore_index) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean") << "The argument reduction of NLLLoss should be one of the following " @@ -982,10 +983,10 @@ Expr nll_loss(Expr predictions, Expr targets, Optional weights, String red } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.nll_loss", nll_loss); -}); +} StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (call->args.size() < 2 || call->args.size() > 3) { @@ -1020,12 +1021,12 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { // infer dtype, vdevice DataType output_dtype; - Optional vdevice; + ffi::Optional vdevice; if (wgt_sinfo != nullptr) { - output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef(pred_sinfo), - GetRef(wgt_sinfo)); - vdevice = InferBinaryArithOpOutVDevice(call, ctx, GetRef(pred_sinfo), - GetRef(wgt_sinfo)); + output_dtype = InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_sinfo), + ffi::GetRef(wgt_sinfo)); + vdevice = InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_sinfo), + ffi::GetRef(wgt_sinfo)); } else { output_dtype = pred_sinfo->dtype; vdevice = pred_sinfo->vdevice; @@ -1066,11 +1067,11 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } arith::Analyzer* analyzer = ctx->GetAnalyzer(); - Optional N; - Optional C; - Array output_shape; // N, d1, d2, ..., dk + ffi::Optional N; + ffi::Optional C; + ffi::Array output_shape; // N, d1, d2, ..., dk - Optional> pred_shape_value; + ffi::Optional> pred_shape_value; if (pred_sinfo->shape.defined()) { pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; } @@ -1085,7 +1086,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { ICHECK(pred_sinfo->ndim == static_cast(pred_shape_value.value().size())); N = pred_shape_value.value()[0]; C = pred_shape_value.value()[1]; - output_shape = Array(); + output_shape = ffi::Array(); output_shape.push_back(N.value()); for (size_t i = 2; i < pred_shape_value.value().size(); ++i) { output_shape.push_back(pred_shape_value.value()[i]); @@ -1093,7 +1094,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } } - Optional> tgt_shape_value; + ffi::Optional> tgt_shape_value; if (tgt_sinfo->shape.defined()) { tgt_shape_value = GetStructInfoAs(tgt_sinfo->shape.value())->values; } @@ -1148,7 +1149,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } if (wgt_sinfo != nullptr) { - Optional> wgt_shape_value; + ffi::Optional> wgt_shape_value; if (wgt_sinfo->shape.defined()) { wgt_shape_value = GetStructInfoAs(wgt_sinfo->shape.value())->values; } @@ -1166,7 +1167,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - String reduction = attrs->reduction; + ffi::String reduction = attrs->reduction; if (reduction == "none") { // () or (N,) or (N, d1, d2, ..., dk) @@ -1178,7 +1179,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } } else { // sum or mean. output is scalar - return TensorStructInfo(/*shape=*/ShapeExpr(Array()), output_dtype, vdevice); + return TensorStructInfo(/*shape=*/ShapeExpr(ffi::Array()), output_dtype, vdevice); } } @@ -1187,7 +1188,7 @@ TVM_REGISTER_OP("relax.nn.nll_loss") .set_num_inputs(3) .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") - .add_argument("weights", "Optional", "The weight of each target values.") + .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLoss) .set_attr("FPurity", Bool(true)); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 39f8c2d73800..989dfbb3f613 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -41,9 +41,9 @@ namespace relax { * \param RequireFloatDtype A boolean indicating if the input is required to have float dtype. */ #define RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(OpName, OpRegName, RequireFloatDtype) \ + RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ RELAX_REGISTER_UNARY_OP(OpRegName).set_attr( \ - "FInferStructInfo", InferStructInfoUnaryArith); \ - RELAX_UNARY_OP_INTERFACE(OpName, OpRegName); + "FInferStructInfo", InferStructInfoUnaryArith) /*! \brief Rectified linear unit. */ Expr relu(Expr data); @@ -83,19 +83,19 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ int axis, double epsilon, bool center, bool scale, double momentum, bool training); /*! \brief Compute layer normalization. */ -Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, - bool scale); +Expr layer_norm(Expr data, Expr gamma, Expr beta, ffi::Array axes, double epsilon, + bool center, bool scale); /*! \brief Compute group normalization. */ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, - Array axes, double epsilon, bool center, bool scale); + ffi::Array axes, double epsilon, bool center, bool scale); /*! \brief Compute instance normalization. */ -Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, ffi::Array axes, double epsilon, bool center, bool scale); /*! \brief Compute root mean square normalization. */ -Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon); +Expr rms_norm(Expr data, Expr weight, ffi::Array axes, double epsilon); /*! * \brief Applies the dropout operation to the input tensor. @@ -111,7 +111,7 @@ Expr dropout(Expr data, double rate); Expr cross_entropy_with_logits(Expr predictions, Expr labels); /*! \brief Negative log likelihood loss. */ -Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, +Expr nll_loss(Expr predictions, Expr targets, ffi::Optional weights, ffi::String reduction, int ignore_index); } // namespace relax diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 6a12a60a4ee9..584135520000 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -27,20 +27,21 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { Pool1DAttrs::RegisterReflection(); Pool2DAttrs::RegisterReflection(); Pool3DAttrs::RegisterReflection(); AdaptivePool1DAttrs::RegisterReflection(); AdaptivePool2DAttrs::RegisterReflection(); AdaptivePool3DAttrs::RegisterReflection(); -}); +} /* relax.nn.max_pool1d */ -Expr MakePool1d(String op_name, Expr data, Array pool_size, Array strides, - Array padding, Array dilation, bool ceil_mode, - bool count_include_pad, String layout, Optional out_layout) { +Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, ffi::Array dilation, + bool ceil_mode, bool count_include_pad, ffi::String layout, + ffi::Optional out_layout) { padding = GetCompletePadding1D(std::move(padding)); CHECK_EQ(pool_size.size(), 1) @@ -52,7 +53,7 @@ Expr MakePool1d(String op_name, Expr data, Array pool_size, Array(); + auto attrs = ffi::make_object(); attrs->pool_size = ConvertIntImmToInt64(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -65,17 +66,17 @@ Expr MakePool1d(String op_name, Expr data, Array pool_size, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr max_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool1d("relax.nn.max_pool1d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.max_pool1d", max_pool1d); -}); +} StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -88,13 +89,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = attrs->pool_size[0]; @@ -112,13 +113,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { } out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPool1d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPool1d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -127,7 +128,7 @@ InferLayoutOutput InferLayoutPool1d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(3), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(3), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -144,9 +145,10 @@ TVM_REGISTER_OP("relax.nn.max_pool1d") /* relax.nn.max_pool2d */ -Expr MakePool2d(String op_name, Expr data, Array pool_size, Array strides, - Array padding, Array dilation, bool ceil_mode, - bool count_include_pad, String layout, Optional out_layout) { +Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, ffi::Array dilation, + bool ceil_mode, bool count_include_pad, ffi::String layout, + ffi::Optional out_layout) { padding = GetCompletePadding2D(std::move(padding)); if (pool_size.size() == 1) { pool_size.push_back(pool_size[0]); @@ -167,7 +169,7 @@ Expr MakePool2d(String op_name, Expr data, Array pool_size, Array(); + auto attrs = ffi::make_object(); attrs->pool_size = ConvertIntImmToInt64(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -180,17 +182,17 @@ Expr MakePool2d(String op_name, Expr data, Array pool_size, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.max_pool2d", max_pool2d); -}); +} StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -203,13 +205,13 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); PrimExpr input_h = data_NCHW_shape[2]; PrimExpr input_w = data_NCHW_shape[3]; @@ -233,13 +235,13 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPool2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPool2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -248,14 +250,15 @@ InferLayoutOutput InferLayoutPool2d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { tir::Layout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetStructInfo(call->args[0]); TensorStructInfo data_sinfo = data_si.as().value(); - Optional data_shape = GetRef(data_sinfo->shape.as()); + ffi::Optional data_shape = + ffi::GetRef(data_sinfo->shape.as()); if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { // Not handling out_layout being different from in_layout now. Any use case ? new_attrs->layout = desired_layout.name(); @@ -282,9 +285,10 @@ TVM_REGISTER_OP("relax.nn.max_pool2d") /* relax.nn.max_pool3d */ -Expr MakePool3d(String op_name, Expr data, Array pool_size, Array strides, - Array padding, Array dilation, bool ceil_mode, - bool count_include_pad, String layout, Optional out_layout) { +Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, ffi::Array dilation, + bool ceil_mode, bool count_include_pad, ffi::String layout, + ffi::Optional out_layout) { padding = GetCompletePadding3D(std::move(padding)); if (pool_size.size() == 1) { pool_size.push_back(pool_size[0]); @@ -308,7 +312,7 @@ Expr MakePool3d(String op_name, Expr data, Array pool_size, Array(); + auto attrs = ffi::make_object(); attrs->pool_size = ConvertIntImmToInt64(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -321,17 +325,17 @@ Expr MakePool3d(String op_name, Expr data, Array pool_size, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr max_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool3d("relax.nn.max_pool3d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.max_pool3d", max_pool3d); -}); +} StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -344,13 +348,13 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCDHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); } - Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); + ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); PrimExpr input_d = data_NCDHW_shape[2]; PrimExpr input_h = data_NCDHW_shape[3]; @@ -380,13 +384,13 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[1]) + 1); out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[2]) + 1); - Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); + ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPool3d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPool3d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -395,7 +399,7 @@ InferLayoutOutput InferLayoutPool3d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(5), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -411,17 +415,17 @@ TVM_REGISTER_OP("relax.nn.max_pool3d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool1d */ -Expr avg_pool1d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr avg_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool1d("relax.nn.avg_pool1d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.avg_pool1d", avg_pool1d); -}); +} TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_num_inputs(1) @@ -433,17 +437,17 @@ TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool2d */ -Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.avg_pool2d", avg_pool2d); -}); +} TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_num_inputs(1) @@ -455,17 +459,17 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool3d */ -Expr avg_pool3d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr avg_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool3d("relax.nn.avg_pool3d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.avg_pool3d", avg_pool3d); -}); +} TVM_REGISTER_OP("relax.nn.avg_pool3d") .set_num_inputs(1) @@ -478,13 +482,13 @@ TVM_REGISTER_OP("relax.nn.avg_pool3d") /* relax.nn.adaptive_avg_pool1d */ -Expr adaptive_avg_pool1d(Expr data, Optional> output_size, String layout, - Optional out_layout) { - ObjectPtr attrs = make_object(); +Expr adaptive_avg_pool1d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout) { + ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); CHECK_EQ(_output_size.size(), 1) << "The output_size length is expected to be 1. However, the given output_size is " << _output_size; @@ -495,10 +499,10 @@ Expr adaptive_avg_pool1d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool1d", adaptive_avg_pool1d); -}); +} StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -511,7 +515,7 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder /*tgt_layout=*/"NCW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && @@ -522,19 +526,19 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder } } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); - Array out_NCW_shape(data_NCW_shape); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCW_shape(data_NCW_shape); if (attrs->output_size.defined()) { out_NCW_shape.Set(2, attrs->output_size.value()[0]); } - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutAdaptiveAvgPool1D(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutAdaptiveAvgPool1D( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -543,7 +547,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool1D(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(3), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(3), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -560,13 +564,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d") /* relax.nn.adaptive_avg_pool2d */ -Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, - Optional out_layout) { - ObjectPtr attrs = make_object(); +Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout) { + ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } @@ -580,10 +584,10 @@ Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool2d", adaptive_avg_pool2d); -}); +} StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -596,7 +600,7 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder /*tgt_layout=*/"NCHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && @@ -607,20 +611,20 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder } } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array out_NCHW_shape(data_NCHW_shape); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCHW_shape(data_NCHW_shape); if (attrs->output_size.defined()) { out_NCHW_shape.Set(2, attrs->output_size.value()[0]); out_NCHW_shape.Set(3, attrs->output_size.value()[1]); } - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutAdaptiveAvgPool2D( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -629,13 +633,14 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { tir::Layout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetStructInfo(call->args[0]); TensorStructInfo data_sinfo = data_si.as().value(); - Optional data_shape = GetRef(data_sinfo->shape.as()); + ffi::Optional data_shape = + ffi::GetRef(data_sinfo->shape.as()); if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { // Not handling out_layout being different from in_layout now. Any use case ? new_attrs->layout = desired_layout.name(); @@ -661,13 +666,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") /* relax.nn.adaptive_avg_pool3d */ -Expr adaptive_avg_pool3d(Expr data, Optional> output_size, String layout, - Optional out_layout) { - ObjectPtr attrs = make_object(); +Expr adaptive_avg_pool3d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout) { + ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } @@ -681,10 +686,10 @@ Expr adaptive_avg_pool3d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nn.adaptive_avg_pool3d", adaptive_avg_pool3d); -}); +} StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -697,7 +702,7 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder /*tgt_layout=*/"NCDHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && @@ -708,21 +713,21 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder } } - Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); - Array out_NCDHW_shape(data_NCDHW_shape); + ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCDHW_shape(data_NCDHW_shape); if (attrs->output_size.defined()) { out_NCDHW_shape.Set(2, attrs->output_size.value()[0]); out_NCDHW_shape.Set(3, attrs->output_size.value()[1]); out_NCDHW_shape.Set(4, attrs->output_size.value()[2]); } - Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); + ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutAdaptiveAvgPool3D(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutAdaptiveAvgPool3D( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -731,7 +736,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool3D(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(5), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h index 7fd66f2b44c3..c5435303e82b 100644 --- a/src/relax/op/nn/pooling.h +++ b/src/relax/op/nn/pooling.h @@ -33,18 +33,18 @@ namespace tvm { namespace relax { /*! \brief 2D maximum pooling operator. */ -Expr max_pool2d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief 2D average pooling operator. */ -Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief 2D adaptive average pooling operator. */ -Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, - Optional out_layout); +Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout); } // namespace relax } // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 1e476eaf035a..54f9da4c786f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -28,13 +28,13 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { CallTIRWithGradAttrs::RegisterReflection(); CallTIRInplaceAttrs::RegisterReflection(); CallInplacePackedAttrs::RegisterReflection(); ToVDeviceAttrs::RegisterReflection(); HintOnDeviceAttrs::RegisterReflection(); -}); +} bool EqualConstInt(const PrimExpr& lhs, int64_t value) { if (const int64_t* pvalue = tir::as_const_int(lhs)) { @@ -57,7 +57,7 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { } StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { - return TupleStructInfo(Array()); + return TupleStructInfo(ffi::Array()); } StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { @@ -112,26 +112,26 @@ StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.call_pure_packed") .set_num_inputs(-1) - .add_argument("args", "Array", + .add_argument("args", "ffi::Array", "The first argument is the function being called. The rest are the " "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallPurePacked) .set_attr("FPurity", Bool(true)); -Expr MakeCallPurePacked(const Expr& callee, Array args, const Attrs& attrs, - Array sinfo_args) { +Expr MakeCallPurePacked(const Expr& callee, ffi::Array args, const Attrs& attrs, + ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.call_pure_packed"); - Array call_args = {callee}; + ffi::Array call_args = {callee}; for (auto arg : args) { call_args.push_back(arg); } return Call(op, call_args, attrs, sinfo_args); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.call_pure_packed", MakeCallPurePacked); -}); +} // call_inplace_packed @@ -227,7 +227,7 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.call_inplace_packed") .set_num_inputs(-1) .set_attrs_type() - .add_argument("args", "Array", + .add_argument("args", "ffi::Array", "The first argument is the function being called. The rest are the " "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallInplacePacked) @@ -237,21 +237,21 @@ TVM_REGISTER_OP("relax.call_inplace_packed") // side effects other than modifying the arguments specified as "inplace" .set_attr("FPurity", Bool(true)); -Expr MakeCallInplacePacked(Expr func, Array args, Array inplace_indices, - Array sinfo_args) { - ObjectPtr attrs = make_object(); - attrs->inplace_indices = Array(inplace_indices.begin(), inplace_indices.end()); +Expr MakeCallInplacePacked(Expr func, ffi::Array args, ffi::Array inplace_indices, + ffi::Array sinfo_args) { + ObjectPtr attrs = ffi::make_object(); + attrs->inplace_indices = ffi::Array(inplace_indices.begin(), inplace_indices.end()); static const Op& op = Op::Get("relax.call_inplace_packed"); - Array call_args = {func}; + ffi::Array call_args = {func}; call_args.insert(call_args.end(), args.begin(), args.end()); return Call(op, call_args, Attrs(attrs), sinfo_args); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.call_inplace_packed", MakeCallInplacePacked); -}); +} // call_tir @@ -285,9 +285,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \return The `arg_sinfo`, if it can be inferred from the arguments. * Otherwise, std::nullopt. */ -static Optional InferCallTIROutputStructInfoFromArguments( - StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo, - Optional> opt_inplace_indices) { +static ffi::Optional InferCallTIROutputStructInfoFromArguments( + StructInfo func_sinfo, StructInfo arg_sinfo, ffi::Optional packed_ints_sinfo, + ffi::Optional> opt_inplace_indices) { auto opt_callee_sinfo = func_sinfo.as(); CHECK(opt_callee_sinfo) << "TypeError: " << "The first argument to `R.call_tir` must be a function, " @@ -368,16 +368,16 @@ static Optional InferCallTIROutputStructInfoFromArguments( // arguments are used. auto dummy_callee_sinfo = [&]() -> FuncStructInfo { - Array dummy_params(callee_params.begin(), - callee_params.begin() + num_input_arguments); + ffi::Array dummy_params(callee_params.begin(), + callee_params.begin() + num_input_arguments); for (size_t i = callee_params.size() - num_trailing_int_arguments; i < callee_params.size(); i++) { dummy_params.push_back(callee_params[i]); } - Array dummy_ret(callee_params.begin() + num_input_arguments, - callee_params.end() - num_trailing_int_arguments); + ffi::Array dummy_ret(callee_params.begin() + num_input_arguments, + callee_params.end() - num_trailing_int_arguments); if (opt_inplace_indices) { // For R.call_tir_inplace, the `inplace_indices` are used to @@ -405,8 +405,8 @@ static Optional InferCallTIROutputStructInfoFromArguments( return FuncStructInfo(dummy_params, dummy_out_sinfo); }(); - auto dummy_args = [&]() -> Array { - Array dummy_args = args->fields.Map( + auto dummy_args = [&]() -> ffi::Array { + ffi::Array dummy_args = args->fields.Map( [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", sinfo); }); for (size_t i = 0; i < num_trailing_int_arguments; i++) { @@ -488,7 +488,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { << "R.call_tir should have exactly one `sinfo_args` parameter, " << "which defines the output of the PrimFunc."; - auto unwrap_binding = [&ctx](Expr expr) -> Optional { + auto unwrap_binding = [&ctx](Expr expr) -> ffi::Optional { if (auto var = expr.as()) { if (auto bound_value = ctx->LookupBinding(var.value())) { return bound_value.value(); @@ -519,7 +519,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { // and we don't know the value bound to that variable. For // example, if a relax function accepted a tuple as an parameter, // then provided that same tuple as an argument to call_tir. - Array tuple_elements; + ffi::Array tuple_elements; size_t num_fields = Downcast(arg_tuple->struct_info_)->fields.size(); for (size_t i = 0; i < num_fields; i++) { tuple_elements.push_back(TupleGetItem(arg_tuple, i)); @@ -546,7 +546,7 @@ void ValidateCallTIR(Call call) { auto callee = call->args[0]; Expr arg_tuple = call->args[1]; - auto packed_int_sinfo = [&]() -> Optional { + auto packed_int_sinfo = [&]() -> ffi::Optional { if (call->args.size() <= 2) { return std::nullopt; } else { @@ -554,7 +554,7 @@ void ValidateCallTIR(Call call) { } }(); - auto opt_inplace_indices = [&]() -> Optional> { + auto opt_inplace_indices = [&]() -> ffi::Optional> { if (const auto* attrs = call->attrs.as()) { return attrs->inplace_indices; } else { @@ -586,8 +586,8 @@ TVM_REGISTER_OP("relax.call_tir") .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, - Optional packed_ints) { +Expr MakeCallTIR(Expr func, Tuple args, ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " @@ -613,10 +613,10 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, return call; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.call_tir", MakeCallTIR); -}); +} // call_tir_with_grad @@ -633,9 +633,9 @@ TVM_REGISTER_OP("relax.call_tir_with_grad") .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, - String te_grad_name, Map te_grad_kwargs, - Optional packed_ints) { +Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out_sinfo_list, + ffi::String te_grad_name, ffi::Map te_grad_kwargs, + ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) @@ -651,7 +651,7 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); } - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->te_grad_name = te_grad_name; attrs->te_grad_kwargs = te_grad_kwargs; @@ -666,10 +666,10 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf return call; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.call_tir_with_grad", MakeCallTIRWithGrad); -}); +} // call_tir_inplace @@ -679,7 +679,7 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // may result in an error if performed before normalization. call = Downcast(NormalizeCallTIR(ctx, std::move(call))); - Array sinfo_outputs = [&]() -> Array { + ffi::Array sinfo_outputs = [&]() -> ffi::Array { auto out_sinfo = call->sinfo_args[0]; if (auto* tuple_output = out_sinfo.as()) { return tuple_output->fields; @@ -778,8 +778,9 @@ TVM_REGISTER_OP("relax.call_tir_inplace") // arguments will no longer be live) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, - Array out_sinfo_list, Optional packed_ints) { +Expr MakeCallTIRInplace(Expr func, Tuple args, ffi::Array inplace_indices, + ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " @@ -787,8 +788,8 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, << sinfo; } - ObjectPtr attrs = make_object(); - attrs->inplace_indices = Array(inplace_indices.begin(), inplace_indices.end()); + ObjectPtr attrs = ffi::make_object(); + attrs->inplace_indices = ffi::Array(inplace_indices.begin(), inplace_indices.end()); StructInfo out_sinfo{nullptr}; if (out_sinfo_list.size() == 1) { @@ -808,10 +809,10 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, return call; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.call_tir_inplace", MakeCallTIRInplace); -}); +} // call_dps_packed @@ -832,7 +833,7 @@ TVM_REGISTER_OP("relax.call_dps_packed") // little reason to use DPS with an impure op .set_attr("FPurity", Bool(true)); -Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_list) { +Expr MakeCallDPSPacked(Expr func, Tuple args, ffi::Array out_sinfo_list) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) @@ -852,16 +853,80 @@ Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_ return Call(op, {func, args}, {}, {out_sinfo}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.call_dps_packed", MakeCallDPSPacked); -}); +} + +// call_py_func + +StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +void ValidateCallPyFunc(Call call) { + // Validate that the function name is a string literal + auto func_name = call->args[0]; + CHECK(func_name->IsInstance()) + << "Operation " << call->op << " expects the first argument to be a string literal " + << "specifying the Python function name. However, the first argument " << func_name + << " is not a string literal."; + + // Validate that args is a tuple + Expr arg_tuple = call->args[1]; + CHECK(arg_tuple->struct_info_.as()) + << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " + << "However, the second argument " << arg_tuple << " has struct info " + << arg_tuple->struct_info_ << "."; + + CHECK(arg_tuple.as() || arg_tuple.as()) + << "Operation " << call->op << " must hold its arguments as an in-line tuple. " + << "However, " << call << " has arguments " << arg_tuple + << ", which is neither an in-line tuple, " + << "nor a variable binding that may be normalized to an in-line tuple."; +} + +TVM_REGISTER_OP("relax.call_py_func") + .set_num_inputs(2) + .add_argument("func_name", "StringImm", "The name of the Python function to call.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallPyFunc) + .set_attr("FValidate", ValidateCallPyFunc) + .set_attr("FPurity", Bool(true)); + +Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_sinfo_list) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_py_func"); + return Call(op, {func_name, args}, {}, {out_sinfo}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc); +} // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { // by default return void. - return TupleStructInfo(Array()); + return TupleStructInfo(ffi::Array()); } else { ICHECK_EQ(call->sinfo_args.size(), 1); return call->sinfo_args[0]; @@ -876,15 +941,15 @@ TVM_REGISTER_OP("relax.call_builtin_with_ctx") // Most builtins are pure, but some are not, like `vm.builtin.attention_kv_cache_append` .set_attr("FPurity", Bool(false)); -Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) { +Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.call_builtin_with_ctx"); return Call(op, {func, args}, Attrs(), sinfo_args); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.call_builtin_with_ctx", MakeCallBuiltinWithCtx); -}); +} TVM_REGISTER_OP("relax.null_value") .set_num_inputs(0) @@ -896,24 +961,24 @@ Expr MakeCallNullValue() { return Call(op, {}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.null_value", MakeCallNullValue); -}); +} // print TVM_REGISTER_OP("relax.print") .set_num_inputs(-1) - .add_argument("vals", "Array", + .add_argument("vals", "ffi::Array", "The first value is Python-style format string to use to print. The others " "are values to print") .set_attr("FInferStructInfo", ReturnVoidStructInfo) .set_attr("FCallPacked", "relax.run.print") .set_attr("FPurity", Bool(false)); -Expr MakePrint(Array vals, StringImm format) { - Array params; +Expr MakePrint(ffi::Array vals, StringImm format) { + ffi::Array params; params.push_back(format); for (const auto val : vals) { params.push_back(val); @@ -922,10 +987,10 @@ Expr MakePrint(Array vals, StringImm format) { return Call(op, params); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.print", MakePrint); -}); +} // assert_op @@ -950,7 +1015,7 @@ StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.assert_op") .set_num_inputs(-1) - .add_argument("vals", "Array", + .add_argument("vals", "ffi::Array", "The first value is used as the assertion condition. The second value is " "Python-style format string to use for displaying an error message, if the " "assert fails. The others are used as format arguments if there is an error.") @@ -958,9 +1023,9 @@ TVM_REGISTER_OP("relax.assert_op") .set_attr("FCallPacked", "relax.run.assert_op") .set_attr("FPurity", Bool(false)); -Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { +Expr MakeAssertOp(Expr condition, ffi::Array vals, StringImm format) { static const Op& op = Op::Get("relax.assert_op"); - Array args = {condition}; + ffi::Array args = {condition}; args.push_back(format); for (auto val : vals) { args.push_back(val); @@ -968,10 +1033,10 @@ Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { return Call(op, args); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.assert_op", MakeAssertOp); -}); +} // make_closure @@ -987,10 +1052,10 @@ Expr MakeClosure(Expr func, Tuple args) { return Call(op, {func, args}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.make_closure", MakeClosure); -}); +} // invoke_closure @@ -1012,15 +1077,15 @@ TVM_REGISTER_OP("relax.invoke_closure") // Not all closures are pure. Use invoke_pure_closure for specifying purity .set_attr("FPurity", Bool(false)); -Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { +Expr InvokeClosure(Expr closure, Tuple args, ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.invoke_closure"); return Call(op, {closure, args}, {}, sinfo_args); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.invoke_closure", InvokeClosure); -}); +} // invoke_pure_closure @@ -1031,15 +1096,15 @@ TVM_REGISTER_OP("relax.invoke_pure_closure") .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) .set_attr("FPurity", Bool(true)); -Expr InvokePureClosure(Expr closure, Tuple args, Array sinfo_args) { +Expr InvokePureClosure(Expr closure, Tuple args, ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.invoke_pure_closure"); return Call(op, {closure, args}, {}, sinfo_args); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.invoke_pure_closure", InvokePureClosure); -}); +} // shape_of @@ -1054,10 +1119,10 @@ Expr MakeShapeOf(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.shape_of", MakeShapeOf); -}); +} // tensor_to_shape @@ -1091,10 +1156,10 @@ Expr MakeTensorToShape(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.tensor_to_shape", MakeTensorToShape); -}); +} // shape_to_tensor StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& ctx) { @@ -1118,10 +1183,10 @@ Expr MakeShapeToTensor(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.shape_to_tensor", MakeShapeToTensor); -}); +} // alloc_tensor @@ -1132,7 +1197,7 @@ StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& c << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); DataType out_dtype; if (const auto* dtype_node = call->args[1].as()) { - const DataTypeImm dtype_imm = GetRef(dtype_node); + const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); out_dtype = dtype_imm->value; } return TensorStructInfo(call->args[0], out_dtype); @@ -1158,10 +1223,10 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind return Call(op, {shape, dtype, runtime_device_index, storage_scope}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.builtin.alloc_tensor", MakeAllocTensor); -}); +} // memory planning alloc_storage @@ -1186,10 +1251,10 @@ Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm stora return Call(op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.memory.alloc_storage", MakeAllocStorage); -}); +} // memory planning alloc_tensor @@ -1198,7 +1263,7 @@ StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& c << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); DataType out_dtype; if (const auto* dtype_node = call->args[3].as()) { - const DataTypeImm dtype_imm = GetRef(dtype_node); + const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); out_dtype = dtype_imm->value; } return TensorStructInfo(call->args[2], out_dtype); @@ -1220,10 +1285,10 @@ Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.memory.alloc_tensor", MakeMemAllocTensor); -}); +} // memory planning kill_storage @@ -1239,10 +1304,10 @@ Expr MakeMemKillStorage(Expr storage) { return Call(op, {storage}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.memory.kill_storage", MakeMemKillStorage); -}); +} // memory planning kill_tensor @@ -1258,10 +1323,10 @@ Expr MakeMemKillTensor(Expr tensor) { return Call(op, {tensor}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.memory.kill_tensor", MakeMemKillTensor); -}); +} // vm alloc_storage @@ -1285,21 +1350,21 @@ Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm d return Call(op, {size, runtime_device_index, dtype, storage_scope}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.vm.alloc_storage", MakeVMAllocStorage); -}); +} // vm alloc_tensor StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) { DataType out_dtype; if (const auto* dtype_node = call->args[3].as()) { - const DataTypeImm dtype_imm = GetRef(dtype_node); + const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); out_dtype = dtype_imm->value; } if (const auto* output_shape = call->args[2].as()) { - return TensorStructInfo(GetRef(output_shape), out_dtype); + return TensorStructInfo(ffi::GetRef(output_shape), out_dtype); } else if (const auto* shape_sinfo = GetStructInfoAs(call->args[2])) { if (shape_sinfo->values.defined()) { return TensorStructInfo(ShapeExpr(shape_sinfo->values.value()), out_dtype); @@ -1326,10 +1391,10 @@ Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm d return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.vm.alloc_tensor", MakeVMAllocTensor); -}); +} // vm kill_object @@ -1345,10 +1410,10 @@ Expr MakeVMKillObject(Expr obj) { return Call(op, {std::move(obj)}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.vm.kill_object", MakeVMKillObject); -}); +} // vm call_tir_dyn @@ -1366,10 +1431,10 @@ Expr MakeCallTIRDyn(Expr func, Tuple args) { return Call(op, {func, args}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.vm.call_tir_dyn", MakeCallTIRDyn); -}); +} // builtin stop_lift_params StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { @@ -1387,10 +1452,10 @@ Expr MakeStopLiftParams(Expr x) { return Call(op, {x}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.builtin.stop_lift_params", MakeStopLiftParams); -}); +} // to_vdevice @@ -1415,15 +1480,15 @@ TVM_REGISTER_OP("relax.to_vdevice") Expr MakeToVDevice(Expr data, VDevice dst_vdev) { static const Op& op = Op::Get("relax.to_vdevice"); - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dst_vdevice = dst_vdev; return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.to_vdevice", MakeToVDevice); -}); +} // hint_on_device @@ -1441,18 +1506,26 @@ TVM_REGISTER_OP("relax.hint_on_device") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", Bool(true)); -Expr MakeHintOnDevice(Expr data, Device device) { +Expr MakeHintOnDevice(Expr data, Device device, ffi::String memory_scope = "global") { static const Op& op = Op::Get("relax.hint_on_device"); - ObjectPtr attrs = make_object(); - attrs->dev_type = static_cast(device.device_type); - attrs->dev_id = device.device_id; + ObjectPtr attrs = ffi::make_object(); + attrs->device_type = static_cast(device.device_type); + attrs->index = device.device_id; + attrs->memory_scope = memory_scope; return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.op.hint_on_device", MakeHintOnDevice); -}); + refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 3) { + *ret = MakeHintOnDevice(args[0].cast(), args[1].cast(), + args[2].cast()); + } else { + *ret = MakeHintOnDevice(args[0].cast(), args[1].cast()); + } + }); +} } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index f439a345eb19..5b9ed1e5f529 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -24,9 +24,9 @@ namespace tvm { namespace relax { -Array GetCallArgs(const Call& call) { +ffi::Array GetCallArgs(const Call& call) { static const Op& call_tir_op = Op::Get("relax.call_tir"); - Array args; + ffi::Array args; if (call->op.same_as(call_tir_op)) { args = Downcast(call->args[1])->fields; } else { @@ -70,19 +70,19 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const } } -Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { +ffi::Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); Op op = Downcast(call->op); - Array input_tensor_sinfo; + ffi::Array input_tensor_sinfo; for (size_t i = 0; i < call->args.size(); ++i) { input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx)); } return input_tensor_sinfo; } -Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, - const Expr& tup) { +ffi::Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& tup) { const auto* tuple_sinfo = GetStructInfoAs(tup); if (tuple_sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) @@ -91,7 +91,7 @@ Array GetTensorStructInfoFromTuple(const Call& call, const Blo << tup->struct_info_->GetTypeKey()); } - Array tensor_sinfo; + ffi::Array tensor_sinfo; tensor_sinfo.reserve(tuple_sinfo->fields.size()); for (StructInfo field_sinfo : tuple_sinfo->fields) { const auto* field_tensor_sinfo = field_sinfo.as(); @@ -101,14 +101,14 @@ Array GetTensorStructInfoFromTuple(const Call& call, const Blo << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " << tup->struct_info_); } - tensor_sinfo.push_back(GetRef(field_tensor_sinfo)); + tensor_sinfo.push_back(ffi::GetRef(field_tensor_sinfo)); } return tensor_sinfo; } -Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, - const Array& x1_shape, - const Array& x2_shape) { +ffi::Optional> InferBinaryBroadcastShape( + const Call& call, const BlockBuilder& ctx, const ffi::Array& x1_shape, + const ffi::Array& x2_shape) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); int x1_ndim = x1_shape.size(); int x2_ndim = x2_shape.size(); @@ -143,11 +143,11 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc for (; i <= max_ndim; ++i) { output_shape.push_back(longer_shape[max_ndim - i]); } - return Array(output_shape.rbegin(), output_shape.rend()); + return ffi::Array(output_shape.rbegin(), output_shape.rend()); } std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, - const Array& axes) { + const ffi::Array& axes) { ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function."; std::vector appeared_dims_set; std::vector axes_non_neg; @@ -177,21 +177,21 @@ std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int nd return axes_non_neg; } -InferLayoutOutput InferLayoutUnaryEwise(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutUnaryEwise( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs)); } bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, - Array shape) { + ffi::Array shape) { bool can_prove = true; try { tir::BijectiveLayout todesired(input_layout, desired_layout); - Array desired_shape = todesired.ForwardShape(shape); - Array back_shape = todesired.BackwardShape(desired_shape); + ffi::Array desired_shape = todesired.ForwardShape(shape); + ffi::Array back_shape = todesired.BackwardShape(desired_shape); arith::Analyzer analyzer; for (size_t i = 0; i < shape.size(); ++i) { if (tir::is_const_int(shape[i])) { diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 4da8b18fcb13..5a556cbd7413 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -71,7 +71,7 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const * \note This function require every input to be Tensor. The number of call arguments is required * to match the number of inputs of the op being called. */ -Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); +ffi::Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); /*! * \brief Get the tensor struct info of the unary operator input. @@ -93,8 +93,8 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const Bl * \return The tensor struct infos of tuple input. * \throw Throw exception if input expression is not a tuple. */ -Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, - const Expr& tup); +ffi::Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& tup); namespace detail { /*! \brief Implementation helper for GetArgStructInfo */ @@ -176,13 +176,14 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c * be prepended with a prefix "relax.op." as the FFI identifier string for the make function, * \param OpRegName The identifier of the operator in the registry. */ -#define RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ - Expr OpName(Expr x) { \ - static const Op& op = Op::Get("relax." OpRegName); \ - return Call(op, {std::move(x)}, Attrs(), {}); \ - } \ - TVM_FFI_STATIC_INIT_BLOCK( \ - { tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName); }) +#define RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ + Expr OpName(Expr x) { \ + static const Op& op = Op::Get("relax." OpRegName); \ + return Call(op, {std::move(x)}, Attrs(), {}); \ + } \ + TVM_FFI_STATIC_INIT_BLOCK() { \ + tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName); \ + } /************ Utilities ************/ @@ -208,7 +209,7 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx << " requires the input tensor to have float dtype. However, the given input dtype is " << input_sinfo->dtype); } - auto output_sinfo = make_object(*input_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = f_compute_out_dtype(input_sinfo); return TensorStructInfo(output_sinfo); } @@ -257,9 +258,9 @@ StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) * \param var_layout_map The layout of vars. * \return The inferred layout result. */ -InferLayoutOutput InferLayoutUnaryEwise(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map); +InferLayoutOutput InferLayoutUnaryEwise( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map); /*! * \brief Get the element dtype from StructInfo @@ -318,7 +319,7 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& if (lhs_dtype.is_void() || rhs_dtype.is_void()) { return DataType::Void(); - } else if (lhs_dtype != rhs_dtype) { + } else if (lhs_dtype != rhs_dtype && !lhs_dtype.is_bool() && !rhs_dtype.is_bool()) { ctx->ReportFatal(Diagnostic::Error(call) << "TypeError: " << "Binary operators must have the same datatype for both operands. " @@ -338,10 +339,11 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& * \return The inferred output vdevice. * \throw Throw exception if the vdevice of two input TensorStructInfo don’t match */ -inline Optional InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx, - const StructInfo& lhs_sinfo, - const StructInfo& rhs_sinfo) { - auto get_vdevice = [&](const StructInfo& sinfo) -> Optional { +inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, + const BlockBuilder& ctx, + const StructInfo& lhs_sinfo, + const StructInfo& rhs_sinfo) { + auto get_vdevice = [&](const StructInfo& sinfo) -> ffi::Optional { if (const auto* tensor = sinfo.as()) { return tensor->vdevice; } else { @@ -349,6 +351,15 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl } }; + /* + * This is the case where the output VDevice defined by a customization pass. + * Like targets that supports mixed VDevices (like differed by memory_scope for Adreno) + * and have specialized derivation for output VDevice. + */ + if (call->sinfo_args.size() > 0) { + return get_vdevice(call->sinfo_args[0]); + } + auto lhs_vdevice = get_vdevice(lhs_sinfo); auto rhs_vdevice = get_vdevice(rhs_sinfo); @@ -358,6 +369,7 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) { return lhs_vdevice; } + if (lhs_vdevice.value() != rhs_vdevice.value()) { ctx->ReportFatal(Diagnostic::Error(call) << "TypeErorr: " @@ -378,9 +390,10 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl * \return The inferred output shape after broadcasting. Or `std::nullopt` if the output shape * cannot be determined due to symbolic broadcast. */ -Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, - const Array& x1_shape, - const Array& x2_shape); +ffi::Optional> InferBinaryBroadcastShape(const Call& call, + const BlockBuilder& ctx, + const ffi::Array& x1_shape, + const ffi::Array& x2_shape); /*! * \brief Convert all axes to non-negative indices, and meanwhile check if the given array of axes @@ -393,7 +406,7 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc * \throw Throw exception if there exists out-of-range axis index or repetitive indices. */ std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, - const Array& axes); + const ffi::Array& axes); /*! * \brief Convert the given axis to non-negative index. Meanwhile check if the axis is in range @@ -414,7 +427,7 @@ inline int NormalizeAxis(const Call& call, const BlockBuilder& ctx, int ndim, in * \param shape_values The given shape values. * \return The product of all the given shape values. */ -PrimExpr ComputeShapeProduct(const Array& shape_values); +PrimExpr ComputeShapeProduct(const ffi::Array& shape_values); /*! * \brief Check if the given permutation is identity permutation. @@ -428,7 +441,7 @@ bool IsIdentityPermutation(const std::vector& permutation); * \param int_imms The input IntImms to be converted. * \return The conversion result, where every IntImm has dtype int64 */ -inline Array ConvertIntImmToInt64(const Array& int_imms) { +inline ffi::Array ConvertIntImmToInt64(const ffi::Array& int_imms) { return int_imms.Map([](const IntImm& i) { return Downcast(cast(DataType::Int(64), i)); }); } @@ -442,7 +455,7 @@ inline Array ConvertIntImmToInt64(const Array& int_imms) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1 or 2. */ -inline Array GetCompletePadding1D(Array padding) { +inline ffi::Array GetCompletePadding1D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0]}; } else if (padding.size() == 2) { @@ -463,7 +476,7 @@ inline Array GetCompletePadding1D(Array padding) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1, 2 or 4. */ -inline Array GetCompletePadding2D(Array padding) { +inline ffi::Array GetCompletePadding2D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0], padding[0], padding[0]}; } else if (padding.size() == 2) { @@ -488,7 +501,7 @@ inline Array GetCompletePadding2D(Array padding) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1, 3 or 6. */ -inline Array GetCompletePadding3D(Array padding) { +inline ffi::Array GetCompletePadding3D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0], padding[0], padding[0], padding[0], padding[0]}; } else if (padding.size() == 3) { @@ -514,11 +527,9 @@ inline Array GetCompletePadding3D(Array padding) { * \return The tensor layout and the bijective conversion in tir::Layout and tir::BijectiveLayout * accordingly. */ -inline std::pair CheckTensorLayout(const Call& call, - const BlockBuilder& ctx, - const String& tensor_layout, - const String& tgt_layout, - const String& tensor_name) { +inline std::pair CheckTensorLayout( + const Call& call, const BlockBuilder& ctx, const ffi::String& tensor_layout, + const ffi::String& tgt_layout, const ffi::String& tensor_name) { tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); if (!tensor2tgt.defined()) { @@ -539,9 +550,10 @@ inline std::pair CheckTensorLayout(const Call * \param layout The layout that the given tensor is expected to have. * \return The shape of the input tensor in ShapeExpr, or `std::nullopt` if the shape is unknown. */ -inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, - const TensorStructInfo& sinfo, - const tir::Layout& layout) { +inline ffi::Optional CheckNdimPerLayoutAndGetShape(const Call& call, + const BlockBuilder& ctx, + const TensorStructInfo& sinfo, + const tir::Layout& layout) { if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", layout " << layout << " requires the input to be " @@ -549,7 +561,7 @@ inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const << sinfo->ndim); } if (const auto* shape_expr = sinfo->shape.as()) { - return GetRef(shape_expr); + return ffi::GetRef(shape_expr); } return std::nullopt; } @@ -568,7 +580,7 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind * \param call The call node * \return The arguments of the call */ -Array GetCallArgs(const Call& call); +ffi::Array GetCallArgs(const Call& call); /** * \brief Checks the given shape can be proved from the source layout to dst layout @@ -578,7 +590,7 @@ Array GetCallArgs(const Call& call); * \return true or false depending on the compatibility */ bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, - Array shape); + ffi::Array shape); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 74ae8e9cbc5c..7051d2b1b975 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -61,7 +61,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, } // VDevice - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, lhs_sinfo, rhs_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, lhs_sinfo, rhs_sinfo); auto get_ndim = [&](const StructInfo& sinfo) -> int { if (sinfo.as()) { @@ -86,9 +86,9 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, // Shapes - auto get_shape = [](const StructInfo& sinfo) -> Optional> { + auto get_shape = [](const StructInfo& sinfo) -> ffi::Optional> { if (sinfo.as()) { - return Array{IntImm(DataType::Int(64), 1)}; + return ffi::Array{IntImm(DataType::Int(64), 1)}; } else if (const auto* tensor = sinfo.as()) { return tensor->GetShape(); } else { @@ -101,7 +101,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, auto lhs_shape = get_shape(lhs_sinfo); auto rhs_shape = get_shape(rhs_sinfo); if (lhs_shape && rhs_shape) { - Optional> output_shape = + ffi::Optional> output_shape = InferBinaryBroadcastShape(call, ctx, lhs_shape.value(), rhs_shape.value()); if (output_shape.defined()) { ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); @@ -109,7 +109,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, } } - auto get_shape_expr = [](const StructInfo& sinfo) -> Optional { + auto get_shape_expr = [](const StructInfo& sinfo) -> ffi::Optional { if (const auto* tensor = sinfo.as()) { return tensor->shape; } else { @@ -142,9 +142,9 @@ StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx const StructInfo& rhs_sinfo) { return DataType::Bool(); }); } -InferLayoutOutput InferLayoutBinaryEwise(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutBinaryEwise( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[1]); @@ -155,17 +155,22 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) << "Unknown dim tensors should not be handled by this function"; - Optional shape1 = GetRef(x1_sinfo->shape.as()); - Optional shape2 = GetRef(x2_sinfo->shape.as()); + ffi::Optional shape1 = ffi::GetRef(x1_sinfo->shape.as()); + ffi::Optional shape2 = ffi::GetRef(x2_sinfo->shape.as()); // Lets handle sub indexing as long as primal dims are matching - if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { - if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { - if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) { - return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); - } - } else if (shape1.defined()) { - if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) { - return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + if ((layout1->layout.ndim() != layout1->layout.ndim_primal()) || + (layout2->layout.ndim() != layout2->layout.ndim_primal())) { + if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { + if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape2.value()->values.size()), layout1->layout, + shape2.value()->values)) { + return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); + } + } else if (shape1.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape1.value()->values.size()), layout2->layout, + shape1.value()->values)) { + return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + } } } } diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index f612ec0598a9..b5650fad2735 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -42,8 +42,9 @@ namespace relax { static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {x1, x2}, Attrs(), {}); \ } \ - TVM_FFI_STATIC_INIT_BLOCK( \ - { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \ + TVM_FFI_STATIC_INIT_BLOCK() { \ + tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); \ + } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(2) \ .add_argument("x1", "Tensor", "The first input tensor.") \ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index a3bec83f749d..a9a0872d683a 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -35,36 +35,37 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { InitAttrs::RegisterReflection(); TriluAttrs::RegisterReflection(); -}); +} /* Initialization operators */ /* relax.full */ -Expr full(Variant> shape, Expr fill_value, Optional dtype) { +Expr full(ffi::Variant> shape, Expr fill_value, + ffi::Optional dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { - shape_in_expr = GetRef(expr); + shape_in_expr = ffi::GetRef(expr); } else if (const auto* _array = shape.as()) { - shape_in_expr = ShapeExpr(GetRef>(_array)); + shape_in_expr = ShapeExpr(ffi::GetRef>(_array)); } else { LOG(FATAL) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. "; } - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.full"); return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.full", full); -}); +} StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -99,20 +100,20 @@ TVM_REGISTER_OP("relax.full") .set_attr("FPurity", Bool(true)); /* relax.full_like */ -Expr full_like(Expr x, Expr fill_value, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.full_like"); return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.full_like", full_like); -}); +} StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo fill_value_sinfo = input_sinfo[1]; if (fill_value_sinfo->ndim != 0) { @@ -125,7 +126,7 @@ StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { if (attrs->dtype.is_void()) { return data_sinfo; } else { - auto output_sinfo = make_object(*data_sinfo.get()); + auto output_sinfo = ffi::make_object(*data_sinfo.get()); output_sinfo->dtype = attrs->dtype; return TensorStructInfo(output_sinfo); } @@ -164,7 +165,7 @@ StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder if (attrs->dtype.is_void()) { return data_sinfo; } else { - auto output_sinfo = make_object(*data_sinfo.get()); + auto output_sinfo = ffi::make_object(*data_sinfo.get()); output_sinfo->dtype = attrs->dtype; return TensorStructInfo(output_sinfo); } @@ -173,24 +174,24 @@ StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder /* relax.ones & relax.ones_like */ Expr ones(Expr shape, DataType dtype) { CHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.ones"); return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr ones_like(Expr x, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr ones_like(Expr x, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.ones_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.ones", ones).def("relax.op.ones_like", ones_like); -}); +} TVM_REGISTER_OP("relax.ones") .set_attrs_type() @@ -210,24 +211,24 @@ TVM_REGISTER_OP("relax.ones_like") /* relax.zeros & relax.zeros_like */ Expr zeros(Expr shape, DataType dtype) { CHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.zeros"); return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr zeros_like(Expr x, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr zeros_like(Expr x, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.zeros_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.zeros", zeros).def("relax.op.zeros_like", zeros_like); -}); +} TVM_REGISTER_OP("relax.zeros") .set_attrs_type() @@ -246,23 +247,23 @@ TVM_REGISTER_OP("relax.zeros_like") /* relax.eye & relax.eye_like */ Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.eye"); return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); } -Expr eye_like(Expr x, PrimValue k, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.eye_like"); return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.eye", eye).def("relax.op.eye_like", eye_like); -}); +} StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -332,16 +333,16 @@ TVM_REGISTER_OP("relax.eye_like") /* relax.arange */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.arange"); return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.arange", arange); -}); +} StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -388,17 +389,17 @@ TVM_REGISTER_OP("relax.arange") /* relax.hamming_window */ Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.hamming_window"); return Call(op, {std::move(window_size), std::move(periodic), std::move(alpha), std::move(beta)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.hamming_window", hamming_window); -}); +} StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ctx) { DataType dtype = call->attrs.as()->dtype; @@ -455,12 +456,12 @@ Expr triu(Expr x, Expr k) { Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.op.tril", static_cast(tril)) .def("relax.op.triu", static_cast(triu)); -}); +} StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { auto [data_sinfo, offset] = GetArgStructInfo(call, ctx); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index f252eebf824f..284448111739 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -41,7 +41,8 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(Variant> shape, Expr fill_value, Optional dtype); +Expr full(ffi::Variant> shape, Expr fill_value, + ffi::Optional dtype); /*! * \brief Construct a tensor such that @@ -54,7 +55,7 @@ Expr full(Variant> shape, Expr fill_value, Optional dtype); +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); /*! * \brief Construct a tensor of all ones, with the input shape and dtype. @@ -72,7 +73,7 @@ Expr ones(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr ones_like(Expr x, Optional dtype); +Expr ones_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a tensor of all zeros, with the input shape and dtype. @@ -90,7 +91,7 @@ Expr zeros(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr zeros_like(Expr x, Optional dtype); +Expr zeros_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. @@ -114,7 +115,7 @@ Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr eye_like(Expr x, PrimValue k, Optional dtype); +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype); /*! \brief Construct a tensor with evenly spaced elements. */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 89e7474c1335..f12be685bdbc 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -31,30 +31,30 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { AstypeAttrs::RegisterReflection(); WrapParamAttrs::RegisterReflection(); -}); +} /* relax.astype */ Expr astype(Expr x, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.astype"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.astype", astype); -}); +} StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - ObjectPtr new_sinfo = make_object(*sinfo.get()); + ObjectPtr new_sinfo = ffi::make_object(*sinfo.get()); new_sinfo->dtype = attrs->dtype; return TensorStructInfo(new_sinfo); } @@ -71,22 +71,22 @@ TVM_REGISTER_OP("relax.astype") /* relax.wrap_param */ Expr MakeWrapParam(Expr data, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.wrap_param"); return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.wrap_param", MakeWrapParam); -}); +} StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - ObjectPtr new_sinfo = make_object(*sinfo.get()); + ObjectPtr new_sinfo = ffi::make_object(*sinfo.get()); new_sinfo->dtype = attrs->dtype; return TensorStructInfo(new_sinfo); } diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index 6b0ca941f00c..52a218b730d0 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -37,10 +37,10 @@ Expr no_grad(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.grad.no_grad", no_grad); -}); +} StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -58,10 +58,10 @@ Expr start_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.grad.start_checkpoint", start_checkpoint); -}); +} StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -83,10 +83,10 @@ Expr end_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.grad.end_checkpoint", end_checkpoint); -}); +} StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -103,9 +103,9 @@ TVM_REGISTER_OP("relax.grad.end_checkpoint") .set_attr("FPurity", Bool(true)); /* relax.grad.nll_loss_backward */ -Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optional weights, - String reduction, int ignore_index) { - ObjectPtr attrs = make_object(); +Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, + ffi::Optional weights, ffi::String reduction, int ignore_index) { + ObjectPtr attrs = ffi::make_object(); attrs->reduction = reduction; attrs->ignore_index = ignore_index; @@ -121,10 +121,10 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optiona } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.grad.nll_loss_backward", nll_loss_backward); -}); +} StructInfo InferStructInfoNLLLossBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -136,16 +136,16 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") - .add_argument("weights", "Optional", "The weight of each target values.") + .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLossBackward) .set_attr("FPurity", Bool(true)); /* relax.grad.max_pool2d_backward */ -Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { - auto attrs = make_object(); +Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout) { + auto attrs = ffi::make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -158,10 +158,10 @@ Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.grad.max_pool2d_backward", max_pool2d_backward); -}); +} StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -176,11 +176,11 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward") .set_attr("FPurity", Bool(true)); /* relax.grad.avg_pool2d_backward */ -Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { - auto attrs = make_object(); +Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout) { + auto attrs = ffi::make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -193,10 +193,10 @@ Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.grad.avg_pool2d_backward", avg_pool2d_backward); -}); +} StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -212,18 +212,18 @@ TVM_REGISTER_OP("relax.grad.avg_pool2d_backward") /* relax.grad.take_backward */ -Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional axis) { - ObjectPtr attrs = make_object(); +Expr take_backward(Expr output_grad, Expr x, Expr indices, ffi::Optional axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.grad.take_backward"); return Call(op, {std::move(output_grad), std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.grad.take_backward", take_backward); -}); +} StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h index b0a58f7e5c49..406d7a2f779e 100644 --- a/src/relax/op/tensor/grad.h +++ b/src/relax/op/tensor/grad.h @@ -41,26 +41,26 @@ Expr no_grad(Expr input); /*! \brief Backward operator of relax.nll_loss. All parameters except output_grad is the same as * relax.nll_loss. Returns the gradient w.r.t. predictions. */ -Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optional weights, - String reduction, int ignore_index); +Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, + ffi::Optional weights, ffi::String reduction, int ignore_index); /*! \brief Backward operator of relax.max_pool2d. All parameters except output_grad is the same as * relax.max_pool2d. Returns the gradient w.r.t. data. */ -Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout); /*! \brief Backward operator of relax.avg_pool2d. All parameters except output_grad is the same as * relax.avg_pool2d. Returns the gradient w.r.t. data. */ -Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout); /*! \brief Backward operator of relax.take. All parameters except output_grad is the same as * relax.take. Returns the gradient w.r.t. data. */ -Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional axis); +Expr take_backward(Expr output_grad, Expr x, Expr indices, ffi::Optional axis); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index dea79b804bb4..29bf767f9542 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -37,15 +37,15 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { TakeAttrs::RegisterReflection(); StridedSliceAttrs::RegisterReflection(); -}); +} /* relax.take */ -Expr take(Expr x, Expr indices, Optional axis, String mode) { - ObjectPtr attrs = make_object(); +Expr take(Expr x, Expr indices, ffi::Optional axis, ffi::String mode) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->mode = std::move(mode); @@ -53,10 +53,10 @@ Expr take(Expr x, Expr indices, Optional axis, String mode) { return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.take", take); -}); +} StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); @@ -70,7 +70,7 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { if (auto tensor_sinfo = sinfo.as()) { return tensor_sinfo.value(); } else if (auto prim_sinfo = sinfo.as()) { - return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); + return TensorStructInfo(ShapeExpr(ffi::Array{}), prim_sinfo->dtype); } else { ctx->ReportFatal(Diagnostic::Error(call) << "Operator " << call->op << " requires the indices argument to be " @@ -115,7 +115,7 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { data_sinfo->vdevice); } - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < data_sinfo->ndim; i++) { if (i == axis) { for (int j = 0; j < indices_sinfo->ndim; j++) @@ -137,7 +137,7 @@ TVM_REGISTER_OP("relax.take") /* relax.strided_slice */ -Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strides, +Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, ffi::Optional strides, bool assume_inbound) { // Initial validation of the arguments. A more complete validation // will be done when inferring the StructInfo, but that requires the @@ -165,10 +165,10 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strid check_tuple("end", end); if (strides.defined()) check_tuple("strides", strides.value()); - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->assume_inbound = assume_inbound; - Array args = {x, axes, begin, end}; + ffi::Array args = {x, axes, begin, end}; if (strides.defined()) { args.push_back(strides.value()); } @@ -179,10 +179,10 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strid return call; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.strided_slice", strided_slice); -}); +} /* \brief Helper function to unpack a relax::Tuple * @@ -198,7 +198,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * a tuple from a `TensorStructInfo`.) * * \tparam PrimType The subtype of PrimExpr to extract. For example, - * extracting an `Array` + * extracting an `ffi::Array` * * \param sinfo The StructInfo to inspect * @@ -207,12 +207,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ */ template >> -Optional> UnpackTupleOfPrimValue(Optional sinfo) { +ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional sinfo) { if (!sinfo) return std::nullopt; // An ObjectStructInfo may contain a tuple of the desired type, but // it isn't yet known whether it does. Return early, as we cannot - // provide a known `Array` to the caller. + // provide a known `ffi::Array` to the caller. if (sinfo.as()) return std::nullopt; auto tuple = sinfo.as(); @@ -220,7 +220,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { << "The struct info " << sinfo << " cannot contain a tuple whose elements are " << PrimType::ContainerType::_type_key; - Array output; + ffi::Array output; for (size_t i = 0; i < tuple->fields.size(); i++) { auto field = tuple->fields[i]; @@ -235,7 +235,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { if (!prim_sinfo->value.defined()) return std::nullopt; - Optional element = prim_sinfo->value.as(); + ffi::Optional element = prim_sinfo->value.as(); if (!element) return std::nullopt; output.push_back(element.value()); @@ -257,7 +257,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { * a tuple from a `TensorStructInfo`.) * * \tparam PrimType The subtype of PrimExpr to extract. For example, - * extracting an `Array` + * extracting an `ffi::Array` * * \param expr The `relax::Expr` to inspect * @@ -266,7 +266,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { */ template >> -Optional> UnpackTupleOfPrimValue(Optional expr) { +ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional expr) { if (expr) { return UnpackTupleOfPrimValue(GetStructInfo(expr.value())); } else { @@ -285,7 +285,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx Expr axes = call->args[1]; Expr begin = call->args[2]; Expr end = call->args[3]; - Optional strides = [&]() -> Optional { + ffi::Optional strides = [&]() -> ffi::Optional { if (n_args > 4) { return call->args[4]; } else { @@ -296,7 +296,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx auto axes_sinfo = GetStructInfo(call->args[1]); auto begin_sinfo = GetStructInfo(call->args[2]); auto end_sinfo = GetStructInfo(call->args[3]); - auto strides_sinfo = [&]() -> Optional { + auto strides_sinfo = [&]() -> ffi::Optional { if (n_args > 4) { return GetStructInfo(call->args[4]); } else { @@ -342,7 +342,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx const auto* data_sinfo = data->struct_info_.as(); DataType dtype = DataType::Void(); - Optional vdevice = std::nullopt; + ffi::Optional vdevice = std::nullopt; int ndim = kUnknownNDim; if (data_sinfo) { dtype = data_sinfo->dtype; @@ -350,7 +350,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx ndim = data_sinfo->ndim; } - Optional shape = [&]() -> Optional { + ffi::Optional shape = [&]() -> ffi::Optional { if (!data_sinfo) return std::nullopt; if (!data_sinfo->shape) return std::nullopt; @@ -378,14 +378,14 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple << ") and " << end_tuple.size() << " 'end' indices specified (" << end_tuple << ")"; - Array strides_tuple; + ffi::Array strides_tuple; if (strides.defined()) { auto opt_strides_tuple = UnpackTupleOfPrimValue(strides); if (!opt_strides_tuple) return std::nullopt; strides_tuple = opt_strides_tuple.value(); } else { - strides_tuple = Array(axes_tuple.size(), IntImm(DataType::Int(64), 1)); + strides_tuple = ffi::Array(axes_tuple.size(), IntImm(DataType::Int(64), 1)); } CHECK_EQ(axes_tuple.size(), strides_tuple.size()) @@ -406,7 +406,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, axes_tuple); auto attrs = call->attrs.as(); - Array output_shape = data_sinfo->GetShape().value(); + ffi::Array output_shape = data_sinfo->GetShape().value(); for (size_t i = 0; i < axes.size(); i++) { size_t axis = axes[i]; PrimExpr input_dim = output_shape[axis]; @@ -436,9 +436,9 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx } } -InferLayoutOutput InferLayoutStridedSlice(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutStridedSlice( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -460,9 +460,9 @@ InferLayoutOutput InferLayoutStridedSlice(const Call& call, << " requires slices to be along static axes. " << "However, expression " << call << " slices along non-static axes " << call->args[1]; - Array axes_tuple = opt_axes_tuple.value(); + ffi::Array axes_tuple = opt_axes_tuple.value(); - Array new_axes; + ffi::Array new_axes; for (const auto& axis : axes_tuple) { int new_axis = FindAxis(existing_layout->layout, axis->value); new_axes.push_back(relax::PrimValue::Int64(new_axis)); @@ -490,10 +490,10 @@ Expr dynamic_strided_slice(Expr x, // return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.dynamic_strided_slice", dynamic_strided_slice); -}); +} StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -515,7 +515,7 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& } int n_axis = data_sinfo->ndim; - auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name) { ICHECK(sinfo) << "Dynamic strided slice requires the input " << name << " to be have the struct info. Please try normalizing the inputs."; CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name @@ -524,7 +524,7 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ICHECK(shape) << "Dynamic strided slice requires the input " << name << " to have well-defined shape."; // NOTE(tvm-team): This strong restriction seems necessary for now until we have a generic - // solution in converting 1d Tensor with unknown num_elem to Array. + // solution in converting 1d Tensor with unknown num_elem to ffi::Array. const auto* num_elem = shape->values[0].as(); ICHECK(num_elem) << "Dynamic strided slice requires the input " << name << " to have a known integer shape value."; diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h index a45fb93792ed..0c5b45c68f2c 100644 --- a/src/relax/op/tensor/index.h +++ b/src/relax/op/tensor/index.h @@ -41,7 +41,7 @@ namespace relax { * \param mode The mode for handling out-of-bounds indices. * \return The taken result. */ -Expr take(Expr x, Expr indices, Optional axis, String mode = "fast"); +Expr take(Expr x, Expr indices, ffi::Optional axis, ffi::String mode = "fast"); /*! * \brief Strided slice of a tensor. @@ -55,8 +55,8 @@ Expr take(Expr x, Expr indices, Optional axis, String mode = "fast"); * \param assume_inbound Whether to assume the indices are in bound. * \return The sliced result */ -Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strides = std::nullopt, - bool assume_inbound = false); +Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, + ffi::Optional strides = std::nullopt, bool assume_inbound = false); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 7dd193ce37cb..01843ba0a3c0 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -85,7 +85,7 @@ std::tuple GetTensorArgInfoWithIndex(const Cal << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " elements"; } - return {GetRef(tensor_sinfo), GetRef(axis_sinfo)}; + return {ffi::GetRef(tensor_sinfo), ffi::GetRef(axis_sinfo)}; } DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } @@ -103,7 +103,7 @@ tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); - tir::PrimFunc func(Array{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); + tir::PrimFunc func(ffi::Array{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)}, PrimStructInfo(field_dtype)); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index dcd2a1e24fca..06b7856dd239 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -34,28 +34,28 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { MatmulAttrs::RegisterReflection(); EinsumAttrs::RegisterReflection(); -}); +} /* relax.matmul */ -Expr matmul(Expr x1, Expr x2, Optional out_dtype) { - ObjectPtr attrs = make_object(); +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->out_dtype = out_dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.matmul"); return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.matmul", matmul); -}); +} StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); Expr lhs = call->args[0]; Expr rhs = call->args[1]; TensorStructInfo x1_sinfo = input_sinfo[0]; @@ -121,11 +121,11 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(out_dtype, output_ndim); } - Array x1_shape_prefix{x1_shape->values.begin(), - x1_shape->values.end() - 2 + x1_prepended}; - Array x2_shape_prefix{x2_shape->values.begin(), - x2_shape->values.end() - 2 + x2_appended}; - Optional> output_shape_prefix = + ffi::Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + ffi::Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; + ffi::Optional> output_shape_prefix = InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); if (!output_shape_prefix.defined()) { if (vdev.defined()) { @@ -146,7 +146,7 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { << x2_reduction_length << " are not equal."); } - Array output_shape = output_shape_prefix.value(); + ffi::Array output_shape = output_shape_prefix.value(); if (!x1_prepended) { output_shape.push_back(x1_shape->values[x1_ndim - 2]); } @@ -175,24 +175,24 @@ TVM_REGISTER_OP("relax.matmul") /* relax.einsum */ -Expr einsum(Expr operands, String subscripts) { - ObjectPtr attrs = make_object(); +Expr einsum(Expr operands, ffi::String subscripts) { + ObjectPtr attrs = ffi::make_object(); attrs->subscripts = std::move(subscripts); static const Op& op = Op::Get("relax.einsum"); return Call(op, {std::move(operands)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.einsum", einsum); -}); +} StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "Einsum op should take 1 argument"); } - Array operands_tensor_sinfo = + ffi::Array operands_tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); if (operands_tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) @@ -219,10 +219,10 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { } } - String subscripts = attrs->subscripts; + ffi::String subscripts = attrs->subscripts; DataType operand_dtype = operands_tensor_sinfo[0]->dtype; - std::vector> input_shapes; + std::vector> input_shapes; input_shapes.reserve(operands_tensor_sinfo.size()); for (TensorStructInfo tensor_sinfo : operands_tensor_sinfo) { @@ -246,7 +246,7 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { } } // Calculate output shape using InferEinsumShape in topi - Array oshape = topi::InferEinsumShape(subscripts, input_shapes); + ffi::Array oshape = topi::InferEinsumShape(subscripts, input_shapes); if (!vdevice_unknown) { return TensorStructInfo(ShapeExpr(oshape), operand_dtype, vdev); @@ -268,10 +268,10 @@ Expr outer(Expr x1, Expr x2) { return Call(op, {std::move(x1), std::move(x2)}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.outer", outer); -}); +} StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { auto input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -290,7 +290,7 @@ StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { if (!x1_shape || !x2_shape) { return TensorStructInfo(x1_sinfo->dtype, 2); } - Array output_shape = {x1_shape->values[0], x2_shape->values[0]}; + ffi::Array output_shape = {x1_shape->values[0], x2_shape->values[0]}; return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype); } diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h index eb003fed1c76..ddfceae4dc35 100644 --- a/src/relax/op/tensor/linear_algebra.h +++ b/src/relax/op/tensor/linear_algebra.h @@ -41,7 +41,7 @@ namespace relax { * When it is not specified, the output dtype will be the same as input dtype. * \return The computed result. */ -Expr matmul(Expr x1, Expr x2, Optional out_dtype); +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype); /*! * \brief Einstein summation on the operands. @@ -49,7 +49,7 @@ Expr matmul(Expr x1, Expr x2, Optional out_dtype); * \param subscripts The einsum expression string. * \return The computed result. */ -Expr einsum(Expr operands, String subscripts); +Expr einsum(Expr operands, ffi::String subscripts); /*! * \brief Compute the outer product of two input expressions. diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 83b157034279..0310c7f46b0d 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -37,7 +37,7 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { ConcatAttrs::RegisterReflection(); ExpandDimsAttrs::RegisterReflection(); LayoutTransformAttrs::RegisterReflection(); @@ -56,7 +56,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ScatterNDAttrs::RegisterReflection(); SliceScatterAttrs::RegisterReflection(); OneHotAttrs::RegisterReflection(); -}); +} /* relax.broadcast_to */ Expr broadcast_to(Expr x, Expr shape) { @@ -64,10 +64,10 @@ Expr broadcast_to(Expr x, Expr shape) { return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.broadcast_to", broadcast_to); -}); +} StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -107,8 +107,8 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) } arith::Analyzer* analyzer = ctx->GetAnalyzer(); - Array old_shape_value = shape_sinfo->values.value(); - Array tgt_shape_value = tgt_shape_sinfo->values.value(); + ffi::Array old_shape_value = shape_sinfo->values.value(); + ffi::Array tgt_shape_value = tgt_shape_sinfo->values.value(); int old_ndim = old_shape_value.size(); int tgt_ndim = tgt_shape_value.size(); for (int i = 0; i < old_ndim; ++i) { @@ -141,22 +141,22 @@ TVM_REGISTER_OP("relax.broadcast_to") /* relax.concat */ -Expr concat(Expr tensors, Optional axis) { - ObjectPtr attrs = make_object(); +Expr concat(Expr tensors, ffi::Optional axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.concat"); return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.concat", concat); -}); +} -Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, - const std::vector>& shape_values, - int axis) { +ffi::Optional> CheckConcatOutputShape( + const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, int axis) { bool shape_unknown = false; arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr concat_sum = [&]() { @@ -174,7 +174,7 @@ Optional> CheckConcatOutputShape(const Call& call, const BlockBu // General case, add up the dimensions along the specified axis. PrimExpr concat_sum = IntImm(DataType::Int(64), 0); - for (Array shape_value : shape_values) { + for (ffi::Array shape_value : shape_values) { concat_sum += shape_value[axis]; } return concat_sum; @@ -201,7 +201,7 @@ Optional> CheckConcatOutputShape(const Call& call, const BlockBu if (shape_unknown) { return std::nullopt; } - Array output_shape = shape_values[0]; + ffi::Array output_shape = shape_values[0]; output_shape.Set(axis, concat_sum); return output_shape; } @@ -210,7 +210,8 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); } - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + ffi::Array tensor_sinfo = + GetTensorStructInfoFromTuple(call, ctx, call->args[0]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op expects at least one tensor in the input Tuple. However, the " @@ -220,11 +221,11 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int output_ndim = attrs->axis.has_value() ? kUnknownNDim : 1; DataType output_dtype = DataType::Void(); - Optional vdev = std::nullopt; + ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; bool vdevice_unknown = false; - std::vector> shape_values; + std::vector> shape_values; shape_values.reserve(tensor_sinfo.size()); for (TensorStructInfo sinfo : tensor_sinfo) { @@ -310,7 +311,8 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { } // As long as the there is known shape value, we will do the best effort check to ensure safety. - Optional> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis); + ffi::Optional> output_shape = + CheckConcatOutputShape(call, ctx, shape_values, axis); if (shape_unknown || !output_shape.defined()) { if (!vdevice_unknown) { @@ -325,25 +327,68 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutConcat(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConcat( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "Invalid Call"; + NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); ICHECK(nlayout.IsNested()); ICHECK(nlayout.NestedArray()[0].IsLeaf()); int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); - Array input_layouts, output_layouts; + + // We may expect mix of sub indexed and regular layouts here + // Pick the first sub indexed layout and try to prove it for all tensors + // On any failre select first occuring regular layout for all + auto nlayout_array = nlayout.NestedArray(); + for (auto n_layout : nlayout_array) { + ICHECK(n_layout.IsLeaf()); + LayoutDecision in_layout = n_layout.LeafValue(); + if (in_layout->layout.ndim() != in_layout->layout.ndim_primal()) { + const auto* tuple_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tuple_sinfo != nullptr) + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_->GetTypeKey(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + StructInfo field_sinfo = tuple_sinfo->fields[i]; + const auto* field_tensor_sinfo = field_sinfo.as(); + ICHECK(field_tensor_sinfo != nullptr) + << call->op + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_; + auto t_sinfo = ffi::GetRef(field_tensor_sinfo); + ffi::Optional t_shape = + ffi::GetRef(t_sinfo->shape.as()); + LayoutDecision curr_layout = nlayout_array[i].LeafValue(); + if (!CanProveLayoutTransform(curr_layout->layout, in_layout->layout, + t_shape.value()->values)) { + // Some tensor unhappy with sub indexed layout, lets pick first regular layout + for (auto pick_layout : nlayout_array) { + if (pick_layout.LeafValue()->layout.ndim() == + pick_layout.LeafValue()->layout.ndim_primal()) { + in_layout = pick_layout.LeafValue(); + break; + } + } + break; + } + } + layout = in_layout; + break; + } + } + + ffi::Array input_layouts, output_layouts; for (int i = 0; i < n_tensor; ++i) { input_layouts.push_back(layout); } output_layouts.push_back(layout); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis.value_or(0)); return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); } @@ -359,18 +404,18 @@ TVM_REGISTER_OP("relax.concat") /* relax.expand_dims */ -Expr expand_dims(Expr x, Array axis) { - ObjectPtr attrs = make_object(); +Expr expand_dims(Expr x, ffi::Array axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.expand_dims"); return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.expand_dims", expand_dims); -}); +} StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -411,9 +456,9 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutExpandDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutExpandDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "Invalid Call"; @@ -462,7 +507,7 @@ TVM_REGISTER_OP("relax.expand_dims") .set_attr("FPurity", Bool(true)); // Helper function for flatten and reshape. -PrimExpr ComputeShapeProduct(const Array& shape_values) { +PrimExpr ComputeShapeProduct(const ffi::Array& shape_values) { PrimExpr shape_prod = IntImm(DataType::Int(64), 1); for (PrimExpr value : shape_values) { shape_prod *= value; @@ -476,10 +521,10 @@ Expr flatten(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.flatten", flatten); -}); +} StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -514,10 +559,10 @@ Expr index_tensor(Expr first, Expr tensors) { return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.index_tensor", index_tensor); -}); +} StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -525,7 +570,8 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - Array indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + ffi::Array indices_sinfo = + GetTensorStructInfoFromTuple(call, ctx, call->args[1]); if (indices_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) @@ -534,7 +580,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) DataType output_dtype = data_sinfo->dtype; int n_indices = static_cast(indices_sinfo.size()); - Optional vdev = data_sinfo->vdevice; + ffi::Optional vdev = data_sinfo->vdevice; // Indices must be integers for (int i = 0; i < n_indices; ++i) { @@ -555,7 +601,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) arith::Analyzer* analyzer = ctx->GetAnalyzer(); bool all_index_have_shape_value = true; - std::vector> index_shapes; + std::vector> index_shapes; int max_index_ndim = 0; for (const auto& s : indices_sinfo) { @@ -571,12 +617,12 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } } - Optional> broadcast_shape; + ffi::Optional> broadcast_shape; bool shape_unknown = !all_index_have_shape_value; if (all_index_have_shape_value) { // initialise broadcast result with 1's - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < max_index_ndim; ++i) { out_shape.push_back(IntImm(DataType::Int(64), 1)); } @@ -636,7 +682,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) if (broadcast_shape.defined()) { const auto* data_shape_expr = data_sinfo->shape.as(); if (data_shape_expr) { - Array result_shape = broadcast_shape.value(); + ffi::Array result_shape = broadcast_shape.value(); for (int i = n_indices; i < data_sinfo->ndim; ++i) { result_shape.push_back(data_shape_expr->values[i]); } @@ -657,10 +703,10 @@ TVM_REGISTER_OP("relax.index_tensor") /* relax.layout_transform */ -Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value, - Optional> axis_separators, - Optional> input_axis_separators) { - ObjectPtr attrs = make_object(); +Expr layout_transform(Expr x, tir::IndexMap index_map, ffi::Optional pad_value, + ffi::Optional> axis_separators, + ffi::Optional> input_axis_separators) { + ObjectPtr attrs = ffi::make_object(); attrs->index_map = std::move(index_map); attrs->pad_value = std::move(pad_value); attrs->axis_separators = std::move(axis_separators); @@ -670,16 +716,16 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_v return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.layout_transform", layout_transform); -}); +} StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); tir::IndexMap index_map = attrs->index_map; - Optional optional_pad_value = attrs->pad_value; + ffi::Optional optional_pad_value = attrs->pad_value; // Check pad_value has same dtype as input. if (optional_pad_value.defined()) { @@ -717,7 +763,7 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& } arith::Analyzer analyzer; - Array output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer); + ffi::Array output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer); return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); } @@ -731,18 +777,18 @@ TVM_REGISTER_OP("relax.layout_transform") /* relax.permute_dims */ -Expr permute_dims(Expr x, Optional> axes) { - ObjectPtr attrs = make_object(); +Expr permute_dims(Expr x, ffi::Optional> axes) { + ObjectPtr attrs = ffi::make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("relax.permute_dims"); return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.permute_dims", permute_dims); -}); +} bool IsIdentityPermutation(const std::vector& permutation) { for (int i = 0; i < static_cast(permutation.size()); ++i) { @@ -798,9 +844,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPermuteDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPermuteDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -817,7 +863,7 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call, existing_layout = LayoutDecision(InitialLayout(ndim)); } - Array order; + ffi::Array order; if (attrs->axes.defined()) { order = attrs->axes.value(); } else { @@ -830,13 +876,13 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call, for (const auto& axis : order) { order_str.push_back(axis->value + 'A'); } - String new_axes = + ffi::String new_axes = TransposeStrLike(InitialLayout(ndim).name(), existing_layout->layout, order_str); - Array new_order; + ffi::Array new_order; for (size_t i = 0; i < new_axes.size(); ++i) { new_order.push_back(Integer(new_axes.at(i) - 'A')); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axes = new_order; return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(ndim)}, Attrs(new_attrs)); } @@ -851,14 +897,15 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, + const ffi::Variant>& shape) { const ffi::ArrayObj* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { array = e->values.as(); // Other non-shape expressions are used directly. } else if (const auto* e = shape.as()) { - return GetRef(e); + return ffi::GetRef(e); // Process special values in constants and produce an expression. } else { array = shape.as(); @@ -874,7 +921,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const Variant CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " "Array of PrimExprs. However, the given new shape is " << shape; - PrimExpr len = GetRef(_len); + PrimExpr len = ffi::GetRef(_len); CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " "integers. However, the give new shape is " << shape; @@ -895,7 +942,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const Variant } } - Array array_ref = GetRef>(array); + ffi::Array array_ref = ffi::GetRef>(array); // When there is no dimension to infer, just return the input array as ShapeExpr. if (dim_to_infer == -1 && zero_dims.empty()) { return ShapeExpr(array_ref); @@ -944,16 +991,16 @@ Expr ConvertNewShapeToExpr(const Expr& data, const Variant return ShapeExpr(array_ref); } -Expr reshape(Expr x, Variant> shape) { +Expr reshape(Expr x, ffi::Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.reshape", reshape); -}); +} StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -973,7 +1020,7 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { << call->args[1]->struct_info_->GetTypeKey()); } - Optional> old_shape_values; + ffi::Optional> old_shape_values; if (data_sinfo->shape.defined()) { const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); ICHECK_NOTNULL(old_shape_sinfo); @@ -1011,8 +1058,8 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ -Expr split(Expr x, Variant> indices_or_sections, int axis) { - ObjectPtr attrs = make_object(); +Expr split(Expr x, ffi::Variant> indices_or_sections, int axis) { + ObjectPtr attrs = ffi::make_object(); ObjectRef indices_or_sections_obj; if (const auto* indices = indices_or_sections.as()) { @@ -1022,7 +1069,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) "However, the given indices " << indices_or_sections << " contains some non-integer."; } - indices_or_sections_obj = ConvertIntImmToInt64(GetRef>(indices)); + indices_or_sections_obj = ConvertIntImmToInt64(ffi::GetRef>(indices)); } else if (const auto* n_section = indices_or_sections.as()) { CHECK_GT(n_section->value, 0) << "Split op expects the input number of sections to be a " "positive integer. However, the given number of sections is " @@ -1039,10 +1086,10 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.split", split); -}); +} StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1051,7 +1098,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { int axis = data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); - if (auto opt_indices = attrs->indices_or_sections.as>()) { + if (auto opt_indices = attrs->indices_or_sections.as>()) { auto p_indices = opt_indices.value(); // When there is not index, return the input tensor's struct info. if (p_indices.size() == 0) { @@ -1059,7 +1106,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { - return TupleStructInfo(Array( + return TupleStructInfo(ffi::Array( p_indices.size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); } @@ -1091,7 +1138,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { split_dim = tvm::max(split_dim, 0); split_dim = ctx->GetAnalyzer()->Simplify(split_dim); - Array shape = data_shape->values; + ffi::Array shape = data_shape->values; shape.Set(axis, split_dim); output_sinfo.push_back( TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); @@ -1106,7 +1153,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { - return TupleStructInfo(Array( + return TupleStructInfo(ffi::Array( n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); } ICHECK_NE(axis, -1); @@ -1114,7 +1161,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { split_len = ctx->GetAnalyzer()->Simplify(split_len); // Construct struct info for tensors except the last one. - Array shape = data_shape->values; + ffi::Array shape = data_shape->values; shape.Set(axis, split_len); std::vector output_sinfo( n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); @@ -1131,9 +1178,9 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { throw; } -InferLayoutOutput InferLayoutSplit(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutSplit( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -1157,7 +1204,8 @@ InferLayoutOutput InferLayoutSplit(const Call& call, "output structinfo, but got " << si; auto sinfo = Downcast(si); - Optional shape_expr = GetRef(sinfo->shape.as()); + ffi::Optional shape_expr = + ffi::GetRef(sinfo->shape.as()); CHECK(shape_expr.defined()); auto shape_arr = shape_expr.value(); if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout, @@ -1168,10 +1216,10 @@ InferLayoutOutput InferLayoutSplit(const Call& call, } } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); ICHECK(out_tuple != nullptr) << "Invalid Call"; - NLayout tuple_layouts(Array(out_tuple->fields.size(), existing_layout)); + NLayout tuple_layouts(ffi::Array(out_tuple->fields.size(), existing_layout)); return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs)); } @@ -1186,18 +1234,18 @@ TVM_REGISTER_OP("relax.split") /* relax.squeeze */ -Expr squeeze(Expr x, Optional> axis) { - ObjectPtr attrs = make_object(); +Expr squeeze(Expr x, ffi::Optional> axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.squeeze"); return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.squeeze", squeeze); -}); +} StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1210,7 +1258,7 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); } - Optional> shape_value; + ffi::Optional> shape_value; if (data_sinfo->shape.defined()) { shape_value = Downcast(data_sinfo->shape.value()->struct_info_)->values; } @@ -1229,15 +1277,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { // Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1. // When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic const auto* int_len = shape_value.value()[axes[i]].as(); - if (int_len != nullptr && int_len->value != 1) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Squeeze expects the input tensor shape values at the given axis " - "positions to be all 1. However, the tensor shape at axis " - << axes[i] << " is " << shape_value.value()[axes[i]] - << " which is not 1. If it is symbolic, please use MatchCast to cast it " - "to 1 before doing Squeeze."); + // If a dimension is not 1, silently skip it (no-op), matching PyTorch behavior. + if ((int_len != nullptr && int_len->value == 1) || int_len == nullptr) { + axis_removal_mask[axes[i]] = true; } - axis_removal_mask[axes[i]] = true; } } else { // When `axis` is not defined, squeeze all unit-length dimensions. @@ -1280,9 +1323,9 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutSqueeze(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutSqueeze( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -1295,7 +1338,7 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, const auto* shape = tensor_sinfo->shape.as(); ICHECK(shape != nullptr) << "Only support static shape for now"; - Array axis; + ffi::Array axis; if (attrs->axis.defined()) { axis = attrs->axis.value(); } else { @@ -1322,8 +1365,9 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { existing_layout = LayoutDecision(InitialLayout(ndim)); } - String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); - Array new_axis; + ffi::String new_axis_str = + TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); + ffi::Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { if (new_axis_str.at(i) == '1') { new_axis.push_back(Integer(i)); @@ -1333,7 +1377,7 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), '1'), output_layout.end()); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = new_axis; return InferLayoutOutput({existing_layout}, {LayoutDecision(Layout(output_layout))}, Attrs(new_attrs)); @@ -1349,7 +1393,8 @@ TVM_REGISTER_OP("relax.squeeze") .set_attr("FPurity", Bool(true)); void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, - const Array& data_shape, const Array& target_shape) { + const ffi::Array& data_shape, + const ffi::Array& target_shape) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); int data_ndim = data_shape.size(); @@ -1388,22 +1433,22 @@ void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, /* relax.stack */ -Expr stack(Expr tensors, Optional axis) { - ObjectPtr attrs = make_object(); +Expr stack(Expr tensors, ffi::Optional axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.stack"); return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.stack", stack); -}); +} -Optional> CheckStackOutputShape(const Call& call, const BlockBuilder& ctx, - const std::vector>& shape_values, - int axis) { +ffi::Optional> CheckStackOutputShape( + const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, int axis) { bool shape_unknown = false; arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1426,7 +1471,7 @@ Optional> CheckStackOutputShape(const Call& call, const BlockBui } // Insert new dimension at axis position - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < axis; ++i) { output_shape.push_back(shape_values[0][i]); } @@ -1442,7 +1487,8 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { ctx->ReportFatal(Diagnostic::Error(call) << "Stack op should have 1 argument"); } - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + ffi::Array tensor_sinfo = + GetTensorStructInfoFromTuple(call, ctx, call->args[0]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) << "Stack op expects at least one tensor in the input Tuple. " @@ -1455,11 +1501,11 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { // Default axis is 0 if not specified int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension DataType output_dtype = DataType::Void(); - Optional vdev = std::nullopt; + ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; bool vdevice_unknown = false; - std::vector> shape_values; + std::vector> shape_values; shape_values.reserve(tensor_sinfo.size()); for (TensorStructInfo sinfo : tensor_sinfo) { @@ -1522,7 +1568,7 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { } return TensorStructInfo(output_dtype, output_ndim); } - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < axis; ++i) { output_shape.push_back(shape_values[0][i]); } @@ -1544,7 +1590,8 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(output_dtype, output_ndim); } - Optional> output_shape = CheckStackOutputShape(call, ctx, shape_values, axis); + ffi::Optional> output_shape = + CheckStackOutputShape(call, ctx, shape_values, axis); if (shape_unknown || !output_shape.defined()) { if (!vdevice_unknown) { return TensorStructInfo(output_dtype, output_ndim, vdev); @@ -1558,9 +1605,9 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutStack(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutStack( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -1571,7 +1618,7 @@ InferLayoutOutput InferLayoutStack(const Call& call, int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); - Array input_layouts, output_layouts; + ffi::Array input_layouts, output_layouts; for (int i = 0; i < n_tensor; ++i) { input_layouts.push_back(layout); } @@ -1583,7 +1630,7 @@ InferLayoutOutput InferLayoutStack(const Call& call, Layout output_layout = Layout(layout_str); output_layouts.push_back(LayoutDecision(output_layout)); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = Integer(FindAxis(layout->layout, axis)); return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); } @@ -1603,23 +1650,23 @@ Expr collapse_sum_like(Expr data, Expr collapse_target) { return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.collapse_sum_like", collapse_sum_like); -}); +} StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo collapse_target_sinfo = input_sinfo[1]; DataType output_dtype = data_sinfo->dtype; - Optional> data_shape_value; + ffi::Optional> data_shape_value; if (data_sinfo->shape.defined()) { data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; } - Optional> collapse_target_shape_value; + ffi::Optional> collapse_target_shape_value; if (collapse_target_sinfo->shape.defined()) { collapse_target_shape_value = GetStructInfoAs(collapse_target_sinfo->shape.value())->values; @@ -1652,10 +1699,10 @@ Expr collapse_sum_to(Expr data, Expr shape) { return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.collapse_sum_to", collapse_sum_to); -}); +} StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -1680,7 +1727,7 @@ StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ct DataType output_dtype = data_sinfo->dtype; - Optional> data_shape_value; + ffi::Optional> data_shape_value; if (data_sinfo->shape.defined()) { data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; } @@ -1700,8 +1747,8 @@ TVM_REGISTER_OP("relax.collapse_sum_to") /* relax.repeat */ -Expr repeat(Expr data, int repeats, Optional axis) { - auto attrs = make_object(); +Expr repeat(Expr data, int repeats, ffi::Optional axis) { + auto attrs = ffi::make_object(); attrs->repeats = std::move(repeats); attrs->axis = std::move(axis); @@ -1709,10 +1756,10 @@ Expr repeat(Expr data, int repeats, Optional axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.repeat", repeat); -}); +} StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1748,7 +1795,7 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { if (!attrs->axis.has_value()) { PrimExpr new_shape = analyzer->Simplify(ComputeShapeProduct(data_shape->values) * attrs->repeats); - return TensorStructInfo(ShapeExpr(Array({new_shape})), data_sinfo->dtype, + return TensorStructInfo(ShapeExpr(ffi::Array({new_shape})), data_sinfo->dtype, data_sinfo->vdevice); } @@ -1768,18 +1815,18 @@ TVM_REGISTER_OP("relax.repeat") /* relax.tile */ -Expr tile(Expr data, Array repeats) { - auto attrs = make_object(); +Expr tile(Expr data, ffi::Array repeats) { + auto attrs = ffi::make_object(); attrs->repeats = std::move(repeats); static const Op& op = Op::Get("relax.tile"); return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.tile", tile); -}); +} StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1809,7 +1856,7 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { int out_ndim = std::max(l, ndim); int l_delta = out_ndim - l; int ndim_delta = out_ndim - ndim; - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < out_ndim; ++i) { if (i < l_delta) { out_shape.push_back(data_shape->values[i - ndim_delta]); @@ -1835,16 +1882,16 @@ TVM_REGISTER_OP("relax.tile") /* relax.flip */ Expr flip(Expr data, Integer axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.flip"); return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.flip", flip); -}); +} StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -1874,16 +1921,16 @@ TVM_REGISTER_OP("relax.flip") /* relax.gather_elements */ Expr gather_elements(Expr data, Expr indices, int axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = Integer(axis); static const Op& op = Op::Get("relax.gather_elements"); return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.gather_elements", gather_elements); -}); +} StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -1945,16 +1992,16 @@ TVM_REGISTER_OP("relax.gather_elements") /* relax.gather_nd */ Expr gather_nd(Expr data, Expr indices, int batch_dims) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->batch_dims = Integer(batch_dims); static const Op& op = Op::Get("relax.gather_nd"); return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.gather_nd", gather_nd); -}); +} StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -2012,7 +2059,7 @@ StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { } // In this condition, all input shapes are known - Array out_shape; + ffi::Array out_shape; if (l > input_dims - batch_dims) { ctx->ReportFatal(Diagnostic::Error(call) << "GatherND requires the last dimension of indices to be less than or " @@ -2041,22 +2088,22 @@ TVM_REGISTER_OP("relax.gather_nd") /* relax.index_put */ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->accumulate = std::move(accumulate); static const Op& op = Op::Get("relax.index_put"); return Call(op, {data, indices, values}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.index_put", index_put); -}); +} StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* values_sinfo = GetStructInfoAs(call->args[2]); - auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String type_key) { + auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name, ffi::String type_key) { if (sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) << "IndexPut requires the input " << name @@ -2068,7 +2115,7 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { diag_def(values_sinfo, "values", call->args[2]->struct_info_->GetTypeKey()); // Handle indices: either a single tensor or a tuple of tensors - Array indices_tensors; + ffi::Array indices_tensors; if (const auto* tuple_sinfo = GetStructInfoAs(call->args[1])) { // Indices is a tuple of tensors @@ -2080,11 +2127,11 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { << "However, element " << i << " is " << tuple_sinfo->fields[i]->GetTypeKey()); } - indices_tensors.push_back(GetRef(tensor_sinfo)); + indices_tensors.push_back(ffi::GetRef(tensor_sinfo)); } } else if (const auto* tensor_sinfo = GetStructInfoAs(call->args[1])) { // Indices is a single tensor - indices_tensors.push_back(GetRef(tensor_sinfo)); + indices_tensors.push_back(ffi::GetRef(tensor_sinfo)); } else { ctx->ReportFatal(Diagnostic::Error(call) << "IndexPut requires indices to be a Tensor or a tuple of Tensors. " @@ -2096,12 +2143,19 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { } // Validate each index tensor + // Index tensors can be multi-dimensional for broadcasting + int max_index_ndim = -1; for (size_t i = 0; i < indices_tensors.size(); ++i) { const auto& tensor_sinfo = indices_tensors[i]; - if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim != 1) { - ctx->ReportFatal(Diagnostic::Error(call) - << "IndexPut requires each index tensor to be 1D. " - << "However, index tensor " << i << " has ndim=" << tensor_sinfo->ndim); + if (!tensor_sinfo->IsUnknownNdim()) { + if (tensor_sinfo->ndim < 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires each index tensor to have at least 1 dimension. " + << "However, index tensor " << i << " has ndim=" << tensor_sinfo->ndim); + } + if (max_index_ndim < tensor_sinfo->ndim) { + max_index_ndim = tensor_sinfo->ndim; + } } if (tensor_sinfo->IsUnknownDtype()) { LOG(WARNING) << "Data type of index tensor " << i @@ -2113,6 +2167,23 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { } } + // Validate that index tensor shapes are broadcastable + if (max_index_ndim > 1) { + for (size_t i = 0; i < indices_tensors.size(); ++i) { + const auto& tensor_sinfo = indices_tensors[i]; + if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim > 1) { + // Check that multi-dimensional indices are broadcastable + const auto* shape = tensor_sinfo->shape.as(); + if (shape) { + // Verify trailing dimensions can broadcast + // For now, we accept any multi-dimensional index and rely on runtime validation + LOG(INFO) << "IndexPut: index tensor " << i << " has ndim=" << tensor_sinfo->ndim + << " for broadcasting"; + } + } + } + } + // Check that the number of index tensors matches data dimensions if (!data_sinfo->IsUnknownNdim() && indices_tensors.size() != static_cast(data_sinfo->ndim)) { @@ -2123,7 +2194,7 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { // Check data and values dtype compatibility if (data_sinfo->IsUnknownDtype() || values_sinfo->IsUnknownDtype()) { - auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { if (sinfo->IsUnknownDtype()) { LOG(WARNING) << "Data type of " << name << " has not been specified. Assume it has an integer type."; @@ -2165,23 +2236,23 @@ TVM_REGISTER_OP("relax.index_put") /* relax.meshgrid */ -Expr meshgrid(Expr tensors, Optional indexing) { - ObjectPtr attrs = make_object(); +Expr meshgrid(Expr tensors, ffi::Optional indexing) { + ObjectPtr attrs = ffi::make_object(); attrs->indexing = indexing; static const Op& op = Op::Get("relax.meshgrid"); return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.meshgrid", meshgrid); -}); +} StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "meshgrid op expects 1 Tuple input argument."); } - Array input_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + ffi::Array input_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); int n_inputs = input_sinfo.size(); @@ -2193,7 +2264,7 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { std::vector lengths; DataType common_dtype = DataType::Void(); bool shape_unknown = false; - Optional vdev = std::nullopt; + ffi::Optional vdev = std::nullopt; bool vdevice_unknown = false; for (int i = 0; i < n_inputs; ++i) { @@ -2233,14 +2304,14 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { } } - Array out_shape; + ffi::Array out_shape; if (!shape_unknown && lengths.size() == static_cast(n_inputs)) { for (const PrimExpr& dim : lengths) { out_shape.push_back(dim); } } - Array out_fields; + ffi::Array out_fields; for (int i = 0; i < n_inputs; ++i) { if (!out_shape.empty()) { if (!vdevice_unknown) { @@ -2270,18 +2341,18 @@ TVM_REGISTER_OP("relax.meshgrid") /* relax.scatter_elements */ -Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction) { - auto attrs = make_object(); +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, ffi::String reduction) { + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->reduction = std::move(reduction); static const Op& op = Op::Get("relax.scatter_elements"); return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.scatter_elements", scatter_elements); -}); +} StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -2289,7 +2360,7 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& const auto* indices_sinfo = GetStructInfoAs(call->args[1]); const auto* updates_sinfo = GetStructInfoAs(call->args[2]); - auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String type_key) { + auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name, ffi::String type_key) { if (sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) << "ScatterElements requires the input " << name @@ -2325,7 +2396,7 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& } if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { - auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { if (sinfo->IsUnknownDtype()) { // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? LOG(WARNING) << "Data type of " << name @@ -2387,17 +2458,17 @@ TVM_REGISTER_OP("relax.scatter_elements") /* relax.scatter_nd */ -Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { - auto attrs = make_object(); +Expr scatter_nd(Expr data, Expr indices, Expr updates, ffi::String reduction) { + auto attrs = ffi::make_object(); attrs->reduction = std::move(reduction); static const Op& op = Op::Get("relax.scatter_nd"); return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.scatter_nd", scatter_nd); -}); +} StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { // `call->args` contains: [data, indices, updates] @@ -2479,14 +2550,15 @@ StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { << "data: " << ShapeExpr(data_shape->values) << ", indices: " << ShapeExpr(indices_shape->values)); } - Array expected_updates_shape; + ffi::Array expected_updates_shape; for (size_t i = 0; i < indices_ndim - 1; i++) { expected_updates_shape.push_back(indices_shape->values[i]); } for (size_t i = k_dim->value; i < data_ndim; i++) { expected_updates_shape.push_back(data_shape->values[i]); } - auto check_shape = [&](const Array& expected, const Array& actual) { + auto check_shape = [&](const ffi::Array& expected, + const ffi::Array& actual) { if (expected.size() != actual.size()) { return false; } @@ -2524,16 +2596,16 @@ TVM_REGISTER_OP("relax.scatter_nd") /* relax.scatter_nd */ Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue end, PrimValue step) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.slice_scatter"); return Call(op, {input, src, start, end, step}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.slice_scatter", slice_scatter); -}); +} StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -2542,7 +2614,7 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx auto* attrs = call->attrs.as(); auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr& arg_expr, - String name) { + ffi::String name) { if (sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter requires the input " << name << " to be a Tensor. However, the given one is " @@ -2576,7 +2648,7 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx } if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) { - auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, ffi::String name) { if (sinfo->IsUnknownDtype()) { LOG(WARNING) << "SliceScatter: Data type of " << name << " has not been specified for call node " << call @@ -2681,7 +2753,7 @@ TVM_REGISTER_OP("relax.slice_scatter") /* relax.one_hot */ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->depth = depth; attrs->axis = axis; @@ -2697,10 +2769,10 @@ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, i return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } // namespace relax -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.one_hot", one_hot); -}); +} StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); @@ -2732,7 +2804,7 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(dtype, indices_sinfo->ndim + 1, indices_sinfo->vdevice); } - Array output_shape = indices_shape->values; + ffi::Array output_shape = indices_shape->values; int axis = attrs->axis; if (axis < 0) { axis += output_shape.size() + 1; diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index cc15d5d4ab76..84d53addcc69 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -44,7 +44,7 @@ Expr broadcast_to(Expr x, Expr shape); * If it is `std::nullopt`, the input tensor is required to be flattened before concatenation. * \return The concatenated tensor. */ -Expr concat(Expr tensors, Optional axis); +Expr concat(Expr tensors, ffi::Optional axis); /*! * \brief Insert new axes at the positions given by `axis`. @@ -52,7 +52,7 @@ Expr concat(Expr tensors, Optional axis); * \param axis The axes at which the input array are expanded. * \return The transformed result. */ -Expr expand_dims(Expr x, Array axis); +Expr expand_dims(Expr x, ffi::Array axis); /*! * \brief Flatten all the tensor dimensions into one. @@ -72,9 +72,9 @@ Expr flatten(Expr x); * \param input axis_separators Array of values for input buffer. * \return The transformed result. */ -Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value, - Optional> axis_separators, - Optional> input_axis_separators = std::nullopt); +Expr layout_transform(Expr x, tir::IndexMap index_map, ffi::Optional pad_value, + ffi::Optional> axis_separators, + ffi::Optional> input_axis_separators = std::nullopt); /*! * \brief Permutes the dimensions of an array. @@ -82,7 +82,7 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_v * \param axes The target axes order, reverse order if not specified. * \return The transposed result. */ -Expr permute_dims(Expr x, Optional> axes); +Expr permute_dims(Expr x, ffi::Optional> axes); /*! * \brief Reshape the input array, supporting `-1` inference in the new @@ -92,7 +92,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, Variant> shape); +Expr reshape(Expr x, ffi::Variant> shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -107,7 +107,7 @@ Expr reshape(Expr x, Variant> shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, Variant> indices_or_sections, int axis); +Expr split(Expr x, ffi::Variant> indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. @@ -117,14 +117,14 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) * If any specified axis has dimension that does not equal 1, it is an error. * \return The squeezed result. */ -Expr squeeze(Expr x, Optional> axis); +Expr squeeze(Expr x, ffi::Optional> axis); /*! * \brief Stack tensors along the specified axis. * \param tensors The input tensors to be stacked. * \param axis The axis along which the tensors will be stacked. * \return The stacked result. */ -Expr stack(Expr tensors, Optional axis); +Expr stack(Expr tensors, ffi::Optional axis); /*! * \brief Return a summation of data to the shape of collapse_target. * For details, please see the operator `relax.collapse_sum_to`. @@ -154,7 +154,7 @@ Expr collapse_sum_to(Expr data, Expr shape); * from the backward. By default, use the flattened input array, and return a flat output array. * \return The computed result. */ -Expr repeat(Expr data, int repeats, Optional axis = std::nullopt); +Expr repeat(Expr data, int repeats, ffi::Optional axis = std::nullopt); /*! * \brief Construct an array by repeating data the number of times given by reps. @@ -171,7 +171,7 @@ Expr repeat(Expr data, int repeats, Optional axis = std::nullopt); * \param repeats The number of repetitions of data along each axis. * \return The computed result. */ -Expr tile(Expr data, Array repeats); +Expr tile(Expr data, ffi::Array repeats); /*! * \brief Reverses the order of elements along given axis. @@ -238,7 +238,7 @@ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false); * \param indexing Indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian indexing). * \return A tuple of tensors representing the coordinate grids. */ -Expr meshgrid(Expr tensors, Optional indexing = String("ij")); +Expr meshgrid(Expr tensors, ffi::Optional indexing = ffi::String("ij")); /*! * \brief Scatter updates into an array according to indices. @@ -250,7 +250,7 @@ Expr meshgrid(Expr tensors, Optional indexing = String("ij")); * either "update", "add", "mul", "mean", "max" or "min". * \return The computed result. */ -Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction); +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, ffi::String reduction); /*! * \brief Scatter updates into an array according to indices. @@ -271,7 +271,7 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re * The shape of `updates` must match the shape of `indices` except for the last dimension, * which must match the slice shape at each index. */ -Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); +Expr scatter_nd(Expr data, Expr indices, Expr updates, ffi::String reduction); /*! * \brief Embeds the values of the src tensor into input at the given dimension. diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index a51d85820e40..406868ab4bfc 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -34,22 +34,22 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ QuantizeAttrs::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { QuantizeAttrs::RegisterReflection(); } /* relax.quantize */ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.quantize"); return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.quantize", quantize); -}); +} StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -93,7 +93,7 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { } auto check_param_size = [&](const TensorStructInfo& param_sinfo, - const TensorStructInfo& data_sinfo, String param_name) { + const TensorStructInfo& data_sinfo, ffi::String param_name) { const PrimExpr& param_dim = param_sinfo->GetShape().value()[0]; const PrimExpr& input_dim = data_sinfo->GetShape().value()[axis]; if (!ctx->GetAnalyzer()->CanProveEqual(param_dim, input_dim)) { @@ -108,7 +108,7 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); - auto output_sinfo = make_object(*input_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = attrs->out_dtype; return TensorStructInfo(output_sinfo); } @@ -125,17 +125,17 @@ TVM_REGISTER_OP("relax.quantize") /* relax.dequantize */ Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.dequantize"); return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.dequantize", dequantize); -}); +} StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -181,7 +181,7 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) } auto check_param_size = [&](const TensorStructInfo& param_sinfo, - const TensorStructInfo& data_sinfo, String param_name) { + const TensorStructInfo& data_sinfo, ffi::String param_name) { const PrimExpr& param_dim = param_sinfo->GetShape().value()[0]; const PrimExpr& input_dim = data_sinfo->GetShape().value()[axis]; if (!ctx->GetAnalyzer()->CanProveEqual(param_dim, input_dim)) { @@ -196,7 +196,7 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); - auto output_sinfo = make_object(*input_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = attrs->out_dtype; return TensorStructInfo(output_sinfo); } diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 803e0a654d1c..ca5635baa74b 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -32,12 +32,12 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ MultinomialFromUniformAttrs::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MultinomialFromUniformAttrs::RegisterReflection(); } /* relax.multinomial_from_uniform */ Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.multinomial_from_uniform"); @@ -45,10 +45,10 @@ Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indice Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.multinomial_from_uniform", multinomial_from_uniform); -}); +} StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index d1ebae3a4fdc..0cd221d53d1c 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -32,28 +32,28 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { ArgmaxArgminAttrs::RegisterReflection(); BucketizeAttrs::RegisterReflection(); -}); +} /* relax.bucketize */ Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->out_int32 = std::move(out_int32); attrs->right = std::move(right); static const Op& op = Op::Get("relax.bucketize"); return Call(op, {std::move(input_tensor), std::move(boundaries)}, Attrs(attrs), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.bucketize", bucketize); -}); +} StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo input_tensor_info = input_sinfo[0]; TensorStructInfo boundaries_info = input_sinfo[1]; @@ -93,13 +93,13 @@ Expr where(Expr condition, Expr x1, Expr x2) { return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.where", where); -}); +} StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo cond_sinfo = input_sinfo[0]; TensorStructInfo x1_sinfo = input_sinfo[1]; TensorStructInfo x2_sinfo = input_sinfo[2]; @@ -139,7 +139,7 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { const auto* x2_shape = x2_sinfo->shape.as(); if (cond_shape && x1_shape && x2_shape) { // Step 1. Compute the broadcasted shape of x1's and x2's - Optional> broadcasted_shape = + ffi::Optional> broadcasted_shape = InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); if (!broadcasted_shape.defined()) { if (vdev.defined()) { @@ -220,12 +220,13 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx const auto* data_shape = data_sinfo->shape.as(); if (data_shape == nullptr) { if (!attrs->axis.has_value() && attrs->keepdims && out_ndim != kUnknownNDim) { - return TensorStructInfo(ShapeExpr(Array(out_ndim, IntImm(out_dtype, /*value=*/1))), - out_dtype, data_sinfo->vdevice); + return TensorStructInfo( + ShapeExpr(ffi::Array(out_ndim, IntImm(out_dtype, /*value=*/1))), out_dtype, + data_sinfo->vdevice); } else { - return out_ndim == 0 - ? TensorStructInfo(ShapeExpr(Array()), out_dtype, data_sinfo->vdevice) - : TensorStructInfo(out_dtype, out_ndim, data_sinfo->vdevice); + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(ffi::Array()), out_dtype, + data_sinfo->vdevice) + : TensorStructInfo(out_dtype, out_ndim, data_sinfo->vdevice); } } @@ -233,7 +234,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx out_dtype = data_shape->values[0]->dtype; } - Array out_shape; + ffi::Array out_shape; out_shape.reserve(out_ndim); for (int i = 0; i < data_sinfo->ndim; ++i) { if (attrs->axis.has_value() && i != axis) { @@ -247,15 +248,16 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx } #define RELAX_REGISTER_ARGMAX_ARGMIN_OP(OpName) \ - Expr OpName(Expr x, Optional axis, bool keepdims) { \ - ObjectPtr attrs = make_object(); \ + Expr OpName(Expr x, ffi::Optional axis, bool keepdims) { \ + ObjectPtr attrs = ffi::make_object(); \ attrs->axis = std::move(axis); \ attrs->keepdims = std::move(keepdims); \ static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs(attrs)); \ } \ - TVM_FFI_STATIC_INIT_BLOCK( \ - { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \ + TVM_FFI_STATIC_INIT_BLOCK() { \ + tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); \ + } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h index 333b5afe76c7..d1cc6e39f43c 100644 --- a/src/relax/op/tensor/search.h +++ b/src/relax/op/tensor/search.h @@ -48,10 +48,10 @@ Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right); Expr where(Expr condition, Expr x1, Expr x2); /*! \brief Computes the argmax of tensor elements over given axis. */ -Expr argmax(Expr x, Optional axis, bool keepdims); +Expr argmax(Expr x, ffi::Optional axis, bool keepdims); /*! \brief Computes the argmin of tensor elements over given axis. */ -Expr argmin(Expr x, Optional axis, bool keepdims); +Expr argmin(Expr x, ffi::Optional axis, bool keepdims); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index f0fc3871371c..d80c73b1317d 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -36,7 +36,7 @@ namespace relax { /* relax.unique */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, - PrimValue return_counts, Optional axis) { + PrimValue return_counts, ffi::Optional axis) { static const Op& op = Op::Get("relax.unique"); Call call; if (!axis) { @@ -48,17 +48,17 @@ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_i return call; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.unique", unique); -}); +} StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = Downcast(call->args[0]->struct_info_); PrimValue axis, return_index, return_inverse, return_counts; if (call->args.size() == 6) { if (auto* prim_value_node = call->args[5].as()) { - axis = GetRef(prim_value_node); + axis = ffi::GetRef(prim_value_node); } } if (!data_sinfo->IsUnknownNdim() && axis.defined()) { @@ -79,7 +79,7 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { CHECK(value->IsInstance()) << value << " expects to be IntImm, but gets " << value->GetTypeKey(); const auto* val_node = value.as(); - auto val_imm = GetRef(val_node); + auto val_imm = ffi::GetRef(val_node); return val_imm->value; }; @@ -149,10 +149,10 @@ Expr nonzero(Expr x) { return Call(op, {std::move(x)}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.nonzero", nonzero); -}); +} StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h index 251dd1975e9f..4af7478d61ef 100644 --- a/src/relax/op/tensor/set.h +++ b/src/relax/op/tensor/set.h @@ -49,7 +49,7 @@ namespace relax { * Additional return values depend on `return_index`, `return_inverse`, and `return_counts`. */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, - PrimValue return_counts, Optional axis); + PrimValue return_counts, ffi::Optional axis); /*! * \brief Returns the indices of the non-zero elements of the input tensor. diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 57e13fa26e01..db0bd8a8c700 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -31,16 +31,16 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { SortAttrs::RegisterReflection(); ArgsortAttrs::RegisterReflection(); TopKAttrs::RegisterReflection(); -}); +} /* relax.sort */ Expr sort(Expr data, int axis, bool descending) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->descending = std::move(descending); @@ -48,10 +48,10 @@ Expr sort(Expr data, int axis, bool descending) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.sort", sort); -}); +} StructInfo InferStructInfoSort(const Call& call, const BlockBuilder& ctx) { return GetUnaryInputTensorStructInfo(call, ctx); @@ -67,7 +67,7 @@ TVM_REGISTER_OP("relax.sort") /* relax.argsort */ Expr argsort(Expr data, int axis, bool descending, DataType dtype) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->descending = std::move(descending); attrs->dtype = std::move(dtype); @@ -76,10 +76,10 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.argsort", argsort); -}); +} StructInfo InferStructInfoArgsort(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -100,8 +100,8 @@ TVM_REGISTER_OP("relax.argsort") /* relax.topk */ -Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dtype) { - auto attrs = make_object(); +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype) { + auto attrs = ffi::make_object(); attrs->k = std::move(k); attrs->axis = std::move(axis); attrs->ret_type = std::move(ret_type); @@ -112,10 +112,10 @@ Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dt return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.topk", topk); -}); +} StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -124,7 +124,7 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { DataType indices_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype; int ndim = data_sinfo->ndim; int k = attrs->k; - String ret_type = attrs->ret_type; + ffi::String ret_type = attrs->ret_type; int axis = attrs->axis; if (axis < 0 && ndim > 0) { axis += ndim; @@ -137,7 +137,7 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)); output_sinfos.push_back(TensorStructInfo(indices_type, data_sinfo->ndim, data_sinfo->vdevice)); } else { - Array out_shape = data_shape->values; + ffi::Array out_shape = data_shape->values; const auto* int_dim = out_shape[axis].as(); if (k > 0 && (int_dim == nullptr || k < int_dim->value)) { out_shape.Set(axis, k); diff --git a/src/relax/op/tensor/sorting.h b/src/relax/op/tensor/sorting.h index 8a785bc4e2b8..a4154ce416ad 100644 --- a/src/relax/op/tensor/sorting.h +++ b/src/relax/op/tensor/sorting.h @@ -63,7 +63,7 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype); * \param dtype The data type of the indices output. * \return The computed result. */ -Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dtype); +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 700016b223ef..621c23d36310 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -32,10 +32,10 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { StatisticalAttrs::RegisterReflection(); ScanopAttrs::RegisterReflection(); -}); +} StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -69,16 +69,16 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) if (data_shape == nullptr) { if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { return TensorStructInfo( - ShapeExpr(Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), + ShapeExpr(ffi::Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), data_sinfo->dtype, data_sinfo->vdevice); } else { - return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array()), data_sinfo->dtype, + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(ffi::Array()), data_sinfo->dtype, data_sinfo->vdevice) : TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice); } } - Array out_shape; + ffi::Array out_shape; out_shape.reserve(out_ndim); for (int i = 0; i < data_sinfo->ndim; ++i) { if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { @@ -91,9 +91,9 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutStatistical(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutStatistical( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -103,7 +103,7 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; int ndim = tensor_sinfo->ndim; - Array axis; + ffi::Array axis; if (attrs->axis.defined()) { axis = attrs->axis.value(); } else { @@ -131,7 +131,7 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, [](unsigned char c) { return std::isdigit(c); }), new_axis_str.end()); - Array new_axis; + ffi::Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { if (new_axis_str.at(i) == '#') { new_axis.push_back(Integer(i)); @@ -145,7 +145,7 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, output_layout.push_back(output_layout_ref[i]); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = new_axis; return InferLayoutOutput({exisiting_layout}, {attrs->keepdims ? exisiting_layout : Layout(output_layout)}, @@ -168,7 +168,7 @@ StructInfo InferStructInfoScan(const Call& call, const BlockBuilder& ctx) { for (const auto v : data_shape->values) { flattened_d *= v; } - return TensorStructInfo(ShapeExpr(Array({flattened_d})), out_type, + return TensorStructInfo(ShapeExpr(ffi::Array({flattened_d})), out_type, data_sinfo->vdevice); } } @@ -181,8 +181,9 @@ StructInfo InferStructInfoScan(const Call& call, const BlockBuilder& ctx) { } /* relax.cumprod */ -Expr cumprod(Expr data, Optional axis, Optional dtype, Bool exclusive) { - auto attrs = make_object(); +Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional dtype, + Bool exclusive) { + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->dtype = std::move(dtype.value_or(DataType::Void())); attrs->exclusive = std::move(exclusive); @@ -191,10 +192,10 @@ Expr cumprod(Expr data, Optional axis, Optional dtype, Bool e return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.cumprod", cumprod); -}); +} TVM_REGISTER_OP("relax.cumprod") .set_attrs_type() @@ -204,8 +205,8 @@ TVM_REGISTER_OP("relax.cumprod") .set_attr("FPurity", Bool(true)); /* relax.cumsum */ -Expr cumsum(Expr data, Optional axis, Optional dtype, Bool exclusive) { - auto attrs = make_object(); +Expr cumsum(Expr data, ffi::Optional axis, ffi::Optional dtype, Bool exclusive) { + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->dtype = std::move(dtype.value_or(DataType::Void())); attrs->exclusive = std::move(exclusive); @@ -214,10 +215,10 @@ Expr cumsum(Expr data, Optional axis, Optional dtype, Bool ex return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.cumsum", cumsum); -}); +} TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index e79ce1d4aeaa..a80ef728683a 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -43,15 +43,16 @@ namespace relax { * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. */ #define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName) \ - Expr OpName(Expr x, Optional> axis, bool keepdims) { \ - ObjectPtr attrs = make_object(); \ + Expr OpName(Expr x, ffi::Optional> axis, bool keepdims) { \ + ObjectPtr attrs = ffi::make_object(); \ attrs->axis = std::move(axis); \ attrs->keepdims = keepdims; \ static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ } \ - TVM_FFI_STATIC_INIT_BLOCK( \ - { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); }); \ + TVM_FFI_STATIC_INIT_BLOCK() { \ + tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName); \ + } \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ @@ -67,22 +68,22 @@ namespace relax { * reduced are left in the result as dimensions with size one. With this option, the result will * broadcast correctly against the input tensor. \return The result after reduction. */ -Expr max(Expr x, Optional> axis, bool keepdims); +Expr max(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the mean of tensor elements over given axes. */ -Expr mean(Expr x, Optional> axis, bool keepdims); +Expr mean(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the min of tensor elements over given axes. */ -Expr min(Expr x, Optional> axis, bool keepdims); +Expr min(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the product of tensor elements over given axes. */ -Expr prod(Expr x, Optional> axis, bool keepdims); +Expr prod(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the standard deviation of tensor elements over given axes. */ -Expr std(Expr x, Optional> axis, bool keepdims); +Expr std(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the sum of tensor elements over given axes. */ -Expr sum(Expr x, Optional> axis, bool keepdims); +Expr sum(Expr x, ffi::Optional> axis, bool keepdims); /*! * \brief Numpy style cumprod op. Return the cumulative inclusive product of the elements along @@ -97,8 +98,8 @@ Expr sum(Expr x, Optional> axis, bool keepdims); * \return The computed * result. */ -Expr cumprod(Expr data, Optional axis = std::nullopt, - Optional dtype = std::nullopt, Bool exclusive = Bool(false)); +Expr cumprod(Expr data, ffi::Optional axis = std::nullopt, + ffi::Optional dtype = std::nullopt, Bool exclusive = Bool(false)); /*! * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along @@ -112,11 +113,11 @@ Expr cumprod(Expr data, Optional axis = std::nullopt, * which the first element is not included. * \return The computed result. */ -Expr cumsum(Expr data, Optional axis = std::nullopt, - Optional dtype = std::nullopt, Bool exclusive = Bool(false)); +Expr cumsum(Expr data, ffi::Optional axis = std::nullopt, + ffi::Optional dtype = std::nullopt, Bool exclusive = Bool(false)); /*! \brief Computes the variance of tensor elements over given axes. */ -Expr variance(Expr x, Optional> axis, bool keepdims); +Expr variance(Expr x, ffi::Optional> axis, bool keepdims); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index b60344e351d6..a38585cb507a 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -30,7 +30,7 @@ namespace tvm { namespace relax { StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo t1 = input_sinfo[0]; TensorStructInfo t2 = input_sinfo[1]; TensorStructInfo t3 = input_sinfo[2]; @@ -87,7 +87,7 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { auto* s3 = t3->shape.as(); arith::Analyzer* analyzer = ctx->GetAnalyzer(); if (s1 && s2 && s3) { - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < ndim; ++i) { PrimExpr dim1 = s1->values[i]; PrimExpr dim2 = s2->values[i]; @@ -115,9 +115,9 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(output_dtype, ndim); } -InferLayoutOutput InferLayoutEwiseFMA(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutEwiseFMA( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout0 = GetLayoutDecision(var_layout_map, call->args[0]); @@ -145,10 +145,10 @@ Expr ewise_fma(Expr x1, Expr x2, Expr x3) { return Call(op, {x1, x2, x3}, Attrs(), {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.ewise_fma", ewise_fma); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index ac7b995ff122..50f5ce2bf35f 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -87,10 +87,10 @@ Expr clip(Expr x, Expr min, Expr max) { return Call(op, {std::move(x), std::move(min), std::move(max)}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.op.clip", clip); -}); +} /***************** Check operators *****************/ diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h index 6984ba6304eb..1847ba3c365a 100644 --- a/src/relax/op/tensor/unary.h +++ b/src/relax/op/tensor/unary.h @@ -38,7 +38,7 @@ namespace relax { * (Only for unary arith operators since all check operators don't require float dtype.) */ #define RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName) \ - RELAX_UNARY_OP_INTERFACE(OpName, #OpName); \ + RELAX_UNARY_OP_INTERFACE(OpName, #OpName) \ RELAX_REGISTER_UNARY_OP(#OpName) #define RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(OpName, RequireFloatDtype) \ diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc new file mode 100644 index 000000000000..2a1ad8f40aa4 --- /dev/null +++ b/src/relax/op/vision/nms.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "nms.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK() { AllClassNonMaximumSuppressionAttrs::RegisterReflection(); } + +/* relax.vision.all_class_non_max_suppression */ + +Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr max_output_boxes_per_class, + Expr iou_threshold, Expr score_threshold, + ffi::String output_format) { + auto attrs = tvm::ffi::make_object(); + attrs->output_format = output_format; + + static const Op& op = Op::Get("relax.vision.all_class_non_max_suppression"); + return Call(op, + {std::move(boxes), std::move(scores), std::move(max_output_boxes_per_class), + std::move(iou_threshold), std::move(score_threshold)}, + Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.all_class_non_max_suppression", + all_class_non_max_suppression); +} + +StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& ctx) { + tvm::ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto boxes_sinfo = input_sinfo[0]; + const auto scores_sinfo = input_sinfo[1]; + ICHECK(!boxes_sinfo->IsUnknownNdim()) << "Only support known ndim"; + ICHECK(!scores_sinfo->IsUnknownNdim()) << "Only support known ndim"; + ICHECK_EQ(boxes_sinfo->ndim, 3) << "AllClassNMS input boxes should be 3-D."; + ICHECK_EQ(scores_sinfo->ndim, 3) << "AllClassNMS input scores count should be 3-D."; + + const auto batch = boxes_sinfo->shape.as()->values[0]; + const auto num_classes = scores_sinfo->shape.as()->values[1]; + const auto num_boxes = boxes_sinfo->shape.as()->values[1]; + + auto vdev = input_sinfo[0]->vdevice; + const auto* attrs = call->attrs.as(); + if (attrs->output_format == "onnx") { + auto vdev = input_sinfo[0]->vdevice; + auto num_total_boxes = batch * num_classes * num_boxes; + tvm::ffi::Array oshape_values = {num_total_boxes, 3}; + ShapeExpr oshape(oshape_values); + tvm::ffi::Array counts_values = {1}; + ShapeExpr counts_shape(counts_values); + tvm::ffi::Array fields = {TensorStructInfo(oshape, DataType::Int(64), vdev), + TensorStructInfo(counts_shape, DataType::Int(64), vdev)}; + return TupleStructInfo(fields); + } + + auto num_total_boxes_per_batch = num_classes * num_boxes; + tvm::ffi::Array indices_values = {batch, num_total_boxes_per_batch, 2}; + ShapeExpr indices_shape(indices_values); + tvm::ffi::Array scores_values = {batch, num_total_boxes_per_batch}; + ShapeExpr scores_shape(scores_values); + tvm::ffi::Array counts_values = {batch}; + ShapeExpr counts_shape(counts_values); + tvm::ffi::Array fields = {TensorStructInfo(indices_shape, DataType::Int(64), vdev), + TensorStructInfo(scores_shape, DataType::Float(32), vdev), + TensorStructInfo(counts_shape, DataType::Int(64), vdev)}; + return TupleStructInfo(fields); +} + +TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression") + .set_attrs_type() + .set_num_inputs(5) + .add_argument("boxes", "Tensor", "The input boxes in the format [batch, num_boxes, 4].") + .add_argument("scores", "Tensor", + "Scores for each box and class in the format [batch, num_classes, num_boxes].") + .add_argument("max_output_boxes_per_class", "Tensor", + "The maximum number of output boxes per class.") + .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.") + .add_argument("score_threshold", "Tensor", + "The score threshold to filter out low score boxes early.") + .set_attr("FInferStructInfo", InferStructInfoAllClassNMS) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/ffi/tests/cpp/test_c_ffi_abi.cc b/src/relax/op/vision/nms.h similarity index 50% rename from ffi/tests/cpp/test_c_ffi_abi.cc rename to src/relax/op/vision/nms.h index 1efceef2971a..c86bf98c94d5 100644 --- a/ffi/tests/cpp/test_c_ffi_abi.cc +++ b/src/relax/op/vision/nms.h @@ -7,7 +7,7 @@ * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an @@ -16,16 +16,29 @@ * specific language governing permissions and limitations * under the License. */ -#include -#include +/*! + * \file nms.h + * \brief The functions to make Relax Non-maximum suppression operator calls. + */ + +#ifndef TVM_RELAX_OP_VISION_NMS_H_ +#define TVM_RELAX_OP_VISION_NMS_H_ + +#include +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { -namespace { +/*! \brief Compute All Class NonMaximumSuppression. */ +Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr max_output_boxes_per_class, + Expr iou_threshold, Expr score_threshold, + ffi::String output_format); -TEST(ABIHeaderAlignment, Default) { - TVMFFIObject value; - value.type_index = 10; - EXPECT_EQ(reinterpret_cast(&value)->type_index, 10); - static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 16 bytes"); -} +} // namespace relax +} // namespace tvm -} // namespace +#endif // TVM_RELAX_OP_VISION_NMS_H_ diff --git a/src/relax/testing/transform.cc b/src/relax/testing/transform.cc index 67660f665178..c8d7078258a8 100644 --- a/src/relax/testing/transform.cc +++ b/src/relax/testing/transform.cc @@ -36,10 +36,10 @@ tvm::transform::Pass ApplyEmptyCppMutator() { "relax.testing.ApplyEmptyCppMutator", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.testing.transform.ApplyEmptyCppMutator", ApplyEmptyCppMutator); -}); +} } // namespace testing } // namespace relax diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 49e5b862e900..26290775fe64 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -39,13 +39,13 @@ namespace relax { /*! \brief Append the loss function to the backbone function in an IRModule.*/ class AppendLossMutator : private ExprMutator { public: - static IRModule Transform(IRModule mod, String func_name, Function loss_function, - int num_backbone_outputs, Optional new_func_name) { + static IRModule Transform(IRModule mod, ffi::String func_name, Function loss_function, + int num_backbone_outputs, ffi::Optional new_func_name) { auto* old_func = mod->Lookup(func_name).as(); CHECK(old_func) << func_name << "is not a Relax Function"; // functions should be copied to satisfy the well-formed check - Function new_func = CopyWithNewVars(GetRef(old_func)); + Function new_func = CopyWithNewVars(ffi::GetRef(old_func)); Function new_loss_func = CopyWithNewVars(loss_function); AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs); @@ -53,7 +53,7 @@ class AppendLossMutator : private ExprMutator { WithAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol, new_func_name.value_or(func_name + "_loss")); - auto new_module = GetRef(mod.CopyOnWrite()); + auto new_module = ffi::GetRef(mod.CopyOnWrite()); auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss")); new_module->Add(new_var, new_func_transformed); return new_module; @@ -73,7 +73,7 @@ class AppendLossMutator : private ExprMutator { CheckAndRemapBackboneReturn(); CheckAndRemapLossParams(loss_function_->params); - Array new_params = func->params; + ffi::Array new_params = func->params; new_params.insert(new_params.end(), loss_function_->params.begin() + num_backbone_outputs_, loss_function_->params.end()); Expr new_body = this->VisitExpr(func->body); @@ -85,8 +85,8 @@ class AppendLossMutator : private ExprMutator { CHECK(seq_expr->blocks.size() == 1 && seq_expr->blocks[0]->IsInstance()) << "Backbone should have only one DataflowBlock"; - auto new_blocks = Array({this->VisitBindingBlock(seq_expr->blocks[0])}); - auto ret = Array({loss_body_->body}); + auto new_blocks = ffi::Array({this->VisitBindingBlock(seq_expr->blocks[0])}); + auto ret = ffi::Array({loss_body_->body}); ret.insert(ret.end(), backbone_return_arr_.begin() + num_backbone_outputs_, backbone_return_arr_.end()); return SeqExpr(new_blocks, ret.size() == 1 ? ret[0] : Tuple(ret)); @@ -118,22 +118,22 @@ class AppendLossMutator : private ExprMutator { CHECK(loss_body_->blocks.size() == 1 && loss_body_->blocks[0]->IsInstance()) << "The loss function should have only one DataflowBlock"; auto var_node = loss_body_->body.as(); - CHECK(var_node && IsScalarTensor(GetRef(var_node))) + CHECK(var_node && IsScalarTensor(ffi::GetRef(var_node))) << "The loss function must return a scalar(0-dim Tensor) Var"; } /*! - * \brief Convert the return value of the backbone to Array. The backbone should return one - * or a tuple of Vars. + * \brief Convert the return value of the backbone to ffi::Array. The backbone should return + * one or a tuple of Vars. */ void BackboneReturnToArr(const Expr& backbone_return) { if (auto* var = backbone_return.as()) { - backbone_return_arr_.push_back(GetRef(var)); + backbone_return_arr_.push_back(ffi::GetRef(var)); } else if (auto* tuple = backbone_return.as()) { for (auto i : tuple->fields) { auto var = i.as(); CHECK(var) << "The return value of the backbone should be either a Var or a Tuple of Vars"; - backbone_return_arr_.push_back(GetRef(var)); + backbone_return_arr_.push_back(ffi::GetRef(var)); } } else { LOG(FATAL) << "The return value of the backbone should be either a Var or a Tuple of Vars"; @@ -145,7 +145,7 @@ class AppendLossMutator : private ExprMutator { * and the elements in backbone_return_arr_ and loss_func_params have matched struct_info. Also * sets up var_remap_ from loss parameter Vars to backbone returned Vars. */ - void CheckAndRemapLossParams(const Array& loss_func_params) { + void CheckAndRemapLossParams(const ffi::Array& loss_func_params) { static StructuralEqual checker; CHECK(static_cast(loss_func_params.size()) >= num_backbone_outputs_) << "The number of parameters of the loss function is " << loss_func_params.size() @@ -199,13 +199,13 @@ class AppendLossMutator : private ExprMutator { /*! \brief The body of the loss function */ SeqExpr loss_body_; /*! \brief The unpacked return values of the backbone. All return values should be Vars. */ - Array backbone_return_arr_; + ffi::Array backbone_return_arr_; }; namespace transform { -Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outputs, - Optional new_func_name) { +Pass AppendLoss(ffi::String func_name, Function loss_function, int num_backbone_outputs, + ffi::Optional new_func_name) { auto pass_func = [=](IRModule mod, PassContext pc) { return relax::AppendLossMutator::Transform(mod, func_name, loss_function, num_backbone_outputs, new_func_name); @@ -216,10 +216,10 @@ Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outpu /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.training.AppendLoss", AppendLoss); -}); +} } // namespace transform diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h index 1bfb20da3521..c22588804d08 100644 --- a/src/relax/training/utils.h +++ b/src/relax/training/utils.h @@ -50,8 +50,8 @@ namespace transform { * will be `func_name + "_loss"`. * \return The Pass. */ -TVM_DLL Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outputs = 1, - Optional new_func_name = std::nullopt); +TVM_DLL Pass AppendLoss(ffi::String func_name, Function loss_function, int num_backbone_outputs = 1, + ffi::Optional new_func_name = std::nullopt); } // namespace transform } // namespace relax diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 55ca86c306eb..889272019174 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -40,7 +40,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns( +std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); @@ -73,24 +73,42 @@ std::tuple)>> Crea pat_permuted_matmul_on_rhs; PrimExpr symbolic_var_constraints = Bool(true); - if (auto upper_bounds = func->GetAttr>("tir_var_upper_bound")) { - Map name_lookup; + auto upper_bounds = func->GetAttr>("tir_var_upper_bound"); + auto lower_bounds = func->GetAttr>("tir_var_lower_bound"); + + if (upper_bounds || lower_bounds) { + ffi::Map name_lookup; for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) { name_lookup.Set(tir_var->name_hint, tir_var); symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var); } - for (const auto& [key, obj_bound] : upper_bounds.value()) { - auto tir_var_name = Downcast(key); - if (auto opt_var = name_lookup.Get(tir_var_name)) { - auto var = opt_var.value(); - auto expr_bound = Downcast(obj_bound); - symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound); + // Add lower bound constraints + if (lower_bounds) { + for (const auto& [key, obj_bound] : lower_bounds.value()) { + auto tir_var_name = Downcast(key); + if (auto opt_var = name_lookup.Get(tir_var_name)) { + auto var = opt_var.value(); + auto expr_bound = Downcast(obj_bound); + symbolic_var_constraints = symbolic_var_constraints && (expr_bound <= var); + } + } + } + + // Add upper bound constraints + if (upper_bounds) { + for (const auto& [key, obj_bound] : upper_bounds.value()) { + auto tir_var_name = Downcast(key); + if (auto opt_var = name_lookup.Get(tir_var_name)) { + auto var = opt_var.value(); + auto expr_bound = Downcast(obj_bound); + symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound); + } } } } - auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { auto expr_a = matches[pat_a]; auto expr_b = matches[pat_b]; auto expr_c = matches[pat_c]; @@ -102,7 +120,7 @@ std::tuple)>> Crea return expr; } - auto get_shape = [](Expr expr) -> Optional> { + auto get_shape = [](Expr expr) -> ffi::Optional> { auto sinfo = expr->struct_info_.as(); if (sinfo) { return sinfo->GetShape(); @@ -214,10 +232,10 @@ Pass AdjustMatmulOrder() { return CreateFunctionPass(pass_func, 1, "AdjustMatmulOrder", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.AdjustMatmulOrder", AdjustMatmulOrder); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index d0b462bb1e5b..4e71e0c3eb43 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -52,18 +52,18 @@ class ExternFunctionRewriter : ExprMutator { } Expr VisitExpr_(const FunctionNode* func_node) override { - if (!func_node->GetAttr(attr::kCodegen) && - !func_node->GetAttr(attr::kComposite)) { + if (!func_node->GetAttr(attr::kCodegen) && + !func_node->GetAttr(attr::kComposite)) { return ExprMutator::VisitExpr_(func_node); } if (auto workspace = func_node->GetAttr(attr::kWorkspaceSize)) { // Append the workspace parameter to this function. - Array new_params = func_node->params; + ffi::Array new_params = func_node->params; auto sinfo = TensorStructInfo(ShapeExpr({Integer(max_workspace_size_)}), DataType::UInt(8)); Var workspace_param(name_sup_->FreshName("workspace"), sinfo); - if (func_node->GetAttr(attr::kCodegen)) { + if (func_node->GetAttr(attr::kCodegen)) { workspace_var_param_ = workspace_param; } @@ -81,7 +81,7 @@ class ExternFunctionRewriter : ExprMutator { if (auto var = new_op.as()) { if (auto callee = builder_->LookupBinding(var.value()); callee && callee->IsInstance() && - Downcast(callee.value())->GetAttr(attr::kComposite)) { + Downcast(callee.value())->GetAttr(attr::kComposite)) { // Append the workspace argument to this call. The callee should have been updated to accept // a workspace as the last parameter. auto new_args = call_node->args; @@ -127,13 +127,13 @@ class WorkspaceProvider : ExprMutator { WithAttr(f, tvm::attr::kGlobalSymbol, new_gvar->name_hint)); gvar_map_[gvar] = new_gvar; new_gvars_.insert(new_gvar); - builder_->GetContextIRModule()->Remove(GetRef(gvar)); + builder_->GetContextIRModule()->Remove(ffi::GetRef(gvar)); } for (const auto& [gvar, f] : mod_->functions) { workspace_var_main_ = Var(); - if (!f->IsInstance() || f->GetAttr(attr::kCodegen) || - f->GetAttr(attr::kComposite)) { + if (!f->IsInstance() || f->GetAttr(attr::kCodegen) || + f->GetAttr(attr::kComposite)) { continue; } auto func = Downcast(mod_->Lookup(gvar)); @@ -202,10 +202,10 @@ Pass AllocateWorkspace() { return CreateModulePass(pass_func, 0, "AllocateWorkspace", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.AllocateWorkspace", AllocateWorkspace); -}); +} } // namespace transform } // namespace tvm diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 4013d3aad17e..a612ef83bde0 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -43,17 +43,17 @@ using namespace tir; static constexpr const char* kOperatorName = "operator_name"; /*! \brief Construct ranges from shape dimensions */ -static Array ConstructRangeFromShape(const Array& shape) { +static ffi::Array ConstructRangeFromShape(const ffi::Array& shape) { return shape.Map([](const PrimExpr& dim) { return Range(tir::make_zero(dim.dtype()), dim); }); } -static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { +static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); ICHECK(shape.defined()); return shape.value(); } -static Array GetShapeFromTensor(const Expr& expr) { +static ffi::Array GetShapeFromTensor(const Expr& expr) { const auto& tensor_sinfo = Downcast(expr->struct_info_); return GetShapeFromTensorStructInfo(tensor_sinfo); } @@ -64,8 +64,8 @@ static IndexMap DeepCopyIndexMap(const IndexMap& index_map) { /*! \brief Checks if the \p transform is bijective on the shape of \p expr */ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { - Array input_shape = GetShapeFromTensor(expr); - Array initial_ranges = ConstructRangeFromShape(input_shape); + ffi::Array input_shape = GetShapeFromTensor(expr); + ffi::Array initial_ranges = ConstructRangeFromShape(input_shape); arith::Analyzer analyzer; auto [inverse, padding_predicate] = transform.NonSurjectiveInverse(initial_ranges, &analyzer); (void)inverse; // to avoid unused variable warning; @@ -80,10 +80,12 @@ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { */ class AlterOpImplMutator : public ExprMutator { public: - AlterOpImplMutator(const IRModule& mod, const Map& op_impl_map, - const Map>& op_buffer_transforms_, - const Map>>>& axis_separators_, - const Map>>>& input_axis_separators_) + AlterOpImplMutator( + const IRModule& mod, const ffi::Map& op_impl_map, + const ffi::Map>& op_buffer_transforms_, + const ffi::Map>>>& axis_separators_, + const ffi::Map>>>& + input_axis_separators_) : ExprMutator(mod), mod_(mod), op_impl_map_(op_impl_map), @@ -119,7 +121,7 @@ class AlterOpImplMutator : public ExprMutator { ICHECK(call->args[0]->IsInstance()); const tir::PrimFunc& old_func = Downcast(mod_->Lookup(Downcast(call->args[0]))); - Optional maybe_op_kind = old_func->attrs.GetAttr(kOperatorName); + ffi::Optional maybe_op_kind = old_func->attrs.GetAttr(kOperatorName); // If the callee does not have kOperatorName attribute or no replacement is requested for // it, nothing to do here. @@ -128,9 +130,9 @@ class AlterOpImplMutator : public ExprMutator { const auto& replacement_func = op_impl_map_[op_kind]; - Array buffer_transforms; - Optional>> axis_separators; - Optional>> input_axis_separators; + ffi::Array buffer_transforms; + ffi::Optional>> axis_separators; + ffi::Optional>> input_axis_separators; if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind]; if (op_buffer_axis_separators__.count(op_kind)) axis_separators = op_buffer_axis_separators__[op_kind]; @@ -145,7 +147,7 @@ class AlterOpImplMutator : public ExprMutator { GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind); - auto call_tir_inputs_tuple = GetRef(call->args[1].as()); + auto call_tir_inputs_tuple = ffi::GetRef(call->args[1].as()); Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators, input_axis_separators); @@ -159,18 +161,18 @@ class AlterOpImplMutator : public ExprMutator { input_axis_separators); } - Array GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) { + ffi::Array GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) { if (const auto* tensor_sinfo = output_sinfo.as()) - return {GetRef(tensor_sinfo)}; + return {ffi::GetRef(tensor_sinfo)}; const auto* tuple_sinfo = output_sinfo.as(); ICHECK(tuple_sinfo); - Array arr_tensor_sinfo; + ffi::Array arr_tensor_sinfo; arr_tensor_sinfo.reserve(tuple_sinfo->fields.size()); for (const auto& sinfo : tuple_sinfo->fields) { const auto* tensor_sinfo = sinfo.as(); ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not supported yet"; - arr_tensor_sinfo.push_back(GetRef(tensor_sinfo)); + arr_tensor_sinfo.push_back(ffi::GetRef(tensor_sinfo)); } return arr_tensor_sinfo; } @@ -183,16 +185,16 @@ class AlterOpImplMutator : public ExprMutator { } Expr TransformLayout(const Expr& expr, const IndexMap& index_map, - const Array& axis_separators, - const Array& input_axis_separators) { + const ffi::Array& axis_separators, + const ffi::Array& input_axis_separators) { if (IsScalarConstant(expr) || index_map.get() == nullptr) { return expr; } - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); // We want to avoid two layout_transform ops to share the same index map even if they are // identical. The scope of vars used in index map initial indices is local to the op. Not doing // so would confuse the structural equality check. - attrs->index_map = std::move(DeepCopyIndexMap(index_map)); + attrs->index_map = DeepCopyIndexMap(index_map); attrs->axis_separators = std::move(axis_separators); attrs->input_axis_separators = std::move(input_axis_separators); return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); @@ -202,13 +204,13 @@ class AlterOpImplMutator : public ExprMutator { * \brief Adds the \p remove_pad op to the module if it has not already been added before. * \returns The global var associated with the remove_pad PrimFunc. */ - GlobalVar GetOrCreateRemovePadOp(const Array& old_shape, const DataType& dtype) { + GlobalVar GetOrCreateRemovePadOp(const ffi::Array& old_shape, const DataType& dtype) { int t_shape = old_shape.size(); if (remove_pad_map_.count(t_shape) != 0) { return remove_pad_map_[t_shape]; } // Create dynamic shapes for input and output tensors - Array dyn_padded_shape, dyn_old_shape; + ffi::Array dyn_padded_shape, dyn_old_shape; for (int i = 0; i < t_shape; i++) { tir::Var var1("p" + std::to_string(i), old_shape[i].dtype()); tir::Var var2("i" + std::to_string(i), old_shape[i].dtype()); @@ -221,12 +223,12 @@ class AlterOpImplMutator : public ExprMutator { // Output tensor of remove_pad op te::Tensor output_tensor = te::compute( dyn_old_shape, - [&placeholder_tensor](const Array& indices) { + [&placeholder_tensor](const ffi::Array& indices) { return placeholder_tensor(indices); }, "output", topi::kElementWise); - String op_name = "remove_pad"; + ffi::String op_name = "remove_pad"; // Create PrimFunc and add op_name to func.attrs PrimFunc remove_pad_with_frozen_layout = WithAttr(CreatePrimFunc({placeholder_tensor, output_tensor}), kOperatorName, op_name); @@ -242,13 +244,13 @@ class AlterOpImplMutator : public ExprMutator { Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map, const TensorStructInfo& old_tensor_sinfo, - const Array& axis_separator, - const Array& input_axis_separator) { + const ffi::Array& axis_separator, + const ffi::Array& input_axis_separator) { if (IsScalarConstant(expr) || index_map.get() == nullptr) { return expr; } - Array old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); - Array initial_ranges = ConstructRangeFromShape(old_shape); + ffi::Array old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); + ffi::Array initial_ranges = ConstructRangeFromShape(old_shape); arith::Analyzer analyzer; auto [inverse_index_map, padding_predicate] = index_map.NonSurjectiveInverse(initial_ranges, &analyzer); @@ -269,7 +271,8 @@ class AlterOpImplMutator : public ExprMutator { * \brief Adds the \p replacement_func to the module if it has not already been added before. * \returns The global var associated with the PrimFunc. */ - GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func, const String& op_kind) { + GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func, + const ffi::String& op_kind) { if (cache_.count(replacement_func) != 0) { return cache_[replacement_func]; } @@ -287,22 +290,22 @@ class AlterOpImplMutator : public ExprMutator { /*! * \brief Updates call inputs with layout transformed inputs */ - Tuple UpdateInputs(const Tuple& inputs, const Array& transforms, - const Optional>>& axis_separators, - const Optional>>& input_axis_separators) { + Tuple UpdateInputs(const Tuple& inputs, const ffi::Array& transforms, + const ffi::Optional>>& axis_separators, + const ffi::Optional>>& input_axis_separators) { if (transforms.empty()) return inputs; - Array updated_inputs; + ffi::Array updated_inputs; int index = 0; for (const auto& input : inputs->fields) { - Array axis_separator; - Array input_axis_separator; + ffi::Array axis_separator; + ffi::Array input_axis_separator; if (axis_separators.defined()) { - Array> axis_separators_value = axis_separators.value(); + ffi::Array> axis_separators_value = axis_separators.value(); axis_separator = axis_separators_value[index]; } if (input_axis_separators.defined()) { - Array> input_axis_separators_value = input_axis_separators.value(); + ffi::Array> input_axis_separators_value = input_axis_separators.value(); input_axis_separator = input_axis_separators_value[index]; } auto transform = transforms[index++]; @@ -314,7 +317,7 @@ class AlterOpImplMutator : public ExprMutator { /*! \brief Updates output struct info */ StructInfo UpdateStructInfo(const StructInfo& out_sinfo, - const Array& buffer_transforms) { + const ffi::Array& buffer_transforms) { if (buffer_transforms.empty()) return out_sinfo; if (out_sinfo->IsInstance()) @@ -327,7 +330,7 @@ class AlterOpImplMutator : public ExprMutator { << out_sinfo; const auto& tuple_sinfo = Downcast(out_sinfo); - Array sinfo_fields; + ffi::Array sinfo_fields; size_t first_output_index = buffer_transforms.size() - tuple_sinfo->fields.size(); size_t i = 0; for (const auto& si : tuple_sinfo->fields) { @@ -354,15 +357,16 @@ class AlterOpImplMutator : public ExprMutator { return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype); } - Expr TransformOutputs(const Expr& expr, const Array& buffer_transforms, - const StructInfo& old_struct_info, - const Optional>>& axis_separators, - const Optional>>& input_axis_separators) { + Expr TransformOutputs( + const Expr& expr, const ffi::Array& buffer_transforms, + const StructInfo& old_struct_info, + const ffi::Optional>>& axis_separators, + const ffi::Optional>>& input_axis_separators) { if (buffer_transforms.empty()) return expr; - Array old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info); + ffi::Array old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info); - Array axis_sep, input_axis_sep; + ffi::Array axis_sep, input_axis_sep; size_t num_outputs = old_output_sinfo.size(); if (num_outputs == 0) return expr; @@ -371,11 +375,11 @@ class AlterOpImplMutator : public ExprMutator { if (num_outputs == 1) { IndexMap output_map = buffer_transforms[first_output_index]; if (axis_separators.defined()) { - Array> axis_separators_value = axis_separators.value(); + ffi::Array> axis_separators_value = axis_separators.value(); axis_sep = axis_separators_value[first_output_index]; } if (input_axis_separators.defined()) { - Array> input_axis_separators_value = input_axis_separators.value(); + ffi::Array> input_axis_separators_value = input_axis_separators.value(); input_axis_sep = input_axis_separators_value[first_output_index]; } return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep, @@ -384,15 +388,15 @@ class AlterOpImplMutator : public ExprMutator { // In case of more than one output, we would have to get each item of the output tuple, // transform it and return a tuple of all transformed outputs. - Array transformed_outputs; + ffi::Array transformed_outputs; for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i) { const auto& output_map = buffer_transforms[i + first_output_index]; if (axis_separators.defined()) { - Array> axis_separators_value = axis_separators.value(); + ffi::Array> axis_separators_value = axis_separators.value(); axis_sep = axis_separators_value[i + first_output_index]; } if (input_axis_separators.defined()) { - Array> input_axis_separators_value = input_axis_separators.value(); + ffi::Array> input_axis_separators_value = input_axis_separators.value(); input_axis_sep = input_axis_separators_value[i + first_output_index]; } auto output = builder_->Normalize(TupleGetItem(expr, static_cast(i))); @@ -404,19 +408,21 @@ class AlterOpImplMutator : public ExprMutator { private: /*! \brief Cache to keep track of the GlobalVar associated with the new PrimFunc added */ - Map cache_; + ffi::Map cache_; /*! \brief Input IRModule */ const IRModule& mod_; /*! \brief Map from shape_dim.size to the remove_pad GlobalVar */ std::unordered_map remove_pad_map_; /*! \brief Map from kOperatorName attribute to the replacement PrimFunc */ - const Map& op_impl_map_; + const ffi::Map& op_impl_map_; /*! \brief Map from kOperatorName attribute to the layout transforms on i/o buffers */ - const Map>& op_buffer_transforms__; + const ffi::Map>& op_buffer_transforms__; /*! \brief Map from kOperatorName attribute to the axis separatos on i/o buffers */ - const Map>>>& op_buffer_axis_separators__; + const ffi::Map>>>& + op_buffer_axis_separators__; /*! \brief Map from kOperatorName attribute to the input axis separatos */ - const Map>>>& op_buffer_input_axis_separators__; + const ffi::Map>>>& + op_buffer_input_axis_separators__; const Op& call_tir_op_ = Op::Get("relax.call_tir"); const Op& layout_transform_op_ = Op::Get("relax.layout_transform"); @@ -424,10 +430,12 @@ class AlterOpImplMutator : public ExprMutator { namespace transform { -Pass AlterOpImpl(const Map& op_impl_map, - const Map>& op_buffer_transforms_, - const Map>>>& axis_separators_, - const Map>>>& input_axis_separators_) { +Pass AlterOpImpl( + const ffi::Map& op_impl_map, + const ffi::Map>& op_buffer_transforms_, + const ffi::Map>>>& axis_separators_, + const ffi::Map>>>& + input_axis_separators_) { auto pass_func = [=](IRModule mod, PassContext pc) { return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_, input_axis_separators_) @@ -439,10 +447,10 @@ Pass AlterOpImpl(const Map& op_impl_map, /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.AlterOpImpl", AlterOpImpl); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc index 58f22eb47ad4..f5b1061b6708 100644 --- a/src/relax/transform/annotate_tir_op_pattern.cc +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -48,10 +48,10 @@ Pass AnnotateTIROpPattern() { return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.AnnotateTIROpPattern", AnnotateTIROpPattern); -}); +} } // namespace transform diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc index a7c8013a56fd..064ff015eedf 100644 --- a/src/relax/transform/attach_attr_layout_free_buffers.cc +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -70,9 +70,9 @@ class AttrAttacher : public ExprMutator { return call; } GlobalVar gv = Downcast(call->args[0]); - Array call_tir_args = Downcast(call->args[1])->fields; + ffi::Array call_tir_args = Downcast(call->args[1])->fields; // Compute the layout free buffers - Array layout_free_buffers; + ffi::Array layout_free_buffers; for (size_t i = 0; i < call_tir_args.size(); i++) { if (layout_free_exprs_.count(call_tir_args[i].get())) { layout_free_buffers.push_back(i); @@ -88,7 +88,7 @@ class AttrAttacher : public ExprMutator { // So we don't need to worry about the duplicate insertion GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint); // Create a new call node with the updated tir::PrimFunc - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->args = {new_gv, Tuple(call_tir_args)}; return Call(n); } @@ -106,10 +106,10 @@ Pass AttachAttrLayoutFreeBuffers() { return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.AttachAttrLayoutFreeBuffers", AttachAttrLayoutFreeBuffers); -}); +} } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 9ef135608dc4..0079b504989a 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -34,25 +34,26 @@ namespace transform { Pass AttachGlobalSymbol() { auto pass_func = [=](IRModule mod, PassContext pc) { - String c_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix).value_or(""); + ffi::String c_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix).value_or(""); IRModule updates; - Map gvar_updates; + ffi::Map gvar_updates; for (const auto& [gvar, func] : mod->functions) { - Optional old_name = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional old_name = func->GetAttr(tvm::attr::kGlobalSymbol); // TODO(tvm-team): re-enable once fix relax integration part // if (old_name) continue; - Optional new_name; + ffi::Optional new_name; BaseFunc new_func; if (auto* prim_func = func.as()) { new_name = c_prefix + gvar->name_hint; - new_func = WithAttr(GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); + new_func = + WithAttr(ffi::GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); } else if (auto* relax_func = func.as()) { new_name = gvar->name_hint; - new_func = WithAttr(GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); + new_func = WithAttr(ffi::GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); } if (new_name.has_value() && (!old_name.has_value() || old_name.value() != new_name.value())) { @@ -80,10 +81,10 @@ Pass AttachGlobalSymbol() { return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.AttachGlobalSymbol", AttachGlobalSymbol); -}); +} } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 13b138ecce55..4ad9b3ab5051 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -32,7 +32,7 @@ namespace tvm { namespace relax { void MatchSymbolicVar(const Expr& arg, const Expr& constant, - Map* symbolic_var_map, arith::Analyzer* analyzer_) { + ffi::Map* symbolic_var_map, arith::Analyzer* analyzer_) { auto opt_arg_sinfo = MatchStructInfo(arg); CHECK(opt_arg_sinfo) << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " @@ -70,9 +70,9 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, const PrimExpr& const_dim = const_shape->values[i]; ICHECK(tir::is_const_int(const_dim)); if (const auto* shape_var = arg_shape->values[i].as()) { - auto it = symbolic_var_map->find(GetRef(shape_var)); + auto it = symbolic_var_map->find(ffi::GetRef(shape_var)); if (it == symbolic_var_map->end()) { - symbolic_var_map->Set(GetRef(shape_var), const_dim); + symbolic_var_map->Set(ffi::GetRef(shape_var), const_dim); } else { CHECK(analyzer_->CanProveEqual((*it).second, const_dim)) << "The shape of the bound parameter is expected to be " << (*it).second @@ -82,23 +82,23 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, } } -std::tuple, Map> NormalizeBindings( - const Function& func, const Map& untyped_params) { +std::tuple, ffi::Map> NormalizeBindings( + const Function& func, const ffi::Map& untyped_params) { ICHECK(func.defined()); ICHECK(untyped_params.defined()); // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; + std::unordered_map> string_lookup; std::unordered_set var_set; for (const auto& param : func->params) { string_lookup[param->name_hint()].push_back(param); var_set.insert(param.get()); } - Map relax_var_remap; + ffi::Map relax_var_remap; auto normalize_key = [&](ffi::Any obj) -> relax::Var { - if (auto opt_str = obj.as()) { + if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); CHECK(it != string_lookup.end()) @@ -131,7 +131,7 @@ std::tuple, Map> NormalizeBindings( auto normalize_value = [&](ffi::Any obj) -> relax::Expr { if (auto opt = obj.as()) { return opt.value(); - } else if (auto opt = obj.as()) { + } else if (auto opt = obj.as()) { return Constant(opt.value()); } else { LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; @@ -143,7 +143,7 @@ std::tuple, Map> NormalizeBindings( } arith::Analyzer analyzer; - Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); // for (const auto& [bind_param, bind_expr] : relax_var_remap) { // MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer); @@ -158,7 +158,7 @@ std::tuple, Map> NormalizeBindings( * \param params params dict * \return Function */ -Function FunctionBindParams(Function func, const Map& untyped_params) { +Function FunctionBindParams(Function func, const ffi::Map& untyped_params) { auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params); Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); @@ -172,48 +172,49 @@ Function FunctionBindParams(Function func, const Map& untyped_pa * \param param The param dict * \return The module after binding params. */ -IRModule BindParam(IRModule m, String func_name, Map bind_params) { +IRModule BindParam(IRModule m, ffi::String func_name, ffi::Map bind_params) { IRModuleNode* new_module = m.CopyOnWrite(); - Map functions = m->functions; + ffi::Map functions = m->functions; for (const auto& func_pr : functions) { if (const auto* relax_f = func_pr.second.as()) { if (relax_f->GetLinkageType() == LinkageType::kExternal) { // Use global_symbol if it's external linkage - Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = + relax_f->GetAttr(tvm::attr::kGlobalSymbol); if (gsymbol.has_value() && gsymbol.value() == func_name) { - Function f_after_bind = FunctionBindParams(GetRef(relax_f), bind_params); + Function f_after_bind = FunctionBindParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } else { // Use global var's name_hint if it's internal linkage if (func_pr.first->name_hint == func_name) { - Function f_after_bind = FunctionBindParams(GetRef(relax_f), bind_params); + Function f_after_bind = FunctionBindParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } } } - return GetRef(new_module); + return ffi::GetRef(new_module); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.FunctionBindParams", FunctionBindParams); -}); +} namespace transform { -Pass BindParams(String func_name, Map params) { +Pass BindParams(ffi::String func_name, ffi::Map params) { auto pass_func = [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; return CreateModulePass(pass_func, 0, "BindParams", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.BindParams", BindParams); -}); +} } // namespace transform diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 5ba25b7e16e1..04a4b0819cda 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -31,17 +31,17 @@ namespace tvm { namespace relax { -Function FunctionBindSymbolicVars(Function func, - Map, PrimExpr> obj_remap) { +Function FunctionBindSymbolicVars( + Function func, ffi::Map, PrimExpr> obj_remap) { // Early bail-out if no updates need to be made. if (obj_remap.empty()) { return func; } - Array old_symbolic_vars = DefinedSymbolicVars(func); + ffi::Array old_symbolic_vars = DefinedSymbolicVars(func); // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; + std::unordered_map> string_lookup; std::unordered_set symbolic_var_set; for (const auto& var : old_symbolic_vars) { string_lookup[var->name_hint].push_back(var); @@ -49,10 +49,10 @@ Function FunctionBindSymbolicVars(Function func, } // Replacement map to be used when rewriting the function. - Map var_remap; + ffi::Map var_remap; for (const auto& [key, replacement] : obj_remap) { if (auto opt = key.as()) { - String string_key = opt.value(); + ffi::String string_key = opt.value(); auto it = string_lookup.find(string_key); CHECK(it != string_lookup.end()) << "Function does not use symbolic var with name \"" << string_key << "\". " @@ -91,8 +91,8 @@ Function FunctionBindSymbolicVars(Function func, } namespace { -IRModule ModuleBindSymbolicVars(IRModule mod, - Map, PrimExpr> binding_map) { +IRModule ModuleBindSymbolicVars( + IRModule mod, ffi::Map, PrimExpr> binding_map) { std::unordered_set used; IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { @@ -100,7 +100,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, auto func = opt.value(); // Collect bindings that are used by this function. - auto func_binding_map = [&]() -> Map, PrimExpr> { + auto func_binding_map = [&]() -> ffi::Map, PrimExpr> { std::unordered_set var_names; std::unordered_set vars; for (const auto& var : DefinedSymbolicVars(func)) { @@ -108,10 +108,10 @@ IRModule ModuleBindSymbolicVars(IRModule mod, vars.insert(var.get()); } - Map, PrimExpr> out; + ffi::Map, PrimExpr> out; for (const auto& [key, replacement] : binding_map) { bool used_by_function = false; - if (auto opt = key.as()) { + if (auto opt = key.as()) { used_by_function = var_names.count(opt.value()); } else if (auto ptr = key.as()) { used_by_function = vars.count(ptr); @@ -134,7 +134,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, } } - Array unused; + ffi::Array unused; for (const auto& [key, replacement] : binding_map) { if (!used.count(key)) { unused.push_back(key); @@ -151,15 +151,15 @@ IRModule ModuleBindSymbolicVars(IRModule mod, } } // namespace -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.FunctionBindSymbolicVars", FunctionBindSymbolicVars); -}); +} namespace transform { -Pass BindSymbolicVars(Map, PrimExpr> binding_map, - Optional func_name) { +Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, + ffi::Optional func_name) { auto pass_func = [=](IRModule mod, PassContext context) -> IRModule { if (func_name) { auto gvar = mod->GetGlobalVar(func_name.value()); @@ -177,10 +177,10 @@ Pass BindSymbolicVars(Map, PrimExpr> binding_map, return tvm::transform::CreateModulePass(pass_func, 1, "relax.BindSymbolicVars", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.BindSymbolicVars", BindSymbolicVars); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index 16b7348b8dc7..877f3d7dea35 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -36,11 +36,11 @@ namespace relax { class ModelParamBundler : public ExprMutator { public: - explicit ModelParamBundler(Optional param_tuple_name) + explicit ModelParamBundler(ffi::Optional param_tuple_name) : param_tuple_name_(param_tuple_name) {} Expr VisitExpr_(const FunctionNode* op) override { - Function func = GetRef(op); + Function func = ffi::GetRef(op); auto opt_num_input = func->attrs.GetAttr(attr::kNumInput); if (!opt_num_input) return func; auto signed_num_input = opt_num_input.value()->value; @@ -51,12 +51,12 @@ class ModelParamBundler : public ExprMutator { << "but only has " << func->params.size() << " parameters total."; size_t num_input = signed_num_input; - Array params; + ffi::Array params; for (size_t i = 0; i < num_input; i++) { params.push_back(func->params[i]); } - Array param_tuple; + ffi::Array param_tuple; for (size_t i = num_input; i < func->params.size(); i++) { param_tuple.push_back(GetStructInfo(func->params[i])); } @@ -74,7 +74,7 @@ class ModelParamBundler : public ExprMutator { } Expr VisitExpr_(const VarNode* op) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (auto it = var_to_expr_.find(var); it != var_to_expr_.end()) { return builder_->Emit((*it).second, op->name_hint()); } else { @@ -83,17 +83,17 @@ class ModelParamBundler : public ExprMutator { } private: - Optional param_tuple_name_; - Map var_to_expr_; + ffi::Optional param_tuple_name_; + ffi::Map var_to_expr_; }; -Function BundleModelParams(const Function& func, Optional param_tuple_name) { +Function BundleModelParams(const Function& func, ffi::Optional param_tuple_name) { ModelParamBundler mutator(param_tuple_name); return Downcast(mutator(func)); } namespace transform { -Pass BundleModelParams(Optional param_tuple_name) { +Pass BundleModelParams(ffi::Optional param_tuple_name) { auto pass_func = [=](IRModule mod, PassContext pc) { IRModule updates; @@ -116,10 +116,10 @@ Pass BundleModelParams(Optional param_tuple_name) { return CreateModulePass(pass_func, 1, "BundleModelParams", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.BundleModelParams", BundleModelParams); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index a47b9bfe5105..d4763b44b713 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -74,7 +74,7 @@ class CallTIRMutator : public ExprMutator { call->op == call_dps_packed_op) { bool is_inplace = (call->op == call_tir_inplace_op); const auto* inplace_attrs = call->attrs.as(); - Array outs; + ffi::Array outs; if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { // single output case const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); @@ -130,7 +130,7 @@ class CallTIRMutator : public ExprMutator { << expr->struct_info_; } - Array args; + ffi::Array args; if (call->args[1].as()) { args = Downcast(call->args[1])->fields; // for call_tir_inplace, don't reinsert in-place args, only the newly allocated ones @@ -167,7 +167,7 @@ class CallTIRMutator : public ExprMutator { return std::move(Tuple(outs)); } - return GetRef(call); + return ffi::GetRef(call); } /*! \brief The context IRModule. */ @@ -184,10 +184,10 @@ Pass CallTIRRewrite() { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.CallTIRRewrite", CallTIRRewrite); -}); +} } // namespace transform diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 54c508ff2302..decbecd3098b 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -59,7 +59,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { << ", while the later definition of Relax variable " << binding->var << " instead implies that TIR variable " << tir_var << " is " << prim_expr; } else { - known_values_[tir_var] = KnownValue{prim_expr, GetRef(binding)}; + known_values_[tir_var] = KnownValue{prim_expr, ffi::GetRef(binding)}; } } ExprMutator::VisitBinding_(binding); @@ -76,7 +76,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { - return GetRef(op); + return ffi::GetRef(op); } // The two branches may have had different TIR variables inlined. @@ -119,7 +119,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { if (known_values_.empty()) { return expr; } - PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> Optional { + PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> ffi::Optional { if (auto it = known_values_.find(var); it != known_values_.end()) { return it->second.expr; } else { @@ -144,10 +144,10 @@ class SymbolicVarCanonicalizer : public ExprMutator { }; struct CanonicalizationPlan { - Map replace_usage; - Map replace_binding; + ffi::Map replace_usage; + ffi::Map replace_binding; std::unordered_set bindings_to_remove; - Map inline_constant; + ffi::Map inline_constant; }; /*! \brief Utility class to identify usage location @@ -232,8 +232,8 @@ class CanonicalizePlanner : public ExprVisitor { void VisitExpr_(const FunctionNode* func) override { // for functions, treat any free vars as used outside their home DF block auto cache = current_block_; - current_block_ = Optional(); - auto free_vars = FreeVars(GetRef(func)); + current_block_ = ffi::Optional(); + auto free_vars = FreeVars(ffi::GetRef(func)); for (auto var : free_vars) { used_outside_home_dataflow_.insert(var); } @@ -244,26 +244,26 @@ class CanonicalizePlanner : public ExprVisitor { void VisitExpr_(const SeqExprNode* seq) override { // need to reset current_block_ for nested seq exprs (such as in If nodes) auto cache = current_block_; - current_block_ = Optional(); + current_block_ = ffi::Optional(); ExprVisitor::VisitExpr_(seq); current_block_ = cache; } void VisitBindingBlock_(const BindingBlockNode* block) override { CHECK(!current_block_.defined()) << "Forgetting to unset current block"; - current_block_ = GetRef(block); + current_block_ = ffi::GetRef(block); ExprVisitor::VisitBindingBlock_(block); - current_block_ = Optional(); + current_block_ = ffi::Optional(); } void VisitBindingBlock_(const DataflowBlockNode* block) override { CHECK(!current_block_.defined()) << "Forgetting to unset current block"; - current_block_ = GetRef(block); + current_block_ = ffi::GetRef(block); ExprVisitor::VisitBindingBlock_(block); - current_block_ = Optional(); + current_block_ = ffi::Optional(); } - Optional UnwrapKnownValue(Expr expr) { + ffi::Optional UnwrapKnownValue(Expr expr) { // If the expression is a variable, then it can be unwrapped into // its known value. auto unwrap_var = [this](Expr expr) -> Expr { @@ -299,7 +299,7 @@ class CanonicalizePlanner : public ExprVisitor { // If the expression is a Tuple, and each element is // `TupleGetItem(earlier_tuple, i)`, then this is just a copy of // `earlier_tuple`. - auto earlier_tuple = [&]() -> Optional { + auto earlier_tuple = [&]() -> ffi::Optional { auto expr_tuple = expr.as(); if (!expr_tuple) { return std::nullopt; @@ -385,14 +385,14 @@ class CanonicalizePlanner : public ExprVisitor { } void VisitExpr_(const VarNode* var) override { - auto var_ref = GetRef(var); + auto var_ref = ffi::GetRef(var); // if a var is used in a dataflow block but *not* the one // where it was defined, it also needs to be exposed, so also we treat that as // used outside of a dataflow block if (!inside_dataflow() || (def_blocks_.count(var_ref) && (current_block_.defined() && !current_block_.value().same_as(def_blocks_.at(var_ref))))) { - used_outside_home_dataflow_.insert(GetRef(var)); + used_outside_home_dataflow_.insert(ffi::GetRef(var)); } } @@ -400,12 +400,12 @@ class CanonicalizePlanner : public ExprVisitor { return current_block_.defined() && current_block_.value().as(); } - Optional current_block_; - Map def_blocks_; + ffi::Optional current_block_; + ffi::Map def_blocks_; - Map trivial_bindings_; - Map known_bindings_; - Map known_bound_to_constant_; + ffi::Map trivial_bindings_; + ffi::Map known_bindings_; + ffi::Map known_bound_to_constant_; std::unordered_set defined_inside_dataflow_; // Set of vars either used outside a dataflow block altogether or outside their // home dataflow block (the one where they were defined) @@ -440,7 +440,7 @@ class BindingCanonicalizer : public ExprMutator { } Expr VisitExpr_(const VarNode* var) override { - Var new_var = GetRef(var); + Var new_var = ffi::GetRef(var); while (auto opt = plan_.replace_usage.Get(new_var->vid)) { new_var = opt.value(); } @@ -470,7 +470,7 @@ class BindingCanonicalizer : public ExprMutator { // disqualify any vars that appear in the RHS // (for a function literal, consider only free vars) - Array rhs_vars; + ffi::Array rhs_vars; if (!value->IsInstance()) { rhs_vars = FreeVars(value); } else { @@ -494,12 +494,12 @@ class BindingCanonicalizer : public ExprMutator { // disqualify if the RHS is not a single dataflow var // or if the var has been output before if (const auto* rhs_var = value.as()) { - if (output_vars.count(GetRef(rhs_var))) { - disqualified_set.insert(GetRef(rhs_var)); + if (output_vars.count(ffi::GetRef(rhs_var))) { + disqualified_set.insert(ffi::GetRef(rhs_var)); } - output_vars.insert(GetRef(rhs_var)); + output_vars.insert(ffi::GetRef(rhs_var)); } else { - Array disqualified; + ffi::Array disqualified; // for function literal, consider only free vars if (value->IsInstance()) { disqualified = FreeVars(value); @@ -518,7 +518,7 @@ class BindingCanonicalizer : public ExprMutator { // second pass: for each binding where the LHS is a candidate, remove the binding. // If the RHS is a candidate, replace it with the definition - Array new_bindings; + ffi::Array new_bindings; bool changed = false; for (auto binding : new_block->bindings) { if (binding->var->IsInstance() && @@ -592,10 +592,10 @@ Pass CanonicalizeBindings() { "CanonicalizeBindings"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.CanonicalizeBindings", CanonicalizeBindings); -}); +} } // namespace transform diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 9c0318ee3926..c60864d671c5 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -39,13 +39,13 @@ namespace tvm { namespace relax { -using FCheck = ffi::TypedFunction, Array, Map)>; +using FCheck = ffi::TypedFunction, ffi::Array, ffi::Map)>; /*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose batch sizes are compatible are combined. */ std::unordered_map> GroupShapes( - const std::vector>& shapes) { + const std::vector>& shapes) { std::unordered_map> indices_map; for (size_t i = 0; i < shapes.size(); ++i) { indices_map[shapes[i].size()].push_back(i); @@ -77,7 +77,7 @@ struct Patterns { struct SplitInfo { Var rhs; - Optional bias; + ffi::Optional bias; PrimExpr split_size; DFPattern pattern_to_replace; }; @@ -116,10 +116,10 @@ Patterns CreatePatterns(const BranchInfo& branch_info) { } /*! \brief Create a rewriter for the given parallel matmul branches. */ -ffi::TypedFunction(Map, Map)> GetRewriter( +ffi::TypedFunction(ffi::Map, ffi::Map)> GetRewriter( const Patterns& patterns, const BranchInfo& branch_info, FCheck check) { auto batch_dims_compatible = [](size_t rhs_dim, const std::vector& indices, - const std::vector>& rhs_shapes) { + const std::vector>& rhs_shapes) { arith::Analyzer ana; for (auto ind : indices) { ICHECK_EQ(static_cast(rhs_shapes[ind].size()), rhs_dim); @@ -133,17 +133,17 @@ ffi::TypedFunction(Map, Map)> GetRewri return true; }; - return [=](Map matchings, Map bindings) { - std::vector> rhs_shapes; + return [=](ffi::Map matchings, ffi::Map bindings) { + std::vector> rhs_shapes; for (const auto& rhs_pat : patterns.rhs) { auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape(); if (!rhs_shape_opt) { - return Map{}; + return ffi::Map{}; } rhs_shapes.push_back(rhs_shape_opt.value()); } - Map replacements; + ffi::Map replacements; for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) { if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue; @@ -159,7 +159,7 @@ ffi::TypedFunction(Map, Map)> GetRewri std::vector splits; for (auto index : indices) { Var rhs = matchings[patterns.rhs[index]]; - Optional bias = std::nullopt; + ffi::Optional bias = std::nullopt; if (branch_info.bias_dim.has_value()) { bias = matchings[patterns.bias[index]]; } @@ -190,8 +190,8 @@ ffi::TypedFunction(Map, Map)> GetRewri continue; } - Array rhs; - Array bias; + ffi::Array rhs; + ffi::Array bias; for (const auto& split : splits) { rhs.push_back(split.rhs); if (split.bias) { @@ -228,7 +228,7 @@ ffi::TypedFunction(Map, Map)> GetRewri } int split_index = 0; - Array sections; + ffi::Array sections; for (size_t i = 0; i + 1 < splits.size(); i++) { auto width = splits[i].split_size.as(); ICHECK(width) << "InternalError: " @@ -388,10 +388,10 @@ Pass CombineParallelMatmul(FCheck check) { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.CombineParallelMatmul", CombineParallelMatmul); -}); +} } // namespace transform diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index c2cffd2c4439..7129af2236c1 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -87,10 +87,10 @@ Pass ComputePrimValue() { return CreateModulePass(pass_func, 0, "ComputePrimValue", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ComputePrimValue", ComputePrimValue); -}); +} } // namespace transform diff --git a/src/relax/transform/convert_dataflow.cc b/src/relax/transform/convert_dataflow.cc index 4fad1f831842..ac95acce63f2 100644 --- a/src/relax/transform/convert_dataflow.cc +++ b/src/relax/transform/convert_dataflow.cc @@ -39,7 +39,7 @@ class DataflowBlockExtractor : public ExprMutator { explicit DataflowBlockExtractor(size_t min_size) : ExprMutator(), min_size_(min_size) {} Expr VisitExpr_(const SeqExprNode* seq) override { - Array new_blocks; + ffi::Array new_blocks; Expr new_body = VisitExpr(seq->body); bool changed = !new_body.same_as(seq->body); @@ -49,15 +49,15 @@ class DataflowBlockExtractor : public ExprMutator { // make a dataflowblock. Because these bindings occur prior to // `dataflow_bindings`, this array may only be accumulated into // when `dataflow_bindings` is empty. - Array non_dataflow_bindings; + ffi::Array non_dataflow_bindings; // Current bindings that may legally be added to a DataflowBlock. - Array dataflow_bindings; + ffi::Array dataflow_bindings; // If present, a DataflowBlock whose bindings are currently in // `dataflow_bindings`. Used to propagate DataflowBlock to the // output, even if it doesn't meet the minimum size. - Optional input_dataflow_block; + ffi::Optional input_dataflow_block; // Handle any bindings currently in `dataflow_bindings`. These // are either pushed to their own block, or to the end of @@ -134,7 +134,7 @@ class DataflowBlockExtractor : public ExprMutator { if (changed) { return SeqExpr(new_blocks, new_body); } else { - return GetRef(seq); + return ffi::GetRef(seq); } } @@ -160,10 +160,10 @@ Pass ConvertToDataflow(int min_size) { return tvm::transform::Sequential({pass, CanonicalizeBindings()}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ConvertToDataflow", ConvertToDataflow); -}); +} } // namespace transform diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 2ba757c76a70..c543799e3b0d 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -78,12 +78,13 @@ using tir::Layout; */ class LayoutConvertMutator : public ExprMutator { public: - explicit LayoutConvertMutator(const Map>& desired_layouts) + explicit LayoutConvertMutator( + const ffi::Map>& desired_layouts) : desired_layouts_(desired_layouts) {} private: - Array LayoutToIntegers(const Layout& layout) { - Array ret; + ffi::Array LayoutToIntegers(const Layout& layout) { + ffi::Array ret; LayoutDecision src = InitialLayoutDecision(layout.ndim()); for (size_t i = 0; i < layout.ndim(); ++i) { ret.push_back(Integer(src->layout.IndexOf(layout[i]))); @@ -93,17 +94,17 @@ class LayoutConvertMutator : public ExprMutator { IndexMap LayoutIndexMap(int ndim, const Layout& src_layout, const Layout& desired_layout) { tir::BijectiveLayout todesired(src_layout, desired_layout); - Optional inverse_index_map; + ffi::Optional inverse_index_map; - Array initial_indices; - Array initial_indices_expr; + ffi::Array initial_indices; + ffi::Array initial_indices_expr; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { auto var = tvm::tir::Var("i" + std::to_string(i), DataType::Int(32)); initial_indices.push_back(var); initial_indices_expr.push_back(var); } - Array desired_shape = todesired.ForwardIndex(initial_indices_expr); + ffi::Array desired_shape = todesired.ForwardIndex(initial_indices_expr); return IndexMap(initial_indices, desired_shape, std::move(inverse_index_map)); } @@ -125,9 +126,9 @@ class LayoutConvertMutator : public ExprMutator { } else { auto index_map = LayoutIndexMap(from.LeafValue()->layout.ndim(), from.LeafValue()->layout, to.LeafValue()->layout); - ObjectPtr attrs = make_object(); - Array axis_separator; - Array input_axis_separator; + ObjectPtr attrs = ffi::make_object(); + ffi::Array axis_separator; + ffi::Array input_axis_separator; attrs->index_map = Downcast(LoadJSON(SaveJSON(index_map))); attrs->axis_separators = std::move(axis_separator); attrs->input_axis_separators = std::move(input_axis_separator); @@ -141,9 +142,9 @@ class LayoutConvertMutator : public ExprMutator { std::array({GetNLayout(var_layout_map_, expr), to}), fvisitleaf); } - Array RewriteArgs(const Array& args, const Array& to) { - // The `Array args` array contains both tensor and - // non-tensor arguments, where the `Array to` array only + ffi::Array RewriteArgs(const ffi::Array& args, const ffi::Array& to) { + // The `ffi::Array args` array contains both tensor and + // non-tensor arguments, where the `ffi::Array to` array only // contains tensor arguments. The number of tensor arguments in // `args` should match the full extent of `to`. @@ -175,7 +176,7 @@ class LayoutConvertMutator : public ExprMutator { return RewriteExpr(var, InitialNLayout(var)); } - Expr VisitExpr_(const VarNode* op) final { return VisitVars_(GetRef(op)); } + Expr VisitExpr_(const VarNode* op) final { return VisitVars_(ffi::GetRef(op)); } bool HasUnknownDimTensor(const NLayout& nlayout) { bool find = false; @@ -186,7 +187,7 @@ class LayoutConvertMutator : public ExprMutator { return find; } - bool HasUnknownDimTensor(const Array& args) { + bool HasUnknownDimTensor(const ffi::Array& args) { for (const auto& arg : args) { if (IsNestedTensor(arg)) { if (HasUnknownDimTensor(GetNLayout(var_layout_map_, arg))) { @@ -197,17 +198,18 @@ class LayoutConvertMutator : public ExprMutator { return false; } - Optional GetInferLayoutInfo(const CallNode* call_node, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { + ffi::Optional GetInferLayoutInfo( + const CallNode* call_node, + const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const OpNode* op_node = call_node->op.as(); if (op_node == nullptr) return std::nullopt; - Op op = Downcast(GetRef(op_node)); + Op op = Downcast(ffi::GetRef(op_node)); const auto attr_map = Op::GetAttrMap("FRelaxInferLayout"); if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { // If the op has FRelaxInferLayout, and all the input tensors have known ndim FRelaxInferLayout f = attr_map[op]; - return f(GetRef(call_node), desired_layouts, var_layout_map); + return f(ffi::GetRef(call_node), desired_layouts, var_layout_map); } else { // Otherwise, we use the default policy. return std::nullopt; @@ -215,9 +217,9 @@ class LayoutConvertMutator : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - Optional res = + ffi::Optional res = GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_); - ObjectPtr new_call = make_object(*call_node); + ObjectPtr new_call = ffi::make_object(*call_node); new_call->struct_info_ = std::nullopt; if (!res.defined() || (!IsNestedTensor(binding->var) && !binding->var->IsInstance())) { @@ -227,14 +229,14 @@ class LayoutConvertMutator : public ExprMutator { for (const auto& arg : call_node->args) { input_layout.push_back(InitialNLayout(arg)); } - Array new_args = RewriteArgs(call_node->args, std::move(input_layout)); + ffi::Array new_args = RewriteArgs(call_node->args, std::move(input_layout)); new_call->args = std::move(new_args); ReEmitBinding(binding, builder_->Normalize(Call(new_call))); // update the layout map var_layout_map_[binding->var] = InitialNLayout(binding->var); } else { // Convert the layout according to the inferred layout output. - Array new_args = RewriteArgs(call_node->args, res.value()->input_layouts); + ffi::Array new_args = RewriteArgs(call_node->args, res.value()->input_layouts); for (const auto& [i, arg] : res.value()->new_args) { new_args.Set(i->value, arg); } @@ -273,7 +275,7 @@ class LayoutConvertMutator : public ExprMutator { input_layout.push_back(InitialNLayout(field)); } } - Array new_fields = RewriteArgs(val->fields, std::move(input_layout)); + ffi::Array new_fields = RewriteArgs(val->fields, std::move(input_layout)); if (IsNestedTensor(binding->var)) { ReEmitBinding(binding, builder_->Normalize(Tuple(new_fields))); var_layout_map_[binding->var] = input_layout; @@ -322,7 +324,7 @@ class LayoutConvertMutator : public ExprMutator { binding->struct_info, std::array({from_layout, input_layout}), fvisitleaf); // re-emit old binding if nothing changes if (new_struct_info.same_as(binding->struct_info)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); } else { Var new_var = builder_->EmitMatchCast(RewriteExpr(binding->value, input_layout), new_struct_info); @@ -332,18 +334,18 @@ class LayoutConvertMutator : public ExprMutator { } std::unordered_map var_layout_map_; - Map> desired_layouts_; + ffi::Map> desired_layouts_; }; // namespace relax DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block, - Map> desired_layouts) { + ffi::Map> desired_layouts) { LayoutConvertMutator mutator(desired_layouts); return Downcast(mutator.VisitBindingBlock(df_block)); } namespace transform { -Pass ConvertLayout(Map> desired_layouts) { +Pass ConvertLayout(ffi::Map> desired_layouts) { ffi::TypedFunction pass_func = [=](DataflowBlock df_block, IRModule m, PassContext pc) { return Downcast(ConvertLayoutPass(df_block, desired_layouts)); @@ -351,10 +353,10 @@ Pass ConvertLayout(Map> desired_layouts) { return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ConvertLayout", ConvertLayout); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index fa75669362ad..3b56d6ca1d81 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -48,7 +48,7 @@ std::unordered_map> AnalyzeLiveness(const DataflowBlock Binding b = block->bindings[i]; Var defined_var = b->var; Expr value = GetBoundValue(b); - Array used_vars; + ffi::Array used_vars; // for a function literal, we consider only the free vars // (those captured from the outer scope) if (value.as()) { @@ -105,7 +105,7 @@ class AliasAnalyzer { // (in the case of in-place ops) safe to overwrite. This may not be true of function args. std::pair>, std::unordered_map>>> - Analyze(const DataflowBlock& block, const Array& inputs) { + Analyze(const DataflowBlock& block, const ffi::Array& inputs) { for (auto input : inputs) { int curr_idx = get_fresh_idx(); alias_map_[input] = {curr_idx}; @@ -227,7 +227,7 @@ class AliasAnalyzer { // TODO(@slyubomirsky): We will probably want special handling for closures ret.insert(get_fresh_idx()); } else if (auto* target_var_node = value.as()) { - auto target_var = GetRef(target_var_node); + auto target_var = ffi::GetRef(target_var_node); if (alias_map_.count(target_var)) { ret.insert(alias_map_[target_var].begin(), alias_map_[target_var].end()); } else { @@ -324,7 +324,7 @@ std::unordered_set GatherCandidateSin // don't consider cases where we don't know the shape at compile time // (we will use the analyzer to do best-effort analysis where there are vars) if (tensor_info->shape.as()) { - return {GetRef(tensor_info)}; + return {ffi::GetRef(tensor_info)}; } else { return {}; } @@ -337,7 +337,7 @@ std::unordered_set GatherCandidateSin } // at least one field should be eligible to be done in-place if (!ret.empty()) { - ret.insert(GetRef(tuple_info)); + ret.insert(ffi::GetRef(tuple_info)); } return ret; } else { @@ -447,7 +447,7 @@ bool InplaceConditionsMet( const std::unordered_map>>& tuple_map, const std::unordered_set& currently_live, const Expr& target, int binding_idx) { if (auto* var_node = target.as()) { - auto current_var = GetRef(var_node); + auto current_var = ffi::GetRef(var_node); // if the var is live past this point, we can't use it for in-place computations anyway if (live_ranges.count(current_var)) { auto live_range = live_ranges.at(current_var); @@ -523,7 +523,7 @@ class InplaceOpportunityNode : public Object { public: // need to use Array for the benefit of the FFI Integer binding_idx; - Array arg_idxs; + ffi::Array arg_idxs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -531,23 +531,21 @@ class InplaceOpportunityNode : public Object { .def_ro("binding_idx", &InplaceOpportunityNode::binding_idx) .def_ro("arg_idxs", &InplaceOpportunityNode::arg_idxs); } - - static constexpr const char* _type_key = "relax.transform.InplaceOpportunity"; - TVM_DECLARE_BASE_OBJECT_INFO(InplaceOpportunityNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("relax.transform.InplaceOpportunity", InplaceOpportunityNode, Object); }; -TVM_FFI_STATIC_INIT_BLOCK({ InplaceOpportunityNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { InplaceOpportunityNode::RegisterReflection(); } class InplaceOpportunity : public ObjectRef { public: - TVM_DLL InplaceOpportunity(const Integer& binding_idx, const Array& arg_idxs) { - auto node = make_object(); + TVM_DLL InplaceOpportunity(const Integer& binding_idx, const ffi::Array& arg_idxs) { + auto node = ffi::make_object(); node->binding_idx = binding_idx; node->arg_idxs = arg_idxs; data_ = std::move(node); } - TVM_DEFINE_OBJECT_REF_METHODS(InplaceOpportunity, ObjectRef, InplaceOpportunityNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(InplaceOpportunity, ObjectRef, InplaceOpportunityNode); }; // Check for in-place eligibility: @@ -564,7 +562,7 @@ class InplaceOpportunity : public ObjectRef { // The first element is the index of the *binding* in the block. // All remaining elements are the indices of *eligible arguments* in that call. std::pair, std::vector> -FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, +FindInplaceOpportunities(const DataflowBlock& block, const ffi::Array& inputs, const BlockBuilder& ctx) { auto live_ranges = AnalyzeLiveness(block); AliasAnalyzer analyzer; @@ -619,7 +617,7 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, if (auto* call_node = value.as()) { if (auto* op_node = call_node->op.as()) { - if (!OpSupportsInplace(GetRef(op_node))) { + if (!OpSupportsInplace(ffi::GetRef(op_node))) { continue; } @@ -669,14 +667,14 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, } // produce a list of candidates for this index - Array size_candidate_list; + ffi::Array size_candidate_list; for (auto candidate : candidates) { size_candidate_list.push_back(Integer(candidate)); } size_match_list.push_back(InplaceOpportunity(Integer(i), size_candidate_list)); // also gather up the exact match candidates if there are any - Array exact_candidate_list; + ffi::Array exact_candidate_list; for (auto candidate : candidates) { if (!exact_match_candidates.count(candidate)) { continue; @@ -695,10 +693,11 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, } // Replace buffers in a PrimFunc according to the mapping. -tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map& buffer_map) { +tir::Stmt RemapBuffers(const tir::Stmt& stmt, + const ffi::Map& buffer_map) { class BufferMapper : public tir::StmtExprMutator { public: - explicit BufferMapper(const Map& buffer_map) + explicit BufferMapper(const ffi::Map& buffer_map) : buffer_map_(buffer_map) {} tir::Stmt Remap(const tir::Stmt& stmt) { return VisitStmt(stmt); } @@ -766,7 +765,7 @@ tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map& buffer_map_; + const ffi::Map& buffer_map_; }; BufferMapper mapper(buffer_map); @@ -786,7 +785,7 @@ class ModuleInplaceTransformer : public ExprMutator { if (auto* func_node = kv.second.as()) { auto gv = kv.first; auto func_params = func_node->params; - auto function = Downcast(VisitExpr(GetRef(func_node))); + auto function = Downcast(VisitExpr(ffi::GetRef(func_node))); builder_->UpdateFunction(gv, function); } } @@ -810,14 +809,14 @@ class ModuleInplaceTransformer : public ExprMutator { // the only case we will override: we will visit all binding blocks // and replace any valid calls in them BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { - auto block = GetRef(op); + auto block = ffi::GetRef(op); auto old_idxs = inplace_idxs; // For now, only handle exact match cases. // Note: Not passing any input values for now, as we can't make any assumptions // about them. auto matches_found = FindInplaceOpportunities(block, {}, builder_); - Map> new_idxs; + ffi::Map> new_idxs; for (auto match : matches_found.second) { new_idxs.Set(block->bindings[match->binding_idx.IntValue()], match->arg_idxs); } @@ -838,7 +837,7 @@ class ModuleInplaceTransformer : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding) override { - auto binding_ref = GetRef(binding); + auto binding_ref = ffi::GetRef(binding); if (!inplace_idxs.count(binding_ref)) { ExprMutator::VisitBinding_(binding); return; @@ -848,7 +847,7 @@ class ModuleInplaceTransformer : public ExprMutator { } void VisitBinding_(const MatchCastNode* binding) override { - auto binding_ref = GetRef(binding); + auto binding_ref = ffi::GetRef(binding); if (!inplace_idxs.count(binding_ref)) { ExprMutator::VisitBinding_(binding); return; @@ -861,7 +860,7 @@ class ModuleInplaceTransformer : public ExprMutator { // Given the call and indices of arguments that could be done in-place, // replace the call with a call to an in-place PrimFunc. // (Made public for testing.) - Call CreateInplaceCall(const Call& call, const Array& inplace_indices) { + Call CreateInplaceCall(const Call& call, const ffi::Array& inplace_indices) { static const auto& legalize_map = Op::GetAttrMap("FLegalize"); static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); @@ -890,8 +889,8 @@ class ModuleInplaceTransformer : public ExprMutator { // 2. For each output var, replace its instances with the corresponding inplace index var // 3. Do the same for the *buffer vars* corresponding to the output vars // 4. Remove the output vars from the param list and buffer map - Map buffer_subst_map; - Map var_subst_map; + ffi::Map buffer_subst_map; + ffi::Map var_subst_map; for (size_t i = 0; i < num_outs; i++) { // we will substitute output i with the corresponding param indicated by inplace indices auto output_var = old_primfunc->params[num_params - num_outs + i]; @@ -907,12 +906,13 @@ class ModuleInplaceTransformer : public ExprMutator { // apply substitutions new_body = RemapBuffers(new_body, buffer_subst_map); - new_body = tir::Substitute(new_body, [&var_subst_map](const tir::Var& v) -> Optional { - if (var_subst_map.count(v)) { - return var_subst_map.at(v); - } - return Optional(); - }); + new_body = + tir::Substitute(new_body, [&var_subst_map](const tir::Var& v) -> ffi::Optional { + if (var_subst_map.count(v)) { + return var_subst_map.at(v); + } + return ffi::Optional(); + }); // remove the now-unused outputs from the buffer map auto new_buffer_map = old_primfunc->buffer_map; @@ -922,8 +922,8 @@ class ModuleInplaceTransformer : public ExprMutator { // now get rid of the last num_outputs arguments // (couldn't do earlier or else it would have thrown off the indexing) - Array new_params(old_primfunc->params.begin(), - old_primfunc->params.begin() + (num_params - num_outs)); + ffi::Array new_params(old_primfunc->params.begin(), + old_primfunc->params.begin() + (num_params - num_outs)); tir::PrimFunc new_primfunc(new_params, new_body, old_primfunc->ret_type, new_buffer_map, old_primfunc->attrs, old_primfunc->span); @@ -935,11 +935,11 @@ class ModuleInplaceTransformer : public ExprMutator { // update the call (change the op, update the argument, change the attrs) legalized_call_cow->op = call_tir_inplace_op; - Array new_args(legalized_call->args.begin(), legalized_call->args.end()); + ffi::Array new_args(legalized_call->args.begin(), legalized_call->args.end()); new_args.Set(0, new_gv); legalized_call_cow->args = new_args; - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->inplace_indices = inplace_indices; legalized_call_cow->attrs = Attrs(attrs); @@ -952,43 +952,43 @@ class ModuleInplaceTransformer : public ExprMutator { private: const IRModule& mod_; // Keep track of legalizers we add so we can clean up at the end. - Array legalizers_added; + ffi::Array legalizers_added; // The current function's params will be treated as non-aliased // (we are assuming good behavior on the user's part). - Array func_params; + ffi::Array func_params; // map of eligible bindings to indices of arguments that can be used as the in-place target - Map> inplace_idxs; + ffi::Map> inplace_idxs; }; namespace transform { -Map> DataflowLivenessAnalysis(const DataflowBlock& block) { +ffi::Map> DataflowLivenessAnalysis(const DataflowBlock& block) { auto liveness_ranges = AnalyzeLiveness(block); - Map> ret; + ffi::Map> ret; for (auto kv : liveness_ranges) { ret.Set(kv.first, {kv.second.first, kv.second.second}); } return ret; } -Array DataflowAliasAnalysis(const DataflowBlock& block, Array inputs) { +ffi::Array DataflowAliasAnalysis(const DataflowBlock& block, ffi::Array inputs) { AliasAnalyzer analyzer; auto res = analyzer.Analyze(block, inputs); auto alias_sets = res.first; auto tuple_map = res.second; - Map> new_alias_sets; - Map>> new_tuple_map; + ffi::Map> new_alias_sets; + ffi::Map>> new_tuple_map; for (auto kv : alias_sets) { - Array aliases; + ffi::Array aliases; for (auto alias : kv.second) { aliases.push_back(alias); } new_alias_sets.Set(kv.first, aliases); } for (auto kv : tuple_map) { - Array> elem_aliases; + ffi::Array> elem_aliases; for (auto alias_set : kv.second) { - Array dim_aliases; + ffi::Array dim_aliases; for (auto alias : alias_set) { dim_aliases.push_back(alias); } @@ -1010,16 +1010,16 @@ tvm::transform::Pass DataflowUseInplaceCalls() { 0, "DataflowInsertInPlaceCalls", {}, false); } -Array> DataflowInplaceAnalysis(const DataflowBlock& block, - const Array& inputs, - const IRModule& mod) { +ffi::Array> DataflowInplaceAnalysis(const DataflowBlock& block, + const ffi::Array& inputs, + const IRModule& mod) { auto index_lists = relax::FindInplaceOpportunities(block, inputs, BlockBuilder::Create(mod)); - return {Array(index_lists.first.begin(), index_lists.first.end()), - Array(index_lists.second.begin(), index_lists.second.end())}; + return {ffi::Array(index_lists.first.begin(), index_lists.first.end()), + ffi::Array(index_lists.second.begin(), index_lists.second.end())}; } // these are exposed only for testing -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.testing.transform.DataflowLivenessAnalysis", DataflowLivenessAnalysis) @@ -1027,18 +1027,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("relax.testing.transform.DataflowInplaceAnalysis", DataflowInplaceAnalysis) .def("relax.testing.transform.SingleInplaceCall", [](const IRModule& mod, const Call& call, - const Array& inplace_indices) -> Array { + const ffi::Array& inplace_indices) -> ffi::Array { ModuleInplaceTransformer transformer(mod); auto ret_call = transformer.CreateInplaceCall(call, inplace_indices); - return Array{ret_call, transformer.CurrentMod()}; + return ffi::Array{ret_call, transformer.CurrentMod()}; }); -}); +} // actually exposed -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.DataflowUseInplaceCalls", DataflowUseInplaceCalls); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 59874e737778..fbb077ddf941 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -91,7 +91,8 @@ IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set return mod; } -IRModule DeadCodeElimination(const IRModule& arg_mod, Array entry_function_names) { +IRModule DeadCodeElimination(const IRModule& arg_mod, + ffi::Array entry_function_names) { IRModule mod = arg_mod; // S0: Make a list of all user-specified entry functions and @@ -134,17 +135,17 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array entry_functi namespace transform { -Pass DeadCodeElimination(Array entry_functions) { +Pass DeadCodeElimination(ffi::Array entry_functions) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::DeadCodeElimination(m, entry_functions); }; return CreateModulePass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.DeadCodeElimination", DeadCodeElimination); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index df57434ebb02..81d4d3881ede 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -36,9 +36,9 @@ TensorStructInfo MatchTensorStructInfo(Expr data) { return _sinfo.value(); } -Expr ExpandToMatchInput(Expr data, int ndim, Array axes) { +Expr ExpandToMatchInput(Expr data, int ndim, ffi::Array axes) { axes = GetOrderedPositiveAxes(axes, ndim); - Array expand_axes; + ffi::Array expand_axes; for (int i = 0, j = 0; i < ndim; ++i) { if (j < static_cast(axes.size()) && i == axes[j]->value) { ++j; @@ -89,7 +89,7 @@ Expr MutateBatchNormForTraining(Call call) { TensorStructInfo sinfo = MatchTensorStructInfo(data); - Array reduce_axes; + ffi::Array reduce_axes; for (int i = 0; i < sinfo->ndim; ++i) { if (i != attrs->axis) { reduce_axes.push_back(i); @@ -148,12 +148,12 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); Var call = builder->Emit(Call(call_pure_packed_op, {ExternFunc("vm.builtin.tensor_to_shape"), expr}, {}, - {GetRef(sinfo)})); + {ffi::GetRef(sinfo)})); // Operators like reshape take the output of `TensorToShape` as their output shape. // Because TOPI expects to have such output shape in symbolic shape at least (i.e., - // Array), we define symbolic variables and returns them as a ShapeExpr. - Array shape_var; + // ffi::Array), we define symbolic variables and returns them as a ShapeExpr. + ffi::Array shape_var; for (int i = 0; i < sinfo->ndim; i++) { shape_var.push_back(tir::Var("x", DataType::Int(64))); } @@ -233,7 +233,7 @@ Pass DecomposeOps() { /*required=*/{}); } -Pass DecomposeOpsForInference(Optional func_name) { +Pass DecomposeOpsForInference(ffi::Optional func_name) { if (func_name) { return ApplyPassToFunction(DecomposeOps(), func_name.value()); } else { @@ -241,7 +241,7 @@ Pass DecomposeOpsForInference(Optional func_name) { } } -Pass DecomposeOpsForTraining(Optional func_name) { +Pass DecomposeOpsForTraining(ffi::Optional func_name) { auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), DecomposeOps()}, "DecomposeOpsForTraining"); if (func_name) { @@ -251,12 +251,12 @@ Pass DecomposeOpsForTraining(Optional func_name) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.transform.DecomposeOpsForInference", DecomposeOpsForInference) .def("relax.transform.DecomposeOpsForTraining", DecomposeOpsForTraining); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 68e37970030a..e893b5151b52 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -48,7 +48,7 @@ namespace { */ struct ReplacementKey { tvm::relax::Expr bound_value; - tvm::Optional match_cast = std::nullopt; + tvm::ffi::Optional match_cast = std::nullopt; explicit ReplacementKey(const tvm::relax::Binding& binding) : bound_value(GetBoundValue(binding)) { @@ -155,7 +155,7 @@ class CommonSubexprEliminator : public ExprMutator { // copy of the mutator, to avoid replacing a child-scope // expression with a parent-scope binding, or vice versa. if (expr_replacements_.size() || var_remap_.size()) { - return VisitWithCleanScope(GetRef(op)); + return VisitWithCleanScope(ffi::GetRef(op)); } else { return ExprMutator::VisitExpr_(op); } @@ -168,7 +168,7 @@ class CommonSubexprEliminator : public ExprMutator { if (op->cond.same_as(cond) && op->true_branch.same_as(true_branch) && op->false_branch.same_as(false_branch) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(cond, true_branch, false_branch, op->span); } @@ -193,7 +193,7 @@ class CommonSubexprEliminator : public ExprMutator { static const auto& allocator_attr_map = Op::GetAttrMap("TAllocator"); if (const auto* call = expr.as()) { if (const auto* op = call->op.as()) { - bool is_allocator = allocator_attr_map.get(GetRef(op), Bool(false))->value; + bool is_allocator = allocator_attr_map.get(ffi::GetRef(op), Bool(false))->value; if (is_allocator) { return true; } @@ -222,10 +222,10 @@ Pass EliminateCommonSubexpr(bool call_only) { return CreateFunctionPass(pass_func, 1, "EliminateCommonSubexpr", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.EliminateCommonSubexpr", EliminateCommonSubexpr); -}); +} } // namespace transform diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 70662396fe52..5504c2a59942 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -41,7 +41,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns( +std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); @@ -58,7 +58,7 @@ std::tuple)>> Crea auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs); - auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { auto lhs = matches[pat_lhs]; auto rhs_a = matches[pat_rhs_a]; auto rhs_b = matches[pat_rhs_b]; @@ -105,10 +105,10 @@ Pass ExpandMatmulOfSum() { return CreateFunctionPass(pass_func, 1, "ExpandMatmulOfSum", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ExpandMatmulOfSum", ExpandMatmulOfSum); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index 5b711b767562..0239652c791a 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -32,8 +32,8 @@ namespace { template using PMap = std::unordered_map; -Optional ExpandParams(Function func) { - bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); +ffi::Optional ExpandParams(Function func) { + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return std::nullopt; bool has_tuple_param = std::any_of( @@ -42,12 +42,12 @@ Optional ExpandParams(Function func) { if (!has_tuple_param) return std::nullopt; - Array params; - Array bindings; + ffi::Array params; + ffi::Array bindings; std::function expand_param = [&](const Var& param) { if (auto sinfo = param->struct_info_.as()) { - Array internal_tuple; + ffi::Array internal_tuple; for (size_t i = 0; i < sinfo->fields.size(); i++) { auto name = static_cast(std::stringstream() << param->name_hint() << "_" << i) @@ -89,7 +89,7 @@ class TupleExpander : public ExprMutator { if (auto gvar = node->op.as()) { if (auto it = replacements_.find(gvar.value()); it != replacements_.end()) { - Array new_args; + ffi::Array new_args; std::function expand_arg = [&](const Expr& arg) { if (auto sinfo = arg->struct_info_.as()) { @@ -179,10 +179,10 @@ Pass ExpandTupleArguments() { "ExpandTupleArguments"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ExpandTupleArguments", ExpandTupleArguments); -}); +} } // namespace transform diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 819de35e20f0..6c213a9504a8 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -34,7 +34,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& meta_schedule::Builder builder = f_get_local_builder().cast(); ICHECK(builder.defined()) << "ValueError: The local builder is not defined!"; // fetch a local runner - meta_schedule::Runner runner{nullptr}; + meta_schedule::Runner runner{ffi::UnsafeInit()}; if (benchmark) { static const auto f_get_local_runner = tvm::ffi::Function::GetGlobalRequired("meta_schedule.runner.get_local_runner"); @@ -42,13 +42,13 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& ICHECK(runner.defined()) << "ValueError: The local runner is not defined!"; } // create an IRModule - IRModule mod = IRModule(Map( - {{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol, String("main"))}})); + IRModule mod = IRModule(ffi::Map( + {{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol, ffi::String("main"))}})); // fetch the number of physical cores static const auto f_cpu_count = tvm::ffi::Function::GetGlobalRequired("meta_schedule.cpu_count"); int num_threads = f_cpu_count(false).cast(); // store the results - Array results; + ffi::Array results; std::vector costs; // create a TuneContext meta_schedule::TuneContext task = meta_schedule::TuneContext( @@ -72,16 +72,16 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& /*cost_model=*/std::nullopt); int fail_count = 0, max_fail_count = 100; while (valid_count > 0 && fail_count < max_fail_count) { - Optional> candidates = + ffi::Optional> candidates = task->search_strategy.value()->GenerateMeasureCandidates(); if (!candidates.defined()) break; - Array builder_inputs; + ffi::Array builder_inputs; for (const meta_schedule::MeasureCandidate& candidate : candidates.value()) { builder_inputs.push_back(meta_schedule::BuilderInput( /*mod=*/candidate->sch->mod(), /*target=*/target)); } - Array builder_results = builder->Build(builder_inputs); + ffi::Array builder_results = builder->Build(builder_inputs); ICHECK_EQ(builder_results.size(), candidates.value().size()); int idx = 0; bool no_valid = true; // whether there is no valid schedule in this iteration @@ -95,7 +95,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& } fail_count += no_valid; // increase fail_count if there is no valid schedule if (benchmark) { - Array runner_inputs; + ffi::Array runner_inputs; int idx = 0; for (const meta_schedule::BuilderResult& builder_result : builder_results) { if (!builder_result->error_msg.has_value()) { @@ -106,7 +106,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& } idx++; } - Array runner_futures = runner->Run(runner_inputs); + ffi::Array runner_futures = runner->Run(runner_inputs); for (const meta_schedule::RunnerFuture& runner_future : runner_futures) { meta_schedule::RunnerResult runner_result = runner_future->Result(); if (runner_result->error_msg.has_value()) { @@ -153,12 +153,13 @@ Pass FewShotTuning(int valid_count, bool benchmark) { tvm::Target target = tvm::Target::Current(); ICHECK(target.defined()) << "Target is not set in current context"; // generate the few shot tuned prim funcs. - Map result; + ffi::Map result; for (const auto& [gv, func] : m->functions) { if (func->IsInstance() && !func->HasNonzeroAttr(tir::attr::kIsScheduled)) { - result.Set(gv, FewShotTunePrimFunc(GetRef(func.as()), - target, valid_count, benchmark)); + result.Set(gv, + FewShotTunePrimFunc(ffi::GetRef(func.as()), + target, valid_count, benchmark)); } else { result.Set(gv, func); } @@ -173,10 +174,10 @@ Pass FewShotTuning(int valid_count, bool benchmark) { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.FewShotTuning", FewShotTuning); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index c1aee73cc258..b714d4924359 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -50,7 +50,7 @@ class ConstantFolder : public ExprMutator { * \note Only TensorStructInfo is supported at this moment. Return std::nullopt * if the input struct info is not TensorStructInfo. */ - static Optional MatchConstShape(const StructInfo& struct_info) { + static ffi::Optional MatchConstShape(const StructInfo& struct_info) { // Only support single output for call_tir at this moment. const auto* tensor_sinfo = struct_info.as(); if (tensor_sinfo == nullptr) { @@ -73,8 +73,9 @@ class ConstantFolder : public ExprMutator { * \brief Pattern match op to constant array arguments. * \return The constant array arguments, or nullopt if match fails. */ - static Optional> MatchConstArrayArgs(const Array& args) { - Array res; + static ffi::Optional> MatchConstArrayArgs( + const ffi::Array& args) { + ffi::Array res; for (auto arg : args) { auto* ptr = arg.as(); if (!ptr) return std::nullopt; @@ -87,12 +88,12 @@ class ConstantFolder : public ExprMutator { * \brief Pattern match op to a TIR function and look it up. * \return The TIR function, or nullopt if pattern match fails. */ - Optional MatchPrimFunc(const Expr& op) { + ffi::Optional MatchPrimFunc(const Expr& op) { const GlobalVar& global_var = Downcast(op); // NOTE: as check works for nullptr(returns null) - Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); + ffi::Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); + return ffi::GetRef(pfunc); } return std::nullopt; } @@ -101,7 +102,7 @@ class ConstantFolder : public ExprMutator { * \brief Get a cached build version of func * \return The cached func, nullopt if func cannot be built. */ - Optional GetCachedBuild(tir::PrimFunc func) { + ffi::Optional GetCachedBuild(tir::PrimFunc func) { // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once // would be helpful for future cases where PrimFunc recursively call into each other Target eval_cpu_target{"llvm"}; @@ -110,7 +111,7 @@ class ConstantFolder : public ExprMutator { if (it != func_build_cache_.end()) { return it->second; } - Optional build_func = std::nullopt; + ffi::Optional build_func = std::nullopt; try { // Not all the primfunc can be directly built via llvm, for example, if a function is @@ -118,7 +119,7 @@ class ConstantFolder : public ExprMutator { // now // TODO(Hongyi): further check and narrow the scope of foldable function const auto pf = tvm::ffi::Function::GetGlobalRequired("tir.build"); - func = WithAttr(func, tvm::attr::kGlobalSymbol, String("tir_function")); + func = WithAttr(func, tvm::attr::kGlobalSymbol, ffi::String("tir_function")); ffi::Module rt_module = pf(func, eval_cpu_target).cast(); build_func = rt_module->GetFunction("tir_function"); } catch (const tvm::Error& err) { @@ -144,21 +145,22 @@ class ConstantFolder : public ExprMutator { // Try constant evaluate the function call // if failed return std::nullopt - Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array arr_args, - ffi::Shape shape, DataType ret_type) { + ffi::Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, + ffi::Array arr_args, ffi::Shape shape, + DataType ret_type) { // obtain function from the cache. - Optional func = GetCachedBuild(tir_func); + ffi::Optional func = GetCachedBuild(tir_func); if (!func) return std::nullopt; // here the vector size has an additional + 1 because we need to put ret_tensor at the end std::vector packed_args(arr_args.size() + 1); DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0}; - runtime::NDArray ret_tensor = runtime::NDArray::Empty(shape, ret_type, cpu_dev); + runtime::Tensor ret_tensor = runtime::Tensor::Empty(shape, ret_type, cpu_dev); // avoid set rvalue ref which get de-allocated later, store args in a vector // where temp_args[i] are lvalue ref that is stable - std::vector temp_args(arr_args.begin(), arr_args.end()); + std::vector temp_args(arr_args.begin(), arr_args.end()); size_t arg_offset = 0; for (; arg_offset < arr_args.size(); ++arg_offset) { @@ -174,15 +176,15 @@ class ConstantFolder : public ExprMutator { } // Returns the folded expr if the call is successfully folded to constant, otherwise null. - Optional VisitCallTIR(Call call) { + ffi::Optional VisitCallTIR(Call call) { // call_tir needs to have at least three arguments ICHECK_GE(call->args.size(), 2); - Optional func = MatchPrimFunc(call->args[0]); + ffi::Optional func = MatchPrimFunc(call->args[0]); ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; - Optional> arr_args = + ffi::Optional> arr_args = MatchConstArrayArgs(call->args[1].as()->fields); ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; - Optional shape = MatchConstShape(call->sinfo_args[0]); + ffi::Optional shape = MatchConstShape(call->sinfo_args[0]); bool output_not_tuple = call->sinfo_args.size() == 1; // Pattern 0: call constant function, const argument with const shape. if (func && arr_args && shape && output_not_tuple) { @@ -216,7 +218,7 @@ class ConstantFolder : public ExprMutator { if (op_node == nullptr) { return post_call; } - auto op = GetRef(op_node); + auto op = ffi::GetRef(op_node); if (op.same_as(call_tir_op)) { return VisitCallTIR(post_call).value_or(post_call); @@ -230,10 +232,10 @@ class ConstantFolder : public ExprMutator { // // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16, 16])) // - Array new_args; + ffi::Array new_args; for (auto arg : post_call->args) { if (arg->IsInstance()) { - Optional val = LookupBinding(Downcast(arg)); + ffi::Optional val = LookupBinding(Downcast(arg)); if (val.defined() && val.value()->IsInstance()) { new_args.push_back(val.value()); continue; @@ -254,7 +256,7 @@ class ConstantFolder : public ExprMutator { // If the legalized expression is call_tir, try to fold it. const CallNode* call = legalized_expr.as(); if (call && call->op.same_as(call_tir_op)) { - return VisitCallTIR(GetRef(call)).value_or(post_call); + return VisitCallTIR(ffi::GetRef(call)).value_or(post_call); } } else if (op->name == "relax.tensor_to_shape") { // Special handling for composite op "relax.tensor_to_shape" @@ -268,14 +270,14 @@ class ConstantFolder : public ExprMutator { Expr arg = post_call->args[0]; if (arg->IsInstance()) { Constant constant = Downcast(arg); - runtime::NDArray ndarray = constant->data; + runtime::Tensor ndarray = constant->data; ICHECK_EQ(ndarray->device.device_type, kDLCPU); - ICHECK(ndarray->strides == nullptr); + ICHECK(ndarray.IsContiguous()); ICHECK_EQ(ndarray->byte_offset, 0); ICHECK_EQ(ndarray->ndim, 1); const int64_t* data = static_cast(ndarray->data); int64_t num_elems = ndarray->shape[0]; - Array shape_values; + ffi::Array shape_values; for (int64_t i = 0; i < num_elems; i++) { shape_values.push_back(IntImm(DataType::Int(64), data[i])); } @@ -286,17 +288,17 @@ class ConstantFolder : public ExprMutator { // TODO(sunggg): revisit this when we extend ConstantFolding to fold ffi::Function. Expr arg = post_call->args[0]; ShapeExpr shape = Downcast(arg); - Array values = shape->values; - Array arr; + ffi::Array values = shape->values; + ffi::Array arr; bool is_known = true; for (size_t i = 0; i < values.size(); i++) { PrimExpr val = values[i]; - arr.push_back(GetRef(val.as())); + arr.push_back(ffi::GetRef(val.as())); is_known &= (val.dtype() == DataType::Int(64)); } if (is_known) { const auto func = tvm::ffi::Function::GetGlobalRequired("relax.run.shape_to_tensor"); - runtime::NDArray vals = func(arr).cast(); + runtime::Tensor vals = func(arr).cast(); return Constant(vals); } } @@ -306,7 +308,7 @@ class ConstantFolder : public ExprMutator { } Expr VisitExpr_(const VarNode* op) final { - Optional opt = LookupBinding(GetRef(op)); + ffi::Optional opt = LookupBinding(ffi::GetRef(op)); // `as` check checks if opt is not null and is instance of constant if (opt.as()) { return opt.value(); @@ -315,7 +317,7 @@ class ConstantFolder : public ExprMutator { } // cache for function build, via structural equality - std::unordered_map, StructuralHash, StructuralEqual> + std::unordered_map, StructuralHash, StructuralEqual> func_build_cache_; }; @@ -328,10 +330,10 @@ Pass FoldConstant() { return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant); -}); +} } // namespace transform diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 4deb720342f2..561695787de8 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -48,10 +48,10 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { transform::FusionPatternNode::RegisterReflection(); transform::PatternCheckContextNode::RegisterReflection(); -}); +} /* Note on Fusing algorithm: @@ -120,10 +120,10 @@ class GraphCreator : public ExprVisitor { // true. const auto* func = it.second.as(); if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive) || - func->GetAttr(attr::kCodegen).has_value()) { + func->GetAttr(attr::kCodegen).has_value()) { continue; } - creator(GetRef(func)); + creator(ffi::GetRef(func)); } // The algorithm of the graph creator ensures that each created node will be added to the @@ -195,7 +195,7 @@ class GraphCreator : public ExprVisitor { static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); OpPatternKind pattern = OpPatternKind::kOpaque; - Array args = call->args; + ffi::Array args = call->args; // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the // function attribute and visit the arguments one by one. @@ -209,7 +209,7 @@ class GraphCreator : public ExprVisitor { // Override args for call_tir args = Downcast(call->args[1])->fields; - Optional opt_pattern = func->GetAttr("op_pattern"); + ffi::Optional opt_pattern = func->GetAttr("op_pattern"); if (opt_pattern.defined()) { pattern = static_cast(Downcast(opt_pattern)->value); } else { @@ -222,7 +222,7 @@ class GraphCreator : public ExprVisitor { for (const Expr& arg : args) { ICHECK(IsLeafOrTuple(arg)) << "FuseOps expects all relax::Call nodes to have non-nested arguments, " - << "but " << GetRef(call) << " has argument " << arg + << "but " << ffi::GetRef(call) << " has argument " << arg << ", which is neither a leaf node nor a relax::Tuple"; VisitLeaf(arg, binding_var_node, pattern); } @@ -297,7 +297,7 @@ class GraphCreator : public ExprVisitor { */ IndexedForwardGraph::Node* CreateNode(const Object* key) { ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) - << "The object " << GetRef(key) << " appears at multiple definition sites."; + << "The object " << ffi::GetRef(key) << " appears at multiple definition sites."; auto* node = arena_->make(); graph_.node_map[key] = node; return node; @@ -312,12 +312,12 @@ class GraphCreator : public ExprVisitor { void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { auto it = graph_.node_map.find(key); ICHECK(it != graph_.node_map.end() && it->second == node) - << "Cannot add node " << GetRef(key) << " to the post-DFS order, " + << "Cannot add node " << ffi::GetRef(key) << " to the post-DFS order, " << "because the node for this object has not yet been created."; // We only set the reference of the node when adding it to the post-dfs order. Thus, if the // reference of a node is already set, it must have been appended to the post-dfs order. - ICHECK(node->ref == nullptr) << "Cannot add node " << GetRef(key) + ICHECK(node->ref == nullptr) << "Cannot add node " << ffi::GetRef(key) << " to the post-DFS order, " << "because it has already been added."; @@ -354,7 +354,7 @@ class GraphCreator : public ExprVisitor { */ void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) - << "The input node " << GetRef(node->ref) + << "The input node " << ffi::GetRef(node->ref) << " cannot have have its OpPatternKind set more than once."; initialized_nodes_.insert(node); node->pattern = pattern; @@ -481,7 +481,7 @@ class FunctionCreator : public ExprMutator { * It will become the value of the kComposite attribute of the created function. * \note The created function won't be returned immediately. It's stored in the `function_` field. */ - void CreateFunction(Map group_attrs) { + void CreateFunction(ffi::Map group_attrs) { // Step 1. Start constructing a new dataflow block. builder_->BeginDataflowBlock(); @@ -493,16 +493,16 @@ class FunctionCreator : public ExprMutator { ICHECK(!item_indices.empty()); int param_idx = tuple_param_idx_[tuple_arg]; Var param = params_[param_idx]; - String param_name = params_[param_idx]->name_hint(); + ffi::String param_name = params_[param_idx]->name_hint(); TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); - Array item_args; - Array item_params; + ffi::Array item_args; + ffi::Array item_params; item_args.reserve(item_indices.size()); item_params.reserve(item_indices.size()); for (int item_idx : item_indices) { Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]); - item_args.push_back(TupleGetItem(GetRef(tuple_arg), item_idx)); + item_args.push_back(TupleGetItem(ffi::GetRef(tuple_arg), item_idx)); item_params.push_back(item_param); tuple_get_item_remap[tuple_arg][item_idx] = item_param; } @@ -513,7 +513,7 @@ class FunctionCreator : public ExprMutator { } // Step 3. Visit each binding and collect outputs one by one. - Array outputs(output_vars_.size(), Expr()); + ffi::Array outputs(output_vars_.size(), Expr()); for (const Binding& binding : bindings_) { // Special handing for TupleGetItem. if (const auto* var_binding = binding.as()) { @@ -561,7 +561,7 @@ class FunctionCreator : public ExprMutator { /*ret_struct_info=*/std::nullopt, // /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); - Array free_vars = + ffi::Array free_vars = FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); @@ -577,15 +577,15 @@ class FunctionCreator : public ExprMutator { } /*! \brief The original bindings of the function */ - Array bindings_; + ffi::Array bindings_; /*! \brief The parameters of the function */ - Array params_; + ffi::Array params_; /*! \brief The arguments to call the function on the caller side */ - Array arguments_; + ffi::Array arguments_; /*! \brief The name for the fused function */ - String name_hint_ = "fused"; + ffi::String name_hint_ = "fused"; /*! \brief The constructed Relax function */ - Optional function_ = std::nullopt; + ffi::Optional function_ = std::nullopt; private: std::optional GetOutputIndex(Var v) { @@ -612,8 +612,9 @@ class FunctionCreator : public ExprMutator { const auto* var = expr.as(); if ((var == nullptr || defined_vars_.count(var) == 0) && (lift_constant_ || !expr->IsInstance())) { - String name = var != nullptr ? var->name_hint() - : String("param_" + std::to_string(n_param_for_const_++)); + ffi::String name = var != nullptr + ? var->name_hint() + : ffi::String("param_" + std::to_string(n_param_for_const_++)); StructInfo param_sinfo = GetStructInfo(expr); if (!IsInlinableConstants(expr)) { Var param(std::move(name), GetStructInfo(expr)); @@ -719,8 +720,8 @@ class OperatorFusor : public ExprMutator { * \brief The main transformation on the IRModule * \return The new IRModule after transformation */ - IRModule Transform(const Array& entry_function_names = {}) { - Array entry_functions; + IRModule Transform(const ffi::Array& entry_function_names = {}) { + ffi::Array entry_functions; if (entry_function_names.empty()) { entry_functions = mod_->GetGlobalVars(); } else { @@ -733,7 +734,7 @@ class OperatorFusor : public ExprMutator { // Only visit Relax functions with neither attr::kPrimitive nor // attr::kCodegen. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive) && - !func->GetAttr(attr::kCodegen).has_value()) { + !func->GetAttr(attr::kCodegen).has_value()) { auto updated_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, updated_func); } @@ -882,7 +883,7 @@ class OperatorFusor : public ExprMutator { * \param bindings The bindings to be collected * \note The function update is done by `AppendBinding(...)` */ - void CollectFuncBindings(const Array& bindings) { + void CollectFuncBindings(const ffi::Array& bindings) { for (const Binding& binding : bindings) { // If the binding is the only binding in its group, there is no need to create a new function. Group* group = GetGroupFromBinding(binding); @@ -898,7 +899,7 @@ class OperatorFusor : public ExprMutator { } } - void CollectFuncBoundary(const Array& bindings) { + void CollectFuncBoundary(const ffi::Array& bindings) { for (const Binding& binding : bindings) { // Step 1. Get current binding's group Group* cur_group = GetGroupFromBinding(binding); @@ -969,8 +970,8 @@ class OperatorFusor : public ExprMutator { * \param args The arguments to be updated * \return The updated arguments */ - Array UpdateArgs(const Array& args) { - Array new_args; + ffi::Array UpdateArgs(const ffi::Array& args) { + ffi::Array new_args; new_args.reserve(args.size()); for (const Expr& arg : args) { new_args.push_back(VisitExpr(arg)); @@ -980,7 +981,7 @@ class OperatorFusor : public ExprMutator { private: // Topologically sort bindings according to the group dependency relations. - Array TopoSortByGroupDep(const Array& bindings) { + ffi::Array TopoSortByGroupDep(const ffi::Array& bindings) { std::unordered_map> bindings_per_group; // The order to visit groups should respect the original order of bindings as much as possible. std::vector group_order; @@ -1003,7 +1004,7 @@ class OperatorFusor : public ExprMutator { } }; - Array sorted; + ffi::Array sorted; for (auto g : group_order) { dfs_visit(g, [&sorted, &bindings_per_group](Group* leaf) { @@ -1054,7 +1055,7 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { IRModule MakeGroupedFunctions( IRModule mod, const std::unordered_map& partition, - bool lift_constants, const Array& entry_function_names) { + bool lift_constants, const ffi::Array& entry_function_names) { return OperatorFusor(mod, partition, lift_constants).Transform(entry_function_names); } @@ -1069,19 +1070,20 @@ class PatternBasedPartitioner : ExprVisitor { using PatternCheckContext = transform::PatternCheckContext; using ExprVisitor::VisitExpr_; using FCheckMatch = ffi::TypedFunction; - using FAttrsGetter = ffi::TypedFunction(const Map&)>; + using FAttrsGetter = + ffi::TypedFunction(const ffi::Map&)>; - static GroupMap Run(String pattern_name, DFPattern pattern, - Map annotation_patterns, FCheckMatch check, Expr expr, - support::Arena* arena, FAttrsGetter attrs_getter) { + static GroupMap Run(ffi::String pattern_name, DFPattern pattern, + ffi::Map annotation_patterns, FCheckMatch check, + Expr expr, support::Arena* arena, FAttrsGetter attrs_getter) { PatternBasedPartitioner part(pattern_name, pattern, annotation_patterns, check, arena, attrs_getter); part.VisitExpr(expr); return part.group_map_; } - PatternBasedPartitioner(String pattern_name, DFPattern pattern, - Map annotation_patterns, FCheckMatch check, + PatternBasedPartitioner(ffi::String pattern_name, DFPattern pattern, + ffi::Map annotation_patterns, FCheckMatch check, support::Arena* arena, FAttrsGetter attrs_getter) : pat_name_(pattern_name), pat_(pattern), @@ -1091,7 +1093,7 @@ class PatternBasedPartitioner : ExprVisitor { attrs_getter_(attrs_getter) {} void VisitBindingBlock_(const DataflowBlockNode* block) final { - current_block_use_def_ = DataflowBlockUseDef(GetRef(block)); + current_block_use_def_ = DataflowBlockUseDef(ffi::GetRef(block)); ExprVisitor::VisitBindingBlock_(block); current_block_use_def_ = {}; } @@ -1112,14 +1114,14 @@ class PatternBasedPartitioner : ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { VisitVarDef(binding->var); - if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef(call), bindings_)) { + if (auto matches_opt = ExtractMatchedExpr(pat_, ffi::GetRef(call), bindings_)) { const auto& context = CreatePatternCheckContext(call, matches_opt.value()); if (check_ != nullptr && !check_(context)) { return; } for (const auto& [pat, match] : matches_opt.value()) { - if ((pat->IsInstance() && match != GetRef(call)) || + if ((pat->IsInstance() && match != ffi::GetRef(call)) || pat->IsInstance()) { auto g = GetGroup(match); if (g && g->FindRoot()->num_nodes > 1) { @@ -1164,7 +1166,7 @@ class PatternBasedPartitioner : ExprVisitor { // the previous group. For example, when there are two back-to-back conv2d ops, the output // of the first conv2d is matched to the input of the second conv2d via a wildcard pattern. // But we must avoid merging the first conv2d into the group of the second conv2d. - if ((pat->IsInstance() && match != GetRef(call)) || + if ((pat->IsInstance() && match != ffi::GetRef(call)) || pat->IsInstance()) { // Put the bound variable on the LHS into the same parent group. AddToGroup(value_to_bound_var_[match], parent_group); @@ -1196,28 +1198,28 @@ class PatternBasedPartitioner : ExprVisitor { } PatternCheckContext CreatePatternCheckContext(const CallNode* call, - const Map& matched_result) { - Map annotated_expr; + const ffi::Map& matched_result) { + ffi::Map annotated_expr; for (const auto& it : annotation_pat_) { if (matched_result.count(it.second)) { annotated_expr.Set(it.first, matched_result[it.second]); } } - Map matched_bindings; + ffi::Map matched_bindings; for (const auto& [pat, match] : matched_result) { if (pat->IsInstance() || pat->IsInstance()) { matched_bindings.Set(value_to_bound_var_[match], match); } } - return PatternCheckContext(GetRef(call), annotated_expr, matched_bindings, + return PatternCheckContext(ffi::GetRef(call), annotated_expr, matched_bindings, current_block_use_def_, value_to_bound_var_); } // check if a previous matched subgraph is subsumed by the current matched result - bool GraphSubsumedInMatchedValues(const Array& vars_in_graph, - const Map& matched_result) { + bool GraphSubsumedInMatchedValues(const ffi::Array& vars_in_graph, + const ffi::Map& matched_result) { std::set matched_vars; for (const auto& [pat, match] : matched_result) { if ((pat->IsInstance() || pat->IsInstance())) @@ -1230,17 +1232,17 @@ class PatternBasedPartitioner : ExprVisitor { return true; } - String pat_name_; + ffi::String pat_name_; DFPattern pat_; - Map annotation_pat_; + ffi::Map annotation_pat_; FCheckMatch check_; support::Arena* arena_; FAttrsGetter attrs_getter_; - Map bindings_; - Map value_to_bound_var_; - Map> current_block_use_def_; + ffi::Map bindings_; + ffi::Map value_to_bound_var_; + ffi::Map> current_block_use_def_; GroupMap group_map_; - std::map> vars_in_group_; + std::map> vars_in_group_; }; /*! @@ -1263,8 +1265,8 @@ class CompositeFunctionAnnotator : public ExprMutator { } const auto& base_func = (*it).second; if (const auto* func = base_func.as()) { - if (func->GetAttr(attr::kComposite).has_value() || - func->GetAttr(attr::kCodegen).has_value()) { + if (func->GetAttr(attr::kComposite).has_value() || + func->GetAttr(attr::kCodegen).has_value()) { continue; } @@ -1284,15 +1286,15 @@ class CompositeFunctionAnnotator : public ExprMutator { if (auto it = gvar_map_.find(gvar); it != gvar_map_.end()) { return Call(it->second, call_node->args); } - auto func = builder_->GetContextIRModule()->Lookup(GetRef(gvar)); - if (auto composite_name = func->GetAttr(attr::kComposite)) { + auto func = builder_->GetContextIRModule()->Lookup(ffi::GetRef(gvar)); + if (auto composite_name = func->GetAttr(attr::kComposite)) { auto new_func = Downcast(VisitExpr(func)); auto codegen_name = GetCodegenName(composite_name.value()); auto gsymbol = gvar->name_hint + "_" + codegen_name; new_func = WithAttrs(new_func, {{attr::kCodegen, codegen_name}, {tvm::attr::kGlobalSymbol, gsymbol}}); new_func = WithoutAttr(std::move(new_func), tvm::relax::attr::kPrimitive); - builder_->GetContextIRModule()->Remove(GetRef(gvar)); + builder_->GetContextIRModule()->Remove(ffi::GetRef(gvar)); auto new_gvar = builder_->AddFunction(new_func, gsymbol); gvar_map_[gvar] = new_gvar; return Call(new_gvar, call_node->args); @@ -1304,7 +1306,7 @@ class CompositeFunctionAnnotator : public ExprMutator { Expr VisitExpr_(const FunctionNode* func_node) final { Function f_inner = Downcast(ExprMutator::VisitExpr_(func_node)); - if (!func_node->GetAttr(attr::kComposite)) { + if (!func_node->GetAttr(attr::kComposite)) { // This lambda function doesn't have `attr::kComposite`, so it // was not produced by FuseOps. return f_inner; @@ -1312,8 +1314,8 @@ class CompositeFunctionAnnotator : public ExprMutator { f_inner = WithoutAttr(std::move(f_inner), tvm::relax::attr::kPrimitive); - Array param_vars; - Array params; + ffi::Array param_vars; + ffi::Array params; for (auto v : func_node->params) { Var new_v(v->name_hint(), GetStructInfo(v)); @@ -1341,13 +1343,13 @@ class CompositeFunctionAnnotator : public ExprMutator { std::unordered_map gvar_map_; }; -IRModule FuseOpsByPattern(const tvm::Array& patterns, IRModule mod, +IRModule FuseOpsByPattern(const tvm::ffi::Array& patterns, IRModule mod, bool bind_constants, bool annotate_codegen, - Array entry_function_names) { + ffi::Array entry_function_names) { support::Arena arena; for (const auto& pattern : patterns) { - Array entry_functions; + ffi::Array entry_functions; if (entry_function_names.size()) { for (const auto& name : entry_function_names) { auto gv = mod->GetGlobalVar(name); @@ -1363,8 +1365,8 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, } const FunctionNode* function = base_func.as(); if (function->GetAttr(attr::kPrimitive).value_or(false) || - function->GetAttr(attr::kComposite).has_value() || - function->GetAttr(attr::kCodegen).has_value()) { + function->GetAttr(attr::kComposite).has_value() || + function->GetAttr(attr::kCodegen).has_value()) { continue; } entry_functions.push_back(Downcast(base_func)); @@ -1379,7 +1381,7 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, CHECK(!group_map.count(key)) << "ValueError: " << "IRModule is invalid. " - << "The object " << GetRef(key) << " appears in multiple partitions, " + << "The object " << ffi::GetRef(key) << " appears in multiple partitions, " << "which can occur when the IRModule was not single-site assignment"; group_map.insert({key, value}); } @@ -1395,10 +1397,11 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, namespace transform { -FusionPattern::FusionPattern(String name, DFPattern pattern, - Map annotation_patterns, - Optional check, Optional attrs_getter) { - ObjectPtr n = make_object(); +FusionPattern::FusionPattern(ffi::String name, DFPattern pattern, + ffi::Map annotation_patterns, + ffi::Optional check, + ffi::Optional attrs_getter) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->pattern = std::move(pattern); n->annotation_patterns = std::move(annotation_patterns); @@ -1407,21 +1410,22 @@ FusionPattern::FusionPattern(String name, DFPattern pattern, data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.transform.FusionPattern", - [](String name, DFPattern pattern, Map annotation_patterns, - Optional check, Optional attrs_getter) { + [](ffi::String name, DFPattern pattern, ffi::Map annotation_patterns, + ffi::Optional check, ffi::Optional attrs_getter) { return FusionPattern(name, pattern, annotation_patterns, check, attrs_getter); }); -}); +} -PatternCheckContext::PatternCheckContext(Expr matched_expr, Map annotated_expr, - Map matched_bindings, - Map> var_usages, - Map value_to_bound_var) { - ObjectPtr n = make_object(); +PatternCheckContext::PatternCheckContext(Expr matched_expr, + ffi::Map annotated_expr, + ffi::Map matched_bindings, + ffi::Map> var_usages, + ffi::Map value_to_bound_var) { + ObjectPtr n = ffi::make_object(); n->matched_expr = std::move(matched_expr); n->annotated_expr = std::move(annotated_expr); n->matched_bindings = std::move(matched_bindings); @@ -1443,13 +1447,13 @@ Pass FuseOps(int fuse_opt_level) { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.FuseOps", FuseOps); -}); +} -Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants, - bool annotate_codegen, const Array& entry_function_names) { +Pass FuseOpsByPattern(const tvm::ffi::Array& patterns, bool bind_constants, + bool annotate_codegen, const ffi::Array& entry_function_names) { auto pass_func = // [=](IRModule m, PassContext pc) { return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen, @@ -1461,10 +1465,10 @@ Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_const /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.FuseOpsByPattern", FuseOpsByPattern); -}); +} } // namespace transform diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index db3916bc2210..549cd2197b4b 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -39,10 +39,10 @@ namespace tir { */ class SymbolicMatcher : ExprFunctor { public: - explicit SymbolicMatcher(arith::Analyzer* analyzer, Map* var_remap) + explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map* var_remap) : analyzer_(analyzer), var_remap_(var_remap) {} - void Match(const Array& params, const Array& args) { + void Match(const ffi::Array& params, const ffi::Array& args) { CHECK_EQ(params.size(), args.size()); for (size_t i = 0; i < params.size(); ++i) { Match(params[i], args[i]); @@ -66,15 +66,15 @@ class SymbolicMatcher : ExprFunctor(); \ - if (rhs) { \ - VisitExpr(op->a, rhs->a); \ - VisitExpr(op->b, rhs->b); \ - } else { \ - must_prove_ = must_prove_ && (GetRef(op) == other); \ - } \ +#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \ + void VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + if (rhs) { \ + VisitExpr(op->a, rhs->a); \ + VisitExpr(op->b, rhs->b); \ + } else { \ + must_prove_ = must_prove_ && (ffi::GetRef(op) == other); \ + } \ } TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode); @@ -98,7 +98,7 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << GetRef(op) + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an integer argument with value " << op->value << ", " << "but was provided with the argument " << other; } @@ -107,7 +107,7 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << GetRef(op) + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an float argument with value " << op->value << ", " << "but was provided with the argument " << other; } @@ -116,7 +116,7 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs) { - LOG(FATAL) << "Parameter expression " << GetRef(op) << " expected an cast to " + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an cast to " << op->dtype << " as the argument, " << "but was provided with the argument " << other; } @@ -124,13 +124,14 @@ class SymbolicMatcher : ExprFunctor(op); + auto lhs = ffi::GetRef(op); if (lhs.same_as(rhs)) { // Reference identity, no further checks needed. } else if (op->dtype.code() != rhs->dtype.code()) { - LOG(FATAL) << "Parameter expression " << GetRef(op) << " with dtype " << op->dtype - << " cannot match to argument " << rhs << " with dtype " << rhs.dtype(); + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " with dtype " + << op->dtype << " cannot match to argument " << rhs << " with dtype " + << rhs.dtype(); } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { VisitExpr((*it).second, rhs); } else { @@ -144,12 +145,12 @@ class SymbolicMatcher : ExprFunctortrue_value, rhs->true_value); VisitExpr(op->false_value, rhs->false_value); } else { - must_prove_ = must_prove_ && (GetRef(op) == other); + must_prove_ = must_prove_ && (ffi::GetRef(op) == other); } } arith::Analyzer* analyzer_; - Map* var_remap_; + ffi::Map* var_remap_; PrimExpr must_prove_ = Bool(true); }; @@ -158,8 +159,8 @@ class SymbolicMatcher : ExprFunctor& buffer_map, - const Map& var_map) { + explicit FuseTIRBufferSubstitutor(const ffi::Map& buffer_map, + const ffi::Map& var_map) { buffer_remap_ = buffer_map; var_remap_ = var_map; for (const auto& [src, tgt] : buffer_map) { @@ -171,16 +172,16 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { Buffer SubstituteAllocatedBuffer(Buffer buffer) { ICHECK(buffer_remap_.find(buffer) == buffer_remap_.end()); - Array shape = + ffi::Array shape = MutateArray(buffer->shape, [this](const PrimExpr& expr) { return this->VisitExpr(expr); }); - Array strides = MutateArray( + ffi::Array strides = MutateArray( buffer->strides, [this](const PrimExpr& expr) { return this->VisitExpr(expr); }); PrimExpr elem_offset = this->VisitExpr(buffer->elem_offset); if (shape.same_as(buffer->shape) && strides.same_as(buffer->strides) && elem_offset.same_as(buffer->elem_offset)) { return buffer; } else { - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->shape = std::move(shape); n->strides = std::move(strides); n->elem_offset = std::move(elem_offset); @@ -192,10 +193,10 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { private: PrimExpr VisitExpr_(const VarNode* _op) final { - if (auto it = var_remap_.find(GetRef(_op)); it != var_remap_.end()) { + if (auto it = var_remap_.find(ffi::GetRef(_op)); it != var_remap_.end()) { return (*it).second; } else { - return GetRef(_op); + return ffi::GetRef(_op); } } @@ -206,7 +207,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { return load; } else { - auto n = make_object(*load.get()); + auto n = ffi::make_object(*load.get()); n->buffer = buffer; return BufferLoad(n); } @@ -219,7 +220,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { return store; } else { - auto n = make_object(*store.get()); + auto n = ffi::make_object(*store.get()); n->buffer = buffer; return BufferStore(n); } @@ -239,7 +240,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { region.same_as(match_buffer->source->region)) { return match_buffer; } else { - auto n = make_object(*match_buffer.get()); + auto n = ffi::make_object(*match_buffer.get()); n->buffer = tgt_buffer; n->source = BufferRegion(src_buffer, region); return MatchBufferRegion(n); @@ -257,15 +258,15 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { }; // Step 1. Mutate `match_buffers`. - Array match_buffers = + ffi::Array match_buffers = MutateArray(block->match_buffers, f_mutate_match_buffers); // Step 2. Mutate the read/write region. - Array reads = MutateArray(block->reads, f_mutate_read_write_region); - Array writes = MutateArray(block->writes, f_mutate_read_write_region); + ffi::Array reads = MutateArray(block->reads, f_mutate_read_write_region); + ffi::Array writes = MutateArray(block->writes, f_mutate_read_write_region); // Step 3. Mutate the Allocate Buffers. - Array alloc_buffers = MutateArray(block->alloc_buffers, [this](const Buffer& buffer) { - return SubstituteAllocatedBuffer(buffer); - }); + ffi::Array alloc_buffers = + MutateArray(block->alloc_buffers, + [this](const Buffer& buffer) { return SubstituteAllocatedBuffer(buffer); }); reads = UnionAccessRegion(reads); writes = UnionAccessRegion(writes); @@ -288,16 +289,16 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { private: /*! \brief Mapping from src buffer to tgt buffer. */ - Map buffer_remap_; + ffi::Map buffer_remap_; /*! \brief Mapping from src tir var to tgt var. */ - Map var_remap_; + ffi::Map var_remap_; - Array UnionAccessRegion(const Array& regions) const { + ffi::Array UnionAccessRegion(const ffi::Array& regions) const { // For now we only allow Buffer access the same elements. // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to `A[vi, vj]` // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. // Note: the order of return region should remain the same as the first occurrence of the region - Array ret; + ffi::Array ret; std::unordered_map buffer_region_set; for (const BufferRegion& region : regions) { @@ -343,7 +344,7 @@ class BlockNameDeduplicator : public tir::StmtMutator { Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(tir::StmtMutator::VisitStmt_(op)); - String name = GetUniqueName(block->name_hint); + ffi::String name = GetUniqueName(block->name_hint); if (name == block->name_hint) { return block; @@ -355,29 +356,64 @@ class BlockNameDeduplicator : public tir::StmtMutator { } } - String GetUniqueName(const String& prefix) { - String unique_prefix = prefix; - auto it = name_count_.find(prefix); - while (name_count_.count(unique_prefix)) { - unique_prefix = prefix + "_" + std::to_string(++it->second); + ffi::String GetUniqueName(const ffi::String& prefix) { + std::string str_prefix = std::string(prefix); + + // Find where the trailing digits start + size_t base_len = str_prefix.length(); + while (base_len > 0 && std::isdigit(str_prefix[base_len - 1])) { + --base_len; + } + + std::string base_name; + int64_t start_num = 0; + bool has_suffix = base_len < str_prefix.length(); + + if (has_suffix) { + base_name = str_prefix.substr(0, base_len); + try { + start_num = std::stoll(str_prefix.substr(base_len)); + } catch (const std::out_of_range&) { + // Fallback: if the number is too large, treat the whole string as a base name. + has_suffix = false; + base_name = str_prefix; + } + } else { + base_name = str_prefix; + } + + // Check if the original name is available + ffi::String candidate = prefix; + if (!name_count_.count(candidate)) { + name_count_[candidate] = 0; + return candidate; + } + + // Generate unique name by incrementing the numeric suffix + int64_t counter = has_suffix ? start_num + 1 : 1; + while (true) { + candidate = ffi::String(base_name + std::to_string(counter)); + if (!name_count_.count(candidate)) { + name_count_[candidate] = 0; + return candidate; + } + ++counter; + ICHECK_GT(counter, 0) << "Counter overflow when generating unique block name for prefix: " + << prefix; } - name_count_[unique_prefix] = 0; - return unique_prefix; } - // TODO(relax-team): It should detects the number suffix and do renaming properly - // e.g. GetUniqueName("name1") should return "name2" instead of "name10". /*! \brief The count map to make block name unique. */ - std::unordered_map name_count_; + std::unordered_map name_count_; }; } // namespace tir namespace relax { -static Array GetInplaceOutputIndices(const Array& inplace_indices, - int num_inputs) { - Array ret; +static ffi::Array GetInplaceOutputIndices(const ffi::Array& inplace_indices, + int num_inputs) { + ffi::Array ret; int last_idx = num_inputs; for (auto idx : inplace_indices) { int i = idx.IntValue(); @@ -396,7 +432,7 @@ static Array GetInplaceOutputIndices(const Array& inplace_indi class RelaxToTIRVarMapCollector : public ExprVisitor { public: explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {} - static Map Collect(const IRModule& mod, const Function& func) { + static ffi::Map Collect(const IRModule& mod, const Function& func) { RelaxToTIRVarMapCollector visitor(mod); visitor(func->body); return visitor.relax_to_tir_var_map_; @@ -414,7 +450,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " - << GetRef(call); + << ffi::GetRef(call); CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_); } @@ -426,7 +462,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { const auto& relax_args = Downcast(call->args[1])->fields; - Array relax_results; + ffi::Array relax_results; if (lhs_var->IsInstance()) { relax_results = Downcast(lhs_var)->fields; } else { @@ -437,7 +473,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { size_t num_inputs = relax_args.size(); size_t num_outputs = relax_results.size(); - Array output_idxs; + ffi::Array output_idxs; if (in_place) { const auto* attrs = call->attrs.as(); CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; @@ -479,7 +515,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { private: /*! \brief The IRModule */ const IRModule& mod_; - Map relax_to_tir_var_map_; + ffi::Map relax_to_tir_var_map_; Var current_var_; }; @@ -491,8 +527,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param gv The global var of relax subfunction to be fused into one PrimFunc * \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call) */ - static std::pair> GetFusedTIR(const IRModule& mod, - const GlobalVar& gv) { + static std::pair> GetFusedTIR(const IRModule& mod, + const GlobalVar& gv) { FusedTIRConstructor visitor(mod, gv->name_hint); BaseFunc f = mod->Lookup(gv); CHECK(f->IsInstance()) @@ -500,7 +536,7 @@ class FusedTIRConstructor : public ExprVisitor { CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) << "Expected a function with attr `kPrimitive`"; visitor(Downcast(f)); - Array inplace_indices; + ffi::Array inplace_indices; for (size_t idx : visitor.inplace_indices_) { inplace_indices.push_back(Integer(idx)); } @@ -508,18 +544,19 @@ class FusedTIRConstructor : public ExprVisitor { } private: - explicit FusedTIRConstructor(const IRModule& mod, const String& func_name) + explicit FusedTIRConstructor(const IRModule& mod, const ffi::String& func_name) : mod_(mod), func_name_(func_name) {} void VisitExpr_(const FunctionNode* func) final { - auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, GetRef(func)); - std::vector> prim_func_params; + auto relax_to_tir_var_map = + RelaxToTIRVarMapCollector::Collect(mod_, ffi::GetRef(func)); + std::vector> prim_func_params; for (const Var& relax_param : func->params) { size_t size_before = prim_func_params.size(); CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map.Get(relax_param)); - auto param_buffers = [&]() -> Array { - Array out; + auto param_buffers = [&]() -> ffi::Array { + ffi::Array out; for (size_t i = size_before; i < prim_func_params.size(); i++) { if (auto buf = prim_func_params[i].as()) { out.push_back(buf.value()); @@ -565,7 +602,7 @@ class FusedTIRConstructor : public ExprVisitor { ICHECK(it != func_info_.expr2buffers.end()) << "Fail to detect output buffers for function body"; - const Array& buffers = (*it).second; + const ffi::Array& buffers = (*it).second; // map of input buffers to indices (helpful for detecting in-place inputs) std::unordered_map buffer_to_idx; @@ -635,7 +672,7 @@ class FusedTIRConstructor : public ExprVisitor { ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " - << GetRef(call); + << ffi::GetRef(call); // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); @@ -659,7 +696,7 @@ class FusedTIRConstructor : public ExprVisitor { // Step 5. Map input arguments to buffer MapInputBuffer(prim_func, call->args[1]); - const Array>& output_buffer_shapes = GetCallTIROutputShapes(call); + const ffi::Array>& output_buffer_shapes = GetCallTIROutputShapes(call); AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes); @@ -696,14 +733,14 @@ class FusedTIRConstructor : public ExprVisitor { } end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_sinfo->fields[tuple_get_item->index]); func_info_.expr2buffers.Set( - GetRef(tuple_get_item), + ffi::GetRef(tuple_get_item), {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); } } void VisitExpr_(const TupleNode* tuple) final { ExprVisitor::VisitExpr_(tuple); - Array buffers; + ffi::Array buffers; for (const Expr& expr : tuple->fields) { auto it = func_info_.expr2buffers.find(expr); if (it != func_info_.expr2buffers.end()) { @@ -711,7 +748,7 @@ class FusedTIRConstructor : public ExprVisitor { } } if (!buffers.empty()) { - func_info_.expr2buffers.Set(GetRef(tuple), buffers); + func_info_.expr2buffers.Set(ffi::GetRef(tuple), buffers); } } @@ -723,7 +760,7 @@ class FusedTIRConstructor : public ExprVisitor { * \brief Get the number of outputs for a call_tir node. * \return The number of outputs. */ - static Array> GetCallTIROutputShapes(const CallNode* call) { + static ffi::Array> GetCallTIROutputShapes(const CallNode* call) { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_)); @@ -734,7 +771,7 @@ class FusedTIRConstructor : public ExprVisitor { return shape_expr->values; }; if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { - Array> shapes; + ffi::Array> shapes; for (const StructInfo& field : tuple_sinfo->fields) { const auto* tensor_sinfo = field.as(); CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " @@ -754,11 +791,11 @@ class FusedTIRConstructor : public ExprVisitor { } /*! \brief Map old TIR func param buffer to new buffer, and then update `buffer_subst_map` */ - void MapArgsToBuffer(const Array args, const Array& buffers) { + void MapArgsToBuffer(const ffi::Array args, const ffi::Array& buffers) { size_t buffer_idx = 0; for (const Expr& arg : args) { if (const auto* v = arg.as()) { - auto it = func_info_.expr2buffers.find(GetRef(v)); + auto it = func_info_.expr2buffers.find(ffi::GetRef(v)); // Substitute the buffer with the already allocated one if it is an intermediate var if (it != func_info_.expr2buffers.end()) { for (const tir::Buffer& target_buffer : (*it).second) { @@ -781,8 +818,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param output_size The number of output params. All output params are at the end of param list. */ void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { - Array arg_list; - Array buffer_list; + ffi::Array arg_list; + ffi::Array buffer_list; if (const auto* arg_tuple = args.as()) { arg_list = arg_tuple->fields; } else { @@ -799,14 +836,14 @@ class FusedTIRConstructor : public ExprVisitor { MapArgsToBuffer(arg_list, buffer_list); } - static Array GetPrimFuncOutputParams(const tir::PrimFunc& func, - const Array& output_indices) { + static ffi::Array GetPrimFuncOutputParams(const tir::PrimFunc& func, + const ffi::Array& output_indices) { size_t n = func->params.size(); int symbolic_var_index = -1; size_t output_size = output_indices.size(); ICHECK_GE(n, output_size); - Array ret; + ffi::Array ret; for (auto idx : output_indices) { int i = idx.IntValue(); const tir::Var& param = func->params[static_cast(i)]; @@ -835,15 +872,15 @@ class FusedTIRConstructor : public ExprVisitor { * \param output_shapes The shape of output params. */ void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& func, - const Array>& output_shapes) { + const ffi::Array>& output_shapes) { bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace")); size_t n = func->params.size(); int num_inputs = Downcast(call->args[1])->fields.size(); size_t output_size = output_shapes.size(); ICHECK_GE(n, output_size); - Array output_buffers; - Array output_idxs; + ffi::Array output_buffers; + ffi::Array output_idxs; if (is_inplace) { const auto* attrs = call->attrs.as(); CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; @@ -854,7 +891,7 @@ class FusedTIRConstructor : public ExprVisitor { } } - Array output_params = GetPrimFuncOutputParams(func, output_idxs); + ffi::Array output_params = GetPrimFuncOutputParams(func, output_idxs); auto input_buffers = func_info_.expr2buffers.Get(call->args[1]); for (size_t i = 0; i < output_size; ++i) { const tir::Var& param = output_params[i]; @@ -868,8 +905,8 @@ class FusedTIRConstructor : public ExprVisitor { } auto unify_name_hints = [this, &buffer]() { - String base_name = buffer->name; - String unique_name = base_name + "_intermediate"; + ffi::String base_name = buffer->name; + ffi::String unique_name = base_name + "_intermediate"; size_t unique_id = 0; std::unordered_set names; @@ -883,7 +920,7 @@ class FusedTIRConstructor : public ExprVisitor { return unique_name; }; // Update buffer with new symbolic shape according to the sinfo - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->shape = output_shapes[i]; n->name = unify_name_hints(); tir::Buffer new_buffer(n); @@ -895,7 +932,7 @@ class FusedTIRConstructor : public ExprVisitor { func_info_.buffer_subst_map.Set(buffer, new_buffer); } // Update expr2buffers - func_info_.expr2buffers.Set(GetRef(call), output_buffers); + func_info_.expr2buffers.Set(ffi::GetRef(call), output_buffers); } /*! @@ -905,8 +942,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param out The vector into which to collect the params/buffers */ static void CollectPrimFuncParams(const Var& relax_param, - std::vector>* out, - const Optional& tir_buffer_param) { + std::vector>* out, + const ffi::Optional& tir_buffer_param) { auto struct_info = GetStructInfo(relax_param); CHECK(!struct_info.as()) @@ -955,12 +992,12 @@ class FusedTIRConstructor : public ExprVisitor { * \return The fused TIR */ tir::PrimFunc ConstructFunc() { - Map attr_map; + ffi::Map attr_map; attr_map.Set(tir::attr::kNoAlias, true); tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers - Array alloc_buffers; + ffi::Array alloc_buffers; for (const tir::Buffer& buf : func_info_.alloc_buffers) { if (func_info_.output_buffers.count(buf.get()) == 0) { alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf)); @@ -998,25 +1035,25 @@ class FusedTIRConstructor : public ExprVisitor { /*! \brief auxiliary information for FuseTIR */ struct FuseFuncInfo { /*! \brief The arguments for calling prim_func */ - Array arguments; + ffi::Array arguments; /*! * \brief The map from each dataflow var (intermediate var) to the corresponding buffers * allocated in the fused func */ - Map> expr2buffers; + ffi::Map> expr2buffers; /*! \brief The buffers to allocate in the fused func*/ - Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ - Array bodies; + ffi::Array bodies; /*! \brief The params of the fused function*/ - Array params; + ffi::Array params; /*! * \brief The map from buffer in original functions to corresponding buffer in the fused * function */ - Map buffer_subst_map; + ffi::Map buffer_subst_map; /*! \brief The `buffer_map` in the fused function*/ - Map buffer_map; + ffi::Map buffer_map; /*! \brief The output buffers in the function buffer_map*/ std::unordered_set output_buffers; /*! \brief The name of the fused function */ @@ -1028,7 +1065,7 @@ class FusedTIRConstructor : public ExprVisitor { * `symbolic_var_matcher`, and must be before it in the struct * order. */ - Map symbolic_var_remap; + ffi::Map symbolic_var_remap; /*! \brief The map from symbolic var to its value in the fused function * @@ -1046,7 +1083,7 @@ class FusedTIRConstructor : public ExprVisitor { /*! \brief The IRModule */ const IRModule& mod_; /*! \brief The name hint for the input func. */ - String func_name_; + ffi::String func_name_; /*! \brief The helper info to fuse TIR prim_func */ FuseFuncInfo func_info_; /*! \brief The tir function after fusion*/ @@ -1075,7 +1112,7 @@ class TIRFuseMutator : public ExprMutator { public: static IRModule Transform(IRModule mod) { // Collect all primitive relax functions - Map primitive_relax; + ffi::Map primitive_relax; for (const auto& gvar : mod->GetGlobalVars()) { const auto& base_func = mod->Lookup(gvar); // Only fuse primitive relax functions @@ -1134,7 +1171,7 @@ class TIRFuseMutator : public ExprMutator { struct Replacement { GlobalVar fused_tir_gvar; Function original_function; - Array inplace_indices; + ffi::Array inplace_indices; }; explicit TIRFuseMutator(std::unordered_map replacements) @@ -1145,14 +1182,14 @@ class TIRFuseMutator : public ExprMutator { // Get shape from call tir static Expr GetCallTIRShape(StructInfo sinfo) { if (auto* tuple = sinfo.as()) { - Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); + ffi::Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); return Tuple(fields); } else { auto* tensor = sinfo.as(); ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; auto* shape_expr = tensor->shape.as(); ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; - return GetRef(shape_expr); + return ffi::GetRef(shape_expr); } } @@ -1185,8 +1222,8 @@ class TIRFuseMutator : public ExprMutator { // Step a. Collect all relax/symbolic arguments. Tuple arguments // are not supported by PrimFunc, so this step verifies that // ExpandTupleArguments has already removed them. - Array arg_list; - Array tir_vars; + ffi::Array arg_list; + ffi::Array tir_vars; for (size_t i = 0; i < call->args.size(); ++i) { auto arg = call->args[i]; auto sinfo = GetStructInfo(arg); @@ -1221,7 +1258,7 @@ class TIRFuseMutator : public ExprMutator { } // Step b. Create call_tir or call_tir_inplace - Array call_args = {fused_tir_gv, Tuple(arg_list)}; + ffi::Array call_args = {fused_tir_gv, Tuple(arg_list)}; if (!tir_vars.empty()) { call_args.push_back(ShapeExpr(tir_vars)); } @@ -1229,7 +1266,7 @@ class TIRFuseMutator : public ExprMutator { Attrs call_attrs = call->attrs; if (replacement.inplace_indices.size()) { call_op = call_tir_inplace_op_; - auto inplace_attrs = make_object(); + auto inplace_attrs = ffi::make_object(); inplace_attrs->inplace_indices = replacement.inplace_indices; call_attrs = Attrs(inplace_attrs); } @@ -1268,10 +1305,10 @@ Pass FuseTIR() { "FuseTIR"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.FuseTIR", FuseTIR); -}); +} } // namespace transform diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index ff14dc9eef1e..15bf6a273a3f 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -160,7 +160,7 @@ class CheckpointCollector : private ExprMutator { ICHECK(var) << "The first argument of relax.grad.start_checkpoint and " "relax.grad.end_checkpoint should be a Var"; // var might already be remapped. Find the original var - auto orig_var = Downcast(ExprMutator::VisitExpr(GetRef(var))); + auto orig_var = Downcast(ExprMutator::VisitExpr(ffi::GetRef(var))); // Add remapping from binding->var to new_var if (!binding->var.as() && var->IsInstance()) { // For output binding, emit a dummy binding @@ -203,7 +203,7 @@ class CheckpointGenerator : private ExprMutator { * \param checkpoints The checkpointed vars. checkpoints being empty means all Vars are * checkpointed */ - CheckpointGenerator(const BlockBuilder& builder, const Array& orig_params, + CheckpointGenerator(const BlockBuilder& builder, const ffi::Array& orig_params, const DataflowBlock& forward_block, const VarIdSet& checkpoints) : builder_(builder) { // func params will always be checkpointed @@ -238,10 +238,10 @@ class CheckpointGenerator : private ExprMutator { using ExprMutator::VisitExpr_; // Visit the use-site of a defined Var - Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef(op)); } + Expr VisitExpr_(const VarNode* op) final { return VisitVar(ffi::GetRef(op)); } // Visit the use-site of a defined DataflowVar - Expr VisitExpr_(const DataflowVarNode* op) final { return VisitVar(GetRef(op)); } + Expr VisitExpr_(const DataflowVarNode* op) final { return VisitVar(ffi::GetRef(op)); } Expr VisitVar(const Var& var) { auto it = checkpoint_map_.find(var); @@ -258,7 +258,7 @@ class CheckpointGenerator : private ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { Expr new_op = this->VisitExpr(call_node->op); - tvm::Array call_args; + tvm::ffi::Array call_args; for (Expr arg : call_node->args) { Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); @@ -268,9 +268,9 @@ class CheckpointGenerator : private ExprMutator { BlockBuilder builder_; // The mapping from the forward vars to the checkpoint vars. - Map checkpoint_map_; + ffi::Map checkpoint_map_; // The mapping from the forward vars to their bindings, used to generate checkpoint bindings - Map binding_map_; + ffi::Map binding_map_; }; /*! @@ -294,8 +294,8 @@ class BackwardBindingGenerator : private ExprVisitor { * \return The return expr of new adjoint function. */ static Expr Generate(const BlockBuilder& builder, const DataflowBlock& forward_block, - const Array& require_grads, const Var& target_var, - const Array& orig_params, const Expr& orig_return_value, + const ffi::Array& require_grads, const Var& target_var, + const ffi::Array& orig_params, const Expr& orig_return_value, const CheckpointCollector& cp_collector) { CheckpointGenerator checkpoint_generator(builder, orig_params, forward_block, cp_collector.checkpoints); @@ -358,7 +358,7 @@ class BackwardBindingGenerator : private ExprVisitor { // Support for checkpointing auto [checkpoint_var, checkpoint_call] = - checkpoint_generator_.UpdateBinding(binding->var, GetRef(call)); + checkpoint_generator_.UpdateBinding(binding->var, ffi::GetRef(call)); if (call_op == Op::Get("relax.call_tir")) { LOG(FATAL) << "Differentiation of call_tir op without registering corresponding gradient " @@ -384,7 +384,7 @@ class BackwardBindingGenerator : private ExprVisitor { } } } else { - const Array& partials = gradient_op_map[call_op]( + const ffi::Array& partials = gradient_op_map[call_op]( checkpoint_var, Downcast(checkpoint_call), adjoint_var, builder_); ICHECK(partials.size() == call->args.size()) << "partials number != inputs number"; for (size_t i = 0; i < partials.size(); ++i) { @@ -406,7 +406,7 @@ class BackwardBindingGenerator : private ExprVisitor { // b_adjoint += a_adjoint_var[0][0], c_adjoint += a_adjoint_var[0][1], // d_adjoint += a_adjoint_var[1] void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final { - UpdateAdjoint(GetRef(tuple), adjoint_var_map_[binding->var]); + UpdateAdjoint(ffi::GetRef(tuple), adjoint_var_map_[binding->var]); } // For TupleGetItem nodes, we do a partial update @@ -422,7 +422,7 @@ class BackwardBindingGenerator : private ExprVisitor { const Var& tuple_var = Downcast(tuple_get_item->tuple); if (adjoint_var_map_.count(tuple_var) == 0) { - auto nested_zeros = Downcast(NestedZeros(GetRef(tuple_sinfo))); + auto nested_zeros = Downcast(NestedZeros(ffi::GetRef(tuple_sinfo))); auto tuple_fields = nested_zeros->fields; tuple_fields.Set(tuple_get_item->index, adjoint_var_map_[binding->var]); EmitAdjoint(tuple_var, Tuple(tuple_fields), false); @@ -435,11 +435,11 @@ class BackwardBindingGenerator : private ExprVisitor { // For assign nodes, we add the adjoint of output to the adjoint of input void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* var) final { - UpdateAdjoint(GetRef(var), adjoint_var_map_[binding->var]); + UpdateAdjoint(ffi::GetRef(var), adjoint_var_map_[binding->var]); } void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { - UpdateAdjoint(GetRef(var), adjoint_var_map_[binding->var]); + UpdateAdjoint(ffi::GetRef(var), adjoint_var_map_[binding->var]); } // For constant nodes, we do not have to handle it because it does not contribute to the adjoint @@ -479,9 +479,9 @@ class BackwardBindingGenerator : private ExprVisitor { // Returns the new return value, which would be like: // Tuple(original_return_value, // Tuple(adjoint_of_require_grads_1, adjoint_of_require_grads_2, ...)) - Expr Epilogue(const Array& require_grads, const Expr& orig_return_value) { + Expr Epilogue(const ffi::Array& require_grads, const Expr& orig_return_value) { // create adjoint variables for inputs, and then bind adjoints - Array out_adjoints; + ffi::Array out_adjoints; for (Var var : require_grads) { // var might be wrapped in start_checkpoint or end_checkpoint, so we should find the original @@ -520,7 +520,7 @@ class BackwardBindingGenerator : private ExprVisitor { } static Expr AdjointMsgToExpr(AdjointMsg msg) { - return NestedMsgToExpr(msg, [](Optional leaf_expr) { + return NestedMsgToExpr(msg, [](ffi::Optional leaf_expr) { if (!leaf_expr.defined()) { LOG(FATAL) << "Null should not exist in AdjointMsg."; } @@ -559,7 +559,7 @@ class BackwardBindingGenerator : private ExprVisitor { ICHECK(GetStructInfoAs(r_leaf)) << "The leaf of adjoint should have StructInfo and be a Tensor."; Expr res = add(l_leaf, r_leaf); - UpdateStructInfo(res, GetRef(sinfo)); + UpdateStructInfo(res, ffi::GetRef(sinfo)); return res; }); return AdjointMsgToExpr(res); @@ -575,7 +575,7 @@ class BackwardBindingGenerator : private ExprVisitor { auto* sinfo = GetStructInfoAs(tuple); ICHECK(sinfo) << "The first argument of AddInTuple should have tuple struct info."; ICHECK(index >= 0 && index < static_cast(sinfo->fields.size())); - Array res; + ffi::Array res; for (size_t i = 0; i < sinfo->fields.size(); ++i) { Expr field; if (const auto* expr_tuple = tuple.as()) { @@ -594,7 +594,7 @@ class BackwardBindingGenerator : private ExprVisitor { // The block builder of the corresponding GradientMutator, to emit bindings BlockBuilder builder_; // Forward Var to its adjoint Var - Map adjoint_var_map_; + ffi::Map adjoint_var_map_; // information collected by CheckpointCollector CheckpointCollector cp_collector_; // The generator for checkpoint bindings @@ -603,13 +603,13 @@ class BackwardBindingGenerator : private ExprVisitor { class GradientMutator : private ExprMutator { public: - static IRModule Transform(IRModule mod, String func_name, Optional> require_grads, - int target_index) { + static IRModule Transform(IRModule mod, ffi::String func_name, + ffi::Optional> require_grads, int target_index) { // Step 1. Copy function auto* old_func = mod->Lookup(func_name).as(); CHECK(old_func) << func_name << "is not a Relax Function"; auto copier = FunctionCopier(); - auto new_func = copier.Copy(GetRef(old_func)); + auto new_func = copier.Copy(ffi::GetRef(old_func)); // Step 2. Handle the checkpoints and eliminate start_checkpoint and end_checkpoint ops auto cp_collector = CheckpointCollector(); @@ -630,7 +630,7 @@ class GradientMutator : private ExprMutator { } private: - GradientMutator(const IRModule& module, const Array& require_grads, int target_index, + GradientMutator(const IRModule& module, const ffi::Array& require_grads, int target_index, const CheckpointCollector& cp_collector) : ExprMutator(module), require_grads_(require_grads), @@ -638,7 +638,7 @@ class GradientMutator : private ExprMutator { target_index_(target_index) {} // Add the adjoint function of func to the IRModule using BlockBuilder - IRModule AddAdjointFunction(const Function& func, const String& func_name, + IRModule AddAdjointFunction(const Function& func, const ffi::String& func_name, bool remove_all_unused = true) { // Step 4.1 forward -> forward + backward auto new_func = Downcast(VisitExpr(func)); @@ -695,7 +695,7 @@ class GradientMutator : private ExprMutator { } // generate backward bindings and the return value - return_expr_ = BackwardBindingGenerator::Generate(builder_, GetRef(block), + return_expr_ = BackwardBindingGenerator::Generate(builder_, ffi::GetRef(block), require_grads_, target_var_, orig_params_, orig_return_expr_, cp_collector_); @@ -715,7 +715,7 @@ class GradientMutator : private ExprMutator { CHECK_EQ(target_index, 0) << "When the function has only one return value, target_index can " "only be 0. But the target_index specified is " << target_index; - target_var_ = GetRef(var); + target_var_ = ffi::GetRef(var); } else if (auto* tuple = e.as()) { CHECK(target_index >= 0 && target_index < static_cast(tuple->fields.size())) << "target_index should be in the range of the number of return values of the " @@ -725,7 +725,7 @@ class GradientMutator : private ExprMutator { auto* var = tuple->fields[target_index].as(); CHECK(var) << "Target must be a Var, but the specified target is " << tuple->fields[target_index]; - target_var_ = GetRef(var); + target_var_ = ffi::GetRef(var); } else { LOG(FATAL) << "The return value of the function must be Var or Tuple. However, the return " "value of the given function is " @@ -742,10 +742,11 @@ class GradientMutator : private ExprMutator { // 1. there should be no duplicate var // 2. every var should be a parameter or a intermediate var in the function // 3. the type of the input var should be Tensor of floating point dtype, or Tuple of that - static Array CheckAndMapRequireGrads(const Array& require_grads, - const Map& var_map, const String& func_name) { + static ffi::Array CheckAndMapRequireGrads(const ffi::Array& require_grads, + const ffi::Map& var_map, + const ffi::String& func_name) { VarIdSet var_set; - Array mapped_vars; + ffi::Array mapped_vars; for (const auto& var : require_grads) { auto it = var_map.find(var); CHECK(it != var_map.end()) << "There is no Var named " << var->name_hint() @@ -764,21 +765,22 @@ class GradientMutator : private ExprMutator { } // differentiation sources - Array require_grads_; + ffi::Array require_grads_; // information collected by CheckpointCollector CheckpointCollector cp_collector_; // the differentiation target int target_index_; Var target_var_; // the return value of the original function and the differentiated function - Array orig_params_; + ffi::Array orig_params_; Expr orig_return_expr_; Expr return_expr_; }; namespace transform { -Pass Gradient(String func_name, Optional> require_grads, int target_index) { +Pass Gradient(ffi::String func_name, ffi::Optional> require_grads, + int target_index) { auto pass_func = [=](IRModule mod, PassContext pc) { return relax::GradientMutator::Transform(mod, func_name, require_grads, target_index); }; @@ -788,10 +790,10 @@ Pass Gradient(String func_name, Optional> require_grads, int target_i /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.Gradient", Gradient); -}); +} } // namespace transform diff --git a/src/relax/transform/gradient_simplifier.cc b/src/relax/transform/gradient_simplifier.cc index 966e8b7ad692..5388e3706542 100644 --- a/src/relax/transform/gradient_simplifier.cc +++ b/src/relax/transform/gradient_simplifier.cc @@ -112,7 +112,7 @@ class GradientSimplifier : private ExprMutator { if (ndim == 1) { return expr; } - auto axes = Array(); + auto axes = ffi::Array(); for (int i = 0; i < ndim - 2; ++i) { axes.push_back(i); } @@ -140,7 +140,7 @@ class GradientSimplifier : private ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) { - auto result = ExprMutator::VisitExpr(GetRef(call_node)); + auto result = ExprMutator::VisitExpr(ffi::GetRef(call_node)); auto new_call_node = result.as(); auto reemit_and_return = [&]() { ReEmitBinding(binding, result); diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 43bb40b4df4a..ac838d584821 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -31,33 +31,33 @@ NType NTypeFrom(const StructInfo& sinfo, DataType dtype) { else return NType(DLDataTypeToString(dtype)); }; - return MapToNestedMsg(sinfo, fmapleaf); + return MapToNestedMsg(sinfo, fmapleaf); } NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetStructInfo(expr), dtype); } NType NTypeMerge(const NType& a, const NType& b) { - auto fcombine = [&](const String& a_str, const String& b_str) -> String { + auto fcombine = [&](const ffi::String& a_str, const ffi::String& b_str) -> ffi::String { if (a_str == "") { return b_str; } else if (b_str == "") { return a_str; } - DataType a = DataType(StringToDLDataType(a_str)); - DataType b = DataType(StringToDLDataType(b_str)); + DataType a = DataType(ffi::StringToDLDataType(a_str)); + DataType b = DataType(ffi::StringToDLDataType(b_str)); ICHECK_EQ(a.code(), b.code()); ICHECK_EQ(a.lanes(), b.lanes()); return a.bits() > b.bits() ? a_str : b_str; }; - return CombineNestedMsg(a, b, fcombine); + return CombineNestedMsg(a, b, fcombine); } -Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { return {Integer(MixedPrecisionPolicyKind::kFollow), call}; } -Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { return {Integer(MixedPrecisionPolicyKind::kNever), call}; } diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index a3a86dd2e0c3..e8ac586036a8 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -49,11 +49,11 @@ using TMixedPrecisionPolicy = int; // NType is the message we want to track for vars with nested tensorstructinfo // which represents the realization decision of the var. // The string is the name of the dtype decision. -using NType = NestedMsg; +using NType = NestedMsg; struct NTypeEqual { bool operator()(const NType& a, const NType& b) const { - auto dtype_equal = [](const String& a, const String& b) { return a == b; }; + auto dtype_equal = [](const ffi::String& a, const ffi::String& b) { return a == b; }; return Equal(a, b, dtype_equal); } }; @@ -74,9 +74,9 @@ using VarDTypeMap = std::unordered_map; using FInferMixedPrecision = ffi::TypedFunction; -Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); -Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index b2f647c5c229..bc572f8a5407 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -67,7 +67,7 @@ Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst) return Layout(axes); } -String TransposeStrLike(const String& input, const Layout& src, const Layout& dst) { +ffi::String TransposeStrLike(const ffi::String& input, const Layout& src, const Layout& dst) { ICHECK(src.ndim() == dst.ndim() && input.size() == src.ndim()) << "Layouts must have the same size"; std::string axes; @@ -120,7 +120,7 @@ LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const Expr& NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg) { auto fmapleaf = [&](const Expr& expr) -> NLayout { if (const auto* var = expr.as()) { - auto it = var_layout_map.find(GetRef(var)); + auto it = var_layout_map.find(ffi::GetRef(var)); if (it != var_layout_map.end()) { return (*it).second; } else { @@ -134,7 +134,8 @@ NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg) { return MapToNestedMsg(arg, fmapleaf); } -bool NoDesiredLayout(const Call& call, const Map>& desired_layouts) { +bool NoDesiredLayout(const Call& call, + const ffi::Map>& desired_layouts) { const OpNode* op_node = call->op.as(); if (op_node == nullptr) return false; const auto& it = desired_layouts.find(op_node->name); @@ -156,10 +157,10 @@ LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { LayoutDecisionNode::RegisterReflection(); InferLayoutOutputNode::RegisterReflection(); -}); +} } // namespace relax } // namespace tvm diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index 69148ce0601f..973e46b45c4e 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -69,15 +69,13 @@ class LayoutDecisionNode : public Object { .def_ro("is_unknown_dim", &LayoutDecisionNode::is_unknown_dim); } - TVM_DECLARE_BASE_OBJECT_INFO(LayoutDecisionNode, Object); - - static constexpr const char* _type_key = "relax.transform.LayoutDecision"; + TVM_FFI_DECLARE_OBJECT_INFO("relax.transform.LayoutDecision", LayoutDecisionNode, Object); }; class LayoutDecision : public ObjectRef { public: LayoutDecision(Layout layout, bool is_unknown_dim = false) { // NOLINT(*) - auto n = make_object(); + auto n = ffi::make_object(); n->layout = std::move(layout); n->is_unknown_dim = is_unknown_dim; data_ = n; @@ -92,7 +90,7 @@ class LayoutDecision : public ObjectRef { return operator->()->layout.name(); } - TVM_DEFINE_OBJECT_REF_METHODS(LayoutDecision, ObjectRef, LayoutDecisionNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LayoutDecision, ObjectRef, LayoutDecisionNode); }; using NLayout = NestedMsg; @@ -105,10 +103,10 @@ using NLayout = NestedMsg; */ class InferLayoutOutputNode : public Object { public: - Array input_layouts; - Array output_layouts; + ffi::Array input_layouts; + ffi::Array output_layouts; Attrs new_attrs; - Map new_args; + ffi::Map new_args; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -119,23 +117,21 @@ class InferLayoutOutputNode : public Object { .def_ro("new_args", &InferLayoutOutputNode::new_args); } - TVM_DECLARE_BASE_OBJECT_INFO(InferLayoutOutputNode, Object); - - static constexpr const char* _type_key = "relax.transform.InferLayoutOutput"; + TVM_FFI_DECLARE_OBJECT_INFO("relax.transform.InferLayoutOutput", InferLayoutOutputNode, Object); }; class InferLayoutOutput : public ObjectRef { public: - explicit InferLayoutOutput(Array input_layouts, Array output_layouts, - Attrs new_attrs, Map new_args = {}) { - auto n = make_object(); + explicit InferLayoutOutput(ffi::Array input_layouts, ffi::Array output_layouts, + Attrs new_attrs, ffi::Map new_args = {}) { + auto n = ffi::make_object(); n->input_layouts = std::move(input_layouts); n->output_layouts = std::move(output_layouts); n->new_attrs = std::move(new_attrs); n->new_args = std::move(new_args); data_ = n; } - TVM_DEFINE_OBJECT_REF_METHODS(InferLayoutOutput, ObjectRef, InferLayoutOutputNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(InferLayoutOutput, ObjectRef, InferLayoutOutputNode); }; struct NLayoutEqual { @@ -150,7 +146,7 @@ struct NLayoutEqual { } }; -using VarLayoutMap = Map; +using VarLayoutMap = ffi::Map; /*! * \brief Layout conversion interface. @@ -159,7 +155,7 @@ using VarLayoutMap = Map; * \param var_layout_map The layout of the variables. */ using FRelaxInferLayout = ffi::TypedFunction>& desired_layouts, + const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map)>; /*! @@ -225,7 +221,7 @@ Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst); * \param dst The destination layout. * \return The transposed input str. */ -String TransposeStrLike(const String& input, const Layout& src, const Layout& dst); +ffi::String TransposeStrLike(const ffi::String& input, const Layout& src, const Layout& dst); /*! * \brief Find axis in the dst layout. 0 represents the first axis, 1 represents the second axis, @@ -258,7 +254,8 @@ NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg); * \param desired_layouts The desired layouts of the operator. * \return True if the op is not in the desired layout. */ -bool NoDesiredLayout(const Call& call, const Map>& desired_layouts); +bool NoDesiredLayout(const Call& call, + const ffi::Map>& desired_layouts); /*! * \brief Let a tensor with ndim to follow the src layout decision. diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 44363e19464f..f3f21cc7843d 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -35,7 +35,8 @@ namespace { class FunctionInliner : public ExprMutator { public: - explicit FunctionInliner(const Map, Function>& replacements) + explicit FunctionInliner( + const ffi::Map, Function>& replacements) : replacements_(replacements) {} using ExprMutator::VisitExpr_; @@ -80,7 +81,7 @@ class FunctionInliner : public ExprMutator { } private: - Optional GetFunction(const GlobalVar& gvar) const { + ffi::Optional GetFunction(const GlobalVar& gvar) const { if (auto opt = replacements_.Get(gvar)) { return opt; } else if (auto opt = replacements_.Get(gvar->name_hint)) { @@ -90,14 +91,14 @@ class FunctionInliner : public ExprMutator { } } - Expr InlinedCall(Function func, const Array& args) const { + Expr InlinedCall(Function func, const ffi::Array& args) const { // Ensures that the inlined instance does not have duplicate usage // with other inlined copies, or with the original callee. func = CopyWithNewVars(std::move(func)); - Array param_bindings; + ffi::Array param_bindings; - Map param_map; + ffi::Map param_map; for (size_t i = 0; i < args.size(); i++) { // Option 1: Use tvm::relax::Bind to substitute arguments into // the body. If the arguments contain DataflowVar instances, @@ -138,7 +139,7 @@ class FunctionInliner : public ExprMutator { return SeqExpr({binding_block}, body); } - const Map, Function>& replacements_; + const ffi::Map, Function>& replacements_; std::unordered_set inline_stack_; }; } // namespace @@ -149,8 +150,8 @@ class FunctionInliner : public ExprMutator { * \param params params dict * \return Function */ -Function FunctionInlineFunctions(Function func, - const Map, Function>& replacements) { +Function FunctionInlineFunctions( + Function func, const ffi::Map, Function>& replacements) { for (const auto& [key, func] : replacements) { if (auto ptr = key.as()) { CHECK(!replacements.count(ptr->name_hint)) @@ -165,20 +166,20 @@ Function FunctionInlineFunctions(Function func, return Downcast(mutator(std::move(func))); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.FunctionInlineFunctions", FunctionInlineFunctions); -}); +} namespace transform { Pass InlinePrivateFunctions() { auto pass_func = [=](IRModule mod, PassContext pc) { - Map, Function> replacements; + ffi::Map, Function> replacements; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto func = opt.value(); - bool is_private = !func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_private = !func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_private) { replacements.Set(gvar, func); } @@ -223,10 +224,10 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "InlinePrivateFunctions", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.InlinePrivateFunctions", InlinePrivateFunctions); -}); +} } // namespace transform diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 8c3b76703d8e..e1e8a5d87998 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -169,7 +169,8 @@ class CollectLastUsage : public ExprVisitor { << "Operator " << val->op << " should have one argument, " << "but instead found " << val->args.size() << " arguments: " << val->args; auto killed_object = val->args[0].as(); - ICHECK(killed_object) << "Internal error: non-normalized expression " << GetRef(val); + ICHECK(killed_object) << "Internal error: non-normalized expression " + << ffi::GetRef(val); killed_objects_.insert(killed_object); } else { // Only recursively visit if it isn't one of the special cases. @@ -213,14 +214,14 @@ class CollectLastUsage : public ExprVisitor { class KillInserter : public ExprMutator { private: Expr VisitExpr_(const FunctionNode* op) override { - last_usage_ = CollectLastUsage::Collect(GetRef(op)); + last_usage_ = CollectLastUsage::Collect(ffi::GetRef(op)); auto mutated = ExprMutator::VisitExpr_(op); last_usage_.clear(); return mutated; } Expr VisitExpr_(const SeqExprNode* op) override { - last_usage_ = CollectLastUsage::Collect(GetRef(op)); + last_usage_ = CollectLastUsage::Collect(ffi::GetRef(op)); auto mutated = ExprMutator::VisitExpr_(op); last_usage_.clear(); return mutated; @@ -231,17 +232,17 @@ class KillInserter : public ExprMutator { if (auto it = last_usage_.find(binding->var.get()); it != last_usage_.end()) { static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor"); for (const auto& tensor_obj : it->second.tensors) { - builder_->Emit(Call(mem_kill_tensor, {GetRef(tensor_obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(mem_kill_tensor, {ffi::GetRef(tensor_obj)}), /*name_hint=*/"_"); } static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); for (const VarNode* storage_obj : it->second.storage) { - builder_->Emit(Call(mem_kill_storage, {GetRef(storage_obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(mem_kill_storage, {ffi::GetRef(storage_obj)}), /*name_hint=*/"_"); } static const Op& vm_kill_object = Op::Get("relax.vm.kill_object"); for (const VarNode* obj : it->second.objects) { - builder_->Emit(Call(vm_kill_object, {GetRef(obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(vm_kill_object, {ffi::GetRef(obj)}), /*name_hint=*/"_"); } } } @@ -266,10 +267,10 @@ Pass KillAfterLastUse() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "KillAfterLastUse", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.KillAfterLastUse", KillAfterLastUse); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 1fd82b1cc610..e77b0a266038 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -40,7 +40,7 @@ namespace { /* \brief Collect names of functions to be lifted out */ class LambdaNameCollector : ExprVisitor { public: - static std::unordered_map Collect(const IRModule& mod) { + static std::unordered_map Collect(const IRModule& mod) { LambdaNameCollector visitor; for (const auto& [gvar, base_func] : mod->functions) { @@ -60,8 +60,8 @@ class LambdaNameCollector : ExprVisitor { private: void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override { - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { - String public_name = opt.value(); + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + ffi::String public_name = opt.value(); // If a kGlobalSymbol exists, we must use the name exactly as it // appears, with no modifications. Because these errors would @@ -102,21 +102,22 @@ class LambdaNameCollector : ExprVisitor { } // De-duplication of collected names - std::unordered_map Finalize() const { + std::unordered_map Finalize() const { // The functions which still must be assigned a name - std::unordered_map> remaining_to_name = lambda_location_; + std::unordered_map> remaining_to_name = + lambda_location_; // Collecting the functions that now have a name. - std::unordered_map lifted_names; + std::unordered_map lifted_names; // A lookup for names that are unavailable for use. - std::unordered_set unavailable_names = previous_global_vars_; + std::unordered_set unavailable_names = previous_global_vars_; // A helper function to generate de-duplicated names. The // `proposed_name_generation_func` should be a function with // signature: // - // Optional func(const FunctionNode*, const Array&) + // ffi::Optional func(const FunctionNode*, const ffi::Array&) // // The first argument will be the lambda function being lifted. // The second argument will be the nested location where that @@ -135,9 +136,10 @@ class LambdaNameCollector : ExprVisitor { return; } - std::unordered_map new_names; + std::unordered_map new_names; for (const auto& [func, location] : remaining_to_name) { - if (Optional opt_proposed_name = proposed_name_generation_func(func, location)) { + if (ffi::Optional opt_proposed_name = + proposed_name_generation_func(func, location)) { auto proposed_name = opt_proposed_name.value(); if (unavailable_names.count(proposed_name)) { @@ -163,7 +165,8 @@ class LambdaNameCollector : ExprVisitor { }; // 1. Start with any publicly explosed names from kGlobalSymbol - attempt_name_generation([&](const FunctionNode* func, const auto&) -> Optional { + attempt_name_generation([&](const FunctionNode* func, + const auto&) -> ffi::Optional { if (auto it = lifted_with_global_symbol_.find(func); it != lifted_with_global_symbol_.end()) { return it->second; } else { @@ -173,7 +176,7 @@ class LambdaNameCollector : ExprVisitor { // 2. Try concatenating the name of the relax variable with the // name of the function that contains it. - attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { + attempt_name_generation([&](const FunctionNode*, const auto& location) -> ffi::String { std::stringstream stream; stream << location.front() << "_" << location.back(); return stream.str(); @@ -181,26 +184,27 @@ class LambdaNameCollector : ExprVisitor { // 3. Try concatenating the entire path together. Don't include // paths of length 2, as they would already be attempted earlier. - attempt_name_generation([&](const FunctionNode*, const auto& location) -> Optional { - if (location.size() == 2) return std::nullopt; - - std::stringstream stream; - bool is_first = true; - for (const auto& loc : location) { - if (is_first) { - is_first = false; - } else { - stream << "_"; - } - stream << loc; - } - return String(stream.str()); - }); + attempt_name_generation( + [&](const FunctionNode*, const auto& location) -> ffi::Optional { + if (location.size() == 2) return std::nullopt; + + std::stringstream stream; + bool is_first = true; + for (const auto& loc : location) { + if (is_first) { + is_first = false; + } else { + stream << "_"; + } + stream << loc; + } + return ffi::String(stream.str()); + }); // 4. Fallback. Count the number of times a relax variable with // that name was used. - std::unordered_map usage_count; - attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { + std::unordered_map usage_count; + attempt_name_generation([&](const FunctionNode*, const auto& location) -> ffi::String { std::stringstream stream; stream << location.front() << "_" << location.back(); int usage = usage_count[stream.str()]++; @@ -215,11 +219,11 @@ class LambdaNameCollector : ExprVisitor { return lifted_names; } - Array name_stack_; - std::unordered_set previous_global_vars_; - std::unordered_map> new_public_names_; - std::unordered_map lifted_with_global_symbol_; - std::unordered_map> lambda_location_; + ffi::Array name_stack_; + std::unordered_set previous_global_vars_; + std::unordered_map> new_public_names_; + std::unordered_map lifted_with_global_symbol_; + std::unordered_map> lambda_location_; }; } // namespace @@ -255,9 +259,9 @@ class LambdaLifter : public ExprMutator { return ExprMutator::VisitExpr_(func_node); } - auto func = GetRef(func_node); + auto func = ffi::GetRef(func_node); - String lift_func_name = [&]() { + ffi::String lift_func_name = [&]() { auto it = lifted_names_.find(func_node); ICHECK(it != lifted_names_.end()) << "InternalError: " @@ -266,7 +270,7 @@ class LambdaLifter : public ExprMutator { return it->second; }(); - Array captured_vars; + ffi::Array captured_vars; bool is_recursive = false; bool is_closure = false; for (const auto& var : FreeVars(func)) { @@ -278,15 +282,15 @@ class LambdaLifter : public ExprMutator { } } - Array typed_captured_vars; - Map rebinding_map; + ffi::Array typed_captured_vars; + ffi::Map rebinding_map; for (auto free_var : captured_vars) { Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); typed_captured_vars.push_back(var); rebinding_map.Set(free_var, var); } - tvm::Array lifted_func_params = + tvm::ffi::Array lifted_func_params = func_node->params.Map([this](Var var) { return VisitVarDef(var); }); for (const auto& var : typed_captured_vars) { lifted_func_params.push_back(var); @@ -323,7 +327,7 @@ class LambdaLifter : public ExprMutator { Function lifted_func; if (lifted_func_params.same_as(func_node->params) && body.same_as(func_node->body) && ret_struct_info.same_as(func_node->ret_struct_info)) { - lifted_func = GetRef(func_node); + lifted_func = ffi::GetRef(func_node); } else { lifted_func = Function(lifted_func_params, body, ret_struct_info, func_node->is_pure, func_node->attrs); @@ -354,7 +358,7 @@ class LambdaLifter : public ExprMutator { } Expr VisitExpr_(const CallNode* call_node) final { - auto call = GetRef(call_node); + auto call = ffi::GetRef(call_node); auto orig_sinfo = Downcast(call->struct_info_); @@ -393,7 +397,7 @@ class LambdaLifter : public ExprMutator { if (auto it = nested_closure_map_.find(var); it != nested_closure_map_.end()) { Call nested_call = it->second; - Array new_args = call->args; + ffi::Array new_args = call->args; for (const auto arg : nested_call->args) { new_args.push_back(arg); } @@ -407,7 +411,7 @@ class LambdaLifter : public ExprMutator { } Expr VisitExpr_(const VarNode* op) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (auto it = rebind_map_.find(var); it != rebind_map_.end()) { return it->second; } @@ -436,12 +440,12 @@ class LambdaLifter : public ExprMutator { } } else if (const auto* global_var = val.as()) { - if (closures_.count(GetRef(global_var))) { + if (closures_.count(ffi::GetRef(global_var))) { return true; } IRModule ctx_mod = builder_->GetContextIRModule(); ICHECK(ctx_mod->functions.size() > 0); - BaseFunc func = ctx_mod->Lookup(GetRef(global_var)); + BaseFunc func = ctx_mod->Lookup(ffi::GetRef(global_var)); const auto* func_node = func.as(); if (func_node) { return IsClosure(func_node->body); @@ -477,11 +481,11 @@ class LambdaLifter : public ExprMutator { private: std::unordered_map nested_closure_map_; std::unordered_map rebind_map_; - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> closures_; - Optional current_lambda_var_ = std::nullopt; + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> closures_; + ffi::Optional current_lambda_var_ = std::nullopt; IRModule mod_; - std::unordered_map lifted_names_; + std::unordered_map lifted_names_; /*! \brief Cache ops that would be used later to reduce lookup overhead. */ const Op& make_closure_op_ = Op::Get("relax.make_closure"); @@ -496,10 +500,10 @@ Pass LambdaLift() { return tvm::transform::CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.LambdaLift", LambdaLift); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 9b59b680eceb..bc6f4530db59 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -69,15 +69,15 @@ class LazyInputMutator : public ExprMutator { FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, ObjectStructInfo())); - Array new_params(func->params.begin(), func->params.begin() + num_input_params); + ffi::Array new_params(func->params.begin(), func->params.begin() + num_input_params); new_params.push_back(fget_param); auto array_externally_visible_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo))); std::unordered_set externally_visible_vars(array_externally_visible_vars.begin(), array_externally_visible_vars.end()); - StructInfo new_ret_struct_info = - EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) -> Optional { + StructInfo new_ret_struct_info = EraseToWellDefined( + func->ret_struct_info, [&](const tir::Var& var) -> ffi::Optional { if (externally_visible_vars.count(var)) { return var; } else { @@ -85,7 +85,7 @@ class LazyInputMutator : public ExprMutator { } }); - auto node = GetRef(func); + auto node = ffi::GetRef(func); node.CopyOnWrite()->params = new_params; node.CopyOnWrite()->ret_struct_info = new_ret_struct_info; node = WithAttr(node, attr::kNumInput, num_input_params + 1); @@ -98,7 +98,7 @@ class LazyInputMutator : public ExprMutator { Expr VisitExpr_(const VarNode* op) override { if (plan_) { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (auto it = plan_->param_lookup.find(var); it != plan_->param_lookup.end()) { auto untyped = builder_->Emit(relax::Call(plan_->fget_param, @@ -148,9 +148,10 @@ class LazyOutputMutator : public ExprMutator { define_lookup(0, func_body->body); } - Var fset_output("fset_output", - FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, - TupleStructInfo(Array{}), /* purity = */ false)); + Var fset_output( + "fset_output", + FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, + TupleStructInfo(ffi::Array{}), /* purity = */ false)); plan_ = FunctionPlan{std::move(output_lookup), fset_output}; std::optional num_input_params = GetNumInputParams(func); @@ -160,32 +161,32 @@ class LazyOutputMutator : public ExprMutator { fset_output); BindingBlock start_of_func = [&]() { - Array propagated_params; + ffi::Array propagated_params; for (auto param : func->params) { GenerateSetOutputCalls(param, [&](const auto& fset_output_call) { - Var void_output("_void", TupleStructInfo(Array{})); + Var void_output("_void", TupleStructInfo(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); }); } return BindingBlock(propagated_params); }(); BindingBlock end_of_func = [&]() { - Array propagated_params; + ffi::Array propagated_params; for (const auto& [output_index, expr] : inline_outputs) { Call fset_output_call(fset_output, {PrimValue(IntImm(DataType::Int(64), output_index)), expr}); - Var void_output("_void", TupleStructInfo(Array{})); + Var void_output("_void", TupleStructInfo(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); } return BindingBlock(propagated_params); }(); - Array new_blocks = func_body->blocks; + ffi::Array new_blocks = func_body->blocks; new_blocks.insert(new_blocks.begin(), start_of_func); new_blocks.push_back(end_of_func); - Expr new_body = SeqExpr(new_blocks, Tuple(Array{})); + Expr new_body = SeqExpr(new_blocks, Tuple(ffi::Array{})); - auto node = GetRef(func); + auto node = ffi::GetRef(func); { auto write_ptr = node.CopyOnWrite(); write_ptr->params = new_params; @@ -249,7 +250,7 @@ namespace transform { Pass LazyGetInput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { - if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { return func; } return WithLazyInputs(func); @@ -260,14 +261,14 @@ Pass LazyGetInput() { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.LazyGetInput", LazyGetInput); -}); +} Pass LazySetOutput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { - if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { return func; } return WithLazyOutputs(func); @@ -278,10 +279,10 @@ Pass LazySetOutput() { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.LazySetOutput", LazySetOutput); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 780de9f57029..75e0776418ed 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -31,6 +31,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -60,12 +62,19 @@ bool KnowAllShapeValues(const StructInfo& sinfo) { class LegalizeMutator : public ExprMutator { public: - explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap, + explicit LegalizeMutator(const IRModule& mod, + const ffi::Optional>& cmap, + const ffi::Optional> skip_ops, bool enable_warning) : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { if (cmap) { cmap_ = cmap.value(); } + if (skip_ops.defined()) { + for (const auto name : skip_ops.value()) { + skip_ops_.insert(Op::Get(name)); + } + } } IRModule Transform() { @@ -130,14 +139,14 @@ class LegalizeMutator : public ExprMutator { Call WrapPureCall(const Call& ret) { static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); - Array ret_args = {ret->op}; + ffi::Array ret_args = {ret->op}; for (auto arg : ret->args) { ret_args.push_back(arg); } return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args); } - Optional GetTarget(const Array& sinfos) { + ffi::Optional GetTarget(const ffi::Array& sinfos) { for (auto sinfo : sinfos) { if (const auto* tinfo = sinfo.as()) { if (tinfo->vdevice.defined()) { @@ -236,7 +245,11 @@ class LegalizeMutator : public ExprMutator { if (op_node == nullptr) { return visited_call; } - auto op = GetRef(op_node); + auto op = ffi::GetRef(op_node); + + if (skip_ops_.find(op) != skip_ops_.end()) { + return visited_call; + } bool shapes_are_known_if_required = [&]() -> bool { bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value; @@ -312,7 +325,7 @@ class LegalizeMutator : public ExprMutator { legalization_func = legalize_map[op]; } else if (call_packed_map.count(op)) { // Third choice, use an explicit FCallPacked replacement. This does not require the shape - String packed_func_name = call_packed_map[op]; + ffi::String packed_func_name = call_packed_map[op]; legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr { return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetStructInfo(call)}); }; @@ -378,7 +391,7 @@ class LegalizeMutator : public ExprMutator { /*! \brief The context IRModule. */ IRModule mod_; /*! \brief The customized legalization function map. */ - Map cmap_; + ffi::Map cmap_; /*! \brief If VDevice annotations produced at least one PrimFunc with a Target attr*/ bool generated_tir_with_target_attr_{false}; /*! @@ -386,16 +399,21 @@ class LegalizeMutator : public ExprMutator { * legalization function is not registered. */ bool enable_warning_; + /*! + * \brief List of ops to be skipped from legalization + */ + std::set skip_ops_; }; namespace transform { -Pass LegalizeOps(Optional> cmap, bool enable_warning) { +Pass LegalizeOps(ffi::Optional> cmap, + ffi::Optional> skip_ops, bool enable_warning) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; if (apply_legalize_ops) { - mod = LegalizeMutator(mod, cmap, enable_warning).Transform(); + mod = LegalizeMutator(mod, cmap, skip_ops, enable_warning).Transform(); } return mod; }; @@ -405,10 +423,10 @@ Pass LegalizeOps(Optional> cmap, bool enable_warning) /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.LegalizeOps", LegalizeOps); -}); +} } // namespace transform diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 40a1c307cee5..f7c49d0da8df 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -64,20 +64,20 @@ struct BaseCollectInfo { * model weights, and computed tensors that require neither model * weights nor runtime arguments (e.g. `R.zeros([16], "float16")`). */ - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> requires_compile_time_param; /*! \brief Variables that are required at runtime */ - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> required_at_runtime; protected: - Array GetCompileTimeOutputsHelper(const Array& params) const { + ffi::Array GetCompileTimeOutputsHelper(const ffi::Array& params) const { // The output of the compile-time function is in the following order: // 1) Any parameter that is required at runtime in the original order, followed by, // 2) Any binding that is computable at compile-time and required at runtime in the original // order. - Array output; + ffi::Array output; for (const auto& param : params) { if (required_at_runtime.count(param)) { output.push_back(param); @@ -93,11 +93,12 @@ struct BaseCollectInfo { return output; } - Function MakeCompileTimeFunctionHelper(const Array params, const Array& bindings, - const Array& output_symbolic_vars, - const Array& outputs) const { - Array output_var_binding; - Array output_exprs; + Function MakeCompileTimeFunctionHelper(const ffi::Array params, + const ffi::Array& bindings, + const ffi::Array& output_symbolic_vars, + const ffi::Array& outputs) const { + ffi::Array output_var_binding; + ffi::Array output_exprs; if (output_symbolic_vars.size()) { output_exprs.push_back( ShapeExpr(output_symbolic_vars.Map([](tir::Var var) -> PrimExpr { return var; }))); @@ -131,14 +132,14 @@ struct BaseCollectInfo { struct GlobalCollectInfo : public BaseCollectInfo { // The original functions - Array orig_functions; + ffi::Array orig_functions; // The parameters of the compile-time function. - Array params; + ffi::Array params; // The cross-function mapping between variables. - Map var_remap; + ffi::Map var_remap; // The cross-function between between TIR variables. - Map tir_var_remap; - Array GetPropagatedSymbolicVariables() const { + ffi::Map tir_var_remap; + ffi::Array GetPropagatedSymbolicVariables() const { auto vars_from_original_params = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); auto vars_from_transformed_params = [&]() -> std::unordered_set { @@ -147,7 +148,7 @@ struct GlobalCollectInfo : public BaseCollectInfo { return {tir_vars.begin(), tir_vars.end()}; }(); - Array output; + ffi::Array output; for (const auto& tir_var : vars_from_original_params) { if (required_at_runtime.count(tir_var) && !vars_from_transformed_params.count(tir_var)) { output.push_back(tir_var); @@ -160,7 +161,7 @@ struct GlobalCollectInfo : public BaseCollectInfo { return MakeCompileTimeFunctionHelper(params, computable_at_compile_time, GetPropagatedSymbolicVariables(), GetCompileTimeOutputs()); } - Array GetCompileTimeOutputs() const { return GetCompileTimeOutputsHelper(params); } + ffi::Array GetCompileTimeOutputs() const { return GetCompileTimeOutputsHelper(params); } }; struct LocalCollectInfo : public BaseCollectInfo { /* \brief The analyzed function */ @@ -171,15 +172,16 @@ struct LocalCollectInfo : public BaseCollectInfo { GlobalCollectInfo* global_info = nullptr; - Array GetCompileTimeInputs() const { - return Array(orig_func->params.begin() + num_runtime_params, orig_func->params.end()); + ffi::Array GetCompileTimeInputs() const { + return ffi::Array(orig_func->params.begin() + num_runtime_params, orig_func->params.end()); } - Array GetRuntimeInputs() const { - return Array(orig_func->params.begin(), orig_func->params.begin() + num_runtime_params); + ffi::Array GetRuntimeInputs() const { + return ffi::Array(orig_func->params.begin(), + orig_func->params.begin() + num_runtime_params); } - Array GetPropagatedSymbolicVariables() const { + ffi::Array GetPropagatedSymbolicVariables() const { auto vars_from_any_param = DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo))); @@ -195,7 +197,7 @@ struct LocalCollectInfo : public BaseCollectInfo { return {tir_var_vec.begin(), tir_var_vec.end()}; }(); - Array output; + ffi::Array output; for (const auto& tir_var : vars_from_any_param) { if (required_at_runtime.count(tir_var) && !vars_from_runtime_params.count(tir_var) && !vars_from_transformed_params.count(tir_var)) { @@ -205,7 +207,7 @@ struct LocalCollectInfo : public BaseCollectInfo { return output; } - Array GetCompileTimeOutputs() const { + ffi::Array GetCompileTimeOutputs() const { return GetCompileTimeOutputsHelper(GetCompileTimeInputs()); } @@ -216,29 +218,29 @@ struct LocalCollectInfo : public BaseCollectInfo { } Function MakeRuntimeFunction() const { - Array bindings; + ffi::Array bindings; // Any parameter that isn't available until runtime must be an // input, along with any output from the compile-time function. // Compile-time outputs must have a fresh non-dataflow var to // serve as the parameter. This trivial binding will later be // removed with CanonicalizeBindings. - Array params = GetRuntimeInputs(); + ffi::Array params = GetRuntimeInputs(); auto propagated_tir_vars = [&]() { - Array local_tir_vars = GetPropagatedSymbolicVariables(); + ffi::Array local_tir_vars = GetPropagatedSymbolicVariables(); if (!global_info) { return local_tir_vars; } // When global lifting is enabled, the compile-time outputs are the global outputs, but the // variables in the global outputs to the local variables. - Map reverse_map; + ffi::Map reverse_map; for (const auto& var : local_tir_vars) { if (auto it = global_info->tir_var_remap.find(var); it != global_info->tir_var_remap.end()) { reverse_map.Set(Downcast((*it).second), var); } } - Array global_tir_vars = global_info->GetPropagatedSymbolicVariables(); + ffi::Array global_tir_vars = global_info->GetPropagatedSymbolicVariables(); global_tir_vars = global_tir_vars.Map([&](const tir::Var& var) { if (auto it = reverse_map.find(var); it != reverse_map.end()) { return Downcast((*it).second); @@ -256,20 +258,20 @@ struct LocalCollectInfo : public BaseCollectInfo { Var shape_expr("vars_from_compile_time_params", shape_sinfo); params.push_back(shape_expr); } - Array compile_time_outputs = [&]() { - Array local_outputs = GetCompileTimeOutputs(); + ffi::Array compile_time_outputs = [&]() { + ffi::Array local_outputs = GetCompileTimeOutputs(); if (!global_info) { return local_outputs; } // When global lifting is enabled, the compile-time outputs are the global outputs, but the // variables in the global outputs to the local variables. - Map reverse_map; + ffi::Map reverse_map; for (const auto& var : local_outputs) { if (auto it = global_info->var_remap.find(var); it != global_info->var_remap.end()) { reverse_map.Set(Downcast((*it).second), var); } } - Array global_outputs = global_info->GetCompileTimeOutputs(); + ffi::Array global_outputs = global_info->GetCompileTimeOutputs(); global_outputs = global_outputs.Map([&](const Var& var) { if (auto it = reverse_map.find(var); it != reverse_map.end()) { return Downcast((*it).second); @@ -378,7 +380,7 @@ class BaseLiftableBindingCollector : public ExprVisitor { return true; } - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> liftable_vars_; + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> liftable_vars_; bool is_in_dataflow_block_{false}; }; @@ -389,32 +391,31 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { visitor(func); visitor.info_.orig_func = func; - auto set_union = - [&](std::unordered_set, ObjectPtrHash, ObjectPtrEqual>& - target_set, - const std::unordered_set, ObjectPtrHash, ObjectPtrEqual>& - source_set, - const Map& var_remap, const Map& tir_var_remap) { - // In-place update the set in global info by unioning with the local set, variable - // mappings are applied. - for (const auto& relax_or_tir_var : source_set) { - if (relax_or_tir_var.as()) { - if (auto it = var_remap.find(Downcast(relax_or_tir_var)); - it != var_remap.end()) { - target_set.insert(Downcast((*it).second)); - } else { - target_set.insert(Downcast(relax_or_tir_var)); - } - } else { - if (auto it = tir_var_remap.find(Downcast(relax_or_tir_var)); - it != tir_var_remap.end()) { - target_set.insert(Downcast((*it).second)); - } else { - target_set.insert(Downcast(relax_or_tir_var)); - } - } + auto set_union = [&](std::unordered_set, ObjectPtrHash, + ObjectPtrEqual>& target_set, + const std::unordered_set, ObjectPtrHash, + ObjectPtrEqual>& source_set, + const ffi::Map& var_remap, + const ffi::Map& tir_var_remap) { + // In-place update the set in global info by unioning with the local set, variable + // mappings are applied. + for (const auto& relax_or_tir_var : source_set) { + if (relax_or_tir_var.as()) { + if (auto it = var_remap.find(Downcast(relax_or_tir_var)); it != var_remap.end()) { + target_set.insert(Downcast((*it).second)); + } else { + target_set.insert(Downcast(relax_or_tir_var)); } - }; + } else { + if (auto it = tir_var_remap.find(Downcast(relax_or_tir_var)); + it != tir_var_remap.end()) { + target_set.insert(Downcast((*it).second)); + } else { + target_set.insert(Downcast(relax_or_tir_var)); + } + } + } + }; if (global_info) { set_union(global_info->requires_compile_time_param, visitor.info_.requires_compile_time_param, @@ -508,8 +509,8 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { /*! \brief Visitor to find the correspondence between parameters in multiple functions. */ class ParamRemapper : private ExprFunctor { public: - static std::pair, Map> GetParamMapping( - const Array& functions) { + static std::pair, ffi::Map> GetParamMapping( + const ffi::Array& functions) { ParamRemapper mapper; if (functions.size()) { auto num_inputs_0 = functions[0]->GetAttr(attr::kNumInput).value()->value; @@ -536,15 +537,15 @@ class ParamRemapper : private ExprFunctor { private: void VisitExpr_(const VarNode* lhs_var, const Expr& rhs_expr) final { auto rhs_var = Downcast(rhs_expr); - if (auto it = var_remap_.find(GetRef(lhs_var)); it != var_remap_.end()) { + if (auto it = var_remap_.find(ffi::GetRef(lhs_var)); it != var_remap_.end()) { CHECK((*it).second.same_as(rhs_var)); } else { - var_remap_.Set(GetRef(lhs_var), rhs_var); + var_remap_.Set(ffi::GetRef(lhs_var), rhs_var); } CHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, /*map_free_vars=*/true)) << "The struct info of the parameters should be the same for all target functions"; - auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(GetRef(lhs_var))); + auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(ffi::GetRef(lhs_var))); auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr)); ICHECK_EQ(lhs_tir_vars.size(), rhs_tir_vars.size()); for (size_t i = 0; i < lhs_tir_vars.size(); i++) { @@ -556,15 +557,15 @@ class ParamRemapper : private ExprFunctor { } } - Map var_remap_; - Map tir_var_remap_; + ffi::Map var_remap_; + ffi::Map tir_var_remap_; }; class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { public: - static GlobalCollectInfo Collect(const Array& functions, - const Map& var_remap, - const Map& tir_var_remap) { + static GlobalCollectInfo Collect(const ffi::Array& functions, + const ffi::Map& var_remap, + const ffi::Map& tir_var_remap) { GlobalLiftableBindingCollector collector(var_remap, tir_var_remap); ICHECK(functions.size()); for (const auto& func : functions) { @@ -574,9 +575,9 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { } collector(func); } - Array params(functions[0]->params.begin() + - functions[0]->GetAttr(attr::kNumInput).value()->value, - functions[0]->params.end()); + ffi::Array params(functions[0]->params.begin() + + functions[0]->GetAttr(attr::kNumInput).value()->value, + functions[0]->params.end()); // todo(@tvm-team): use c++20 designated initializers when windows CI supports it GlobalCollectInfo info = GlobalCollectInfo(); info.orig_functions = functions; @@ -611,8 +612,8 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { } private: - GlobalLiftableBindingCollector(const Map& var_remap, - const Map tir_var_remap) + GlobalLiftableBindingCollector(const ffi::Map& var_remap, + const ffi::Map tir_var_remap) : var_remap_(var_remap), tir_var_remap_(tir_var_remap) {} void VisitBinding(const Binding& binding) override { CHECK(!binding->IsInstance()) << "MatchCast is not supported in global lifting"; @@ -633,9 +634,9 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { // The cross-function mapping between variables. This is initialized with the mapping from the // function parameters, and is updated with the mapping between binding variables asthe collector // visits the bindings. - Map var_remap_; + ffi::Map var_remap_; // The cross-function between between TIR variables. - Map tir_var_remap_; + ffi::Map tir_var_remap_; std::vector unified_bindings_; // The mapping between the unified bindings and the original bindings in different functions. // The unified binding is the binding with all variables replaced by the unified variables as @@ -678,7 +679,7 @@ class ConsumeBundledParams : public ExprMutator { builder_->Emit( Call(call_pure_packed, {builtin_tuple_reset_item, tuple_get_item->tuple, PrimValue(tuple_get_item->index)}, - tvm::Attrs(), {TupleStructInfo(Array{})})); + tvm::Attrs(), {TupleStructInfo(ffi::Array{})})); } else { ExprMutator::VisitBinding_(binding, tuple_get_item); } @@ -700,10 +701,10 @@ class ConsumeBundledParams : public ExprMutator { }; std::vector> GetTargetFunctions( - const IRModule& mod, const Variant>& shared_transform) { + const IRModule& mod, const ffi::Variant>& shared_transform) { std::vector> target_functions; - if (shared_transform.as>().value_or(Array{}).size()) { - auto names = shared_transform.as>().value(); + if (shared_transform.as>().value_or(ffi::Array{}).size()) { + auto names = shared_transform.as>().value(); for (const auto& name : names) { auto gvar = mod->global_var_map_.Get(name); CHECK(gvar) << "When LiftTransformParams is called with a list of function names, " @@ -752,11 +753,11 @@ std::vector> GetTargetFunctions( namespace transform { -Pass PartitionTransformParams(Variant> shared_transform) { +Pass PartitionTransformParams(ffi::Variant> shared_transform) { auto pass_func = [=](IRModule mod, PassContext pc) { std::optional global_collect_info; - CHECK((shared_transform.as() || shared_transform.as>())) + CHECK((shared_transform.as() || shared_transform.as>())) << "shared_transform should be a boolean or an array of function names"; auto target_functions = GetTargetFunctions(mod, shared_transform); @@ -783,7 +784,7 @@ Pass PartitionTransformParams(Variant> shared_transform) { updated_runtime_functions->Add(gvar, new_runtime_func); } - Map lifted_transform_functions; + ffi::Map lifted_transform_functions; if (global_collect_info.has_value()) { auto global_transform = global_collect_info.value().MakeCompileTimeFunc(); lifted_transform_functions.Set("transform_params", global_transform); @@ -818,7 +819,7 @@ Pass PartitionTransformParams(Variant> shared_transform) { return tvm::transform::CreateModulePass(pass_func, 1, "PartitionTransformParams", {}); } -Pass LiftTransformParams(Variant> shared_transform) { +Pass LiftTransformParams(ffi::Variant> shared_transform) { // A post-proc utility as as the third step in LiftTransformParams // // 1. PartitionTransformParams: Partition each function into a @@ -867,10 +868,10 @@ Pass LiftTransformParams(Variant> shared_transform) { "LiftTransformParams"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.LiftTransformParams", LiftTransformParams); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 36911cd094d8..d1e61b1c5748 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -38,14 +38,14 @@ class Mutator : public ExprMutator { if (op->op.same_as(alloc_tensor_op)) { CHECK_EQ(op->args.size(), 4) << "Op " << op->op << " should have three arguments, " << "[shape, dtype, runtime_device_index, storage_scope]. " - << "However, received " << GetRef(op); + << "However, received " << ffi::GetRef(op); auto shape_arg = op->args[0]; auto dtype = Downcast(op->args[1]); PrimValue runtime_device_index = Downcast(op->args[2]); StringImm storage_scope = Downcast(op->args[3]); - auto shape = [&]() -> Array { + auto shape = [&]() -> ffi::Array { if (auto ptr = shape_arg.as()) { return ptr->values; } @@ -100,10 +100,10 @@ Pass LowerAllocTensor() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "LowerAllocTensor", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.LowerAllocTensor", LowerAllocTensor); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 025e91c3c3ab..e8a9b74d94c4 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -166,14 +166,14 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } private: - Optional GetCodegenName(const Expr& callee) { + ffi::Optional GetCodegenName(const Expr& callee) { auto const* gvar = callee.as(); if (!gvar) { return std::nullopt; } auto composite_name_opt = - mod_->Lookup(GetRef(gvar))->GetAttr(attr::kComposite); + mod_->Lookup(ffi::GetRef(gvar))->GetAttr(attr::kComposite); if (!composite_name_opt) { return std::nullopt; } @@ -181,16 +181,16 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { return relax::GetCodegenName(composite_name_opt.value()); } - Optional GetCodegenName(Group* group) { + ffi::Optional GetCodegenName(Group* group) { if (auto opt_str = group->attrs.Get(attr::kCodegen)) { - return Downcast(opt_str.value()); + return Downcast(opt_str.value()); } return std::nullopt; } Group* CreateNewGroup(const CallNode* call) { Group* group = arena_->make(); - if (Optional codegen_name = GetCodegenName(call->op)) { + if (ffi::Optional codegen_name = GetCodegenName(call->op)) { group->attrs.Set(attr::kCodegen, codegen_name.value()); } return group; @@ -220,7 +220,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } } - std::unordered_set GetParentGroupDependencies(const Array& args) { + std::unordered_set GetParentGroupDependencies(const ffi::Array& args) { // Collect groups that parent groups depend on std::unordered_set dependencies; @@ -233,7 +233,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { return dependencies; } - void UpdateGroupDependencies(Group* group, const Array& args) { + void UpdateGroupDependencies(Group* group, const ffi::Array& args) { Group* group_root = group->FindRoot(); std::function visit_expr = [&](Expr expr) { @@ -269,7 +269,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } std::vector GetGroupsToMerge(const CallNode* call) { - Optional codegen_name = GetCodegenName(call->op); + ffi::Optional codegen_name = GetCodegenName(call->op); if (!codegen_name.has_value()) { return {}; } @@ -279,7 +279,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { for (const auto& arg : call->args) { auto arg_group = memo_[arg]; - Optional arg_codegen_name = GetCodegenName(arg_group); + ffi::Optional arg_codegen_name = GetCodegenName(arg_group); if (arg_codegen_name == codegen_name && !parent_dependencies.count(arg_group->FindRoot())) { // If there is a parent group with the same target, which none of the parent dependency // groups depends on, merging "this" call node into the parent group will not form a cyclic @@ -308,7 +308,7 @@ class CompositeInliner : public ExprMutator { using ExprMutator::VisitExpr_; Function Run(Function func) { - inlined_functions_ = Map(); + inlined_functions_ = ffi::Map(); auto new_body = VisitExpr(ToNonDataflow(func->body)); auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs, func->span); @@ -319,7 +319,7 @@ class CompositeInliner : public ExprMutator { if (call->op->IsInstance()) { auto gvar = Downcast(call->op); auto func = Downcast(mod_->Lookup(gvar)); - if (func->GetAttr(attr::kComposite)) { + if (func->GetAttr(attr::kComposite)) { if (!inlined_functions_.count(func)) { auto new_func = CopyWithNewVars(func); new_func = WithoutAttr(new_func, tvm::relax::attr::kPrimitive); @@ -334,7 +334,7 @@ class CompositeInliner : public ExprMutator { private: IRModule mod_; - Map inlined_functions_; + ffi::Map inlined_functions_; }; /*! @@ -361,7 +361,7 @@ class CompositeFunctionAnnotator : public ExprMutator { if (call->op->IsInstance()) { GlobalVar cur_var = Downcast(call->op); auto func = Downcast(mod_->Lookup(cur_var)); - if (auto codegen_name = func->GetAttr(attr::kCodegen)) { + if (auto codegen_name = func->GetAttr(attr::kCodegen)) { GlobalVar new_var; if (var_map_.count(cur_var) > 0) { // if we visited before, we don't need to create the new function, @@ -374,7 +374,7 @@ class CompositeFunctionAnnotator : public ExprMutator { builder_->GetContextIRModule()->Remove(old_var); // rename the function. - String new_func_name = cur_var->name_hint + "_" + codegen_name.value(); + ffi::String new_func_name = cur_var->name_hint + "_" + codegen_name.value(); Function new_func = inliner.Run(Downcast(func)); new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, new_func_name); new_func = WithoutAttr(std::move(new_func), tvm::relax::attr::kPrimitive); @@ -388,7 +388,7 @@ class CompositeFunctionAnnotator : public ExprMutator { return Call(new_var, call->args); } } - return GetRef(call); + return ffi::GetRef(call); } private: @@ -422,10 +422,10 @@ Pass MergeCompositeFunctions() { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.MergeCompositeFunctions", MergeCompositeFunctions); -}); +} } // namespace transform diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index acad7d154402..dd5b93267476 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -35,9 +35,10 @@ namespace transform { class MetaScheduleTuner { public: - explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global, - Integer max_trials_per_task, Optional> op_names, - Map params = {}) + explicit MetaScheduleTuner(Target target, ffi::String work_dir, Integer max_trials_global, + Integer max_trials_per_task, + ffi::Optional> op_names, + ffi::Map params = {}) : target_(target), work_dir_(work_dir), max_trials_global_(max_trials_global), @@ -64,15 +65,15 @@ class MetaScheduleTuner { private: Target target_; - String work_dir_; + ffi::String work_dir_; Integer max_trials_global_; Integer max_trials_per_task_; - Optional> op_names_; - Map params_; + ffi::Optional> op_names_; + ffi::Map params_; tvm::ffi::Function normalize_mod_func_; }; -Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = false) { +Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_warning = false) { using tvm::meta_schedule::Database; Target target = Target::Current(false); const std::optional normalize_mod_func_ = @@ -80,28 +81,29 @@ Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; auto pass_func = [=](IRModule mod, PassContext ctx) { - Database database{nullptr}; + Database database{ffi::UnsafeInit()}; if (Database::Current().defined()) { database = Database::Current().value(); } else { ICHECK(work_dir.has_value()); - String path_workload = work_dir.value() + "/database_workload.json"; - String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; + std::filesystem::create_directories(work_dir.value().c_str()); + ffi::String path_workload = work_dir.value() + "/database_workload.json"; + ffi::String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload << ", Tuning records at: " << path_tuning_record; database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); } - Map result; - auto mod_eq_structural = meta_schedule::ModuleEquality::Create("ignore-ndarray"); + ffi::Map result; + auto mod_eq_structural = meta_schedule::ModuleEquality::Create("ignore-tensor"); for (const auto& iter : mod->functions) { GlobalVar gv = iter.first; BaseFunc base_func = iter.second; if (const auto* prim_func_node = base_func.as()) { - tir::PrimFunc prim_func = GetRef(prim_func_node); + tir::PrimFunc prim_func = ffi::GetRef(prim_func_node); IRModule tir_mod = (*normalize_mod_func_)(prim_func).cast(); - if (Optional opt_record = + if (ffi::Optional opt_record = database->QueryTuningRecord(tir_mod, target, gv->name_hint)) { meta_schedule::TuningRecord record = opt_record.value(); tir::Schedule sch{nullptr}; @@ -146,10 +148,10 @@ Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = return CreateModulePass(pass_func, 0, "MetaScheduleApplyDatabase", {}); } -Pass MetaScheduleTuneIRMod(Map params, String work_dir, +Pass MetaScheduleTuneIRMod(ffi::Map params, ffi::String work_dir, Integer max_trials_global, - Optional max_trials_per_task = std::nullopt, - Optional> op_names = std::nullopt) { + ffi::Optional max_trials_per_task = std::nullopt, + ffi::Optional> op_names = std::nullopt) { Target target = Target::Current(false); auto pass_func = [=](IRModule m, PassContext ctx) { auto max_trials_task = max_trials_per_task.value_or(max_trials_global); @@ -162,7 +164,7 @@ Pass MetaScheduleTuneIRMod(Map params, String work_dir /*traceable*/ true); } -Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { +Pass MetaScheduleTuneTIR(ffi::String work_dir, Integer max_trials_global) { Target target = Target::Current(false); ffi::TypedFunction pass_func = [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { @@ -176,13 +178,13 @@ Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { /*traceable*/ true); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.transform.MetaScheduleApplyDatabase", MetaScheduleApplyDatabase) .def("relax.transform.MetaScheduleTuneIRMod", MetaScheduleTuneIRMod) .def("relax.transform.MetaScheduleTuneTIR", MetaScheduleTuneTIR); -}); +} } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 8bd740009ef8..e764e333f721 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -46,7 +46,7 @@ class NormalizeMutator : public ExprMutatorBase { Expr body = this->VisitWithNewScope(op->body, op->params); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); } @@ -58,13 +58,13 @@ class NormalizeMutator : public ExprMutatorBase { Expr false_b = this->VisitWithNewScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); } } - Expr VisitWithNewScope(const Expr& expr, Optional> params = std::nullopt) { + Expr VisitWithNewScope(const Expr& expr, ffi::Optional> params = std::nullopt) { builder_->BeginBindingBlock(); if (params.defined()) { builder_->BeginScope(params); @@ -82,7 +82,7 @@ class NormalizeMutator : public ExprMutatorBase { Expr VisitExpr_(const SeqExprNode* op) final { bool all_blocks_unchanged = true; - Array blocks; + ffi::Array blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { @@ -100,7 +100,7 @@ class NormalizeMutator : public ExprMutatorBase { } if (all_blocks_unchanged && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return SeqExpr(blocks, body); } @@ -151,7 +151,7 @@ class NormalizeMutator : public ExprMutatorBase { } if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); } else { builder_->EmitNormalized(VarBinding(binding->var, new_value)); } @@ -161,7 +161,7 @@ class NormalizeMutator : public ExprMutatorBase { Expr new_value = this->VisitExpr(binding->value); if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); } else { builder_->EmitNormalized( MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->struct_info)); @@ -219,7 +219,7 @@ class GlobalVarNormalizer : private ExprMutator { /*! \brief Check if any function needs to be renamed. */ bool NeedRename() { for (const auto& [gvar, func] : module_->functions) { - auto global_symbol = func->GetAttr("global_symbol"); + auto global_symbol = func->GetAttr("global_symbol"); if (global_symbol && global_symbol.value() != gvar->name_hint) { return true; } @@ -230,7 +230,7 @@ class GlobalVarNormalizer : private ExprMutator { /*! \brief Add public functions to the builder, and update the name supplier. */ void AddPublicFunctions() { for (const auto& [gvar, func] : module_->functions) { - auto global_symbol = func->GetAttr("global_symbol"); + auto global_symbol = func->GetAttr("global_symbol"); if (!global_symbol) { continue; } @@ -250,7 +250,7 @@ class GlobalVarNormalizer : private ExprMutator { */ void AddPrivateFunctions() { for (auto [gvar, func] : module_->functions) { - auto global_symbol = func->GetAttr("global_symbol"); + auto global_symbol = func->GetAttr("global_symbol"); if (global_symbol) { continue; } @@ -262,13 +262,13 @@ class GlobalVarNormalizer : private ExprMutator { } Expr VisitExpr_(const GlobalVarNode* op) final { - ICHECK(gvar_map_.count(GetRef(op))); - return gvar_map_[GetRef(op)]; + ICHECK(gvar_map_.count(ffi::GetRef(op))); + return gvar_map_[ffi::GetRef(op)]; } IRModule module_; NameSupply name_supply_; - Map gvar_map_; + ffi::Map gvar_map_; }; namespace transform { @@ -280,10 +280,10 @@ Pass Normalize() { return CreateFunctionPass(pass_func, 1, "Normalize", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.Normalize", Normalize); -}); +} Pass NormalizeGlobalVar() { auto pass_func = [=](IRModule mod, PassContext pc) { @@ -294,10 +294,10 @@ Pass NormalizeGlobalVar() { /*pass_name=*/"NormalizeGlobalVar", /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.NormalizeGlobalVar", NormalizeGlobalVar); -}); +} } // namespace transform diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 96885eb255ca..7f1042d57ecc 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -54,8 +54,9 @@ class VDeviceLookup { VDevice operator()(Attrs hint_on_device_attrs) { auto attrs = hint_on_device_attrs.as(); ICHECK(attrs); - int32_t device_type = attrs->dev_type; - int32_t device_id = attrs->dev_id; + int32_t device_type = attrs->device_type; + int32_t device_id = attrs->index; + ffi::String memory_scope = attrs->memory_scope; CHECK(opt_vdevices_.defined()) << "ValueError: The target VDevice in the GlobalInfos was not found."; @@ -66,7 +67,8 @@ class VDeviceLookup { for (auto vdevice : vdevices) { int dev_type = vdevice->target->GetTargetDeviceType(); - if (dev_type == device_type && vdevice->vdevice_id == device_id) { + if (dev_type == device_type && vdevice->vdevice_id == device_id && + memory_scope == vdevice->memory_scope) { return vdevice; } } @@ -77,12 +79,12 @@ class VDeviceLookup { } private: - Optional> opt_vdevices_ = std::nullopt; + ffi::Optional> opt_vdevices_ = std::nullopt; }; class DeviceHintCollector : ExprVisitor { public: - static std::tuple, Map> Collect(IRModule mod) { + static std::tuple, ffi::Map> Collect(IRModule mod) { DeviceHintCollector visitor{VDeviceLookup(mod)}; for (const auto& [gvar, base_func] : mod->functions) { @@ -178,7 +180,7 @@ class DeviceHintCollector : ExprVisitor { } } - Optional LookupBinding(const Expr& expr) const { + ffi::Optional LookupBinding(const Expr& expr) const { if (auto var = expr.as()) { if (auto bound = binding_lookup_.Get(var.value())) { return bound.value(); @@ -194,14 +196,14 @@ class DeviceHintCollector : ExprVisitor { // A lookup of variable bindings, used to unwrap the variable // bindings in functions that return a tuple. - Map binding_lookup_; + ffi::Map binding_lookup_; // A map from Var to the VDevice they are known to occur on. This // only contains variables whose location is explicitly known // (e.g. output of `R.hint_on_device`, variables with explicit // `VDevice` in their struct info), and does not include variables // whose location is (e.g. input of `R.hint_on_device`). - Map known_vdevice_; + ffi::Map known_vdevice_; // A map from Var to the VDevice they are expected to occur on. If // a variable appears in both `known_vdevice_` and @@ -213,7 +215,7 @@ class DeviceHintCollector : ExprVisitor { // Therefore, we only determine that `A` is located on "cuda:0" if // no other annotation has already provided a known location for // `A`. - Map hint_on_device_inputs_; + ffi::Map hint_on_device_inputs_; // The `R.hint_on_device` operator. const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); @@ -223,7 +225,7 @@ class DeviceHintCollector : ExprVisitor { // same VDevice. class VDeviceSetCollector : ExprVisitor { public: - static Map> Collect(IRModule mod) { + static ffi::Map> Collect(IRModule mod) { VDeviceSetCollector visitor; for (const auto& [gvar, base_func] : mod->functions) { if (auto func = base_func.as()) { @@ -249,13 +251,13 @@ class VDeviceSetCollector : ExprVisitor { void VisitExpr_(const VarNode* op) override { if (current_binding_) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); var_to_co_located_vars_[current_binding_.value()].push_back(var); var_to_co_located_vars_[var].push_back(current_binding_.value()); } } - Optional current_binding_ = std::nullopt; + ffi::Optional current_binding_ = std::nullopt; // Lookup from relax variable to the set of relax variables which // must be located on the same device. For example, a trivial @@ -267,18 +269,18 @@ class VDeviceSetCollector : ExprVisitor { // `relax::Call` operation must be located on the same device, with // the exception of `R.hint_on_device` and `R.to_vdevice`, which may // introduce a transfer across devices. - std::unordered_map> var_to_co_located_vars_; + std::unordered_map> var_to_co_located_vars_; const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; -Map InferVDevice(IRModule mod) { +ffi::Map InferVDevice(IRModule mod) { auto [explicit_annotations, hint_on_device_args] = DeviceHintCollector::Collect(mod); auto co_located_var_lookup = VDeviceSetCollector::Collect(mod); - Map known_vdevice; + ffi::Map known_vdevice; std::vector to_visit; // A helper function to propagate all `known_vdevice` entries based @@ -324,7 +326,7 @@ Map InferVDevice(IRModule mod) { // Update the module to include the inferred VDevice annotations. class VDeviceStructInfoUpdater : ExprMutator { public: - static IRModule Apply(IRModule mod, Map vdevice_map) { + static IRModule Apply(IRModule mod, ffi::Map vdevice_map) { VDeviceStructInfoUpdater mutator(VDeviceLookup(mod), vdevice_map); IRModule updates; @@ -346,7 +348,7 @@ class VDeviceStructInfoUpdater : ExprMutator { } private: - VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, Map vdevice_map) + VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, ffi::Map vdevice_map) : vdevice_lookup_(vdevice_lookup), vdevice_map_(vdevice_map) {} Var VisitVarDef(const Var& old_var) override { @@ -390,14 +392,14 @@ class VDeviceStructInfoUpdater : ExprMutator { if (input_vdevice.defined() && input_vdevice.value() == output_vdevice) { return arg; } else { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dst_vdevice = output_vdevice; return Call(to_vdevice_op_, {arg}, Attrs(attrs), {}); } } VDeviceLookup vdevice_lookup_; - Map vdevice_map_; + ffi::Map vdevice_map_; const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; @@ -416,10 +418,10 @@ Pass RealizeVDevice() { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.RealizeVDevice", RealizeVDevice); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index d8bb6465da05..aaa38fcda7ce 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -49,13 +49,13 @@ class PurityRemover : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { if (call->op == call_pure_packed_op_) { - auto ret = Call(call->args[0], Array(call->args.begin() + 1, call->args.end()), + auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), call->attrs, call->sinfo_args); return VisitExpr(ret); } if (call->op == call_inplace_packed_op_) { // call_inplace_packed has its own attrs so we don't pass those down - auto ret = Call(call->args[0], Array(call->args.begin() + 1, call->args.end()), + auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), tvm::Attrs(), call->sinfo_args); return VisitExpr(ret); } @@ -68,7 +68,7 @@ class PurityRemover : public ExprMutator { Expr VisitExpr_(const FunctionNode* func) override { // handling inner functions: we will remove purity annotations from them too - return RemovePurity(GetRef(func)); + return RemovePurity(ffi::GetRef(func)); } private: @@ -89,10 +89,10 @@ Pass RemovePurityChecking() { return CreateFunctionPass(pass_func, 0, "RemovePurityChecking", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.RemovePurityChecking", RemovePurityChecking); -}); +} } // namespace transform diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 26145cde1d48..83170abd635b 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -44,7 +44,7 @@ class PartialTupleUsageCollector : ExprVisitor { PMap num_outputs; for (const auto& [gvar, base_func] : mod->functions) { - bool is_exposed = base_func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_exposed = base_func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (!is_exposed) { if (auto relax_func = base_func.as()) { @@ -98,21 +98,21 @@ class PartialTupleUsageCollector : ExprVisitor { CHECK_GE(op->index, 0) << "IndexError: " << "Indices for TupleGetItem must be non-negative, " - << "but expression " << GetRef(op) << " uses a tuple index of " - << op->index; + << "but expression " << ffi::GetRef(op) + << " uses a tuple index of " << op->index; size_t index = op->index; CHECK_LT(index, used_indices.size()) << "IndexError: " << "Indices for TupleGetItem must be less than the size of the tuple, " - << "but expression " << GetRef(op) << " uses a tuple index of " << op->index + << "but expression " << ffi::GetRef(op) << " uses a tuple index of " << op->index << " for a tuple of size " << used_indices.size(); used_indices[index] = true; } } void VisitExpr_(const VarNode* op) override { - if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef(op))) { + if (auto* usage_mask_ptr = GetCalleeUsageMask(ffi::GetRef(op))) { auto& usage_mask = *usage_mask_ptr; for (size_t i = 0; i < usage_mask.size(); i++) { usage_mask[i] = true; @@ -138,7 +138,7 @@ class PartialTupleUsageCollector : ExprVisitor { } Expr UnwrapBindings(Expr expr) const { - auto get_bound_value = [&](const Expr& expr) -> Optional { + auto get_bound_value = [&](const Expr& expr) -> ffi::Optional { if (auto var = expr.as()) { if (auto known_binding = known_bindings_.Get(var.value())) { return known_binding.value(); @@ -153,7 +153,7 @@ class PartialTupleUsageCollector : ExprVisitor { return expr; } - Map known_bindings_; + ffi::Map known_bindings_; PMap> output_usage_mask_; }; @@ -164,7 +164,7 @@ Function UpdateCallee(Function func, const std::vector& usage_mask) { ICHECK(old_ret_sinfo) << "All functions returning non-tuple outputs " << "should have been pruned already by PartialTupleUsageCollector"; - Array outputs; + ffi::Array outputs; // This helper variable will be removed by the post-proc of // CanonicalizeBindings and DeadCodeElimination. @@ -267,7 +267,7 @@ Pass RemoveUnusedOutputs() { num_outputs_used += used; } - Array new_results; + ffi::Array new_results; int new_result_index = 0; for (size_t i = 0; i < usage_mask.size(); i++) { if (usage_mask[i]) { @@ -337,10 +337,10 @@ Pass RemoveUnusedOutputs() { "RemoveUnusedOutputs"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.RemoveUnusedOutputs", RemoveUnusedOutputs); -}); +} } // namespace transform diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 2e88ebe417b3..5003dec8a8d2 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -51,11 +51,11 @@ struct CalleeAnalysis { * * \return The arguments to be used for the modified function */ - std::function(Array)> arg_updater; + std::function(ffi::Array)> arg_updater; }; std::optional AnalyzeCallee(Function func) { - bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return std::nullopt; auto free_relax_vars = [&]() -> PSet { @@ -66,7 +66,7 @@ std::optional AnalyzeCallee(Function func) { std::vector parameter_mask; parameter_mask.reserve(func->params.size()); - Array params; + ffi::Array params; for (const auto& param : func->params) { bool is_used = free_relax_vars.count(param); parameter_mask.push_back(is_used); @@ -93,7 +93,7 @@ std::optional AnalyzeCallee(Function func) { }(); // Use an array to define the order of the symbolic variables - Array free_tir_vars; + ffi::Array free_tir_vars; for (const auto& tir_var : FreeSymbolicVars(func->body)) { if (!defined_tir_params.count(tir_var)) { free_tir_vars.push_back(tir_var); @@ -110,12 +110,12 @@ std::optional AnalyzeCallee(Function func) { Downcast(func->struct_info_)->purity); auto arg_updater = [parameter_mask, old_relax_params = func->params, - free_tir_vars](Array old_args) -> Array { + free_tir_vars](ffi::Array old_args) -> ffi::Array { ICHECK_EQ(old_args.size(), parameter_mask.size()) << "Call provides " << old_args.size() << ", but the callee accepts " << parameter_mask.size() << " parameters"; - Array new_args; + ffi::Array new_args; for (size_t i = 0; i < old_args.size(); i++) { if (parameter_mask.at(i)) { new_args.push_back(old_args[i]); @@ -123,7 +123,7 @@ std::optional AnalyzeCallee(Function func) { } if (free_tir_vars.size()) { - Map old_binding; + ffi::Map old_binding; for (size_t i = 0; i < old_relax_params.size(); i++) { old_binding.Set(old_relax_params[i], old_args[i]); } @@ -251,10 +251,10 @@ Pass RemoveUnusedParameters() { return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.RemoveUnusedParameters", RemoveUnusedParameters); -}); +} } // namespace transform diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index b97981a7f4e5..73bc1853816e 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -41,7 +41,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns() { +std::tuple)>> CreatePatterns() { // TODO(Lunderberg): Allow pattern-matching to handle a flexible // number of arguments, each of which matches the same type of // pattern. @@ -73,7 +73,7 @@ std::tuple)>> Crea auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern { ICHECK_LT(num_concat, pat_permute_dims.size()); auto concat_tuple = TuplePattern( - Array(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat)); + ffi::Array(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat)); return IsOp("relax.concat")(concat_tuple); }; @@ -82,7 +82,7 @@ std::tuple)>> Crea pat_concat = pat_concat | make_pattern_with_num_concat(i); } - auto get_permute_dims_optional_axes = [](const Expr& expr) -> Optional> { + auto get_permute_dims_optional_axes = [](const Expr& expr) -> ffi::Optional> { auto call = expr.as(); ICHECK(call); auto attrs = call->attrs.as(); @@ -92,12 +92,12 @@ std::tuple)>> Crea }; auto get_permute_dims_axes = - [get_permute_dims_optional_axes](const Expr& expr) -> Array { + [get_permute_dims_optional_axes](const Expr& expr) -> ffi::Array { if (auto opt_axes = get_permute_dims_optional_axes(expr)) { return opt_axes.value(); } else { auto call = Downcast(expr); - Array permutation; + ffi::Array permutation; auto arg_sinfo = call->args[0]->struct_info_.as(); CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, " << "but argument " << call->args[0] << " has struct info " @@ -111,7 +111,7 @@ std::tuple)>> Crea } }; - auto permute_dims_axes_are_compatible = [&](const Array& permute_dims) -> bool { + auto permute_dims_axes_are_compatible = [&](const ffi::Array& permute_dims) -> bool { auto first_axes = get_permute_dims_axes(permute_dims[0]); for (size_t i_arg = 1; i_arg < permute_dims.size(); i_arg++) { auto i_axes = get_permute_dims_axes(permute_dims[i_arg]); @@ -127,9 +127,9 @@ std::tuple)>> Crea return true; }; - auto rewriter = [=](Expr expr, Map matches) -> Expr { - Array args; - Array all_permute_dims; + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { + ffi::Array args; + ffi::Array all_permute_dims; for (size_t i = 0; i < max_concat; i++) { if (auto permute_dim_expr = matches.Get(pat_permute_dims[i])) { all_permute_dims.push_back(permute_dim_expr.value()); @@ -145,7 +145,8 @@ std::tuple)>> Crea if (!permute_dims_axes_are_compatible(all_permute_dims)) { return expr; } - Optional> permute_axes = get_permute_dims_optional_axes(all_permute_dims[0]); + ffi::Optional> permute_axes = + get_permute_dims_optional_axes(all_permute_dims[0]); Call concat_call = Downcast(matches[pat_concat]); auto concat_attrs = concat_call->attrs.as(); @@ -174,11 +175,11 @@ Pass ReorderPermuteDimsAfterConcat() { return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ReorderPermuteDimsAfterConcat", ReorderPermuteDimsAfterConcat); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index eebec15f52ce..25f245101b1b 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -41,7 +41,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns() { +std::tuple)>> CreatePatterns() { auto pat_lhs = WildcardPattern(); auto pat_weights = WildcardPattern(); @@ -50,7 +50,7 @@ std::tuple)>> Crea auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs); - auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { auto lhs = matches[pat_lhs]; auto weights = matches[pat_weights]; auto indices = matches[pat_indices]; @@ -114,7 +114,7 @@ std::tuple)>> Crea // indices.shape = [batch1] // reordered_weight.shape = [infeatures, table_size, outfeatures] - auto reordered_weight = permute_dims(weights, Array{Integer(1), Integer(0), Integer(2)}); + auto reordered_weight = permute_dims(weights, ffi::Array{Integer(1), Integer(0), Integer(2)}); // fused_weight.shape = [infeatures, table_size * outfeatures] auto fused_weight = reshape(reordered_weight, ShapeExpr({weight_shape[1], weight_shape[0] * weight_shape[2]})); @@ -157,10 +157,10 @@ Pass ReorderTakeAfterMatmul() { return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ReorderTakeAfterMatmul", ReorderTakeAfterMatmul); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/replace_global_vars.cc b/src/relax/transform/replace_global_vars.cc index ea5d5e18d8ff..48548de887cd 100644 --- a/src/relax/transform/replace_global_vars.cc +++ b/src/relax/transform/replace_global_vars.cc @@ -37,12 +37,12 @@ namespace { using tvm::transform::GlobalVarReplacer; struct Mutator : ExprMutator { - Map replacements; - explicit Mutator(Map replacements) : replacements(replacements) {} + ffi::Map replacements; + explicit Mutator(ffi::Map replacements) : replacements(replacements) {} using ExprMutator::VisitExpr_; Expr VisitExpr_(const GlobalVarNode* node) override { - auto gvar = GetRef(node); + auto gvar = ffi::GetRef(node); return replacements.Get(gvar).value_or(gvar); } }; @@ -51,14 +51,14 @@ struct Mutator : ExprMutator { TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& func, - Map replacements) -> BaseFunc { + ffi::Map replacements) -> BaseFunc { Mutator mutator(replacements); auto new_func = Downcast(mutator(Downcast(func))); // If the function is externally exposed, and is being replaced // by a GlobalVar with a new name, then the function's // kGlobalSymbol must be updated to match. - if (auto opt = new_func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = new_func->GetAttr(tvm::attr::kGlobalSymbol)) { auto name = opt.value(); for (const auto& [before, after] : replacements) { if (before->name_hint == name) { @@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& func, - Map) -> BaseFunc { + ffi::Map) -> BaseFunc { return Downcast(func); }); diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index b1faf5c09271..8ecfabd7c27a 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -89,7 +89,7 @@ struct LiftedFunctionRewritePlan { // The corresponding binding vars in the original function of the inputs of the lifted function std::vector inputs; // The tir vars in the original function that are propagated to the lifted function - Optional propogated_tir_vars = std::nullopt; + ffi::Optional propogated_tir_vars = std::nullopt; }; /*! \brief Builder of the lifted function for cuda graph capturing or allocations */ @@ -123,22 +123,22 @@ class FuncBuilder : public ExprMutator { /*! \brief Build the new function */ Function Build() { - Array params; - Optional shape_expr = std::nullopt; + ffi::Array params; + ffi::Optional shape_expr = std::nullopt; if (shape_expr_inputs_.size()) { - Array tir_vars; + ffi::Array tir_vars; for (const auto* var : shape_expr_inputs_) { - auto new_var = GetRef(var).copy_with_suffix(""); - tir_var_remap_.Set(GetRef(var), new_var); + auto new_var = ffi::GetRef(var).copy_with_suffix(""); + tir_var_remap_.Set(ffi::GetRef(var), new_var); tir_vars.push_back(new_var); } shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars)); } // Set up the parameters for (const auto* input : inputs_) { - auto new_var = Var( - input->name_hint(), - VisitExprDepStructInfoField(Downcast>(input->struct_info_).value())); + auto new_var = Var(input->name_hint(), + VisitExprDepStructInfoField( + Downcast>(input->struct_info_).value())); var_remap_[input->vid] = new_var; params.push_back(new_var); } @@ -151,14 +151,14 @@ class FuncBuilder : public ExprMutator { VisitBinding_(binding); } // Set up the outputs - Array outputs; + ffi::Array outputs; for (const auto* var : outputs_) { outputs.push_back(VisitExpr_(var)); } auto output = builder_->Emit(Tuple(outputs)); auto block = builder_->EndBlock(); auto body = builder_->Normalize(SeqExpr({block}, output)); - Map attrs; + ffi::Map attrs; attrs.Set(relax::attr::kForcePure, true); auto func = Function(params, body, Downcast(output->struct_info_.value()), /*is_pure=*/true, /*attrs=*/DictAttrs(attrs)); @@ -171,7 +171,7 @@ class FuncBuilder : public ExprMutator { support::OrderedSet outputs_; support::OrderedSet shape_expr_inputs_; std::vector bindings_; - Map tir_var_remap_; + ffi::Map tir_var_remap_; }; // Collect the storage objects that are used as the function output @@ -250,7 +250,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { func->attrs.GetAttr(attr::kNumInput).value_or(Integer(func->params.size())); auto capture_symbolic_var_name_hints = ExtractSymbolicVarHints(func); for (int i = 0; i < static_cast(func->params.size()); ++i) { - Array symbolic_vars = DefinableTIRVarsInStructInfo( + ffi::Array symbolic_vars = DefinableTIRVarsInStructInfo( Downcast(func->params[i]->struct_info_.value())); if (i < num_inputs.IntValue()) { for (const auto& symbolic_var : symbolic_vars) { @@ -278,9 +278,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor { plan->is_alloc = is_alloc; plan->lifted_bindings = std::move(region->bindings_); if (region->shape_expr_inputs_.size()) { - Array tir_vars; + ffi::Array tir_vars; for (const auto* var : region->shape_expr_inputs_) { - tir_vars.push_back(GetRef(var)); + tir_vars.push_back(ffi::GetRef(var)); } plan->propogated_tir_vars = ShapeExpr(tir_vars); } @@ -306,10 +306,11 @@ class CUDAGraphRewritePlanner : public ExprVisitor { * \brief Extract the name hints of the symbolic variables that are allowed to be captured * from the function attributes. */ - std::unordered_set ExtractSymbolicVarHints(const Function& func) { + std::unordered_set ExtractSymbolicVarHints(const Function& func) { auto symbolic_var_names = - func->attrs.GetAttr>("relax.rewrite_cuda_graph.capture_symbolic_vars") - .value_or(Array()); + func->attrs + .GetAttr>("relax.rewrite_cuda_graph.capture_symbolic_vars") + .value_or(ffi::Array()); return {symbolic_var_names.begin(), symbolic_var_names.end()}; } @@ -365,7 +366,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { const auto* call_gv = call->op.as(); bool call_prim_func = - call_gv ? mod_->Lookup(GetRef(call_gv))->IsInstance() : false; + call_gv ? mod_->Lookup(ffi::GetRef(call_gv))->IsInstance() + : false; // Check whether the call can be lifted to the capture function. It requires all the arguments // to be static and the call to be a kernel launch or a pure operation (e.g. memory view). @@ -399,8 +401,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { if (const auto* op = call->op.as()) { return !support::StartsWith(op->name, "relax.memory") && !support::StartsWith(op->name, "relax.builtin") && op->name != "relax.reshape" && - !GetRef(op).same_as(null_value_op) && - !GetRef(op).same_as(call_builtin_with_ctx_op); + !ffi::GetRef(op).same_as(null_value_op) && + !ffi::GetRef(op).same_as(call_builtin_with_ctx_op); } return false; }(); @@ -442,7 +444,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { - if (IsStatic(GetRef(var))) { + if (IsStatic(ffi::GetRef(var))) { AddStaticBinding(binding, false); MarkAsFuncInput({var}); } else { @@ -525,7 +527,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } template - bool IsStatic(const Array& exprs, std::vector* vars_collector = nullptr, + bool IsStatic(const ffi::Array& exprs, std::vector* vars_collector = nullptr, std::vector* tir_vars_collector = nullptr) { bool result = true; for (const auto& expr : exprs) { @@ -657,7 +659,7 @@ Function MergeAllocationPlans(const std::vector& all bool operator<(const StorageRecord& other) const { return size < other.size; } }; // Using an (ordered) map to make sure the result is deterministic - std::map>> storage_records; + std::map>> storage_records; static const auto& mem_alloc_storage_op = Op::Get("relax.memory.alloc_storage"); // Collect the storage records for each storage scope. Storage records are stored separately @@ -675,7 +677,7 @@ Function MergeAllocationPlans(const std::vector& all int64_t virtual_device_id = Downcast(Downcast(alloc_storage->args[1])->value)->value; ICHECK_EQ(virtual_device_id, 0); - String storage_scope = Downcast(alloc_storage->args[2])->value; + ffi::String storage_scope = Downcast(alloc_storage->args[2])->value; auto [it, _] = storage_records.try_emplace(storage_scope, alloc_plans.size()); it->second[plan_id].emplace_back(StorageRecord{size, binding, plan}); } @@ -791,7 +793,7 @@ class CUDAGraphRewriter : public ExprMutator { plan->func, current_func_.value()->name_hint + "_cuda_graph_capture"); StructInfo call_sinfo = plan->func->ret_struct_info; // Arguments of the lifted function - Array args; + ffi::Array args; for (const auto& arg : plan->inputs) { args.push_back(VisitExpr_(arg)); } @@ -803,7 +805,7 @@ class CUDAGraphRewriter : public ExprMutator { const auto& shape_expr = plan->func->params.back(); auto symbolic_params = Downcast(shape_expr->struct_info_.value())->values.value(); - Map tir_var_remap; + ffi::Map tir_var_remap; ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size()); for (int i = 0; i < static_cast(symbolic_params.size()); ++i) { tir_var_remap.Set(Downcast(symbolic_params[i]), propogated_tir_vars->values[i]); @@ -811,8 +813,8 @@ class CUDAGraphRewriter : public ExprMutator { call_sinfo = Bind(call_sinfo, tir_var_remap); } // Arguments of builtin_run_or_capture - Array tuple_arg_fields{gv_func, Tuple(args), - PrimValue(IntImm(DataType::Int(64), index_capture_++))}; + ffi::Array tuple_arg_fields{gv_func, Tuple(args), + PrimValue(IntImm(DataType::Int(64), index_capture_++))}; if (plan->propogated_tir_vars.defined()) { // The shape expr is explicitly passed twice, one as the last argument of the lifted // function, one as the last argument of builtin_run_or_capture as the cache key. Explicitly @@ -857,7 +859,7 @@ class CUDAGraphRewriter : public ExprMutator { // the original var definition is not visited yet. return EmitRedef(op, it->second); } - return GetRef(op); + return ffi::GetRef(op); } Var EmitRedef(const VarNode* var, const Expr& redef) { @@ -872,8 +874,8 @@ class CUDAGraphRewriter : public ExprMutator { int index_alloc_ = 0; int index_capture_ = 0; support::Arena arena_; - Optional gv_global_alloc_ = std::nullopt; - Optional current_func_ = std::nullopt; + ffi::Optional gv_global_alloc_ = std::nullopt; + ffi::Optional current_func_ = std::nullopt; }; IRModule RewriteCUDAGraph(IRModule mod) { @@ -898,10 +900,10 @@ Pass RewriteCUDAGraph() { return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.RewriteCUDAGraph", RewriteCUDAGraph); -}); +} } // namespace transform diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index a9e5e8b3c5ff..fdaa2b927e2e 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -69,7 +69,7 @@ class DataflowReshapeRewriter : public ExprMutator { // We only rewrite the bindings that are not dataflow output (which means they are not // externally referenced) if (!binding->var->IsInstance()) { - this->builder_->EmitNormalized(GetRef(binding)); + this->builder_->EmitNormalized(ffi::GetRef(binding)); } else { ExprMutator::VisitBinding_(binding); } @@ -78,7 +78,7 @@ class DataflowReshapeRewriter : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); if (call->op != call_tir_op) { - return GetRef(call); + return ffi::GetRef(call); } // We bring the calls of reshape PrimFunc back to calls of high-level @@ -94,13 +94,13 @@ class DataflowReshapeRewriter : public ExprMutator { // then flattens the tuple input so that the fused TIR reshape function ends up having // multiple input buffers. But only one of them should be accessed and reshaped. if (used_tensor_arg_indices.size() != 1) { - return GetRef(call); + return ffi::GetRef(call); } auto arg = arg_tuple[used_tensor_arg_indices[0]]; if (!IsCallingTIRReshape(call, arg)) { - return GetRef(call); + return ffi::GetRef(call); } TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); @@ -111,7 +111,7 @@ class DataflowReshapeRewriter : public ExprMutator { const GlobalVar& global_var = Downcast(call->args[0]); const auto* func = mod_->functions.Get(global_var).value().as(); ICHECK_NOTNULL(func); - if (!HasReshapePattern(GetRef(func))) { + if (!HasReshapePattern(ffi::GetRef(func))) { return false; } @@ -130,7 +130,7 @@ class DataflowReshapeRewriter : public ExprMutator { if (inp_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) { return false; } - auto product = [](Array args) -> PrimExpr { + auto product = [](ffi::Array args) -> PrimExpr { PrimExpr p; if (args.empty()) { // Scalar tensors may be empty indicating a single element. @@ -166,10 +166,10 @@ Pass RewriteDataflowReshape() { return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.RewriteDataflowReshape", RewriteDataflowReshape); -}); +} } // namespace transform diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 0cc0a070aac5..71d557d031cf 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -37,12 +37,12 @@ namespace relax { class CodeGenRunner : ExprMutator { public: - using OptionMap = Map; + using OptionMap = ffi::Map; explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {} - IRModule Run(Optional> target_options, - Array entry_function_names) { + IRModule Run(ffi::Optional> target_options, + ffi::Array entry_function_names) { IRModule mod = builder_->GetContextIRModule(); support::OrderedSet entry_functions; @@ -59,7 +59,8 @@ class CodeGenRunner : ExprMutator { std::vector attr_entry_functions; for (const auto& [gv, func] : mod->functions) { if (func->GetLinkageType() == LinkageType::kExternal && - !func->GetAttr(attr::kCodegen) && func->IsInstance()) { + !func->GetAttr(attr::kCodegen) && + func->IsInstance()) { attr_entry_functions.push_back(gv); } } @@ -80,7 +81,7 @@ class CodeGenRunner : ExprMutator { auto out_mod = builder_->GetContextIRModule(); if (ext_mods.size()) { - if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { + if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { auto old_ext_mods = opt_old_ext_mods.value(); ext_mods.insert(ext_mods.begin(), old_ext_mods.begin(), old_ext_mods.end()); } @@ -89,7 +90,7 @@ class CodeGenRunner : ExprMutator { if (constant_names.size()) { // Some backends (e.g. TensorRT) expect constants to be passed when they are instantiated - Map constants; + ffi::Map constants; for (const auto& [constant, name] : constant_names) { ICHECK(!constants.count(name)) << "More than one constant with the name " << name; constants.Set(name, constant->data); @@ -106,11 +107,11 @@ class CodeGenRunner : ExprMutator { Expr VisitExpr_(const CallNode* call_node) override { auto call = Downcast(ExprMutator::VisitExpr_(call_node)); if (auto const* gvar_node = call_node->op.as()) { - const GlobalVar gvar = GetRef(gvar_node); + const GlobalVar gvar = ffi::GetRef(gvar_node); auto create_call_dps_packed = [call_node, this](Expr extern_func, StructInfo ret_struct_info) { - Array new_args({extern_func}); + ffi::Array new_args({extern_func}); new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); static const Op& call_op = Op::Get("relax.call_dps_packed"); @@ -139,7 +140,7 @@ class CodeGenRunner : ExprMutator { } } } - Array new_args; + ffi::Array new_args; for (const auto& arg : call_node->args) { new_args.push_back(VisitExpr(arg)); } @@ -148,8 +149,8 @@ class CodeGenRunner : ExprMutator { } Expr VisitExpr_(const FunctionNode* func_node) override { - Function func = GetRef(func_node); - auto opt_codegen = func->GetAttr(attr::kCodegen); + Function func = ffi::GetRef(func_node); + auto opt_codegen = func->GetAttr(attr::kCodegen); if (opt_codegen) { auto ext_symbol = GetExtSymbol(func); size_t count = 0; @@ -168,8 +169,9 @@ class CodeGenRunner : ExprMutator { } private: - Array InvokeCodegen(IRModule mod, Map target_options) { - std::unordered_map> target_functions; + ffi::Array InvokeCodegen(IRModule mod, + ffi::Map target_options) { + std::unordered_map> target_functions; for (const auto& entry : mod->functions) { if (entry.second->IsInstance()) { @@ -178,26 +180,26 @@ class CodeGenRunner : ExprMutator { PostOrderVisit(entry.second, [&target_functions](Expr e) { if (e->IsInstance()) { auto f = Downcast(e); - if (auto target_opt = f->GetAttr(attr::kCodegen)) { - String target = target_opt.value(); + if (auto target_opt = f->GetAttr(attr::kCodegen)) { + ffi::String target = target_opt.value(); target_functions[target].push_back(f); } } }); } - Array ext_mods; + ffi::Array ext_mods; for (const auto& [target, functions] : target_functions) { OptionMap options = target_options.Get(target).value_or(OptionMap()); // Start the codegen process. // Get the codegen with its ffi key. - String codegen_name = "relax.ext." + target; + ffi::String codegen_name = "relax.ext." + target; const auto codegen = tvm::ffi::Function::GetGlobal(codegen_name); ICHECK(codegen.has_value()) << "Codegen is not found: " << codegen_name << "\n"; - Array compiled_functions = - (*codegen)(functions, options, constant_names).cast>(); + ffi::Array compiled_functions = + (*codegen)(functions, options, constant_names).cast>(); ext_mods.insert(ext_mods.end(), compiled_functions.begin(), compiled_functions.end()); } @@ -205,7 +207,7 @@ class CodeGenRunner : ExprMutator { } /*! \brief The names of all constants in the original module. */ - Map constant_names; + ffi::Map constant_names; /*! \brief Extern funcs for each global variable. */ std::unordered_map extern_funcs_; }; @@ -213,18 +215,19 @@ class CodeGenRunner : ExprMutator { } // namespace relax namespace transform { -Pass RunCodegen(Optional>> target_options, - Array entry_functions) { +Pass RunCodegen( + ffi::Optional>> target_options, + ffi::Array entry_functions) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::CodeGenRunner(m).Run(target_options, entry_functions); }; return CreateModulePass(pass_func, 0, "RunCodegen", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.RunCodegen", RunCodegen); -}); +} } // namespace transform } // namespace tvm diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc new file mode 100644 index 000000000000..6258e14b666d --- /dev/null +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/specialize_tir_params.cc + * \brief Update PrimFunc buffers based on updated scope (or structure) info. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tvm::tir::Buffer; + +static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +class SpecializeTIRCallArgs : ExprMutator { + public: + IRModule Run(IRModule mod) { + mod_ = mod; + for (const auto& [gv, func] : mod->functions) { + if (func->IsInstance()) { + const auto& base_func = mod->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op == call_tir_op) { + return SpecializeTirPrimFunc(call); + } + return call; + } + + private: + Expr SpecializeTirPrimFunc(Call call) { + auto gv = Downcast(call->args[0]); + auto pfunc = Downcast(mod_->Lookup(gv)); + auto args = Downcast(call->args[1])->fields; + ffi::Map> param_map; + + for (size_t i = 0; i < args.size(); ++i) { + auto sinfo = GetStructInfo(args[i]); + CHECK(sinfo->IsInstance()) + << "Expected Tensor struct Info for call :" << call->op; + auto tensor_sinfo = Downcast(sinfo); + CHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; + ffi::String scope = "global"; + if (tensor_sinfo->vdevice.defined()) { + scope = tensor_sinfo->vdevice.value()->memory_scope; + } + ffi::String name; + if (args[i]->IsInstance()) { + name = Downcast(args[i])->name_hint(); + } else { + name = std::string({static_cast('A' + i)}); + } + + const Buffer& buffer = tir::decl_buffer(GetShapeFromTensorStructInfo(tensor_sinfo), + tensor_sinfo->dtype, name, scope); + param_map.Set(pfunc->params[i], buffer); + } + ffi::String scope = "global"; + auto out_sinfo = call->sinfo_args[0]; + if (out_sinfo->IsInstance()) { + auto sinfo = Downcast(out_sinfo); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[pfunc->params.size() - 1], buffer); + } else { + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + ffi::Array sinfo_fields; + int index = 0; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[args.size() + index], buffer); + index++; + } + } + + auto new_pfunc = Specialize(pfunc, param_map); + for (const auto& [var, buffer] : new_pfunc->buffer_map) { + auto* ptr = buffer->data->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + } + auto new_prim_func = WithAttr(new_pfunc, "scoped", Integer(1)); + updates_->Add(gv, new_prim_func); + return call; + } + IRModule mod_; + IRModule updates_; +}; + +namespace transform { + +Pass SpecializePrimFuncBasedOnCallSite() { + auto pass_func = [=](IRModule mod, PassContext pc) { + return relax::SpecializeTIRCallArgs().Run(mod); + }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"SpecializePrimFuncBasedOnCallSite", + /*required=*/{}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.SpecializePrimFuncBasedOnCallSite", + SpecializePrimFuncBasedOnCallSite); +} +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 41528c7d8690..00c6efb192a3 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -50,7 +50,7 @@ using relax::TIRPattern; class ForMatcher : public TensorizeComparator { public: using SymbolMap = std::unordered_map; - explicit ForMatcher(const tir::PrimFunc& pattern, const Array& pattern_vars) + explicit ForMatcher(const tir::PrimFunc& pattern, const ffi::Array& pattern_vars) : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false), pattern_(pattern) { for (const auto& pattern_var : pattern_vars) { this->pattern_vars_.insert(pattern_var); @@ -61,7 +61,7 @@ class ForMatcher : public TensorizeComparator { bool Match(const For& top) { const ForNode* pattern_top = pattern_->body.as()->block->body.as(); ICHECK(pattern_top) << "Invalid pattern function"; - if (!VisitStmt(top, GetRef(pattern_top))) { + if (!VisitStmt(top, ffi::GetRef(pattern_top))) { return false; } // Get evaluated symbols, buffers from the pattern. @@ -82,7 +82,7 @@ class ForMatcher : public TensorizeComparator { private: using ExprComparator::VisitExpr_; - Optional QueryEvaluatedSymbols(const Var& var) { + ffi::Optional QueryEvaluatedSymbols(const Var& var) { for (const SymbolMap& symbol_map : evaluated_symbols) { auto it = symbol_map.find(var); if (it != symbol_map.end()) { @@ -94,16 +94,16 @@ class ForMatcher : public TensorizeComparator { bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final { if (const auto* op = rhs.as()) { - if (pattern_vars_.count(GetRef(op))) { + if (pattern_vars_.count(ffi::GetRef(op))) { // special case for pattern vars const auto* lhs_ptr = lhs.as(); if (lhs_ptr == nullptr) { if (lhs->IsInstance() || lhs->IsInstance()) { - Optional value = QueryEvaluatedSymbols(GetRef(op)); + ffi::Optional value = QueryEvaluatedSymbols(ffi::GetRef(op)); if (value.defined()) { if (!analyzer_.CanProveEqual(lhs, value.value())) return false; } else { - evaluated_symbols.back()[GetRef(op)] = lhs; + evaluated_symbols.back()[ffi::GetRef(op)] = lhs; } return true; } else { @@ -116,7 +116,7 @@ class ForMatcher : public TensorizeComparator { if (const auto* rhs_ptr = rhs.as()) { const auto* operand_a = rhs_ptr->a.as(); const auto* operand_b = rhs_ptr->b.as(); - if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + if (operand_a != nullptr && pattern_vars_.count(ffi::GetRef(operand_a))) { // pattern var is on the left evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->b); @@ -124,11 +124,12 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 1); + evaluated_symbols.back()[ffi::GetRef(operand_a)] = + MakeConstScalar(rhs_ptr->b.dtype(), 1); return true; } } - if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + if (operand_b != nullptr && pattern_vars_.count(ffi::GetRef(operand_b))) { // pattern var is on the right evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->a); @@ -136,7 +137,8 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 1); + evaluated_symbols.back()[ffi::GetRef(operand_b)] = + MakeConstScalar(rhs_ptr->a.dtype(), 1); return true; } } @@ -145,7 +147,7 @@ class ForMatcher : public TensorizeComparator { if (const auto* rhs_ptr = rhs.as()) { const auto* operand_a = rhs_ptr->a.as(); const auto* operand_b = rhs_ptr->b.as(); - if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + if (operand_a != nullptr && pattern_vars_.count(ffi::GetRef(operand_a))) { // pattern var is on the left evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->b); @@ -153,11 +155,12 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 0); + evaluated_symbols.back()[ffi::GetRef(operand_a)] = + MakeConstScalar(rhs_ptr->b.dtype(), 0); return true; } } - if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + if (operand_b != nullptr && pattern_vars_.count(ffi::GetRef(operand_b))) { // pattern var is on the right evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->a); @@ -165,7 +168,8 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 0); + evaluated_symbols.back()[ffi::GetRef(operand_b)] = + MakeConstScalar(rhs_ptr->a.dtype(), 0); return true; } } @@ -241,8 +245,8 @@ class ForMatcher : public TensorizeComparator { bool VisitStmt_(const tir::ForNode* op, const Stmt& other) final { const auto* rhs = other.as(); - loop_stack_lhs_.push_back(GetRef(op)); - loop_stack_rhs_.push_back(GetRef(rhs)); + loop_stack_lhs_.push_back(ffi::GetRef(op)); + loop_stack_rhs_.push_back(ffi::GetRef(rhs)); // The body of loop must be loop or BlockRealize if (!op->body->IsInstance() && !op->body->IsInstance()) { return false; @@ -351,7 +355,7 @@ class ForMatcher : public TensorizeComparator { } template - bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { + bool CompareArray(const ffi::Array& lhs, const ffi::Array& rhs, F Self::*cmp) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { @@ -369,7 +373,7 @@ class ForMatcher : public TensorizeComparator { /*! \brief Analyze the function and match it with a list of patterns */ class TIRPatternMatcher { public: - static Array Match(Array patterns, Stmt body) { + static ffi::Array Match(ffi::Array patterns, Stmt body) { TIRPatternMatcher matcher(patterns); matcher.OpMatternMatch(body); if (matcher.fail_) return {}; @@ -377,13 +381,13 @@ class TIRPatternMatcher { } private: - explicit TIRPatternMatcher(Array patterns) : patterns_(patterns) {} + explicit TIRPatternMatcher(ffi::Array patterns) : patterns_(patterns) {} // Find an op that matches this block bool BlockPatternMatch(const For& top) { for (const TIRPattern& pattern : patterns_) { tir::PrimFunc pattern_func = pattern; - Array pattern_symbolic_vars; + ffi::Array pattern_symbolic_vars; int buffer_count = pattern_func->buffer_map.size(); for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { pattern_symbolic_vars.push_back(pattern_func->params[i]); @@ -391,7 +395,7 @@ class TIRPatternMatcher { ForMatcher block_matcher(pattern_func, pattern_symbolic_vars); if (block_matcher.Match(top)) { // We have found a match - Array symbol_values; + ffi::Array symbol_values; for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { symbol_values.push_back(block_matcher.evaluated_symbols.back()[pattern_func->params[i]]); } @@ -406,7 +410,7 @@ class TIRPatternMatcher { // For each block in the body, try to find its corresponding pattern one by one void OpMatternMatch(const Stmt& body) { - Array blocks; + ffi::Array blocks; if (body->IsInstance()) { // {for} blocks = {body}; @@ -418,7 +422,7 @@ class TIRPatternMatcher { } for (const Stmt& stmt : blocks) { const ForNode* loop = stmt.as(); - if (loop == nullptr || !BlockPatternMatch(GetRef(loop))) { + if (loop == nullptr || !BlockPatternMatch(ffi::GetRef(loop))) { break; } } @@ -429,9 +433,9 @@ class TIRPatternMatcher { /*! \brief Indicate whether we fail to match.*/ bool fail_ = false; /*! \brief The patterns we match the target stmt to.*/ - Array patterns_; + ffi::Array patterns_; /*! \brief The results of the matching process.*/ - Array match_results_; + ffi::Array match_results_; }; /*! \brief helper class to partition a function into 2 parts. Return function information which we @@ -444,7 +448,7 @@ class FunctionPartitioner : public StmtExprVisitor { /*! \brief alloc_buffers for the second function */ std::unordered_set allocs2; /*! \brief whether the current block is in the first function */ - Map block_partition; + ffi::Map block_partition; /*! \brief input buffers for the first function */ std::unordered_set input1; /*! \brief input buffers for the second function */ @@ -485,7 +489,7 @@ class FunctionPartitioner : public StmtExprVisitor { input2.insert(write->buffer); } } - block_partition.Set(GetRef(op), Bool(is_matching_)); + block_partition.Set(ffi::GetRef(op), Bool(is_matching_)); } // The number of matched ops in the function size_t num_matched_ops_; @@ -496,7 +500,7 @@ class FunctionPartitioner : public StmtExprVisitor { class BlockRemover : public StmtExprMutator { public: static Stmt RemoveBlockByPartition( - Stmt stmt, const Map& block_partition, + Stmt stmt, const ffi::Map& block_partition, const std::unordered_set& allocs, bool is_library_part) { BlockRemover remover(block_partition, allocs, is_library_part); @@ -504,24 +508,24 @@ class BlockRemover : public StmtExprMutator { } private: - BlockRemover(const Map& block_partition, + BlockRemover(const ffi::Map& block_partition, const std::unordered_set& allocs, bool is_library_part) : block_partition(block_partition), allocs_(allocs), is_library_part_(is_library_part) {} Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - ObjectPtr n = make_object(*block.operator->()); + ObjectPtr n = ffi::make_object(*block.operator->()); if (op->name_hint != "root") { - ICHECK(block_partition.count(GetRef(op))); - bool block_is_library = block_partition[GetRef(op)]->value; + ICHECK(block_partition.count(ffi::GetRef(op))); + bool block_is_library = block_partition[ffi::GetRef(op)]->value; if (!(is_library_part_ ^ block_is_library)) { n->body = block->body; } else { erased_ = true; } } - Array alloc_buffers; + ffi::Array alloc_buffers; for (const Buffer& b : block->alloc_buffers) { if (allocs_.count(b)) { alloc_buffers.push_back(b); @@ -532,7 +536,7 @@ class BlockRemover : public StmtExprMutator { } Stmt VisitStmt_(const SeqStmtNode* op) final { - Array seq; + ffi::Array seq; for (const Stmt& s : op->seq) { Stmt new_s = VisitStmt(s); if (erased_) { @@ -545,7 +549,7 @@ class BlockRemover : public StmtExprMutator { } bool erased_ = false; - Map block_partition; + ffi::Map block_partition; std::unordered_set allocs_; bool is_library_part_ = false; }; @@ -560,22 +564,21 @@ class BlockRemover : public StmtExprMutator { * \return A pair of functions, the first one is the library kernel and the second one is the * rest. */ -std::pair> SplitFunctions(PrimFunc func, - std::vector>* arg_partition, - Array patterns, - FCodegen f_codegen) { +std::pair> SplitFunctions( + PrimFunc func, std::vector>* arg_partition, ffi::Array patterns, + FCodegen f_codegen) { // Step 1. Find the library kernel and the rest. Stmt body = func->body.as()->block->body; - Array match_results = + ffi::Array match_results = TIRPatternMatcher::Match(patterns, func->body.as()->block->body); if (match_results.empty()) { return {func, std::nullopt}; } - Array codegen_result = f_codegen(match_results); + ffi::Array codegen_result = f_codegen(match_results); ICHECK(codegen_result.size() == 3); - String library_code = Downcast(codegen_result[0]); + ffi::String library_code = Downcast(codegen_result[0]); int num_matched_ops = Downcast(codegen_result[1])->value; - Array func1_args = Downcast>(codegen_result[2]); + ffi::Array func1_args = Downcast>(codegen_result[2]); if (num_matched_ops == 0) { return {func, std::nullopt}; } @@ -601,7 +604,7 @@ std::pair> SplitFunctions(PrimFunc func, Stmt body2 = BlockRemover::RemoveBlockByPartition(func->body, partitioner.block_partition, partitioner.allocs2, false); // Step 3. Craft the first function. - Array new_params1; + ffi::Array new_params1; std::vector arg_partition1; ICHECK_LE(func1_args.size(), partitioner.input1.size()); for (const auto& buffer : func1_args) { @@ -616,7 +619,7 @@ std::pair> SplitFunctions(PrimFunc func, } arg_partition->push_back(arg_partition1); new_params1.push_back(Var("output", DataType::Handle())); - Map new_buffer_map1; + ffi::Map new_buffer_map1; for (const auto& kv : func->buffer_map) { if (partitioner.input1.count(kv.second)) { new_buffer_map1.Set(kv.first, kv.second); @@ -626,7 +629,7 @@ std::pair> SplitFunctions(PrimFunc func, PrimFunc func1 = PrimFunc(new_params1, body1, func->ret_type, new_buffer_map1, func->attrs); func1 = WithAttr(func1, kLibraryKernel, library_code); // Step 4. Craft the second function. - Array new_params2; + ffi::Array new_params2; std::vector arg_partition2; new_params2.push_back(Var("input", DataType::Handle())); for (int i = 0; i < static_cast(func->params.size()); i++) { @@ -639,7 +642,7 @@ std::pair> SplitFunctions(PrimFunc func, } } arg_partition->push_back(arg_partition2); - Map new_buffer_map2; + ffi::Map new_buffer_map2; new_buffer_map2.Set(new_params2[0], partitioner.intermediate_buffer); for (const auto& kv : func->buffer_map) { if (partitioner.input2.count(kv.second)) { @@ -659,18 +662,18 @@ void StringReplace(std::string* subject, const std::string& search, const std::s } } -tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symbol) { +tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, ffi::String global_symbol) { using namespace tvm::tir; - Optional library_code = pf->attrs.GetAttr(kLibraryKernel); + ffi::Optional library_code = pf->attrs.GetAttr(kLibraryKernel); if (!library_code.has_value()) { - return GetRef(pf); + return ffi::GetRef(pf); } std::string source = library_code.value(); StringReplace(&source, "{global_symbol}", global_symbol); ExternFunc ret(global_symbol); - ret = WithAttrs(std::move(ret), Map{ - {String(kCSource), String(source)}, - {String(kCSourceFmt), String(kCSourceFmtCuda)}, + ret = WithAttrs(std::move(ret), ffi::Map{ + {ffi::String(kCSource), ffi::String(source)}, + {ffi::String(kCSourceFmt), ffi::String(kCSourceFmtCuda)}, }); return ret; } @@ -678,13 +681,14 @@ tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symb /*! \brief Emit 2 calls to the library kernel and the rest of the function. */ class SplitMutator : public ExprMutator { public: - SplitMutator(const tvm::IRModule& mod, Array patterns, FCodegen fcodegen) + SplitMutator(const tvm::IRModule& mod, ffi::Array patterns, FCodegen fcodegen) : ExprMutator(mod), mod_(mod), patterns_(patterns), fcodegen_(fcodegen) {} - static IRModule Transform(const IRModule& mod, Array patterns, FCodegen fcodegen) { + static IRModule Transform(const IRModule& mod, ffi::Array patterns, + FCodegen fcodegen) { SplitMutator mutator(mod, patterns, fcodegen); for (auto& kv : mod->functions) { if (auto* func = kv.second.as()) { - Function new_func = Downcast(mutator(GetRef(func))); + Function new_func = Downcast(mutator(ffi::GetRef(func))); mutator.builder_->UpdateFunction(kv.first, new_func); } } @@ -694,7 +698,7 @@ class SplitMutator : public ExprMutator { private: using ExprMutator::VisitExpr_; - inline Array GetCallTIRArgs(Expr args) { + inline ffi::Array GetCallTIRArgs(Expr args) { if (args.as()) { return args.as()->fields; } else { @@ -710,22 +714,22 @@ class SplitMutator : public ExprMutator { // the first argument is the function to be called const auto* gv_ptr = call->args[0].as(); if (gv_ptr == nullptr) return call; - GlobalVar gv = GetRef(gv_ptr); + GlobalVar gv = ffi::GetRef(gv_ptr); // retrieve the function from the module and split it tir::PrimFunc func = Downcast(mod_->Lookup(gv)); std::vector> arg_partition; // split the function into two functions, one for the library kernel and one for the rest. - std::pair> split_funcs = + std::pair> split_funcs = tir::SplitFunctions(func, &arg_partition, patterns_, fcodegen_); if (!split_funcs.second.defined()) { // no need to split, the function itself a library kernel tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(), gv->name_hint); - if (lib_func->IsInstance()) return GetRef(op); + if (lib_func->IsInstance()) return ffi::GetRef(op); // Update the function in the module with the library kernel ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); // emit the call to the library kernel - ObjectPtr new_call = make_object(*call.operator->()); + ObjectPtr new_call = ffi::make_object(*call.operator->()); new_call->op = this->call_dps_packed_; new_call->args = {lib_func, call->args[1]}; return Call(new_call); @@ -734,13 +738,13 @@ class SplitMutator : public ExprMutator { tir::PrimFunc func2 = tir::RenewDefs(split_funcs.second.value()); ICHECK(arg_partition.size() == 2); // emit the first call to the library kernel - Array args1; + ffi::Array args1; for (int p : arg_partition[0]) { args1.push_back(GetCallTIRArgs(call->args[1])[p]); } // replace the function in the module with the library kernel tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint); - if (lib_func->IsInstance()) return GetRef(op); + if (lib_func->IsInstance()) return ffi::GetRef(op); ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); tir::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); @@ -749,7 +753,7 @@ class SplitMutator : public ExprMutator { {TensorStructInfo(ShapeExpr(intermediate_buffer->shape), dtype)}); Var call_var1 = builder_->Emit(call1); // emit the second call to the rest of the function - Array args2; + ffi::Array args2; args2.push_back(call_var1); for (int p : arg_partition[1]) { args2.push_back(GetCallTIRArgs(call->args[1])[p]); @@ -762,12 +766,12 @@ class SplitMutator : public ExprMutator { const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed"); tvm::IRModule mod_; - Array patterns_; + ffi::Array patterns_; FCodegen fcodegen_; }; namespace transform { -Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { +Pass SplitCallTIRByPattern(ffi::Array patterns, FCodegen fcodegen) { auto pass_func = // [=](IRModule m, PassContext pc) { return SplitMutator::Transform(m, patterns, fcodegen); }; return CreateModulePass(/*pass_function=*/pass_func, // @@ -775,10 +779,10 @@ Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { /*pass_name=*/"SplitCallTIRByPattern", // /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.SplitCallTIRByPattern", SplitCallTIRByPattern); -}); +} } // namespace transform diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 3fa9d52147d3..1da49c1d7de3 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -35,7 +35,7 @@ namespace tir { class SplitPrimFuncLayoutRewrite : public StmtMutator { public: explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} - std::tuple, PrimFunc> Transform(const PrimFunc& func) { + std::tuple, PrimFunc> Transform(const PrimFunc& func) { ICHECK(func->body.as()) << "The body of the primfunc should be a root block."; const auto& block = func->body.as()->block; visit_root_block(block.get()); @@ -58,8 +58,8 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; // Step 2: Create the params for the new PrimFunc - Array params; - Map buffer_map; + ffi::Array params; + ffi::Map buffer_map; for (const auto& info : rewrite_infos_) { params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); @@ -76,16 +76,16 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] : SeqStmt(layout_rewrite_preproc_stmts_); body = BlockRealize( - /*iter_values=*/Array(), + /*iter_values=*/ffi::Array(), /*predicate=*/const_true(), /*block=*/ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"root", body)); - Map dict; + ffi::Map dict; for (const auto& [key, original_value] : original_func_->attrs->dict) { if (key == "global_symbol") { - dict.Set(key, Downcast(original_value) + "_weight_prepack"); + dict.Set(key, Downcast(original_value) + "_weight_prepack"); } else if (key != "layout_free_buffers") { dict.Set(key, original_value); } @@ -98,8 +98,8 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { PrimFunc create_compute_func() const { // Step 1: Create the params for the new PrimFunc - Array params = original_func_->params; - Map buffer_map = original_func_->buffer_map; + ffi::Array params = original_func_->params; + ffi::Map buffer_map = original_func_->buffer_map; for (const auto& info : rewrite_infos_) { const Var& param = params[info.buffer_index]; ICHECK(buffer_map[param] == info.pre_rewrite_buffer); @@ -109,7 +109,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { // Step 2: Create the body for the new PrimFunc Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : SeqStmt(compute_stmts_); Block original_block = original_func_->body.as()->block; - Array alloc_buffers; + ffi::Array alloc_buffers; for (const auto& buffer : original_block->alloc_buffers) { auto it = std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(), @@ -120,7 +120,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { } body = BlockRealize( - /*iter_values=*/Array(), + /*iter_values=*/ffi::Array(), /*predicate=*/const_true(), /*block=*/ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, @@ -128,10 +128,10 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { /*init=*/std::nullopt, /*alloc_buffers=*/alloc_buffers)); - Map dict; + ffi::Map dict; for (const auto& [key, original_value] : original_func_->attrs->dict) { if (key == "global_symbol") { - dict.Set(key, Downcast(original_value) + "_prepacked"); + dict.Set(key, Downcast(original_value) + "_prepacked"); } else if (key != "layout_free_buffers") { dict.Set(key, original_value); } @@ -199,7 +199,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { auto new_annotations = op->annotations; new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc); - auto n = make_object(*block.get()); + auto n = ffi::make_object(*block.get()); n->annotations = new_annotations; return Block(n); } @@ -216,9 +216,9 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { private: /*! \brief The stmts that are used for layout rewrite preproc*/ - Array layout_rewrite_preproc_stmts_; + ffi::Array layout_rewrite_preproc_stmts_; /*! \brief The stmts that are other than layout rewrite preproc*/ - Array compute_stmts_; + ffi::Array compute_stmts_; /*! \brief Whether the current subtree is a layout rewrite preproc subtree. -1: visited a non-layout rewrite preproc block @@ -290,9 +290,9 @@ class SplitLayoutRewritePreproc : public ExprMutator { const auto& rewrite_infos = rewrite_infos_it->second; // Step 5: Emit the preproc call - Array call_tir_args = Downcast(call->args[1])->fields; - Array preproc_args; - Array preproc_sinfo_list; + ffi::Array call_tir_args = Downcast(call->args[1])->fields; + ffi::Array preproc_args; + ffi::Array preproc_sinfo_list; for (const auto& info : rewrite_infos) { preproc_args.push_back(call_tir_args[info.buffer_index]); tir::Buffer rewritten_buffer = info.post_rewrite_buffer; @@ -341,9 +341,9 @@ Pass SplitLayoutRewritePreproc() { return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, "SplitLayoutRewritePreproc"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.SplitLayoutRewritePreproc", SplitLayoutRewritePreproc); -}); +} } // namespace transform } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index f2e185ebd2d4..fc3c2259ff9a 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -119,8 +119,8 @@ class StorageTokenNode : public Object { } } - static constexpr const char* _type_key = "relax.transform.StorageToken"; - TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("relax.transform.StorageToken", StorageTokenNode, Object); }; /*! @@ -129,7 +129,7 @@ class StorageTokenNode : public Object { */ class StorageToken : public ObjectRef { public: - explicit StorageToken(Array shape, DataType dtype, std::string storage_scope) { + explicit StorageToken(ffi::Array shape, DataType dtype, std::string storage_scope) { // Compute the tensor size from the shape. int64_t const_coeff = dtype.bytes() * dtype.lanes(); PrimExpr size = tir::make_const(DataType::Int(64), 1); @@ -142,13 +142,13 @@ class StorageToken : public ObjectRef { } size = tir::make_const(DataType::Int(64), const_coeff) * size; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->bytes = size; n->dtype = dtype; n->storage_scope = std::move(storage_scope); data_ = std::move(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(StorageToken, ObjectRef, StorageTokenNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StorageToken, ObjectRef, StorageTokenNode); }; // We use NestedMsg to store the tokens used by each Expr. @@ -170,7 +170,7 @@ class TokenAllocator1D { * \return The request result token. Return std::nullopt if there is no * appropriate available token in the pool. */ - Optional RequestReuse(StorageToken prototype) { + ffi::Optional RequestReuse(StorageToken prototype) { // Step 0. Sanity check: the prototype token is supposed not to be allocated with actual storage ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; // If the prototype has no reference at all, feel free to allocate new storage. @@ -326,7 +326,7 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { } void VisitExpr_(const TupleNode* tuple) final { - Array tokens; + ffi::Array tokens; tokens.reserve(tuple->fields.size()); for (const Expr& field : tuple->fields) { Tokens field_tokens = GetTokens(field); @@ -343,7 +343,7 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { return; } ICHECK(tokens.IsNested()); - Array field_tokens = tokens.NestedArray(); + ffi::Array field_tokens = tokens.NestedArray(); ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); ICHECK_GE(tuple_item->index, 0); SetTokens(tuple_item, field_tokens[tuple_item->index]); @@ -365,38 +365,52 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { }; /*! - * \brief Set the upper bound of the TIR variables that appear in + * \brief Set the range constraints of the TIR variables that appear in * the input function signature in the analyzer. * \param func The function to be analyzed. * \param ana The analyzer which contains the TIR var upper bounds. * \param dom_map The domain map of the TIR variables. */ -void SetTIRVarUpperBound(Function func, arith::Analyzer* ana, - Map* dom_map) { - // Use the attribute-annotated TIR var upper bounds as the TIR var values for +void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, + ffi::Map* dom_map) { + // Use the attribute-annotated TIR var bounds as the TIR var values for // memory planning. - // NOTE: we only apply the annotated upper bounds to the TIR variables that + // NOTE: we only apply the annotated bounds to the TIR variables that // appear in the **function signature**. - Map var_upper_bound_attr_raw = - func->GetAttr>("tir_var_upper_bound").value_or(Map()); - Array non_negative_var_attr_raw = - func->GetAttr>("tir_non_negative_var").value_or(Array()); - std::unordered_map var_upper_bound_attr; - std::unordered_set non_negative_var_attr; + ffi::Map var_upper_bound_attr_raw = + func->GetAttr>("tir_var_upper_bound") + .value_or(ffi::Map()); + ffi::Map var_lower_bound_attr_raw = + func->GetAttr>("tir_var_lower_bound") + .value_or(ffi::Map()); + ffi::Array non_negative_var_attr_raw = + func->GetAttr>("tir_non_negative_var") + .value_or(ffi::Array()); + std::unordered_map var_upper_bound_attr; + std::unordered_map var_lower_bound_attr; + std::unordered_set non_negative_var_attr; // We manually check the value type to ensure the values are all positive IntImm. for (auto [key, value] : var_upper_bound_attr_raw) { var_upper_bound_attr[key] = value; } - for (const String& var_name : non_negative_var_attr_raw) { + for (auto [key, value] : var_lower_bound_attr_raw) { + var_lower_bound_attr[key] = value; + } + for (const ffi::String& var_name : non_negative_var_attr_raw) { non_negative_var_attr.insert(var_name); } - Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); + ffi::Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); for (const tir::Var& tir_var : var_in_signature) { - auto it = var_upper_bound_attr.find(tir_var->name_hint); - if (it != var_upper_bound_attr.end()) { - tvm::Range range = - tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0), - tvm::IntImm(DataType::Int(64), (*it).second->value + 1)); + auto it_upper = var_upper_bound_attr.find(tir_var->name_hint); + auto it_lower = var_lower_bound_attr.find(tir_var->name_hint); + + if (it_upper != var_upper_bound_attr.end() || it_lower != var_lower_bound_attr.end()) { + int64_t lower = (it_lower != var_lower_bound_attr.end()) ? it_lower->second->value : 0; + int64_t upper = (it_upper != var_upper_bound_attr.end()) + ? it_upper->second->value + : std::numeric_limits::max(); + tvm::Range range = tvm::Range::FromMinExtent( + tvm::IntImm(DataType::Int(64), lower), tvm::IntImm(DataType::Int(64), upper - lower + 1)); ana->Bind(tir_var, range); dom_map->Set(tir_var, arith::IntSet::FromRange(range)); } else if (non_negative_var_attr.count(tir_var->name_hint)) { @@ -414,10 +428,10 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana, * \return The upper-bounded shape. When a dimension's upper bound * cannot be determined, we keep the dimension unchanged. */ -Array GetUpperBoundShape(Array shape, arith::Analyzer* ana, - const Map& dom_map) { +ffi::Array GetUpperBoundShape(ffi::Array shape, arith::Analyzer* ana, + const ffi::Map& dom_map) { // Use the upper bounds of TIR vars as their values. - Array upper_bounded_shape; + ffi::Array upper_bounded_shape; upper_bounded_shape.reserve(shape.size()); for (const PrimExpr& dim_len : shape) { int64_t max_bound = ana->const_int_bound(dim_len)->max_value; @@ -436,7 +450,7 @@ Array GetUpperBoundShape(Array shape, arith::Analyzer* ana, } /*! \brief Check if a shape is static (a.k.a., has no TIR variable). */ -bool IsStaticShape(Array shape) { +bool IsStaticShape(ffi::Array shape) { for (const PrimExpr& dim : shape) { const auto* int_len = dim.as(); if (!int_len) { @@ -471,7 +485,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { if (func == nullptr) { continue; } - initializer(GetRef(func)); + initializer(ffi::GetRef(func)); } return initializer.token_map_; } @@ -483,8 +497,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { : ctx_mod_(ctx_mod), analyzer_(analyzer) {} void VisitExpr_(const FunctionNode* func) final { - // Set the upper bound of TIR variables in the analyzer. - SetTIRVarUpperBound(GetRef(func), analyzer_, &dom_map_); + // Set the range constraints of TIR variables in the analyzer. + SetTIRVarRangeConstraints(ffi::GetRef(func), analyzer_, &dom_map_); // Recurse into the function to get its tokens. Tokens body_tokens = GetTokens(func->body); // Discard the tokens used by the function return value, as they are external referenced. @@ -513,7 +527,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // potential external reference. if (IsPrimFuncGlobalVar(call->op) || call->op->IsInstance() || call->op == call_tir_dyn_op) { - Array args = + ffi::Array args = call->op == call_tir_dyn_op ? Downcast(call->args[1])->fields : call->args; ICHECK(!block_stack_.empty()); for (const Expr& arg : call->args) { @@ -559,7 +573,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { if (global_var == nullptr) { return false; } - auto func_it = ctx_mod_->functions.find(GetRef(global_var)); + auto func_it = ctx_mod_->functions.find(ffi::GetRef(global_var)); if (func_it == ctx_mod_->functions.end()) { return false; } @@ -587,7 +601,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic // if the upper bounds of some variables are not provided. - Array upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_, dom_map_); + ffi::Array upper_bounded_shape = + GetUpperBoundShape(shape->values, analyzer_, dom_map_); // Create and set token. StringImm storage_scope = Downcast(call->args[3]); @@ -664,7 +679,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { /*! \brief The arithmetic analyzer. */ arith::Analyzer* analyzer_; /*! \brief The domain map of dynamic TIR variables for analysis. */ - Map dom_map_; + ffi::Map dom_map_; /*! \brief The mapping from each token to the binding block where it is created. */ std::unordered_map token2block_; /*! \brief The mapping from each token to the Exprs that are using this token. */ @@ -780,7 +795,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { /*! \brief Request a storage reuse, or allocate storage if no appropriate storage is reusable. */ StorageToken RequestReuseOrAlloc(StorageToken prototype) { - Optional token = allocator_.RequestReuse(prototype); + ffi::Optional token = allocator_.RequestReuse(prototype); if (!token.defined()) { return allocator_.Alloc(prototype, this->n_storage_++); } else { @@ -840,7 +855,7 @@ class StorageAllocationRewriter : public ExprMutator { plan_dynamic_output_ = static_cast( func_->GetAttr(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value); if (plan_dynamic_output_) { - SetTIRVarUpperBound(GetRef(func_), &ana_, &dom_map_); + SetTIRVarRangeConstraints(ffi::GetRef(func_), &ana_, &dom_map_); } token2storage_var_.clear(); Function func = Downcast(this->VisitExpr_(func_)); @@ -903,7 +918,7 @@ class StorageAllocationRewriter : public ExprMutator { ICHECK_NOTNULL(sinfo); const auto* shape = sinfo->shape.as(); ICHECK_NOTNULL(shape); - Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_); + ffi::Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_); if (!IsStaticShape(shape->values)) { ICHECK(!sinfo->IsUnknownDtype()); ICHECK_EQ(sinfo->dtype, Downcast(call->args[1])->value); @@ -920,7 +935,7 @@ class StorageAllocationRewriter : public ExprMutator { Var storage = builder_->Emit(alloc_storage, "storage"); return Call(mem_alloc_tensor, {storage, // /*offset=*/PrimValue::Int64(0), - /*shape=*/GetRef(shape), // + /*shape=*/ffi::GetRef(shape), // /*dtype=*/DataTypeImm(sinfo->dtype)}); } } @@ -931,7 +946,7 @@ class StorageAllocationRewriter : public ExprMutator { /*! \brief The arithmetic analyzer. */ arith::Analyzer ana_; /*! \brief The domain map of dynamic TIR variables for analysis. */ - Map dom_map_; + ffi::Map dom_map_; /*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */ bool plan_dynamic_output_; /*! @@ -970,10 +985,10 @@ Pass StaticPlanBlockMemory() { return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.StaticPlanBlockMemory", StaticPlanBlockMemory); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index 90b343faa628..66a148e593ca 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -44,7 +44,7 @@ int GetMixedPrecisionInfo(const CallNode* call_node) { if (op_node == nullptr) { return -1; } - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); auto attr_map = Op::GetAttrMap("TMixedPrecisionPolicy"); return attr_map.count(op) ? attr_map[op] : MixedPrecisionPolicyKind::kNever; } @@ -146,12 +146,12 @@ class DTypeDecisionCollector : public ExprVisitor { } // merge the message for all vars in the expr list - void RequireArgsToType(Array args, Array to) { + void RequireArgsToType(ffi::Array args, ffi::Array to) { ICHECK(args.size() == to.size()) << "Invalid target dtypes"; for (size_t i = 0; i < args.size(); ++i) { auto fvisitleaf = [&](const Expr& expr, NType to) { if (const auto* var = expr.as()) { - UpdateVarDTypeMap(GetRef(var), to); + UpdateVarDTypeMap(ffi::GetRef(var), to); } else if (expr->IsInstance()) { // Constant can be casted anyway, so we don't need to do anything here return; @@ -164,7 +164,7 @@ class DTypeDecisionCollector : public ExprVisitor { } // merge the message for all vars in the expr list - void RequireArgsToType(Array args, DataType to) { + void RequireArgsToType(ffi::Array args, DataType to) { std::vector arg_arr; std::vector to_arr; for (const Expr& arg : args) { @@ -178,7 +178,7 @@ class DTypeDecisionCollector : public ExprVisitor { } void VisitVars_(const VarNode* op) { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (IsNestedTensor(var)) { // require the var to be fp32 (its original dtype) UpdateVarDTypeMap(var, NTypeFrom(var, fp32_)); @@ -239,7 +239,7 @@ class DTypeDecisionCollector : public ExprVisitor { } if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -258,7 +258,7 @@ class DTypeDecisionCollector : public ExprVisitor { this->VisitExpr(op->cond); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -301,7 +301,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } } - Array RemapArgs(const Array& args) { + ffi::Array RemapArgs(const ffi::Array& args) { return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, var_remap_); }); } @@ -317,13 +317,13 @@ class ToMixedPrecisionRewriter : public ExprMutator { // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not // supported to be rewritten if (tensor->dtype != fp16_ && tensor->dtype != fp32_) return expr; - return astype(expr, DataType(StringToDLDataType(to[0].LeafValue()))); + return astype(expr, DataType(ffi::StringToDLDataType(to[0].LeafValue()))); }; - return TransformTupleLeaf(expr, std::array({to}), fvisitleaf); + return TransformTupleLeaf(expr, std::array({to}), fvisitleaf); } - Array RewriteArgs(const Array& args, DataType to) { - Array new_args; + ffi::Array RewriteArgs(const ffi::Array& args, DataType to) { + ffi::Array new_args; for (const Expr& arg : args) { if (IsNestedTensor(arg)) { new_args.push_back(RewriteExpr(arg, NTypeFrom(arg, to))); @@ -344,7 +344,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { return true; } - bool AllFP16Castable(const Array& args) { + bool AllFP16Castable(const ffi::Array& args) { auto is_fp16 = [](StructInfo sinfo) { if (auto tensor_sinfo = sinfo.as(); tensor_sinfo && tensor_sinfo->dtype == DataType::Float(16)) { @@ -413,11 +413,11 @@ class ToMixedPrecisionRewriter : public ExprMutator { auto it = only_fp16_map_->find(var); if (it == only_fp16_map_->end()) return; // Get the to dtype, cast to fp16 if the var is fp16 only, otherwise do nothing - auto fcombine = [](const String& from, const String& required) -> String { + auto fcombine = [](const ffi::String& from, const ffi::String& required) -> ffi::String { return required == "float16" ? required : from; }; NType from = NTypeFrom(cur_var); - NType to = CombineNestedMsg(from, it->second, fcombine); + NType to = CombineNestedMsg(from, it->second, fcombine); Expr rewrite = RewriteExpr(cur_var, to); // If cur_var is not rewritten, we don't need to emit a new var if (!rewrite.same_as(cur_var)) { @@ -439,7 +439,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (!builder_->CurrentBlockIsDataFlow()) { return ExprMutator::VisitExpr_(op); } - return VisitVar_(GetRef(op)); + return VisitVar_(ffi::GetRef(op)); } Var VisitVarDef(const Var& var) { return GetRemapped(var); } @@ -464,14 +464,14 @@ class ToMixedPrecisionRewriter : public ExprMutator { // var = Call(op) const auto* op_node = call_node->op.as(); ICHECK(op_node != nullptr); - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); if (wrap_param_op.same_as(op)) { // wrap_param ReEmitBinding(binding, call_node->args[0]); return; } - Call new_call = GetRef(call_node); + Call new_call = ffi::GetRef(call_node); // We first to remap the args to the current vars according to the var_remap_ new_call.CopyOnWrite()->args = RemapArgs(new_call->args); @@ -493,7 +493,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { // cast back to the original datatype. if (!new_call->args.same_as(call_node->args)) { - Array new_typed_args; + ffi::Array new_typed_args; for (size_t i = 0; i < call_node->args.size(); i++) { auto arg = new_call->args[i]; auto old_ntype = NTypeFrom(call_node->args[i]); @@ -532,7 +532,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { ExprMutator::VisitBinding_(binding, tuple_node); return; } - ObjectPtr new_tuple = make_object(*tuple_node); + ObjectPtr new_tuple = ffi::make_object(*tuple_node); new_tuple->fields = RemapArgs(tuple_node->fields); new_tuple->struct_info_ = std::nullopt; Expr new_value = builder_->Normalize(Tuple(new_tuple)); @@ -552,7 +552,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { return; } ObjectPtr new_tuple_get_item = - make_object(*tuple_get_item_node); + ffi::make_object(*tuple_get_item_node); new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0]; new_tuple_get_item->struct_info_ = std::nullopt; Expr new_value = TupleGetItem(new_tuple_get_item); @@ -593,14 +593,14 @@ class ToMixedPrecisionRewriter : public ExprMutator { DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); DataType output_dtype_; - Array params_; + ffi::Array params_; std::unordered_set fp16_input_names_; const Op& wrap_param_op = Op::Get("relax.wrap_param"); }; Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, - Optional> fp16_input_names) { + ffi::Optional> fp16_input_names) { VarDTypeMap only_fp16_map = DTypeDecisionCollector::Collect(f, out_dtype); std::unordered_set fp16_input_names_set; if (fp16_input_names) { @@ -612,17 +612,18 @@ Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, namespace transform { -Pass ToMixedPrecision(const DataType& out_dtype, Optional> fp16_input_names) { +Pass ToMixedPrecision(const DataType& out_dtype, + ffi::Optional> fp16_input_names) { auto pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(ToMixedPrecision(f, out_dtype, fp16_input_names)); }; return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ToMixedPrecision", ToMixedPrecision); -}); +} } // namespace transform diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index 5f87c4a6be72..b9345744320c 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -62,10 +62,10 @@ Pass ToNonDataflow() { return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.ToNonDataflow", ToNonDataflow); -}); +} } // namespace transform diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index c9f11b32bee7..114af668b980 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -149,7 +149,7 @@ class BindingOrderCollector : ExprVisitor { } void VisitExpr_(const VarNode* op) override { - Var upstream_requirement = GetRef(op); + Var upstream_requirement = ffi::GetRef(op); auto downstream_user = current_binding_; dependencies_.downstream_users[upstream_requirement].push_back(downstream_user); @@ -167,7 +167,7 @@ class TopologicalSorter : public ExprMutator { Expr VisitExpr_(const FunctionNode* op) override { auto cached = dependencies_; - dependencies_ = BindingOrderCollector::Collect(GetRef(op)); + dependencies_ = BindingOrderCollector::Collect(ffi::GetRef(op)); if (starting_location_ == StartingLocation::FromOutputs) { std::reverse(dependencies_.binding_order.begin(), dependencies_.binding_order.end()); @@ -184,7 +184,7 @@ class TopologicalSorter : public ExprMutator { } BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { - auto block = GetRef(op); + auto block = ffi::GetRef(op); // A map from not-yet-defined variables to the binding that will // define the variable. Items are removed from this map as they @@ -309,13 +309,13 @@ class TopologicalSorter : public ExprMutator { << "no bindings should remain to emit. " << "However, bindings " << [&]() { - Array arr; + ffi::Array arr; for (const auto& [var, binding] : to_emit) { arr.push_back(var); } return arr; }() << " still remain after emitting " - << Array(new_bindings.begin(), new_bindings.end()) + << ffi::Array(new_bindings.begin(), new_bindings.end()) .Map([](const Binding& binding) { return binding->var; }); if (starting_location_ == StartingLocation::FromOutputs) { @@ -343,10 +343,11 @@ Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) { return relax::transform::CreateFunctionPass(pass_func, 0, "TopologicalSort", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.transform.TopologicalSort", [](String order_str, String direction_str) -> Pass { + "relax.transform.TopologicalSort", + [](ffi::String order_str, ffi::String direction_str) -> Pass { TraversalOrder order = [&]() { if (order_str == "depth-first") { return TraversalOrder::DepthFirst; @@ -373,7 +374,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return TopologicalSort(order, starting_location); }); -}); +} } // namespace transform diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 85acec6942da..071e5bf4c991 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -40,14 +40,14 @@ namespace relax { namespace { class ParamStructInfoMutator : public ExprMutator { public: - explicit ParamStructInfoMutator(ffi::TypedFunction(Var)> sinfo_func) + explicit ParamStructInfoMutator(ffi::TypedFunction(Var)> sinfo_func) : sinfo_func_(sinfo_func) {} using ExprMutator::VisitExpr_; using ExprMutator::VisitVarDef_; Expr VisitExpr_(const FunctionNode* op) override { - auto func = GetRef(op); + auto func = ffi::GetRef(op); auto params = op->params.Map([this](Var param) { if (auto new_sinfo = sinfo_func_(param)) { @@ -65,12 +65,12 @@ class ParamStructInfoMutator : public ExprMutator { return ExprMutator::VisitExpr_(func.get()); } - ffi::TypedFunction(Var)> sinfo_func_; + ffi::TypedFunction(Var)> sinfo_func_; }; } // namespace namespace transform { -Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_func) { +Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_func) { auto pass_func = [=](IRModule mod, PassContext pc) { ParamStructInfoMutator mutator(sinfo_func); @@ -105,10 +105,10 @@ Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_f return tvm::transform::CreateModulePass(pass_func, 1, "UpdateParamStructInfo", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.UpdateParamStructInfo", UpdateParamStructInfo); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc index fc7d8941fe51..a6cbb83b8c73 100644 --- a/src/relax/transform/update_vdevice.cc +++ b/src/relax/transform/update_vdevice.cc @@ -35,7 +35,7 @@ class VDeviceMutator : public ExprMutator { public: VDeviceMutator(const IRModule& mod, VDevice new_vdevice, int64_t index) : ExprMutator(mod), mod_(mod), new_vdevice_(new_vdevice) { - Array vdevices = mod->global_infos["vdevice"]; + ffi::Array vdevices = mod->global_infos["vdevice"]; old_vdevice_ = Downcast(vdevices[index]); } @@ -74,7 +74,7 @@ class VDeviceMutator : public ExprMutator { builder_->UpdateFunction(gv, update_func); } } - Array new_vdevices; + ffi::Array new_vdevices; for (auto vdev : mod_->global_infos["vdevice"]) { if (vdev == old_vdevice_) { new_vdevices.push_back(new_vdevice_); @@ -107,10 +107,10 @@ Pass UpdateVDevice(VDevice new_vdevice, int64_t index) { /*pass_name=*/"UpdateVDevice", /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.UpdateVDevice", UpdateVDevice); -}); +} } // namespace transform } // namespace relax diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc index 19e93bbc0c0e..580b3892e57b 100644 --- a/src/relax/transform/utils.cc +++ b/src/relax/transform/utils.cc @@ -44,15 +44,15 @@ bool IsNestedTensor(const StructInfo& sinfo) { bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); } Function ComposeFunctions(Function func_a, Function func_b) { - Array bindings; + ffi::Array bindings; Var func_a_output("func_a_output", func_a->ret_struct_info); bindings.push_back(VarBinding(func_a_output, func_a->body)); - auto func_a_outputs = [&]() -> Array { + auto func_a_outputs = [&]() -> ffi::Array { if (auto func_a_output_tuple = func_a->ret_struct_info.as()) { - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < func_a_output_tuple->fields.size(); i++) { outputs.push_back(TupleGetItem(func_a_output, i)); } diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 009d00260781..91d75079f73d 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -84,8 +84,8 @@ class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor(vn))); - return memo_[GetRef(vn)]; + ICHECK(memo_.count(ffi::GetRef(vn))); + return memo_[ffi::GetRef(vn)]; } virtual OutputType VisitBinding_(const VarBindingNode* binding) { @@ -115,7 +115,7 @@ class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor entry_funcs); +TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, ffi::Array entry_funcs); /*! * \brief Get the external symbol of the Relax function name. @@ -124,7 +124,7 @@ TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, Array entry_fu * \return An external symbol. */ inline std::string GetExtSymbol(const Function& func) { - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.has_value()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } @@ -142,7 +142,7 @@ inline std::string GetExtSymbol(const Function& func) { */ IRModule MakeGroupedFunctions( IRModule mod, const std::unordered_map& partition, - bool lift_constants = true, const Array& entry_function_names = {}); + bool lift_constants = true, const ffi::Array& entry_function_names = {}); /*! * \brief Check if the given StructInfo is a scalar tensor. The sinfo should be an instance of @@ -172,7 +172,7 @@ bool IsScalarTensor(const Expr& expr); template bool IsNestedTensorConditioned(const StructInfo& sinfo, FType f_condition) { if (const auto* tensor_sinfo = sinfo.as()) { - return f_condition(GetRef(tensor_sinfo)); + return f_condition(ffi::GetRef(tensor_sinfo)); } else if (const auto* tuple_sinfo = sinfo.as()) { return !std::any_of( tuple_sinfo->fields.begin(), tuple_sinfo->fields.end(), @@ -209,7 +209,7 @@ class VarReplacer : public ExprMutator { private: Expr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = var_remap_.find(var->vid); return it == var_remap_.end() ? var : it->second; } @@ -241,19 +241,19 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { // 1. Visit and replace all tir::Vars at the definition point // 2. Revisit the function again and update the use side. PrimExpr VisitExpr_(const tir::VarNode* op) final { - auto it = var_map_.find(GetRef(op)); + auto it = var_map_.find(ffi::GetRef(op)); if (it != var_map_.end()) { return (*it).second; } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); tir::Var v(n); - var_map_.Set(GetRef(op), v); + var_map_.Set(ffi::GetRef(op), v); return v; } } Expr VisitExpr_(const FunctionNode* op) { - tvm::Array params; + tvm::ffi::Array params; bool all_params_unchanged = true; for (Var param : op->params) { Var new_param = this->VisitVarDef(param); @@ -267,14 +267,14 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { Expr body = this->VisitWithNewScope(op->body, params); if (all_params_unchanged && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto new_ret_sinfo = this->VisitExprDepStructInfoField(op->ret_struct_info); return Function(params, body, new_ret_sinfo, op->is_pure, op->attrs); } } - Map var_map_; + ffi::Map var_map_; }; /*! @@ -286,7 +286,7 @@ class FunctionCopier : public SymbolicVarRenewMutator { public: FunctionCopier() = default; Function Copy(Function func) { return Downcast(VisitExpr(func)); } - Map GetVarMap() { return relax_var_map_; } + ffi::Map GetVarMap() { return relax_var_map_; } private: using relax::ExprMutator::VisitExpr; @@ -295,7 +295,7 @@ class FunctionCopier : public SymbolicVarRenewMutator { Var new_var = SymbolicVarRenewMutator::VisitVarDef_(var); Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), new_var->span); var_remap_[var->vid] = copied_var; - relax_var_map_.Set(GetRef(var), copied_var); + relax_var_map_.Set(ffi::GetRef(var), copied_var); return copied_var; } @@ -303,11 +303,11 @@ class FunctionCopier : public SymbolicVarRenewMutator { Var new_var = SymbolicVarRenewMutator::VisitVarDef_(var); Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), new_var->span); var_remap_[var->vid] = copied_var; - relax_var_map_.Set(GetRef(var), copied_var); + relax_var_map_.Set(ffi::GetRef(var), copied_var); return copied_var; } - Map relax_var_map_; + ffi::Map relax_var_map_; }; /*! @@ -319,7 +319,7 @@ class FunctionCopier : public SymbolicVarRenewMutator { */ template inline Constant MakeConstantScalar(T value, DataType dtype) { - runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0}); + runtime::Tensor arr = runtime::Tensor::Empty({}, dtype, {kDLCPU, 0}); if (dtype == DataType::Float(32)) { *static_cast(arr->data) = static_cast(value); } else if (dtype == DataType::Float(64)) { @@ -328,7 +328,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) { *static_cast(arr->data) = static_cast(value); } else if (dtype == DataType::Int(64)) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(1)) { + } else if (dtype == DataType::Bool()) { *static_cast(arr->data) = static_cast(value); } else if (dtype == DataType::UInt(8)) { *static_cast(arr->data) = static_cast(value); @@ -360,7 +360,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) { return Constant(arr); } -inline Array GetOrderedPositiveAxes(const Array& axes, int ndim) { +inline ffi::Array GetOrderedPositiveAxes(const ffi::Array& axes, int ndim) { std::vector ret; ret.reserve(axes.size()); for (const auto& axis : axes) { @@ -376,7 +376,7 @@ inline Array GetOrderedPositiveAxes(const Array& axes, int ndi return support::AsArray(ret); } -inline String GetCodegenName(const std::string& composite_name) { +inline ffi::String GetCodegenName(const std::string& composite_name) { auto delim_pos = composite_name.find("."); ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite function should " "start with a compiler name followed by period."; @@ -384,9 +384,9 @@ inline String GetCodegenName(const std::string& composite_name) { } inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { - Array vdevices = mod->global_infos["vdevice"]; + ffi::Array vdevices = mod->global_infos["vdevice"]; for (int i = 0; i < static_cast(vdevices.size()); ++i) { - if (vdevices[i] == vdevice) { + if (vdevices[i].same_as(vdevice)) { return i; } } @@ -434,7 +434,8 @@ Expr CanonicalizeBindings(Expr expr); * * \ret The updated function. */ -Function BundleModelParams(const Function& func, Optional param_tuple_name = std::nullopt); +Function BundleModelParams(const Function& func, + ffi::Optional param_tuple_name = std::nullopt); /*! \brief Compose two functions * diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 92747a2515d5..37e53a614ff0 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -31,15 +31,15 @@ namespace relax { /*! \brief Helper to implement bind params.*/ class ExprBinder : public ExprMutator { public: - explicit ExprBinder(const tvm::Map& args_map, - const tvm::Map& symbolic_var_map) + explicit ExprBinder(const tvm::ffi::Map& args_map, + const tvm::ffi::Map& symbolic_var_map) : args_map_(args_map), symbolic_var_map_(symbolic_var_map) {} private: using ExprMutator::VisitExpr_; Expr VisitExpr_(const FunctionNode* op) final { - tvm::Array params; + tvm::ffi::Array params; bool all_params_unchanged = true; for (const Var& param : op->params) { if (args_map_.count(param)) { @@ -58,7 +58,7 @@ class ExprBinder : public ExprMutator { // FuncStructInfo does not depend on Expr if (all_params_unchanged && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { // purity won't be affected, no need to update annotation return Function(params, body, VisitExprDepStructInfoField(op->ret_struct_info), op->is_pure, @@ -67,7 +67,7 @@ class ExprBinder : public ExprMutator { } Expr VisitExpr_(const VarNode* op) final { - auto id = GetRef(op); + auto id = ffi::GetRef(op); auto it = args_map_.find(id); if (it != args_map_.end()) { return (*it).second; @@ -86,8 +86,8 @@ class ExprBinder : public ExprMutator { } private: - const tvm::Map& args_map_; - const tvm::Map& symbolic_var_map_; + const tvm::ffi::Map& args_map_; + const tvm::ffi::Map& symbolic_var_map_; }; /*! @@ -97,18 +97,19 @@ class ExprBinder : public ExprMutator { * \param symbolic_var_map The map from symbolic var to the expr it binds to * \return The result expr after bind params */ -Expr Bind(const Expr& expr, const tvm::Map& binds, - const tvm::Map& symbolic_var_map) { +Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, + const tvm::ffi::Map& symbolic_var_map) { return ExprBinder(binds, symbolic_var_map).VisitExpr(expr); } -StructInfo Bind(const StructInfo& sinfo, const tvm::Map& symbolic_var_map) { +StructInfo Bind(const StructInfo& sinfo, + const tvm::ffi::Map& symbolic_var_map) { return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo); } -tvm::Map InferSymbolicVarMap( - const tvm::Map& relax_var_remap, arith::Analyzer* analyzer) { - tvm::Map tir_var_remap; +tvm::ffi::Map InferSymbolicVarMap( + const tvm::ffi::Map& relax_var_remap, arith::Analyzer* analyzer) { + tvm::ffi::Map tir_var_remap; auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape, const PrimExpr& expr_shape) { @@ -218,7 +219,7 @@ bool IsLeafOrTuple(const Expr& expr) { bool IsImpureCall(const Call& call) { if (auto op_ptr = call->op.as()) { - auto op = GetRef(op_ptr); + auto op = ffi::GetRef(op_ptr); static auto purity_map = Op::GetAttrMap("FPurity"); ICHECK(purity_map.count(op)) << "Cannot find the registered purity of this op: " << op->name; return !(purity_map[op]->value); @@ -246,10 +247,10 @@ Expr GetBoundValue(const Binding& b) { */ Function CopyWithNewVars(Function func) { return FunctionCopier().Copy(func); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.CopyWithNewVars", CopyWithNewVars); -}); +} } // namespace relax } // namespace tvm diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 2c02fb556c73..918d55107793 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -19,7 +19,7 @@ /*! * \file src/runtime/const_loader_module.cc - * \brief A wrapper for initializing imported modules using constant NDArray. This + * \brief A wrapper for initializing imported modules using constant Tensor. This * module is intended to be used by various runtime in the TVM stack, i.e. * graph executor, relax VM, AOT runtime, and various user defined runtimes. It * paves the way to separate the code and metedata, which makes compilation @@ -34,7 +34,7 @@ #include #include #include -#include +#include #include @@ -48,9 +48,9 @@ namespace runtime { class ConstLoaderModuleObj : public ffi::ModuleObj { public: ConstLoaderModuleObj( - const std::unordered_map& const_var_ndarray, + const std::unordered_map& const_var_tensor, const std::unordered_map>& const_vars_by_symbol) - : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { + : const_var_tensor_(const_var_tensor), const_vars_by_symbol_(const_vars_by_symbol) { VLOG(1) << "Creating ConstLoaderModule"; // Only the related submodules are cached to reduce the number of runtime // symbol lookup for initialization. Otherwise, symbols/primitives in the @@ -59,7 +59,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { for (const auto& var : kv.second) { VLOG(1) << "ConstLoaderModuleNode has constant '" << var << "' for function '" << kv.first << "'"; - ICHECK_GT(const_var_ndarray_.count(var), 0) + ICHECK_GT(const_var_tensor_.count(var), 0) << "ConstLoaderModuleNode is missing entry for constant '" << var << "' for function '" << kv.first << "'"; } @@ -67,7 +67,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { } } - ffi::Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")"; // Initialize and memoize the module. // Usually, we have some warmup runs. The module initialization should be @@ -78,10 +78,10 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { } ObjectRef _self = ffi::GetRef(this); - if (name == "get_const_var_ndarray") { + if (name == "get_const_var_tensor") { return ffi::Function([_self, this](ffi::PackedArgs args, ffi::Any* rv) { - Map ret_map; - for (const auto& kv : const_var_ndarray_) { + ffi::Map ret_map; + for (const auto& kv : const_var_tensor_) { ret_map.Set(kv.first, kv.second); } *rv = ret_map; @@ -107,17 +107,17 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { /*! * \brief Get the list of constants that is required by the given module. * \param symbol The symbol that is being queried. - * \return The list of needed NDArray. + * \return The list of needed Tensor. */ - Array GetRequiredConstants(const std::string& symbol) { - Array ret; + ffi::Array GetRequiredConstants(const std::string& symbol) { + ffi::Array ret; ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No constants known for function '" << symbol << "'"; std::vector vars = const_vars_by_symbol_[symbol]; for (const auto& var : vars) { - ICHECK_GT(const_var_ndarray_.count(var), 0U) + ICHECK_GT(const_var_tensor_.count(var), 0U) << "No such constant variable '" << var << "' for function '" << symbol << "'"; - ret.push_back(const_var_ndarray_[var]); + ret.push_back(const_var_tensor_[var]); } return ret; } @@ -139,7 +139,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { for (const Any& it : this->imports_) { // Get the initialization function from the imported modules. std::string init_name = "__init_" + symbol; - Optional init = it.cast()->GetFunction(init_name, false); + ffi::Optional init = it.cast()->GetFunction(init_name, false); if (init.has_value()) { auto md = GetRequiredConstants(symbol); // Initialize the module with constants. @@ -157,20 +157,20 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { dmlc::Stream* stream = &ms; std::vector variables; - std::vector const_var_ndarray; - for (const auto& it : const_var_ndarray_) { - String var_name = it.first; + std::vector const_var_tensor; + for (const auto& it : const_var_tensor_) { + ffi::String var_name = it.first; variables.push_back(var_name); - const_var_ndarray.push_back(it.second); + const_var_tensor.push_back(it.second); } // Save all variables in the function. stream->Write(variables); // Save all constant data. - uint64_t sz = static_cast(const_var_ndarray.size()); + uint64_t sz = static_cast(const_var_tensor.size()); stream->Write(sz); for (uint64_t i = 0; i < sz; i++) { - const_var_ndarray[i].Save(stream); + const_var_tensor[i].Save(stream); } // Save the symbol to list of required constant variables mapping @@ -202,17 +202,17 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { ICHECK_EQ(static_cast(sz), variables.size()) << "The number of variables and ndarray counts must match"; // Load the list of ndarray. - std::vector arrays; + std::vector arrays; for (uint64_t i = 0; i < sz; i++) { - NDArray temp; + Tensor temp; temp.Load(stream); arrays.push_back(temp); } - std::unordered_map const_var_ndarray; + std::unordered_map const_var_tensor; for (uint64_t i = 0; i < sz; i++) { - ICHECK_EQ(const_var_ndarray.count(variables[i]), 0U); - const_var_ndarray[variables[i]] = arrays[i]; + ICHECK_EQ(const_var_tensor.count(variables[i]), 0U); + const_var_tensor[variables[i]] = arrays[i]; } // Load the symbol to list of required constant variables mapping @@ -232,7 +232,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { const_vars_by_symbol[symbols[i]] = const_vars[i]; } - auto n = make_object(const_var_ndarray, const_vars_by_symbol); + auto n = ffi::make_object(const_var_tensor, const_vars_by_symbol); return ffi::Module(n); } @@ -242,24 +242,24 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { * modules using execution engine. */ std::unordered_map initialized_; - /*! \brief Variable name to NDArray mapping. */ - std::unordered_map const_var_ndarray_; + /*! \brief Variable name to Tensor mapping. */ + std::unordered_map const_var_tensor_; /*! \brief Symbol name to required constant variables mapping. */ std::unordered_map> const_vars_by_symbol_; }; ffi::Module ConstLoaderModuleCreate( - const std::unordered_map& const_var_ndarray, + const std::unordered_map& const_var_tensor, const std::unordered_map>& const_vars_by_symbol) { - auto n = make_object(const_var_ndarray, const_vars_by_symbol); + auto n = ffi::make_object(const_var_tensor, const_vars_by_symbol); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ffi.Module.load_from_bytes.const_loader", ConstLoaderModuleObj::LoadFromBytes); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h index c093818763d8..30bddc7b377a 100644 --- a/src/runtime/const_loader_module.h +++ b/src/runtime/const_loader_module.h @@ -25,7 +25,7 @@ #ifndef TVM_RUNTIME_CONST_LOADER_MODULE_H_ #define TVM_RUNTIME_CONST_LOADER_MODULE_H_ -#include +#include #include #include @@ -37,14 +37,14 @@ namespace runtime { /*! * \brief Create a ConstLoader module object. * - * \param const_var_ndarray Maps consts var name to NDArray containing data for the var. + * \param const_var_tensor Maps consts var name to Tensor containing data for the var. * \param const_vars_by_symbol Maps the name of a module init function to a list of names of * const vars whose data will be passed to that init function. * * \return The created ConstLoaderModule. */ ffi::Module ConstLoaderModuleCreate( - const std::unordered_map& const_var_ndarray, + const std::unordered_map& const_var_tensor, const std::unordered_map>& const_vars_by_symbol); } // namespace runtime diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc index a38072dec1cd..4be9d57811b3 100644 --- a/src/runtime/contrib/amx/amx_config.cc +++ b/src/runtime/contrib/amx/amx_config.cc @@ -76,7 +76,7 @@ void init_tile_config(__tilecfg_u* dst, uint16_t cols, uint8_t rows) { _tile_loadconfig(dst->a); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("runtime.amx_tileconfig", [](ffi::PackedArgs args, ffi::Any* rv) { int rows = args[0].cast(); @@ -89,10 +89,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = 1; return; }); -}); +} // register a global packed function in c++,to init the system for AMX config -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("runtime.amx_init", [](ffi::PackedArgs args, ffi::Any* rv) { // -----------Detect and request for AMX control---------------------- @@ -134,7 +134,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = 1; return; }); -}); +} #endif } // namespace runtime diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 3de9e85a57c5..b090f0ccfbda 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include "../json/json_node.h" #include "../json/json_runtime.h" @@ -61,7 +61,7 @@ class ACLRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit ACLRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} /*! @@ -77,7 +77,7 @@ class ACLRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; SetupConstants(consts); @@ -588,19 +588,19 @@ class ACLRuntime : public JSONRuntimeBase { } #endif }; -ffi::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module ACLRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.arm_compute_lib_runtime_create", ACLRuntimeCreate) .def("ffi.Module.load_from_bytes.arm_compute_lib", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 9080eeb9bb34..499330cd0b5b 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -88,12 +88,12 @@ ThreadingConfig getDefaultThreadingConfig() { class BNNSJSONRuntime : public JSONRuntimeBase { public: BNNSJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} const char* kind() const override { return "bnns_json"; } - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; @@ -367,7 +367,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { dst_view.get_bnns_view()}; // BNNS limitation: MatMul use reverse dims values. However strides are calculated correctly - // based on BNNSNDArrayDescriptor::layout value. + // based on BNNSTensorDescriptor::layout value. std::reverse(layerParameters.iA_desc.size, layerParameters.iA_desc.size + 3); std::reverse(layerParameters.iB_desc.size, layerParameters.iB_desc.size + 3); std::reverse(layerParameters.o_desc.size, layerParameters.o_desc.size + 3); @@ -557,18 +557,18 @@ class BNNSJSONRuntime : public JSONRuntimeBase { std::vector tensors_eid_; }; -ffi::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module BNNSJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.BNNSJSONRuntimeCreate", BNNSJSONRuntimeCreate) .def("ffi.Module.load_from_bytes.bnns_json", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index f395561a7f6c..1997e0a84d71 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -62,7 +62,7 @@ class Tensor { auto rank = shape.size(); ICHECK(rank < BNNS_MAX_TENSOR_DIMENSION); - desc_ = {BNNSNDArrayFlags(0), + desc_ = {BNNSTensorFlags(0), getPlainLayout(rank), {}, // shape {}, // strides @@ -107,7 +107,7 @@ class Tensor { is_external_data = true; } - const BNNSNDArrayDescriptor& get_desc() const { return desc_; } + const BNNSTensorDescriptor& get_desc() const { return desc_; } static BNNSDataLayout getPlainLayout(size_t rank) { ICHECK(rank <= BNNS_MAX_TENSOR_DIMENSION); @@ -116,9 +116,9 @@ class Tensor { static size_t getRank(BNNSDataLayout layout) { return (layout & 0xF0000) >> 16; } - static size_t getRank(BNNSNDArrayDescriptor desc) { return getRank(desc.layout); } + static size_t getRank(BNNSTensorDescriptor desc) { return getRank(desc.layout); } - static size_t getSize(BNNSNDArrayDescriptor desc) { + static size_t getSize(BNNSTensorDescriptor desc) { auto rank = getRank(desc); return std::accumulate(desc.size, desc.size + rank, 1, std::multiplies()); } @@ -127,13 +127,13 @@ class Tensor { static size_t getElementSize(Dtype dtype) { return (dtype & 0xFFFF) / 8; } /** return size of element in bytes */ - static size_t getElementSize(const BNNSNDArrayDescriptor& desc) { + static size_t getElementSize(const BNNSTensorDescriptor& desc) { return getElementSize(desc.data_type); } private: bool is_external_data = false; - BNNSNDArrayDescriptor desc_; + BNNSTensorDescriptor desc_; }; using TensorPtr = std::shared_ptr; @@ -291,14 +291,14 @@ class TView { operator bool() const { return origin_ != nullptr; } /** Get BNNS descriptor for particular View. Batch and Party attributed are ignored. */ - const BNNSNDArrayDescriptor& get_bnns_view() const { return view_desc_; } + const BNNSTensorDescriptor& get_bnns_view() const { return view_desc_; } private: /** Original tensor object to view on */ TensorPtr origin_; /** Batched view parameters */ - BNNSNDArrayDescriptor view_desc_ = {}; + BNNSTensorDescriptor view_desc_ = {}; size_t batch_size_ = 1; size_t batch_stride_ = 0; diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 8d74ce855c31..85899b64f480 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -124,7 +124,7 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.cblas.matmul", @@ -157,6 +157,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); } }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/dnnl_blas.cc b/src/runtime/contrib/cblas/dnnl_blas.cc index 59400f19dd2f..9862a37301d3 100644 --- a/src/runtime/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/contrib/cblas/dnnl_blas.cc @@ -47,13 +47,13 @@ struct DNNLSgemmOp { }; // matrix multiplication for row major -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.dnnl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); CallGemm(args, ret, DNNLSgemmOp()); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index 19ce6ceb9b07..be8db227e554 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -155,7 +155,7 @@ struct MKLDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.mkl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); @@ -166,10 +166,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ else CallGemm(args, ret, MKLDgemmOp()); }); -}); +} // integer matrix multiplication for row major -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.mkl.matmul_u8s8s32", @@ -202,6 +202,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); } }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 9d13e427b24a..d1cf6b2808b0 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -149,7 +149,7 @@ class CLMLRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit CLMLRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), clml_symbol(symbol_name) {} ~CLMLRuntime() { @@ -201,7 +201,7 @@ class CLMLRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; SetupConstants(consts); @@ -270,7 +270,7 @@ class CLMLRuntime : public JSONRuntimeBase { "same by exporting CLML_DISABLE_RECORDABLE_QUEUE at runtime."; } cl_command_queue queue = CLML_QUEUE; - Map dump_tensors; + ffi::Map dump_tensors; std::ostringstream os; dmlc::JSONWriter writer(&os); writer.BeginObject(); @@ -293,7 +293,7 @@ class CLMLRuntime : public JSONRuntimeBase { // Dump tensor to CPU std::vector shape = node.GetOpShape()[0]; DLDataType tvm_dtype = node.GetOpDataType()[0]; - NDArray narr = NDArray::Empty(ffi::Shape(shape), tvm_dtype, {kDLCPU, 0}); + Tensor narr = Tensor::Empty(ffi::Shape(shape), tvm_dtype, {kDLCPU, 0}); CopyDataFromCLMLTensor(clml_desc, narr.operator->()->data); // Naming convention @@ -315,7 +315,7 @@ class CLMLRuntime : public JSONRuntimeBase { const auto f = tvm::ffi::Function::GetGlobal("runtime.SaveParams"); if (f.has_value()) { - std::string dump_bytes = (*f)(dump_tensors); + std::string dump_bytes = (*f)(dump_tensors).cast(); std::ostringstream oss; /*TODO(Siva) HEX encoding doubles the size, look for better encode that can cross the RPC. */ for (size_t i = 0; i < dump_bytes.size(); ++i) { @@ -349,12 +349,12 @@ class CLMLRuntime : public JSONRuntimeBase { evts.resize(evts.size() + 1); evt = &(evts.back()); } - std::unordered_map metrics; + std::unordered_map metrics; std::string shape_str; std::vector shape = nodes_[nid].GetOpShape()[0]; DLDataType tvm_dtype = nodes_[nid].GetOpDataType()[0]; shape_str.append(profiling::ShapeString(shape, tvm_dtype)); - metrics["Argument Shapes"] = String(shape_str); + metrics["Argument Shapes"] = ffi::String(shape_str); prof->StartCall("CopyIn", cws->tentry->device, metrics); CLML_CALL(clEnqueueCopyMLTensorDataQCOM, queue, layer_.in_placeholder[nid]->tensor, @@ -366,7 +366,7 @@ class CLMLRuntime : public JSONRuntimeBase { } for (size_t i = 0; i < this->layer_.function.size(); ++i) { - std::unordered_map metrics; + std::unordered_map metrics; auto node = this->layer_.op_node_map[this->layer_.function[i]].second; std::string shape_str; for (uint32_t j = 0; j < node.GetInputs().size(); ++j) { @@ -380,7 +380,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::vector shape = node.GetOpShape()[0]; DLDataType tvm_dtype = node.GetOpDataType()[0]; shape_str.append(profiling::ShapeString(shape, tvm_dtype)); - metrics["Argument Shapes"] = String(shape_str); + metrics["Argument Shapes"] = ffi::String(shape_str); // Launch call prof->StartCall(clml_symbol + "-" + this->layer_.layer_names[i], cws->tentry->device, @@ -407,12 +407,12 @@ class CLMLRuntime : public JSONRuntimeBase { evt = &(evts.back()); } - std::unordered_map metrics; + std::unordered_map metrics; std::string shape_str; std::vector shape = nodes_[eid].GetOpShape()[0]; DLDataType tvm_dtype = nodes_[eid].GetOpDataType()[0]; shape_str.append(profiling::ShapeString(shape, tvm_dtype)); - metrics["Argument Shapes"] = String(shape_str); + metrics["Argument Shapes"] = ffi::String(shape_str); prof->StartCall("CopyOut", cws->tentry->device, metrics); CLML_CALL(clEnqueueCopyMLTensorDataQCOM, queue, layer_.outputs[i]->tensor, @@ -466,7 +466,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); int dtype_size = cl_dtype == CL_FLOAT ? 4 : 2; void* tmpptr = reinterpret_cast(malloc(isize * dtype_size)); - TVMArrayCopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + Tensor::CopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), isize * dtype_size); CopyDataToCLMLTensor(layer_.inputs[nid], tmpptr); free(tmpptr); @@ -481,7 +481,7 @@ class CLMLRuntime : public JSONRuntimeBase { if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = tvm::ffi::Function::GetGlobal(std::string("profiling.timer.opencl")); - t = f->operator()(cws->tentry->device); + t = f->operator()(cws->tentry->device).cast(); t->Start(); queue = CLML_QUEUE; evts.resize(evts.size() + 1); @@ -502,7 +502,7 @@ class CLMLRuntime : public JSONRuntimeBase { if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = tvm::ffi::Function::GetGlobal(std::string("profiling.timer.opencl")); - t = f->operator()(cws->tentry->device); + t = f->operator()(cws->tentry->device).cast(); t->Start(); queue = CLML_QUEUE; evts.resize(evts.size() + 1); @@ -553,7 +553,7 @@ class CLMLRuntime : public JSONRuntimeBase { void* tmpptr = reinterpret_cast(malloc(osize * dtype_size)); CopyDataFromCLMLTensor(layer_.outputs[0], tmpptr); - TVMArrayCopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + Tensor::CopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), osize * dtype_size); free(tmpptr); } @@ -1826,18 +1826,18 @@ class CLMLRuntime : public JSONRuntimeBase { std::string clml_symbol; }; -ffi::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module CLMLRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.clml_runtime_create", CLMLRuntimeCreate) .def("ffi.Module.load_from_bytes.clml", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/clml/clml_runtime.h b/src/runtime/contrib/clml/clml_runtime.h index 4431b63cafcc..716ea4665ea4 100644 --- a/src/runtime/contrib/clml/clml_runtime.h +++ b/src/runtime/contrib/clml/clml_runtime.h @@ -33,8 +33,8 @@ #include #include #include -#include #include +#include #include #include @@ -253,11 +253,11 @@ struct CachedLayer { std::map> op_node_map; /* The input tensor map */ std::map> inputs; - /* A place holder Tensor representing TVM NDArray as CLML Tensor */ + /* A place holder Tensor representing TVM Tensor as CLML Tensor */ std::map> in_placeholder; /* The Output tensor map */ std::vector> outputs; - /* A place holder Tensor representing TVM NDArray as CLML Tensor */ + /* A place holder Tensor representing TVM Tensor as CLML Tensor */ std::vector> out_placeholder; /* Tensor shape exception list while returning from CLML Subgraph */ std::map> out_shapes; diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index 257b624bbf2b..9aa8cf839e4c 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include @@ -67,12 +67,12 @@ class CoreMLModel { */ void SetInput(const std::string& key, DLTensor* data_in); /*! - * \brief Return NDArray for given output index. + * \brief Return Tensor for given output index. * \param index The output index. * - * \return NDArray corresponding to given output node index. + * \return Tensor corresponding to given output node index. */ - NDArray GetOutput(int index) const; + Tensor GetOutput(int index) const; /*! * \brief Return the number of outputs * @@ -104,7 +104,7 @@ class CoreMLRuntime : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual Optional GetFunction(const String& name); + virtual ffi::Optional GetFunction(const ffi::String& name); /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 8e0b2542b443..c3ac6185d98f 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -60,14 +60,14 @@ MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil]; - ICHECK(data_in->strides == NULL); + ICHECK(ffi::IsContiguous(*data_in)); memcpy(dest.dataPointer, data_in->data, size); NSString* nsKey = [NSString stringWithUTF8String:key.c_str()]; [input_dict_ setObject:dest forKey:nsKey]; } -NDArray CoreMLModel::GetOutput(int index) const { +Tensor CoreMLModel::GetOutput(int index) const { MLModelDescription* model_desc = model_.modelDescription; NSString* metadata = [model_desc metadata][MLModelDescriptionKey]; NSData* data = [metadata dataUsingEncoding:NSUTF8StringEncoding]; @@ -103,7 +103,7 @@ .device_type = kDLCPU, .device_id = 0, }; - NDArray ret = NDArray::Empty(shape, dtype, cpu_dev); + Tensor ret = Tensor::Empty(shape, dtype, cpu_dev); ret.CopyFromBytes(src.dataPointer, size); return ret; @@ -129,7 +129,7 @@ model_ = std::unique_ptr(new CoreMLModel(url)); } -Optional CoreMLRuntime::GetFunction(const String& name) { +ffi::Optional CoreMLRuntime::GetFunction(const ffi::String& name) { // Return member functions during query. if (name == "invoke" || name == "run") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { model_->Invoke(); }); @@ -153,14 +153,13 @@ NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data options:NSJSONReadingAllowFragments error:nil]; - NSArray* input_names = json[@"inputs"]; + NSffi::Array* input_names = json[@"inputs"]; // Copy input tensors to corresponding data entries. for (auto i = 0; i < args.size() - 1; ++i) { - ICHECK(args[i].type_code() == kTVMDLTensorHandle || - args[i].type_code() == kTVMNDArrayHandle) - << "Expect NDArray or DLTensor as inputs\n"; - if (args[i].type_code() == kTVMDLTensorHandle || args[i].type_code() == kTVMNDArrayHandle) { + ICHECK(args[i].type_code() == kTVMDLTensorHandle || args[i].type_code() == kTVMTensorHandle) + << "Expect Tensor or DLTensor as inputs\n"; + if (args[i].type_code() == kTVMDLTensorHandle || args[i].type_code() == kTVMTensorHandle) { model_->SetInput([input_names[i] UTF8String], args[i]); } else { LOG(FATAL) << "Not implemented"; @@ -171,12 +170,12 @@ model_->Invoke(); // TODO: Support multiple outputs. - NDArray out = model_->GetOutput(0); + Tensor out = model_->GetOutput(0); if (args[args.size() - 1].type_code() == kTVMDLTensorHandle) { DLTensor* arg = args[args.size() - 1]; out.CopyTo(arg); } else { - NDArray arg = args[args.size() - 1]; + Tensor arg = args[args.size() - 1]; out.CopyTo(arg); } *rv = out; @@ -187,17 +186,17 @@ } ffi::Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_path) { - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(symbol, model_path); return ffi::Module(exec); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.coreml_runtime.create", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = CoreMLRuntimeCreate(args[0], args[1]); }); -}); +} ffi::Bytes CoreMLRuntime::SaveToBytes() const { std::string buffer; @@ -251,15 +250,15 @@ BOOL res = [dirWrapper writeToURL:url options:0 originalContentsURL:nil error:nil]; ICHECK(res) << "Failed to create model directory " << [model_path UTF8String]; - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(symbol, [model_path UTF8String]); return ffi::Module(exec); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ffi.Module.load_from_bytes.coreml", CoreMLRuntimeLoadFromBytes); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 13f958744e61..715172ecd8f9 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -516,7 +516,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t } // matrix multiplication for row major -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.cublas.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -541,10 +541,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ CallGemmEx(args, ret, entry_ptr->handle); } }); -}); +} #if CUDART_VERSION >= 10010 -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.cublaslt.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -558,14 +558,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, A->device.device_id)); + static_cast(TVMFFIEnvGetStream(kDLCUDA, A->device.device_id)); CallLtIgemm(args, ret, ltHandle, stream); CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); -}); +} #endif // CUDART_VERSION >= 10010 -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.cublas.batch_matmul", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -589,7 +589,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ CallBatchGemmEx(args, ret, entry_ptr->handle); } }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 0416391303ad..70521c1d7399 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -46,12 +46,12 @@ using namespace tvm::runtime::json; class CublasJSONRuntime : public JSONRuntimeBase { public: CublasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - void Init(const Array& consts) override {} + void Init(const ffi::Array& consts) override {} - ffi::Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since CublasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call cuBLAS on the inputs from ffi::PackedArgs. @@ -76,8 +76,8 @@ class CublasJSONRuntime : public JSONRuntimeBase { : EntryID(outputs_[i - input_var_eid_.size()]); const DLTensor* arg; - if (auto opt_nd = args[i].as()) { - NDArray arr = opt_nd.value(); + if (auto opt_nd = args[i].as()) { + Tensor arr = opt_nd.value(); arg = arr.operator->(); } else { arg = args[i].cast(); @@ -91,7 +91,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { CUDA_CALL(cudaGetDevice(&device_id)); } auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); @@ -153,19 +153,19 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -ffi::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module CublasJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.CublasJSONRuntimeCreate", CublasJSONRuntimeCreate) .def("ffi.Module.load_from_bytes.cublas_json", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 0ba654c9ebc8..f5248fde7e00 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -44,8 +44,8 @@ typedef dmlc::ThreadLocalStore CuBlasThreadStore; CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal(DLDevice curr_device) { CuBlasThreadEntry* retval = CuBlasThreadStore::Get(); - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id)); CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, stream)); return retval; } diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index 515263ef364e..d26f82645eaf 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -190,7 +190,7 @@ void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], c ret[0] = static_cast(best_algo); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.cudnn.conv2d.backward_data", @@ -269,7 +269,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, x_dim, dw_dim, data_dtype, conv_dtype, verbose, ret); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 7a93e194ce3c..6a5737c183b0 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -156,7 +156,7 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid ret[0] = static_cast(best_algo); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.cudnn.conv2d.forward", @@ -240,7 +240,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype, conv_dtype, verbose, ret); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h index ae11764ce02c..248d44d9d65f 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h @@ -69,12 +69,12 @@ class CuDNNSDPARunnerNode : public tvm::runtime::Object { class CuDNNSDPARunner : public tvm::runtime::ObjectRef { public: static CuDNNSDPARunner Create() { - auto n = make_object(); + auto n = ffi::make_object(); return CuDNNSDPARunner(n); } - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CuDNNSDPARunner, tvm::runtime::ObjectRef, - CuDNNSDPARunnerNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CuDNNSDPARunner, tvm::runtime::ObjectRef, + CuDNNSDPARunnerNode); }; } // namespace contrib diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 3888bca3df04..cefa2957b601 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -49,10 +49,10 @@ using namespace tvm::runtime::json; class cuDNNJSONRuntime : public JSONRuntimeBase { public: cuDNNJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { op_execs_.resize(nodes_.size()); // get some config from the graph for (size_t i = 0; i < nodes_.size(); ++i) { @@ -164,8 +164,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { std::function op_exec = [=]() { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream)); auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { @@ -238,19 +237,19 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { std::vector> op_execs_; }; -ffi::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module cuDNNJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.cuDNNJSONRuntimeCreate", cuDNNJSONRuntimeCreate) .def("ffi.Module.load_from_bytes.cudnn_json", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index acedf7a9e2dd..b0e3af9efb59 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -129,8 +129,8 @@ CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(Device curr_device, bool check_e ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED"; } - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id)); CUDNN_CALL(cudnnSetStream(res->handle, stream)); return res; } @@ -267,14 +267,14 @@ SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_des SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tvm.contrib.cudnn.exists", []() -> bool { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); return CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}, false)->exists(); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index eb2fceb3d2db..10df70670c70 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -79,7 +79,7 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* r entry_ptr->softmax_entry.shape_desc, y->data)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.cudnn.softmax.forward", @@ -89,7 +89,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("tvm.contrib.cudnn.log_softmax.forward", [](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/curand/curand.cc b/src/runtime/contrib/curand/curand.cc index 7a9f2d598827..53505770f83a 100644 --- a/src/runtime/contrib/curand/curand.cc +++ b/src/runtime/contrib/curand/curand.cc @@ -110,13 +110,13 @@ void RandomFill(DLTensor* tensor) { } else { LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype; } - TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr); + cuda_api->StreamSync(tensor->device, nullptr); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("runtime.contrib.curand.RandomFill", RandomFill); -}); +} } // namespace curand } // namespace runtime diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh index a09051a86e79..0527829c528d 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh @@ -19,8 +19,9 @@ #include #include +#include #include -#include +#include #include "cutlass/bfloat16.h" #include "cutlass/half.h" @@ -32,11 +33,12 @@ template struct CutlassGroupGemm; template -void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, - NDArray out) { +void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight, Tensor indptr, Tensor workspace, + Tensor out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); @@ -47,7 +49,6 @@ void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDAr int k = weight->shape[2]; float alpha = 1.0f; float beta = 0.0f; - cudaStream_t stream = static_cast(func().cast()); if (DataType(x->dtype) == DataType::Float(16)) { CHECK(DataType(weight->dtype) == DataType::Float(16)); diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu index 90802969c53e..0c9fe0fff14d 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu @@ -21,8 +21,8 @@ #include #include #include -#include #include +#include #include "fp16_group_gemm.cuh" #include "fp16_group_gemm_runner_sm100.cuh" @@ -42,15 +42,15 @@ struct CutlassGroupGemm<100, ElementA, ElementB, ElementC> { } }; -void tvm_cutlass_group_gemm_sm100(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, - NDArray out) { +void tvm_cutlass_group_gemm_sm100(Tensor x, Tensor weight, Tensor indptr, Tensor workspace, + Tensor out) { tvm_cutlass_group_gemm_impl<100>(x, weight, indptr, workspace, out); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("cutlass.group_gemm", tvm_cutlass_group_gemm_sm100); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu index 0b240b85a4f4..e78fb06322e2 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include "fp16_group_gemm.cuh" #include "fp16_group_gemm_runner_sm90.cuh" @@ -41,15 +41,15 @@ struct CutlassGroupGemm<90, ElementA, ElementB, ElementC> { } }; -void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, - NDArray out) { +void tvm_cutlass_group_gemm_sm90(Tensor x, Tensor weight, Tensor indptr, Tensor workspace, + Tensor out) { tvm_cutlass_group_gemm_impl<90>(x, weight, indptr, workspace, out); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("cutlass.group_gemm", tvm_cutlass_group_gemm_sm90); -}); +} #endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu index 5cabd0ca7af2..d41064efbaf0 100644 --- a/src/runtime/contrib/cutlass/fp8_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include "../cublas/cublas_utils.h" #include "gemm_runner.cuh" @@ -39,12 +39,10 @@ namespace tvm { namespace runtime { template -void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray alpha, - NDArray out) { +void tvm_cutlass_fp8_gemm(Tensor x, Tensor weight, Tensor workspace, Tensor alpha, Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_GE(x->ndim, 2); CHECK_EQ(weight->ndim, 2); @@ -78,7 +76,7 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("cutlass.gemm_e5m2_e5m2_fp16", @@ -87,7 +85,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ tvm_cutlass_fp8_gemm) .def("cutlass.gemm_e4m3_e4m3_fp16", tvm_cutlass_fp8_gemm); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index 150485b86822..b2e08b7570ab 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include "fp16_group_gemm_runner_sm90.cuh" @@ -42,12 +42,11 @@ namespace tvm { namespace runtime { template -void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, - NDArray alpha, NDArray out) { +void tvm_cutlass_fp8_group_gemm(Tensor x, Tensor weight, Tensor indptr, Tensor workspace, + Tensor alpha, Tensor out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); @@ -68,7 +67,7 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr static_cast(out->data), stream); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def( @@ -80,7 +79,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("cutlass.group_gemm_e4m3_e4m3_fp16", tvm_cutlass_fp8_group_gemm); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh index 0f688616d55e..35f08efbc57c 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include "cutlass/bfloat16.h" #include "cutlass/half.h" @@ -34,13 +34,13 @@ template -void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(NDArray a, NDArray b, NDArray scales_a, - NDArray scales_b, NDArray workspace, +void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scales_a, + Tensor scales_b, Tensor workspace, int64_t block_size_0, int64_t block_size_1, - NDArray out) { + Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_GE(a->ndim, 2); CHECK_EQ(scales_a->ndim, a->ndim); @@ -100,13 +100,13 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(NDArray a, NDArray b, NDArray sc } template -void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(NDArray a, NDArray b, NDArray scales_a, - NDArray scales_b, NDArray workspace, +void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales_a, + Tensor scales_b, Tensor workspace, int64_t block_size_0, int64_t block_size_1, - NDArray out) { + Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 3); CHECK_EQ(scales_a->ndim, 3); diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh index 95fc578fd43f..87cd8108f9ee 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh @@ -53,7 +53,7 @@ } using namespace cute; -using tvm::runtime::NDArray; +using tvm::runtime::Tensor; template struct CutlassFP8ScaledGroupwiseGemmRunnerSM100 { diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh index 5ec9ed083916..d5321d157c74 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh @@ -54,7 +54,7 @@ using namespace cute; using ProblemShape = Shape; -using tvm::runtime::NDArray; +using tvm::runtime::Tensor; template diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu index 7201604a7c85..e8035c172a3c 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu @@ -21,8 +21,8 @@ #include #include #include -#include #include +#include #include "../cublas/cublas_utils.h" #include "fp8_groupwise_scaled_gemm.cuh" @@ -47,34 +47,34 @@ struct CutlassFP8GroupwiseGemm<100, TileShape, ClusterShape, ElementA, ElementB, } }; -void tvm_cutlass_fp8_groupwise_scaled_gemm_sm100(NDArray a, NDArray b, NDArray scales_a, - NDArray scales_b, NDArray workspace, +void tvm_cutlass_fp8_groupwise_scaled_gemm_sm100(Tensor a, Tensor b, Tensor scales_a, + Tensor scales_b, Tensor workspace, int64_t block_size_0, int64_t block_size_1, - NDArray out) { + Tensor out) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; tvm_cutlass_fp8_groupwise_scaled_gemm_impl<100, TileShape, ClusterShape>( a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); } -void tvm_cutlass_fp8_groupwise_scaled_bmm_sm100(NDArray a, NDArray b, NDArray scales_a, - NDArray scales_b, NDArray workspace, +void tvm_cutlass_fp8_groupwise_scaled_bmm_sm100(Tensor a, Tensor b, Tensor scales_a, + Tensor scales_b, Tensor workspace, int64_t block_size_0, int64_t block_size_1, - NDArray out) { + Tensor out) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; tvm_cutlass_fp8_groupwise_scaled_bmm_impl<100, TileShape, ClusterShape>( a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn", tvm_cutlass_fp8_groupwise_scaled_gemm_sm100) .def("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn", tvm_cutlass_fp8_groupwise_scaled_bmm_sm100); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu index 8099d91419e5..3c326e314386 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu @@ -21,8 +21,8 @@ #include #include #include -#include #include +#include #include "../cublas/cublas_utils.h" #include "fp8_groupwise_scaled_gemm.cuh" @@ -47,33 +47,32 @@ struct CutlassFP8GroupwiseGemm<90, TileShape, ClusterShape, ElementA, ElementB, } }; -void tvm_cutlass_fp8_groupwise_scaled_gemm_sm90(NDArray a, NDArray b, NDArray scales_a, - NDArray scales_b, NDArray workspace, +void tvm_cutlass_fp8_groupwise_scaled_gemm_sm90(Tensor a, Tensor b, Tensor scales_a, + Tensor scales_b, Tensor workspace, int64_t block_size_0, int64_t block_size_1, - NDArray out) { + Tensor out) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; tvm_cutlass_fp8_groupwise_scaled_gemm_impl<90, TileShape, ClusterShape>( a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); } -void tvm_cutlass_fp8_groupwise_scaled_bmm_sm90(NDArray a, NDArray b, NDArray scales_a, - NDArray scales_b, NDArray workspace, - int64_t block_size_0, int64_t block_size_1, - NDArray out) { +void tvm_cutlass_fp8_groupwise_scaled_bmm_sm90(Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, + Tensor workspace, int64_t block_size_0, + int64_t block_size_1, Tensor out) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; tvm_cutlass_fp8_groupwise_scaled_bmm_impl<90, TileShape, ClusterShape>( a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn", tvm_cutlass_fp8_groupwise_scaled_gemm_sm90) .def("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn", tvm_cutlass_fp8_groupwise_scaled_bmm_sm90); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu index 2745c0b1fc03..4f5dd1e1c706 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -19,10 +19,11 @@ #include #include +#include #include #include -#include #include +#include #include "fp8_groupwise_scaled_group_gemm_runner_sm100.cuh" @@ -31,14 +32,13 @@ namespace tvm { namespace runtime { -void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, NDArray b, NDArray scales_a, - NDArray scales_b, NDArray indptr, NDArray workspace, +void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, + Tensor indptr, Tensor workspace, int64_t block_size_0, int64_t block_size_1, - NDArray out) { + Tensor out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommended size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 2); CHECK_EQ(b->ndim, 3); CHECK_EQ(indptr->ndim, 1); @@ -85,11 +85,11 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, NDArray b, NDArray sca } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn", tvm_fp8_groupwise_scaled_group_gemm_sm100); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc b/src/runtime/contrib/cutlass/weight_preprocess.cc index c403039c586a..56e2b39b8094 100644 --- a/src/runtime/contrib/cutlass/weight_preprocess.cc +++ b/src/runtime/contrib/cutlass/weight_preprocess.cc @@ -19,7 +19,7 @@ #include #include -#include +#include #include "cutlass_kernels/cutlass_preprocessors.h" @@ -35,9 +35,9 @@ namespace runtime { // black box. // // The preprocessing functions are defined in C++, so we need to copy the input weight to CPU. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("cutlass.ft_preprocess_weight", [](NDArray packed_weight, int sm, + refl::GlobalDef().def("cutlass.ft_preprocess_weight", [](Tensor packed_weight, int sm, bool is_int4) { bool is_2d = packed_weight->ndim == 2; int num_experts = is_2d ? 1 : packed_weight->shape[0]; @@ -54,11 +54,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ } fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), num_experts, rows, cols, is_int4, sm); - auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device); + auto out = Tensor::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device); out.CopyFromBytes(output_cpu.data(), output_cpu.size()); return out; }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 3ae84a782e47..972c61e9436e 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -349,7 +349,7 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ } // DNNL Conv2d single OP -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.dnnl.conv2d", [](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); @@ -383,7 +383,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, channel_last, pre_cast, post_cast); }); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 686a8048c7b5..f0c47e5639d2 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include @@ -51,7 +51,7 @@ using namespace tvm::runtime::json; class DNNLJSONRuntime : public JSONRuntimeBase { public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), next_unique_eid_offset_(data_entry_.size()), run_arg_eid_(input_var_eid_) { @@ -60,7 +60,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { const char* kind() const override { return "dnnl_json"; } - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; @@ -100,7 +100,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /* Override GetFunction to reimplement Run method */ - ffi::Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -821,7 +821,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { TensorRequisite res; if (const_dl_tensor) { ICHECK(const_dl_tensor->data); - ICHECK(const_dl_tensor->strides == nullptr); + ICHECK(ffi::IsContiguous(*const_dl_tensor)); auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data); res = TensorRequisite::AsIs(mem, eid); } else { @@ -923,18 +923,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::vector run_arg_eid_; }; -ffi::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module DNNLJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.DNNLJSONRuntimeCreate", DNNLJSONRuntimeCreate) .def("ffi.Module.load_from_bytes.dnnl_json", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index a52da2318b71..4e62659dd30e 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -64,16 +64,16 @@ void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, Device dev) { } ffi::Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(tflite_model_bytes, dev); return ffi::Module(exec); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.edgetpu_runtime.create", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = EdgeTPURuntimeCreate(args[0], args[1]); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index 628ffb5bdf8a..b1b264dea72a 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -408,7 +408,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t } // matrix multiplication for row major -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.hipblas.matmul", @@ -455,7 +455,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ CallBatchGemmEx(args, ret, entry_ptr->handle); } }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 08866fc1088a..45bfabc277cc 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -44,12 +44,12 @@ using namespace tvm::runtime::json; class HipblasJSONRuntime : public JSONRuntimeBase { public: HipblasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - void Init(const Array& consts) override {} + void Init(const ffi::Array& consts) override {} - ffi::Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since HipblasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call hipBLAS on the inputs from ffi::PackedArgs. @@ -75,8 +75,8 @@ class HipblasJSONRuntime : public JSONRuntimeBase { : EntryID(outputs_[i - input_var_eid_.size()]); const DLTensor* arg; - if (auto opt_nd = args[i].as()) { - NDArray arr = opt_nd.value(); + if (auto opt_nd = args[i].as()) { + Tensor arr = opt_nd.value(); arg = arr.operator->(); } else { arg = args[i].cast(); @@ -89,7 +89,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { ROCM_CALL(hipGetDevice(&device_id)); } auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(DLDevice{kDLROCM, device_id}); - hipStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + hipStream_t stream = static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); @@ -140,19 +140,19 @@ class HipblasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -ffi::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module HipblasJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.HipblasJSONRuntimeCreate", HipblasJSONRuntimeCreate) .def("ffi.Module.load_from_bytes.hipblas_json", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc index 1b61cbd38219..17ed9a0d936d 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -44,8 +44,7 @@ typedef dmlc::ThreadLocalStore HipBlasThreadStore; HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(DLDevice curr_device) { HipBlasThreadEntry* retval = HipBlasThreadStore::Get(); - TVMFFIStreamHandle stream = - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); + TVMFFIStreamHandle stream = TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id); CHECK_HIPBLAS_ERROR(hipblasSetStream(retval->handle, static_cast(stream))); return retval; } diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index d9e5af60f299..a8bb6c26083f 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -26,8 +26,8 @@ #define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ #include -#include #include +#include #include #include @@ -50,7 +50,7 @@ namespace json { class JSONRuntimeBase : public ffi::ModuleObj { public: JSONRuntimeBase(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : symbol_name_(symbol_name), graph_json_(graph_json), const_names_(const_names) { LoadGraph(graph_json_); } @@ -63,7 +63,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { } /*! \brief Initialize a specific json runtime. */ - virtual void Init(const Array& consts) = 0; + virtual void Init(const ffi::Array& consts) = 0; /*! \brief Invoke the execution engine to inteprete a specific json runtime. */ virtual void Run() = 0; @@ -93,7 +93,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( @@ -123,8 +123,8 @@ class JSONRuntimeBase : public ffi::ModuleObj { // Bind argument tensors to data entries. this->SetInputOutputBuffers(args); - if (auto opt_str = rv->try_cast()) { - String purpose = std::move(opt_str.value()); + if (auto opt_str = rv->try_cast()) { + ffi::String purpose = std::move(opt_str.value()); if ("debug_dump" == purpose) { *rv = this->DebugDump(); } @@ -133,7 +133,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { profiling::Profiler* prof = static_cast(rv->cast()); this->RunProfile(prof); } - // String vendor_prof = this->RunProfile(prof); + // ffi::String vendor_prof = this->RunProfile(prof); }); } else if ("__init_" + this->symbol_name_ == name) { // The function to initialize constant tensors. @@ -141,7 +141,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { ICHECK_EQ(args.size(), 1U); std::lock_guard guard(this->initialize_mutex_); if (!this->initialized_) { - this->Init(args[0].cast>()); + this->Init(args[0].cast>()); this->initialized_ = true; } *rv = 0; @@ -180,11 +180,11 @@ class JSONRuntimeBase : public ffi::ModuleObj { ICHECK(stream->Read(&symbol)) << "Loading symbol name failed"; ICHECK(stream->Read(&graph_json)) << "Loading graph json failed"; ICHECK(stream->Read(&consts)) << "Loading the const name list failed"; - Array const_names; + ffi::Array const_names; for (const auto& it : consts) { const_names.push_back(it); } - auto n = make_object(symbol, graph_json, const_names); + auto n = ffi::make_object(symbol, graph_json, const_names); return ffi::Module(n); } @@ -194,7 +194,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \param format the format to return. * \return A string of JSON. */ - String InspectSource(const String& format) const override { return graph_json_; } + ffi::String InspectSource(const ffi::String& format) const override { return graph_json_; } protected: /*! @@ -212,14 +212,14 @@ class JSONRuntimeBase : public ffi::ModuleObj { : EntryID(outputs_[i - input_var_eid_.size()]); const DLTensor* arg; - if (auto opt_nd = args[i].as()) { - NDArray arr = opt_nd.value(); + if (auto opt_nd = args[i].as()) { + Tensor arr = opt_nd.value(); arg = arr.operator->(); } else { arg = args[i].cast(); } - // Assign input/output the NDArray pointers to data entry so that we can directly + // Assign input/output the Tensor pointers to data entry so that we can directly // read/write host buffers. data_entry_[eid] = arg; } @@ -268,9 +268,9 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \brief Set up the constants/weights for inference by binding their DLTensor pointer to * the corresponding data entry. * - * \param consts A list of constant NDArray to be used. + * \param consts A list of constant Tensor to be used. */ - void SetupConstants(const Array& consts) { + void SetupConstants(const ffi::Array& consts) { for (size_t i = 0; i < consts.size(); ++i) { data_entry_[EntryID(const_idx_[i], 0)] = consts[i].operator->(); } @@ -313,7 +313,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { /*! \brief The graph. */ std::string graph_json_; /*! \brief The required constant names. */ - Array const_names_; + ffi::Array const_names_; /*! \brief The json graph nodes. */ std::vector nodes_; /*! \brief The input nodes, including variables and constants. */ diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index 2c8a70aa6b34..620706250967 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -35,7 +35,7 @@ namespace miopen { using namespace runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed( @@ -226,7 +226,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ entry_ptr->conv_entry.fwd_algo, &beta, entry_ptr->conv_entry.output_desc, y->data, entry_ptr->conv_entry.workspace, entry_ptr->conv_entry.workspace_size)); }); -}); +} } // namespace miopen } // namespace contrib diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index e860ba8ea7f2..617ea5aaf027 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -56,8 +56,7 @@ typedef dmlc::ThreadLocalStore MIOpenThreadStore; MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal(Device curr_device) { // Need to update stream per fetch to avoid stream switching MIOpenThreadEntry* res = MIOpenThreadStore::Get(); - TVMFFIStreamHandle stream = - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); + TVMFFIStreamHandle stream = TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id); MIOPEN_CALL(miopenSetStream(res->handle, stream)); return res; } diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc index 5853cb2a7b11..c5e467626ee8 100644 --- a/src/runtime/contrib/miopen/softmax.cc +++ b/src/runtime/contrib/miopen/softmax.cc @@ -80,7 +80,7 @@ void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, miopenSoftmaxAlgorithm_t entry_ptr->softmax_entry.shape_desc, y->data, alg, mode)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.miopen.softmax.forward", @@ -90,7 +90,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed( "tvm.contrib.miopen.log_softmax.forward", [](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); -}); +} } // namespace miopen } // namespace contrib diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index dfc98388d372..92da557160cb 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -25,7 +25,7 @@ using namespace runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.mps.buffer2img", @@ -91,9 +91,9 @@ ICHECK_EQ(data->ndim, 4); ICHECK_EQ(weight->ndim, 4); ICHECK_EQ(output->ndim, 4); - ICHECK(output->strides == nullptr); - ICHECK(weight->strides == nullptr); - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*output)); + ICHECK(ffi::IsContiguous(*weight)); + ICHECK(ffi::IsContiguous(*data)); ICHECK_EQ(data->shape[0], 1); ICHECK_EQ(output->shape[0], 1); @@ -161,7 +161,7 @@ (*f_img2buf)(&tmp_out, output); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index 9f5270f38fec..b78d8f7d6e51 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -25,7 +25,7 @@ using namespace runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.mps.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); @@ -37,9 +37,9 @@ ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); + ICHECK(ffi::IsContiguous(*C)); + ICHECK(ffi::IsContiguous(*B)); + ICHECK(ffi::IsContiguous(*A)); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); @@ -95,7 +95,7 @@ [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; [cb commit]; }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index bc1eb77ea18c..bfa2e1889b2e 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include @@ -212,7 +212,7 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual Optional GetFunction(const String& name) { + virtual ffi::Optional GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( @@ -226,8 +226,9 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { use_dpdk_cb = true; }); } else if (name == "get_const_vars") { - return ffi::Function( - [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = Array{}; }); + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::Array{}; + }); } else if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { RunInference(args); @@ -274,8 +275,8 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { ICHECK(stream->Read(&num_inputs)) << "Loading num_inputs failed"; ICHECK(stream->Read(&num_outputs)) << "Loading num_outputs failed"; ICHECK(stream->Read(&batch_size)) << "Loading batch_size failed"; - auto n = make_object(symbol_name, nodes_json, bin_code, num_inputs, - num_outputs, batch_size); + auto n = ffi::make_object(symbol_name, nodes_json, bin_code, + num_inputs, num_outputs, batch_size); return ffi::Module(n); } @@ -285,7 +286,7 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { * \param format the format to return. * \return A string of JSON. */ - String InspectSource(const String& format) const override { return nodes_json_; } + ffi::String InspectSource(const ffi::String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -309,8 +310,8 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { i_d_buf_float = reinterpret_cast(i_d_buf); for (int in = 0; in < num_inputs_; in++) { - if (args[in].IsObjectRef()) { - NDArray arr = args[in]; + if (args[in].IsObjectRef()) { + Tensor arr = args[in]; tensor = arr.operator->(); } else { tensor = args[in].operator DLTensor*(); @@ -345,8 +346,8 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { int out = num_inputs_; if (num_outputs_ == 1) { - if (args[out].IsObjectRef()) { - NDArray arr = args[out]; + if (args[out].IsObjectRef()) { + Tensor arr = args[out]; outTensor = arr.operator->(); } else { outTensor = args[out].operator DLTensor*(); @@ -361,8 +362,8 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { for (out = num_inputs_; out < args.size(); out++) { int out_tot_dim = 1; - if (args[out].IsObjectRef()) { - NDArray arr = args[out]; + if (args[out].IsObjectRef()) { + Tensor arr = args[out]; outTensor = arr.operator->(); } else { outTensor = args[out].operator DLTensor*(); @@ -382,8 +383,8 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { const DLTensor* tensor[64]; for (int in = 0; in < num_inputs_; in++) { - if (args[in].IsObjectRef()) { - NDArray arr = args[in]; + if (args[in].IsObjectRef()) { + Tensor arr = args[in]; tensor[in] = arr.operator->(); } else { tensor[in] = args[in].operator DLTensor*(); @@ -398,8 +399,8 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { int i = 0; for (int out = num_inputs_; out < args.size(); out++) { - if (args[out].IsObjectRef()) { - NDArray arr = args[out]; + if (args[out].IsObjectRef()) { + Tensor arr = args[out]; tensor[i] = arr.operator->(); } else { tensor[i] = args[out].operator DLTensor*(); @@ -469,11 +470,12 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { } }; -ffi::Module MarvellHardwareModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, - const String& bin_code, int num_input, +ffi::Module MarvellHardwareModuleRuntimeCreate(const ffi::String& symbol_name, + const ffi::String& nodes_json, + const ffi::String& bin_code, int num_input, int num_output, int batch_size) { - auto n = make_object(symbol_name, nodes_json, bin_code, num_input, - num_output, batch_size); + auto n = ffi::make_object(symbol_name, nodes_json, bin_code, num_input, + num_output, batch_size); return ffi::Module(n); } @@ -483,12 +485,12 @@ bool MarvellHardwareModuleNode::use_dpdk_cb = false; ml_tvmc_cb MarvellHardwareModuleNode::tvmc_cb_ = {}; ml_dpdk_cb MarvellHardwareModuleNode::dpdk_cb_ = {}; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.mrvl_hw_runtime_create", MarvellHardwareModuleRuntimeCreate) .def("ffi.Module.load_from_bytes.mrvl_hw", MarvellHardwareModuleNode::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index 974ca4a69a1f..1a9ad8c47851 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include @@ -70,14 +70,15 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual Optional GetFunction(const String& name) { + virtual ffi::Optional GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); } else if (name == "get_const_vars") { - return ffi::Function( - [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = Array{}; }); + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::Array{}; + }); } else if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { Run(args); @@ -111,7 +112,7 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { ICHECK(stream->Read(&nodes_json)) << "Marvell-Compiler-ERROR-Internal::Loading nodes json failed"; ICHECK(stream->Read(&bin_code)) << "Marvell-Compiler-ERROR-Internal::Loading bin code failed"; - auto n = make_object(symbol_name, nodes_json, bin_code); + auto n = ffi::make_object(symbol_name, nodes_json, bin_code); return ffi::Module(n); } @@ -121,7 +122,7 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { * \param format the format to return. * \return A string of JSON. */ - String InspectSource(const String& format) const override { return nodes_json_; } + ffi::String InspectSource(const ffi::String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -149,18 +150,19 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { } }; -ffi::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, - const String& bin_code) { - auto n = make_object(symbol_name, nodes_json, bin_code); +ffi::Module MarvellSimulatorModuleRuntimeCreate(const ffi::String& symbol_name, + const ffi::String& nodes_json, + const ffi::String& bin_code) { + auto n = ffi::make_object(symbol_name, nodes_json, bin_code); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.mrvl_runtime_create", MarvellSimulatorModuleRuntimeCreate) .def("ffi.Module.load_from_bytes.mrvl_sim", MarvellSimulatorModuleNode::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc index c63bafcd0089..a7d50f412c9d 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include @@ -36,7 +36,7 @@ using namespace tvm::runtime; template -static void NDArrayToFile(const tvm::runtime::NDArray& arr, std::ostream& os) { +static void TensorToFile(const tvm::runtime::Tensor& arr, std::ostream& os) { int ndim = arr->ndim; int tot_dim = 1; for (int i = 0; i < ndim; i++) { @@ -70,8 +70,8 @@ static void ReadInputsAndGenerateInputBin(ffi::PackedArgs args, const std::strin file_out << R"( "inputs": [)" << std::endl; for (size_t i = 0; i < num_inputs; ++i) { const DLTensor* tensor; - if (args[i].IsObjectRef()) { - NDArray arr = args[i]; + if (args[i].IsObjectRef()) { + Tensor arr = args[i]; tensor = arr.operator->(); } else { tensor = args[i].cast(); @@ -80,9 +80,9 @@ static void ReadInputsAndGenerateInputBin(ffi::PackedArgs args, const std::strin for (int64_t i = 0; i < tensor->ndim; i++) { shape.push_back(tensor->shape[i]); } - NDArray arr = NDArray::Empty(shape, tensor->dtype, tensor->device); + Tensor arr = Tensor::Empty(shape, tensor->dtype, tensor->device); arr.CopyFrom(tensor); - NDArrayToFile(arr, file_out); + TensorToFile(arr, file_out); if (i != num_inputs - 1) { file_out << std::endl << "\t," << std::endl; } @@ -108,8 +108,8 @@ static void ReadOutputsAndUpdateRuntime(ffi::PackedArgs args, size_t num_inputs, const std::string& out_bin_prefix) { for (int out = num_inputs; out < args.size(); out++) { const DLTensor* outTensor; - if (args[out].IsObjectRef()) { - NDArray arr = args[out]; + if (args[out].IsObjectRef()) { + Tensor arr = args[out]; outTensor = arr.operator->(); } else { outTensor = args[out].operator DLTensor*(); @@ -118,7 +118,7 @@ static void ReadOutputsAndUpdateRuntime(ffi::PackedArgs args, size_t num_inputs, for (int64_t i = 0; i < outTensor->ndim; i++) { shape.push_back(outTensor->shape[i]); } - NDArray arr = NDArray::Empty(shape, outTensor->dtype, outTensor->device); + Tensor arr = Tensor::Empty(shape, outTensor->dtype, outTensor->device); int ndim = arr->ndim; int tot_dim = 1; for (int i = 0; i < ndim; i++) { @@ -126,7 +126,7 @@ static void ReadOutputsAndUpdateRuntime(ffi::PackedArgs args, size_t num_inputs, } float f; float* data = new float[tot_dim](); - String outbin = out_bin_prefix + "-" + std::to_string(out - num_inputs) + ".bin"; + ffi::String outbin = out_bin_prefix + "-" + std::to_string(out - num_inputs) + ".bin"; std::ifstream fin(outbin, std::ios::binary); ICHECK(fin.is_open()) << "Cannot open file: " << outbin; int i = 0; diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 37ae9f254895..91e291ce30c1 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -62,7 +62,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit MSCTensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} ~MSCTensorRTRuntime() { @@ -87,7 +87,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalOptions(); @@ -122,18 +122,18 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { if (tool_tag_.size() > 0) { const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; - Map input_datas; + ffi::Map input_datas; int device_id = 0; for (const auto& pair : input_bindings_) { const auto& tensor_name = engine_->getBindingName(pair.first); input_datas.Set(tensor_name, device_buffers_[pair.first]); device_id = data_entry_[pair.first]->device.device_id; } - Map> context; + ffi::Map> context; context.Set("datas", input_datas); (*pf)(context, "before_forward", graph_name_, tool_tag_); } - auto tvm_stream = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); + auto tvm_stream = TVMFFIEnvGetStream(kDLCUDA, device_id); #if TRT_VERSION_GE(6, 0, 1) ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) << "Running TensorRT failed."; @@ -155,7 +155,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { if (tool_tag_.size() > 0) { const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; - Map output_datas; + ffi::Map output_datas; for (int bid = 0; bid < engine_->getNbBindings(); bid++) { if (input_bindings_.count(bid)) { continue; @@ -163,13 +163,13 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const auto& tensor_name = engine_->getBindingName(bid); output_datas.Set(tensor_name, device_buffers_[bid]); } - Map> context; + ffi::Map> context; context.Set("datas", output_datas); (*pf)(context, "after_forward", graph_name_, tool_tag_); } } - bool LoadEngine(const String& engine_file) { + bool LoadEngine(const ffi::String& engine_file) { IRuntime* runtime = createInferRuntime(logger_); // build engine std::ifstream input(engine_file_, std::ifstream::binary); @@ -289,14 +289,14 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const auto& pair = tensor_ids_[tensor_name]; auto shape = nodes_[pair.first].GetOpShape()[pair.second]; auto dtype = nodes_[pair.first].GetOpDataType()[pair.second]; - device_buffers_[bid] = runtime::NDArray::Empty(shape, dtype, {kDLCUDA, 0}); + device_buffers_[bid] = runtime::Tensor::Empty(shape, dtype, {kDLCUDA, 0}); } bindings_[bid] = device_buffers_[bid]->data; binded.insert(bid); } } - NDArray GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { + Tensor GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { std::vector shape(data_entry_[entry_id]->shape, data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); if (device_buffers_.count(binding_index)) { @@ -304,7 +304,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { if (shape[0] > device_buffers_[binding_index]->shape[0]) { // Buffer is too small. Need to allocate bigger buffer. device_buffers_[binding_index] = - runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); } else if (shape[0] < device_buffers_[binding_index]->shape[0]) { // Buffer is too large. Create view. return device_buffers_[binding_index].CreateView(shape, data_entry_[entry_id]->dtype); @@ -312,7 +312,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { } else { // Buffer not initialized yet. device_buffers_[binding_index] = - runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); } return device_buffers_.at(binding_index); } @@ -323,15 +323,15 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { << "Please build with USE_TENSORRT_RUNTIME."; } - bool LoadEngine(const String& engine_file) { return false; } + bool LoadEngine(const ffi::String& engine_file) { return false; } void DestroyEngine() {} #endif // TVM_GRAPH_EXECUTOR_TENSORRT private: - String engine_file_; - String tool_tag_; - String graph_name_; + ffi::String engine_file_; + ffi::String tool_tag_; + ffi::String graph_name_; std::unordered_map> tensor_ids_; #ifdef TVM_GRAPH_EXECUTOR_TENSORRT TensorRTLogger logger_; @@ -341,23 +341,23 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { std::unordered_map output_bindings_; std::vector bindings_; std::vector binding_sizes_; - std::unordered_map device_buffers_; + std::unordered_map device_buffers_; #endif }; -ffi::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module MSCTensorRTRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.msc_tensorrt_runtime_create", MSCTensorRTRuntimeCreate) .def("ffi.Module.load_from_bytes.msc_tensorrt", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mscclpp/allreduce.cu b/src/runtime/contrib/mscclpp/allreduce.cu index 2b009c062585..147c306bf452 100644 --- a/src/runtime/contrib/mscclpp/allreduce.cu +++ b/src/runtime/contrib/mscclpp/allreduce.cu @@ -18,7 +18,7 @@ */ #include -#include +#include #include "msccl.cuh" diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index 71335f3ee287..6d3c55513889 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -51,10 +51,10 @@ using JSONGraphNode = tvm::runtime::json::JSONGraphNode; class NNAPIRuntime : public JSONRuntimeBase { public: explicit NNAPIRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - const char* type_key() const final { return "nnapi"; } + const char* kind() const final { return "nnapi"; } #ifdef TVM_GRAPH_EXECUTOR_NNAPI struct CompiledModel { @@ -70,7 +70,7 @@ class NNAPIRuntime : public JSONRuntimeBase { std::optional compiled_model_; - void Init(const Array& consts) final { + void Init(const ffi::Array& consts) final { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required constants."; SetupConstants(consts); @@ -225,7 +225,7 @@ class NNAPIRuntime : public JSONRuntimeBase { std::unordered_map node_output_map_; #else // ifdef TVM_GRAPH_EXECUTOR_NNAPI - void Init(const Array& consts) final { + void Init(const ffi::Array& consts) final { LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; } @@ -235,18 +235,18 @@ class NNAPIRuntime : public JSONRuntimeBase { #endif // ifdef TVM_GRAPH_EXECUTOR_NNAPI }; -ffi::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module NNAPIRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.nnapi_runtime_create", NNAPIRuntimeCreate) .def("ffi.Module.load_from_bytes.nnapi", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 4cb0558d611b..3471902bc311 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -80,7 +80,7 @@ void InitNVSHMEM(ffi::Shape uid_64, int num_workers, int worker_id_start) { << ", npes=" << nvshmem_n_pes(); } -void InitNVSHMEMWrapper(String args) { +void InitNVSHMEMWrapper(ffi::String args) { picojson::value v; std::string err = picojson::parse(v, args); if (!err.empty()) { @@ -121,14 +121,14 @@ void NVSHMEMXCumoduleInit(void* cuModule) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.disco.nvshmem.init_nvshmem_uid", InitNVSHMEMUID) .def("runtime.disco.nvshmem.init_nvshmem", InitNVSHMEM) .def("runtime.disco.nvshmem.init_nvshmem_wrapper", InitNVSHMEMWrapper) .def("runtime.nvshmem.cumodule_init", NVSHMEMXCumoduleInit); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu b/src/runtime/contrib/nvshmem/kv_transfer.cu index e225b1a346da..34916a614ae4 100644 --- a/src/runtime/contrib/nvshmem/kv_transfer.cu +++ b/src/runtime/contrib/nvshmem/kv_transfer.cu @@ -330,9 +330,9 @@ int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages, return 0; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("nvshmem.KVTransfer", _KVTransfer) .def("nvshmem.KVTransferPageToPage", _KVTransferPageToPage); -}); +} diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index 6ac7aa04f7bb..5893d04ac33a 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -57,7 +57,7 @@ class NVSHMEMAllocator final : public PooledAllocator { return allocator; } - NDArray Empty(ffi::Shape shape, DataType dtype, Device device) { + Tensor Empty(ffi::Shape shape, DataType dtype, Device device) { class NVSHMEMAlloc { public: explicit NVSHMEMAlloc(Buffer buffer) : buffer_(buffer) {} @@ -68,8 +68,8 @@ class NVSHMEMAllocator final : public PooledAllocator { Buffer buffer_; }; - Buffer buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem")); - return NDArray::FromNDAlloc(NVSHMEMAlloc(buffer), shape, dtype, device); + Buffer buffer = PooledAllocator::Alloc(device, shape, dtype, ffi::String("nvshmem")); + return Tensor::FromNDAlloc(NVSHMEMAlloc(buffer), shape, dtype, device); } private: @@ -86,24 +86,24 @@ class NVSHMEMAllocator final : public PooledAllocator { void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); } }; -NDArray NVSHMEMEmpty(ffi::Shape shape, DataType dtype, Device device) { +Tensor NVSHMEMEmpty(ffi::Shape shape, DataType dtype, Device device) { return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("runtime.disco.nvshmem.empty", NVSHMEMEmpty); -}); +} void NVSHMEMFinalize() { NVSHMEMAllocator::Global()->Clear(); nvshmem_finalize(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("runtime.disco.nvshmem.finalize_nvshmem", NVSHMEMFinalize); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc index d847e05e1bee..91af80de3794 100644 --- a/src/runtime/contrib/papi/papi.cc +++ b/src/runtime/contrib/papi/papi.cc @@ -51,9 +51,7 @@ struct PAPIEventSetNode : public Object { explicit PAPIEventSetNode(std::vector start_values, Device dev) : start_values(start_values), dev(dev) {} - - static constexpr const char* _type_key = "PAPIEventSetNode"; - TVM_DECLARE_FINAL_OBJECT_INFO(PAPIEventSetNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("PAPIEventSetNode", PAPIEventSetNode, Object); }; /* Get the PAPI component id for the given device. @@ -101,7 +99,7 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { * collected on that device. You can find the names of available metrics by * running `papi_native_avail`. */ - explicit PAPIMetricCollectorNode(Map> metrics) { + explicit PAPIMetricCollectorNode(ffi::Map> metrics) { for (auto& p : metrics) { papi_metric_names[p.first->device] = {}; for (auto& metric : p.second) { @@ -114,7 +112,7 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { /*! \brief Initialization call. * \param devices The devices this collector will be running on */ - void Init(Array devices) { + void Init(ffi::Array devices) { if (!PAPI_is_initialized()) { if (sizeof(long_long) > sizeof(int64_t)) { LOG(WARNING) << "PAPI's long_long is larger than int64_t. Overflow may occur when " @@ -225,7 +223,7 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { int event_set = it->second; std::vector values(papi_metric_names[dev].size()); PAPI_CALL(PAPI_read(event_set, values.data())); - return ObjectRef(make_object(values, dev)); + return ObjectRef(ffi::make_object(values, dev)); } else { return ObjectRef(nullptr); } @@ -237,19 +235,19 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { * \param obj `PAPIEventSetNode` created by a call to `Start`. * \returns A mapping from metric name to value. */ - Map Stop(ObjectRef obj) final { + ffi::Map Stop(ObjectRef obj) final { const PAPIEventSetNode* event_set_node = obj.as(); std::vector end_values(papi_metric_names[event_set_node->dev].size()); PAPI_CALL(PAPI_read(event_sets[event_set_node->dev], end_values.data())); - std::unordered_map reported_metrics; + std::unordered_map reported_metrics; for (size_t i = 0; i < end_values.size(); i++) { if (end_values[i] < event_set_node->start_values[i]) { LOG(WARNING) << "Detected overflow when reading performance counter, setting value to -1."; reported_metrics[papi_metric_names[event_set_node->dev][i]] = - ObjectRef(make_object(-1)); + ObjectRef(ffi::make_object(-1)); } else { reported_metrics[papi_metric_names[event_set_node->dev][i]] = - ObjectRef(make_object(end_values[i] - event_set_node->start_values[i])); + ObjectRef(ffi::make_object(end_values[i] - event_set_node->start_values[i])); } } return reported_metrics; @@ -269,31 +267,32 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { /*! \brief Device-specific metric names. Order of names matches the order in the corresponding * `event_set`. */ std::unordered_map> papi_metric_names; - - static constexpr const char* _type_key = "runtime.profiling.PAPIMetricCollector"; - TVM_DECLARE_FINAL_OBJECT_INFO(PAPIMetricCollectorNode, MetricCollectorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.profiling.PAPIMetricCollector", + PAPIMetricCollectorNode, MetricCollectorNode); }; /*! \brief Wrapper for `PAPIMetricCollectorNode`. */ class PAPIMetricCollector : public MetricCollector { public: - explicit PAPIMetricCollector(Map> metrics) { - data_ = make_object(metrics); + explicit PAPIMetricCollector(ffi::Map> metrics) { + data_ = ffi::make_object(metrics); } - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PAPIMetricCollector, MetricCollector, - PAPIMetricCollectorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PAPIMetricCollector, MetricCollector, + PAPIMetricCollectorNode); }; -MetricCollector CreatePAPIMetricCollector(Map> metrics) { +MetricCollector CreatePAPIMetricCollector( + ffi::Map> metrics) { return PAPIMetricCollector(metrics); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "runtime.profiling.PAPIMetricCollector", - [](Map> metrics) { return PAPIMetricCollector(metrics); }); -}); + refl::GlobalDef().def("runtime.profiling.PAPIMetricCollector", + [](ffi::Map> metrics) { + return PAPIMetricCollector(metrics); + }); +} } // namespace profiling } // namespace runtime diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 04b53d74b404..0158a66be5dd 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -75,7 +75,7 @@ class RandomEngine { */ void SampleUniform(DLTensor* data, float low, float high) { ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; @@ -99,7 +99,7 @@ class RandomEngine { */ void SampleNormal(DLTensor* data, float loc, float scale) { ICHECK_GT(scale, 0) << "standard deviation must be positive"; - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; @@ -122,11 +122,12 @@ class RandomEngine { if (data->device.device_type == kDLCPU) { FillData(data); } else { - runtime::NDArray local = runtime::NDArray::Empty( + runtime::Tensor local = runtime::Tensor::Empty( std::vector{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0}); - DLTensor* tensor = const_cast(local.operator->()); - FillData(tensor); - runtime::NDArray::CopyFromTo(tensor, data); + + const DLTensor* tensor = local.GetDLTensorPtr(); + FillData(const_cast(tensor)); + runtime::Tensor::CopyFromTo(tensor, data); } } @@ -134,11 +135,11 @@ class RandomEngine { if (data->device.device_type == kDLCPU) { FillDataForMeasure(data); } else { - runtime::NDArray local = runtime::NDArray::Empty( + runtime::Tensor local = runtime::Tensor::Empty( std::vector{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0}); - DLTensor* tensor = const_cast(local.operator->()); - FillDataForMeasure(tensor); - runtime::NDArray::CopyFromTo(tensor, data); + const DLTensor* tensor = local.GetDLTensorPtr(); + FillDataForMeasure(const_cast(tensor)); + runtime::Tensor::CopyFromTo(tensor, data); } } diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 580ed1073a47..f444ab07409e 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -70,7 +70,7 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { return RandomThreadLocalStore::Get(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.contrib.random.randint", @@ -80,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int64_t high = args[1].cast(); auto out = args[2].cast(); ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(out->strides == nullptr); + ICHECK(ffi::IsContiguous(*out)); DLDataType dtype = out->dtype; int64_t size = 1; @@ -142,7 +142,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); entry->random_engine.RandomFillForMeasure(out); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 8fdce7e43bf0..73ec8c1b0f95 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -66,7 +66,7 @@ struct RocBlasThreadEntry { typedef dmlc::ThreadLocalStore RocBlasThreadStore; // matrix multiplication for row major -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed( @@ -81,9 +81,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); + ICHECK(ffi::IsContiguous(*C)); + ICHECK(ffi::IsContiguous(*B)); + ICHECK(ffi::IsContiguous(*A)); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); @@ -145,6 +145,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ RocBlasThreadStore::Get()->handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, K * N, A_ptr, lda, M * K, &beta, C_ptr, ldc, M * N, batch_size)); }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index de67555b0a72..afbac3a84701 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -576,12 +576,12 @@ void RegisterTopk() { }); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { RegisterArgsortNMS(); RegisterArgsort(); RegisterSort(); RegisterTopk(); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 9bf793bd3e49..179e75a669fa 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -24,7 +24,7 @@ #include "tensorrt_builder.h" -#include +#include #include #include @@ -233,8 +233,8 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, } weight.count = count; weight.values = new float[count]; - ICHECK_EQ(TVMArrayCopyToBytes(const_cast(dptr), const_cast(weight.values), - weight_bytes), + ICHECK_EQ(TVMTensorCopyToBytes(const_cast(dptr), const_cast(weight.values), + weight_bytes), 0) << TVMGetLastError(); trt_weights_.push_back(weight); diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index 9bccc1ea4848..96905598737c 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -25,7 +25,7 @@ #ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ #define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ -#include +#include #include #include diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index ff565444e2b5..f89f0abe2acb 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -68,7 +68,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit TensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), use_implicit_batch_(true), max_workspace_size_(size_t(1) << 30), @@ -109,7 +109,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalAttributes(); @@ -433,7 +433,7 @@ class TensorRTRuntime : public JSONRuntimeBase { } /*! \brief Retreive a GPU buffer for input or output or allocate if needed. */ - NDArray GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { + Tensor GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { std::vector shape(data_entry_[entry_id]->shape, data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); if (device_buffers_.count(binding_index)) { @@ -441,7 +441,7 @@ class TensorRTRuntime : public JSONRuntimeBase { if (shape[0] > device_buffers_[binding_index]->shape[0]) { // Buffer is too small. Need to allocate bigger buffer. device_buffers_[binding_index] = - runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); } else if (shape[0] < device_buffers_[binding_index]->shape[0]) { // Buffer is too large. Create view. return device_buffers_[binding_index].CreateView(shape, data_entry_[entry_id]->dtype); @@ -449,7 +449,7 @@ class TensorRTRuntime : public JSONRuntimeBase { } else { // Buffer not initialized yet. device_buffers_[binding_index] = - runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); + runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); } return device_buffers_.at(binding_index); } @@ -476,7 +476,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * is not "cuda". Since TensorRT execution can only read data from GPU, we need to copy data from * the runtime device to these buffers first. These will be allocated for the highest batch size * used by all engines. */ - std::unordered_map device_buffers_; + std::unordered_map device_buffers_; /*! \brief TensorRT logger. */ TensorRTLogger logger_; @@ -519,18 +519,18 @@ class TensorRTRuntime : public JSONRuntimeBase { bool use_fp16_; }; -ffi::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module TensorRTRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.tensorrt_runtime_create", TensorRTRuntimeCreate) .def("ffi.Module.load_from_bytes.tensorrt", JSONRuntimeBase::LoadFromBytes); -}); +} } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index c35af35eae13..9029a62f8da0 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -118,7 +118,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = interpreter_->typed_input_tensor(index); DType* src = static_cast(data_in->data); - ICHECK(data_in->strides == NULL); + ICHECK(ffi::IsContiguous(*data_in)); int64_t size = 1; for (int64_t i = 0; i < data_in->ndim; ++i) { size *= data_in->shape[i]; @@ -131,7 +131,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { void TFLiteRuntime::SetNumThreads(int num_threads) { interpreter_->SetNumThreads(num_threads); } -NDArray TFLiteRuntime::GetOutput(int index) const { +Tensor TFLiteRuntime::GetOutput(int index) const { TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[index]); DataType dtype = TfLiteDType2TVMDType(output->type); TfLiteIntArray* dims = output->dims; @@ -141,7 +141,7 @@ NDArray TFLiteRuntime::GetOutput(int index) const { shape.push_back(dims->data[i]); size *= dims->data[i]; } - NDArray ret = NDArray::Empty(shape, dtype, device_); + Tensor ret = Tensor::Empty(shape, dtype, device_); TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = static_cast(ret->data); DType* src = interpreter_->typed_output_tensor(index); @@ -152,7 +152,7 @@ NDArray TFLiteRuntime::GetOutput(int index) const { return ret; } -ffi::Optional TFLiteRuntime::GetFunction(const String& name) { +ffi::Optional TFLiteRuntime::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Return member functions during query. if (name == "set_input") { @@ -180,12 +180,12 @@ ffi::Optional TFLiteRuntime::GetFunction(const String& name) { } ffi::Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(tflite_model_bytes, dev); return ffi::Module(exec); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.tflite_runtime.create", @@ -193,6 +193,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = TFLiteRuntimeCreate(args[0].cast(), args[1].cast()); }) .def("target.runtime.tflite", TFLiteRuntimeCreate); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 396bd01104d5..a5703ee70749 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include @@ -54,7 +54,7 @@ class TFLiteRuntime : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual Optional GetFunction(const String& name); + virtual ffi::Optional GetFunction(const ffi::String& name); /*! * \return The type key of the executor. @@ -84,19 +84,19 @@ class TFLiteRuntime : public ffi::ModuleObj { */ void SetInput(int index, DLTensor* data_in); /*! - * \brief Return NDArray for given input index. + * \brief Return Tensor for given input index. * \param index The input index. * - * \return NDArray corresponding to given input node index. + * \return Tensor corresponding to given input node index. */ - NDArray GetInput(int index) const; + Tensor GetInput(int index) const; /*! - * \brief Return NDArray for given output index. + * \brief Return Tensor for given output index. * \param index The output index. * - * \return NDArray corresponding to given output node index. + * \return Tensor corresponding to given output node index. */ - NDArray GetOutput(int index) const; + Tensor GetOutput(int index) const; /*! * \brief Set the number of threads available to the interpreter. * \param num_threads The number of threads to be set. diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 1adf95f69320..bf0a176862c1 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -94,7 +94,7 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); return thrust::cuda::par_nosync(memory_resouce).on(stream); } @@ -238,7 +238,7 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.thrust.sort", [](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_GE(args.size(), 4); @@ -258,7 +258,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, workspace); }); -}); +} template void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, @@ -287,7 +287,7 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.thrust.stable_sort_by_key", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -348,7 +348,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ LOG(FATAL) << "Unsupported key dtype: " << key_dtype; } }); -}); +} template void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* workspace) { @@ -405,7 +405,7 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.thrust.sum_scan", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -484,7 +484,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ << ". Supported input dtypes are bool, int32, int64, float32, and float64"; } }); -}); +} } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index e5e45735fb55..1472cd73cbb9 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -735,7 +735,7 @@ void single_query_cached_kv_attention_v2( } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tvm.contrib.vllm.single_query_cached_kv_attention", @@ -760,17 +760,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ exp_sums, max_logits, tmp_out, out); } }); -}); +} // Expose for testing -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tvm.contrib.vllm.single_query_cached_kv_attention_v1", single_query_cached_kv_attention_v1) .def("tvm.contrib.vllm.single_query_cached_kv_attention_v2", single_query_cached_kv_attention_v2); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/vllm/cache_alloc.cc b/src/runtime/contrib/vllm/cache_alloc.cc index d616923ad78e..266138406cb9 100644 --- a/src/runtime/contrib/vllm/cache_alloc.cc +++ b/src/runtime/contrib/vllm/cache_alloc.cc @@ -19,15 +19,15 @@ #include #include #include -#include +#include namespace tvm { namespace runtime { namespace vllm { -Array AllocateKVCache(int head_size, int num_layers, int num_heads, int block_size, - int num_blocks) { - Array cache; +ffi::Array AllocateKVCache(int head_size, int num_layers, int num_heads, int block_size, + int num_blocks) { + ffi::Array cache; int element_size = 2; int vec_size = 16 / element_size; @@ -37,11 +37,11 @@ Array AllocateKVCache(int head_size, int num_layers, int num_heads, int DLDevice dev{DLDeviceType::kDLCUDA, device_id}; for (int i = 0; i < num_layers; ++i) { - NDArray key_blocks = - NDArray::Empty({num_blocks, num_heads, head_size / vec_size, block_size, vec_size}, - runtime::DataType::Float(16), dev); - NDArray value_blocks = NDArray::Empty({num_blocks, num_heads, head_size, block_size}, - runtime::DataType::Float(16), dev); + Tensor key_blocks = + Tensor::Empty({num_blocks, num_heads, head_size / vec_size, block_size, vec_size}, + runtime::DataType::Float(16), dev); + Tensor value_blocks = Tensor::Empty({num_blocks, num_heads, head_size, block_size}, + runtime::DataType::Float(16), dev); cache.push_back(key_blocks); cache.push_back(value_blocks); } @@ -49,10 +49,10 @@ Array AllocateKVCache(int head_size, int num_layers, int num_heads, int return cache; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tvm.contrib.vllm.allocate_kv_cache", AllocateKVCache); -}); +} } // namespace vllm } // namespace runtime diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index c7f91aa42fce..5ddf18e48208 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -18,7 +18,7 @@ */ #include #include -#include +#include #include #include @@ -130,12 +130,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, int64_t* value_cache namespace tvm { namespace runtime { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tvm.contrib.vllm.reshape_and_cache", - [](NDArray key, NDArray value, NDArray key_cache, NDArray value_cache, - NDArray slot_mapping) { + [](Tensor key, Tensor value, Tensor key_cache, Tensor value_cache, Tensor slot_mapping) { int num_tokens = key->shape[0]; int num_heads = key->shape[1]; int head_size = key->shape[2]; @@ -158,7 +157,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return Array{key_cache, value_cache}; }) .def("tvm.contrib.vllm.reconstruct_from_cache", - [](NDArray key_cache, NDArray value_cache, NDArray slot_mapping) { + [](Tensor key_cache, Tensor value_cache, Tensor slot_mapping) { int num_tokens = slot_mapping->shape[0]; int num_heads = value_cache->shape[1]; int head_size = value_cache->shape[2]; @@ -166,8 +165,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ int vec_size = key_cache->shape[4]; DLDevice dev = key_cache->device; - auto key = NDArray::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); - auto value = NDArray::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); + auto key = Tensor::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); + auto value = Tensor::Empty({num_tokens, num_heads, head_size}, key_cache->dtype, dev); int key_stride = key->shape[1] * key->shape[2]; int value_stride = value->shape[1] * value->shape[2]; @@ -185,8 +184,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ return Array{key, value}; }) - .def("tvm.contrib.vllm.copy_blocks", [](Array key_value_caches, - NDArray block_mapping) { + .def("tvm.contrib.vllm.copy_blocks", [](ffi::Array key_value_caches, + Tensor block_mapping) { auto num_layers = key_value_caches.size() / 2; auto num_pairs = block_mapping->shape[0] / 2; @@ -203,20 +202,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ reinterpret_cast(key_value_caches[2 * layer_idx + 1]->data); } - NDArray key_cache = key_value_caches[1]; // [num_blocks, num_heads, head_size, block_size] + Tensor key_cache = key_value_caches[1]; // [num_blocks, num_heads, head_size, block_size] DLDevice dev = key_cache->device; - NDArray key_cache_ptrs_gpu = - NDArray::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); - NDArray value_cache_ptrs_gpu = - NDArray::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + Tensor key_cache_ptrs_gpu = + Tensor::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + Tensor value_cache_ptrs_gpu = + Tensor::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(), sizeof(int64_t) * key_cache_ptrs.size()); value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(), sizeof(int64_t) * value_cache_ptrs.size()); - NDArray block_mapping_gpu = - NDArray::Empty(block_mapping.Shape(), runtime::DataType::Int(64), dev); + Tensor block_mapping_gpu = + Tensor::Empty(block_mapping.Shape(), runtime::DataType::Int(64), dev); block_mapping_gpu.CopyFromBytes(block_mapping->data, sizeof(int64_t) * block_mapping->shape[0]); @@ -230,7 +229,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ static_cast(value_cache_ptrs_gpu->data), static_cast(block_mapping_gpu->data), numel_per_block); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index e9b16d003e3a..d9299832ddb3 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -151,12 +151,12 @@ void CPUDeviceAPI::FreeWorkspace(Device dev, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, data); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("device_api.cpu", [](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CPUDeviceAPI::Global(); *rv = static_cast(ptr); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 451348afbf1a..bfd5f7cca98a 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -280,7 +280,7 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {} CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("device_api.cuda", @@ -292,7 +292,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ DeviceAPI* ptr = CUDADeviceAPI::Global(); *rv = static_cast(ptr); }); -}); +} class CUDATimerNode : public TimerNode { public: @@ -301,7 +301,7 @@ class CUDATimerNode : public TimerNode { // cudaEventRecord do some stream synchronization? int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - stream_ = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); + stream_ = TVMFFIEnvGetStream(kDLCUDA, device_id); CUDA_CALL(cudaEventRecord(start_, static_cast(stream_))); } virtual void Stop() { @@ -323,9 +323,7 @@ class CUDATimerNode : public TimerNode { CUDA_CALL(cudaEventCreate(&start_)); CUDA_CALL(cudaEventCreate(&stop_)); } - - static constexpr const char* _type_key = "runtime.cuda.CUDATimerNode"; - TVM_DECLARE_FINAL_OBJECT_INFO(CUDATimerNode, TimerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.cuda.CUDATimerNode", CUDATimerNode, TimerNode); private: cudaEvent_t start_; @@ -333,13 +331,13 @@ class CUDATimerNode : public TimerNode { TVMStreamHandle stream_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.cuda", - [](Device dev) { return Timer(make_object()); }); -}); + [](Device dev) { return Timer(ffi::make_object()); }); +} -TVM_DLL String GetCudaFreeMemory() { +TVM_DLL ffi::String GetCudaFreeMemory() { size_t free_mem, total_mem; CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); std::stringstream ss; @@ -348,18 +346,18 @@ TVM_DLL String GetCudaFreeMemory() { return ss.str(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.GetCudaFreeMemory", GetCudaFreeMemory) .def("runtime.get_cuda_stream", []() { // TODO(tvm-team): remove once confirms all dep such as flashinfer - // migrated to TVMFFIEnvGetCurrentStream + // migrated to TVMFFIEnvGetStream int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - return static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + return static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); }); -}); +} TVM_DLL int GetCudaDeviceCount() { int count; @@ -367,10 +365,10 @@ TVM_DLL int GetCudaDeviceCount() { return count; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("runtime.GetCudaDeviceCount", GetCudaDeviceCount); -}); +} #if (CUDA_VERSION >= 12000) /** @@ -396,7 +394,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param l2_promotion_kind (int): An integer corresponding to the CUtensorMapL2promotion enum. * \param oob_fill_kind (int): An integer corresponding to the CUtensorMapFloatOOBfill enum. */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("runtime.cuTensorMapEncodeTiled", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -580,7 +578,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr; } }); -}); +} #endif // CUDA_VERSION >= 12000 } // namespace runtime diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index eb3bee4757bf..b41bb0516e17 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -43,6 +43,15 @@ namespace tvm { namespace runtime { +namespace { + +inline void EnsureCurrentDeviceContext(int device_id) { + // Driver API entry points require a current context on this thread. `cudaGetDevice` + // reports the logical device, but it does not guarantee the primary context is bound. + CUDA_CALL(cudaSetDevice(device_id)); +} + +} // namespace // Module to support thread-safe multi-GPU execution. // cuModule is a per-GPU module @@ -73,9 +82,9 @@ class CUDAModuleNode : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { @@ -99,7 +108,7 @@ class CUDAModuleNode : public ffi::ModuleObj { return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (format == fmt_) return data_; if (cuda_source_.length() != 0) { return cuda_source_; @@ -112,6 +121,7 @@ class CUDAModuleNode : public ffi::ModuleObj { // get a CUfunction from primary context in device_id CUfunction GetFunc(int device_id, const std::string& func_name) { std::lock_guard lock(mutex_); + EnsureCurrentDeviceContext(device_id); // must recheck under the lock scope if (module_[device_id] == nullptr) { CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str())); @@ -132,6 +142,7 @@ class CUDAModuleNode : public ffi::ModuleObj { // get a global var from primary context in device_id CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { std::lock_guard lock(mutex_); + EnsureCurrentDeviceContext(device_id); // must recheck under the lock scope if (module_[device_id] == nullptr) { CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str())); @@ -178,31 +189,120 @@ class CUDAWrappedFunc { sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); + // Track whether this kernel uses dynamic shared memory and the last size set per device. + std::fill(dyn_smem_initialized_.begin(), dyn_smem_initialized_.end(), false); + // Track whether cluster attribute has been set per device. + std::fill(cluster_attr_initialized_.begin(), cluster_attr_initialized_.end(), false); + use_dyn_shared_memory_ = false; + for (const auto& tag : launch_param_tags) { + if (tag == launch_param::kUseDynamicSharedMemoryTag) { + use_dyn_shared_memory_ = true; + break; + } + } launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(ffi::PackedArgs args, ffi::Any* rv, void** void_args) const { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); + EnsureCurrentDeviceContext(device_id); ThreadWorkLoad wl = launch_param_config_.Extract(args); if (fcache_[device_id] == nullptr) { fcache_[device_id] = m_->GetFunc(device_id, func_name_); - if (wl.dyn_shmem_size >= (48 << 10)) { - // Assumption: dyn_shmem_size doesn't change across different invocations of - // fcache_[device_id] - CUresult result = cuFuncSetAttribute( + } + + // If the kernel uses dynamic shared memory, we should ensure the attribute + // reflects the actual size needed for this launch. Some workloads vary the + // dynamic shared memory between invocations, in which case we cannot set it + // just once. Cache the last value per device to avoid redundant calls. + bool need_dyn_attr = use_dyn_shared_memory_ || (wl.dyn_shmem_size > 0); + if (need_dyn_attr) { + if (!dyn_smem_initialized_[device_id] || dyn_smem_last_[device_id] != wl.dyn_shmem_size) { + CUresult attr_set = cuFuncSetAttribute( fcache_[device_id], CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, wl.dyn_shmem_size); - if (result != CUDA_SUCCESS) { + if (attr_set != CUDA_SUCCESS) { LOG(FATAL) << "Failed to set the allowed dynamic shared memory size to " << wl.dyn_shmem_size; } + dyn_smem_last_[device_id] = wl.dyn_shmem_size; + dyn_smem_initialized_[device_id] = true; } } - CUstream strm = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); - CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), - wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); + CUstream strm = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); + CUresult result; + + ICHECK(wl.grid_dim(0) > 0 && wl.grid_dim(1) > 0 && wl.grid_dim(2) > 0) + << "CUDALaunch Error: grid dimension must be positive, but got" + << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << ")" + << " in kernel " << func_name_ + << ". A zero grid dimension is often caused by a dynamic shape" + << " (e.g. num_tokens) being 0 at runtime."; + + if (wl.use_cluster_launch()) { + // SM90+ cluster launch + CUlaunchConfig config{}; + CUlaunchAttribute attribute[2]{}; + attribute[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attribute[0].value.clusterDim.x = wl.cluster_dim[0]; + attribute[0].value.clusterDim.y = wl.cluster_dim[1]; + attribute[0].value.clusterDim.z = wl.cluster_dim[2]; + attribute[1].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + attribute[1].value.programmaticStreamSerializationAllowed = 1; + + config.attrs = attribute; + config.numAttrs = 2; + config.hStream = strm; + config.gridDimX = wl.grid_dim(0); + config.gridDimY = wl.grid_dim(1); + config.gridDimZ = wl.grid_dim(2); + config.blockDimX = wl.block_dim(0); + config.blockDimY = wl.block_dim(1); + config.blockDimZ = wl.block_dim(2); + config.sharedMemBytes = wl.dyn_shmem_size; + + // Set non-portable cluster size allowed attribute + if (!cluster_attr_initialized_[device_id]) { + CUresult attr_result = cuFuncSetAttribute( + fcache_[device_id], CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1); + if (attr_result != CUDA_SUCCESS) { + const char* msg; + cuGetErrorName(attr_result, &msg); + LOG(FATAL) << "Failed to set cluster attribute for " << func_name_ << ": " << msg; + } + cluster_attr_initialized_[device_id] = true; + } + + result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr); + } else if (launch_param_config_.use_programtic_dependent_launch()) { + CUlaunchConfig config{}; + CUlaunchAttribute attribute[1]{}; + attribute[0].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + attribute[0].value.programmaticStreamSerializationAllowed = 1; + + config.attrs = attribute; + config.numAttrs = 1; + config.hStream = strm; + config.gridDimX = wl.grid_dim(0); + config.gridDimY = wl.grid_dim(1); + config.gridDimZ = wl.grid_dim(2); + config.blockDimX = wl.block_dim(0); + config.blockDimY = wl.block_dim(1); + config.blockDimZ = wl.block_dim(2); + config.sharedMemBytes = wl.dyn_shmem_size; + + result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr); + } else if (launch_param_config_.use_cooperative_launch()) { + result = cuLaunchCooperativeKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args); + } else { + result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), + wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, + strm, void_args, nullptr); + } + if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); @@ -210,7 +310,8 @@ class CUDAWrappedFunc { os << "CUDALaunch Error: " << msg << "\n" << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) - << ")\n"; + << ")" + << " dyn_smem_bytes=" << wl.dyn_shmem_size; std::string cuda = m_->InspectSource(""); if (cuda.length() != 0) { os << "// func_name=" << func_name_ << "\n" @@ -220,6 +321,24 @@ class CUDAWrappedFunc { } LOG(FATAL) << os.str(); } + + // Check for asynchronous CUDA errors that cuLaunchKernel's return value + // does not capture (e.g. illegal memory access during kernel execution). + // This matches the Cython backend's TILELANG_CHECK_LAST_ERROR macro. + if (result == CUDA_SUCCESS) { + cudaError_t last_err = cudaPeekAtLastError(); + if (last_err != cudaSuccess) { + // Use driver API cuGetErrorName for the error name (cudaGetErrorName + // is not available in the cudart stub). The numeric values of + // cudaError_t and CUresult are identical for matching error codes. + const char* err_name = nullptr; + cuGetErrorName(static_cast(last_err), &err_name); + const char* err_str = cudaGetErrorString(last_err); + // Clear the sticky error so subsequent CUDA calls are not poisoned. + cudaGetLastError(); + LOG(FATAL) << func_name_ << ": " << (err_name ? err_name : "unknown") << " - " << err_str; + } + } } private: @@ -234,6 +353,15 @@ class CUDAWrappedFunc { mutable std::array fcache_; // launch parameters configuration LaunchParamConfig launch_param_config_; + // Whether this kernel uses dynamic shared memory + bool use_dyn_shared_memory_{false}; + // Cached last dynamic shared memory size per device and whether it's initialized + mutable std::array dyn_smem_last_; + mutable std::array dyn_smem_initialized_; + // Whether cluster attribute has been initialized per device + mutable std::array cluster_attr_initialized_; + // have pdl setting + bool has_programmatic_dependent_launch_; }; class CUDAPrepGlobalBarrier { @@ -245,6 +373,7 @@ class CUDAPrepGlobalBarrier { void operator()(const ffi::PackedArgs& args, ffi::Any* rv) const { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); + EnsureCurrentDeviceContext(device_id); if (pcache_[device_id] == 0) { pcache_[device_id] = m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); @@ -261,7 +390,7 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -Optional CUDAModuleNode::GetFunction(const String& name) { +ffi::Optional CUDAModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == symbol::tvm_prepare_global_barrier) { @@ -278,12 +407,12 @@ Optional CUDAModuleNode::GetFunction(const String& name) { ffi::Module CUDAModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string cuda_source) { - auto n = make_object(data, fmt, fmap, cuda_source); + auto n = ffi::make_object(data, fmt, fmap, cuda_source); return ffi::Module(n); } // Load module from module. -ffi::Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module CUDAModuleLoadFile(const std::string& file_name, const ffi::String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -305,12 +434,12 @@ ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes& bytes) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ffi.Module.load_from_file.cuda", CUDAModuleLoadFile) .def("ffi.Module.load_from_file.ptx", CUDAModuleLoadFile) .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 0c7f939181a2..b69ecc71882c 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -34,16 +34,16 @@ typedef dmlc::ThreadLocalStore L2FlushStore; L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("l2_cache_flush_cuda", [](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); L2Flush::ThreadLocal()->Flush(stream); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 28dc313ba3e6..f8910f6e8800 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -21,7 +21,7 @@ * \file device_api.cc * \brief Device specific implementations */ -#include +#include #include #include #include @@ -107,7 +107,7 @@ static size_t GetDataAlignment(const DLDataType dtype) { return align; } -size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { +size_t DeviceAPI::GetDataSize(const DLTensor& arr, ffi::Optional mem_scope) { if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { size_t size = 1; for (int i = 0; i < arr.ndim; ++i) { @@ -121,7 +121,7 @@ size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { } void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { + ffi::Optional mem_scope) { if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { // by default, we can always redirect to the flat memory allocations DLTensor temp; @@ -169,13 +169,13 @@ void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { } TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { - return TVMFFIEnvGetCurrentStream(dev.device_type, dev.device_id); + return TVMFFIEnvGetStream(dev.device_type, dev.device_id); } void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.Device_StreamCreate", @@ -198,10 +198,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast(src), reinterpret_cast(dst)); }); -}); +} // set device api -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed(tvm::runtime::symbol::tvm_set_device, @@ -235,7 +235,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ dev.device_id = device_id; DeviceAPIManager::Get(dev)->SetStream(dev, stream); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 46ecb49f50fc..2ea9ef575d05 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -38,7 +38,7 @@ struct BcastSessionObj::Internal { } static DRef MakeDRef(int reg_id, Session session) { - ObjectPtr p = make_object(); + ObjectPtr p = ffi::make_object(); p->reg_id = reg_id; p->session = session; return DRef(std::move(p)); @@ -48,17 +48,17 @@ struct BcastSessionObj::Internal { DRef BcastSessionObj::GetGlobalFunc(const std::string& name) { int reg_id = AllocateReg(); BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kGetGlobalFunc, reg_id, name); - return BcastSessionObj::Internal::MakeDRef(reg_id, GetRef(this)); + return BcastSessionObj::Internal::MakeDRef(reg_id, ffi::GetRef(this)); } -void BcastSessionObj::CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) { - this->AppendHostNDArray(host_array); +void BcastSessionObj::CopyFromWorker0(const Tensor& host_array, const DRef& remote_array) { + this->AppendHostTensor(host_array); BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kCopyFromWorker0, remote_array->reg_id); } -void BcastSessionObj::CopyToWorker0(const NDArray& host_array, const DRef& remote_array) { - this->AppendHostNDArray(host_array); +void BcastSessionObj::CopyToWorker0(const Tensor& host_array, const DRef& remote_array) { + this->AppendHostTensor(host_array); BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kCopyToWorker0, remote_array->reg_id); } @@ -67,11 +67,11 @@ void BcastSessionObj::Shutdown() { BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kShutDown, 0); } -void BcastSessionObj::InitCCL(String ccl, ffi::Shape device_ids) { +void BcastSessionObj::InitCCL(ffi::String ccl, ffi::Shape device_ids) { const auto pf = tvm::ffi::Function::GetGlobal("runtime.disco." + ccl + ".init_ccl"); CHECK(pf.has_value()) << "ValueError: Cannot initialize CCL `" << ccl << "`, because cannot find function: runtime.disco." << ccl << ".init_ccl"; - (*pf)(GetRef(this), device_ids); + (*pf)(ffi::GetRef(this), device_ids); } void BcastSessionObj::SyncWorker(int worker_id) { @@ -97,7 +97,7 @@ DRef BcastSessionObj::CallWithPacked(const ffi::PackedArgs& args) { args_vec[2] = func->reg_id; } this->BroadcastPacked(ffi::PackedArgs(args_vec, args.size())); - return BcastSessionObj::Internal::MakeDRef(reg_id, GetRef(this)); + return BcastSessionObj::Internal::MakeDRef(reg_id, ffi::GetRef(this)); } void BcastSessionObj::DeallocReg(int reg_id) { @@ -114,7 +114,7 @@ int BcastSessionObj::AllocateReg() { return reg_id; } -void BcastSessionObj::AppendHostNDArray(const NDArray& host_array) { +void BcastSessionObj::AppendHostTensor(const Tensor& host_array) { std::lock_guard lock(worker_zero_data_.queue_mutex_); worker_zero_data_.host_arrays.push(host_array); } diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h index f92369d85337..119ca36409f0 100644 --- a/src/runtime/disco/bcast_session.h +++ b/src/runtime/disco/bcast_session.h @@ -37,11 +37,11 @@ class BcastSessionObj : public SessionObj { virtual ~BcastSessionObj() = default; DRef GetGlobalFunc(const std::string& name) override; - void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) override; - void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) override; + void CopyFromWorker0(const Tensor& host_array, const DRef& remote_array) override; + void CopyToWorker0(const Tensor& host_array, const DRef& remote_array) override; void SyncWorker(int worker_id) override; void Shutdown() override; - void InitCCL(String ccl, IntTuple device_ids) override; + void InitCCL(ffi::String ccl, IntTuple device_ids) override; ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0; void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) override = 0; @@ -53,11 +53,11 @@ class BcastSessionObj : public SessionObj { /*! \brief Allocate a register id, either from `free_regs_` or by incrementing `reg_count_` */ virtual int AllocateReg(); /*! - * \brief Append an controler-side NDArray to a special queue used to communicate with + * \brief Append an controler-side Tensor to a special queue used to communicate with worker-0. * \param host_array The array to be appended to worker-0 */ - virtual void AppendHostNDArray(const NDArray& host_array); + virtual void AppendHostTensor(const Tensor& host_array); /*! * \brief Broadcast a command to all workers via TVM's ffi::Function calling convention. * As part of the calling convention, The first argument in the packed sequence must be @@ -102,7 +102,7 @@ class BcastSessionObj : public SessionObj { */ class BcastSession : public Session { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session, BcastSessionObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BcastSession, Session, BcastSessionObj); }; } // namespace runtime diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index b650b143e401..8584d15c5e04 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -49,17 +49,17 @@ class DSOLibraryCache { std::mutex mutex_; }; -ffi::Module LoadVMModule(std::string path, Optional device) { +ffi::Module LoadVMModule(std::string path, ffi::Optional device) { static DSOLibraryCache cache; ffi::Module dso_mod = cache.Open(path); Device dev = UseDefaultDeviceIfNone(device); - Optional vm_load_executable = dso_mod->GetFunction("vm_load_executable"); + ffi::Optional vm_load_executable = dso_mod->GetFunction("vm_load_executable"); if (!vm_load_executable.has_value()) { // not built by RelaxVM, return the dso_mod directly return dso_mod; } auto mod = (*vm_load_executable)().cast(); - Optional vm_initialization = mod->GetFunction("vm_initialization"); + ffi::Optional vm_initialization = mod->GetFunction("vm_initialization"); if (!vm_initialization.has_value()) { LOG(FATAL) << "ValueError: File `" << path << "` is not built by RelaxVM, because `vm_initialization` does not exist"; @@ -70,8 +70,8 @@ ffi::Module LoadVMModule(std::string path, Optional device) { return mod; } -NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional device) { - return NDArray::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); +Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, ffi::Optional device) { + return Tensor::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } ffi::Function GetCCLFunc(const char* name) { @@ -83,37 +83,37 @@ ffi::Function GetCCLFunc(const char* name) { return *pf; } -void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { +void AllReduce(Tensor send, ReduceKind reduce_kind, bool in_group, Tensor recv) { GetCCLFunc("allreduce")(send, static_cast(reduce_kind), in_group, recv); } -void AllGather(NDArray send, bool in_group, NDArray recv) { +void AllGather(Tensor send, bool in_group, Tensor recv) { GetCCLFunc("allgather")(send, in_group, recv); } -TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv) { +TVM_DLL void BroadcastFromWorker0(Tensor send, bool in_group, Tensor recv) { GetCCLFunc("broadcast_from_worker0")(send, in_group, recv); } -TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { +TVM_DLL void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) { GetCCLFunc("scatter_from_worker0")(send, in_group, recv); } -void GatherToWorker0(NDArray send, bool in_group, Optional recv) { +void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { GetCCLFunc("gather_to_worker0")(send, in_group, recv); } -void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); } +void RecvFromWorker0(Tensor buffer) { GetCCLFunc("recv_from_worker0")(buffer); } -void SendToNextGroup(NDArray buffer) { GetCCLFunc("send_to_next_group")(buffer); } +void SendToNextGroup(Tensor buffer) { GetCCLFunc("send_to_next_group")(buffer); } -void RecvFromPrevGroup(NDArray buffer) { GetCCLFunc("recv_from_prev_group")(buffer); } +void RecvFromPrevGroup(Tensor buffer) { GetCCLFunc("recv_from_prev_group")(buffer); } -void SendToWorker(NDArray buffer, int receiver_id) { +void SendToWorker(Tensor buffer, int receiver_id) { GetCCLFunc("send_to_worker")(buffer, receiver_id); } -void RecvFromWorker(NDArray buffer, int sender_id) { +void RecvFromWorker(Tensor buffer, int sender_id) { GetCCLFunc("recv_from_worker")(buffer, sender_id); } @@ -125,13 +125,13 @@ void SyncWorker() { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.disco.load_vm_module", LoadVMModule) .def("runtime.disco.empty", - [](ffi::Shape shape, DataType dtype, Optional device, bool worker0_only, - bool in_group) -> Optional { + [](ffi::Shape shape, DataType dtype, ffi::Optional device, bool worker0_only, + bool in_group) -> ffi::Optional { int worker_id = WorkerId(); int group_size = DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups; @@ -140,11 +140,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (worker0_only && !is_worker0) { return std::nullopt; } else { - return DiscoEmptyNDArray(shape, dtype, device); + return DiscoEmptyTensor(shape, dtype, device); } }) .def("runtime.disco.allreduce", - [](NDArray send, ffi::Shape reduce_kind, bool in_group, NDArray recv) { + [](Tensor send, ffi::Shape reduce_kind, bool in_group, Tensor recv) { int kind = IntegerFromShape(reduce_kind); CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; AllReduce(send, static_cast(kind), in_group, recv); @@ -169,7 +169,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ "tvm.runtime.threading.set_current_thread_affinity"); f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]}); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index 1761f7f2dc7a..7dc55c0b4b7c 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -101,7 +101,7 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { dev, barrier_ptr_size, alignment, DataType::UInt(32), /*reset_memory_to_zero=*/true); // Create the CUDAIPCMemory object. - ObjectPtr ipc_memory = make_object(); + ObjectPtr ipc_memory = ffi::make_object(); nccl::CCLThreadLocalContext* nccl_ctx = nccl::CCLThreadLocalContext::Get(); ipc_memory->remote_data = data_comm_ptrs; ipc_memory->barrier_in = barrier_in_comm_ptrs; @@ -202,7 +202,7 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { * \return The allocated storage object with internal CUDA IPC memory buffer. */ memory::Storage IPCAllocStorage(ffi::Shape buffer_shape, DLDataType dtype_hint) { - auto storage_obj = ffi::SimpleObjAllocator().make_object(); + auto storage_obj = ffi::make_object(); nccl::CCLThreadLocalContext* nccl_ctx = nccl::CCLThreadLocalContext::Get(); Device device{DLDeviceType::kDLCUDA, nccl_ctx->device_id}; CUDAIPCMemoryAllocator* allocator = CUDAIPCMemoryAllocator::Global(); @@ -213,13 +213,13 @@ memory::Storage IPCAllocStorage(ffi::Shape buffer_shape, DLDataType dtype_hint) return storage; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.disco.cuda_ipc.alloc_storage", IPCAllocStorage) .def("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear", []() { CUDAIPCMemoryAllocator::Global()->Clear(); }); -}); +} /******************** CUDAIPCMemoryObj ********************/ diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/disco/cuda_ipc/custom_allreduce.cc index f1293d4a4606..060a098a9d63 100644 --- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc @@ -113,10 +113,10 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { ctx->GetDefaultStream()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("runtime.disco.cuda_ipc.custom_allreduce", CustomAllReduce); -}); +} } // namespace cuda_ipc } // namespace nccl diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index 8e63355283a8..d9865ca2bec4 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -36,10 +36,10 @@ TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() { void DiscoWorker::SetRegister(int reg_id, ffi::AnyView value) { ICHECK(0 <= reg_id && reg_id < static_cast(register_file.size())); ffi::Any& rv = register_file.at(reg_id); - if (rv.type_index() == ffi::TypeIndex::kTVMFFINDArray && - value.type_index() == ffi::TypeIndex::kTVMFFINDArray) { - NDArray dst = rv.cast(); - NDArray src = value.cast(); + if (rv.type_index() == ffi::TypeIndex::kTVMFFITensor && + value.type_index() == ffi::TypeIndex::kTVMFFITensor) { + Tensor dst = rv.cast(); + Tensor src = value.cast(); dst.CopyFrom(src); } else { rv = value; @@ -112,25 +112,25 @@ struct DiscoWorker::Impl { } } - static NDArray GetNDArrayFromHost(DiscoWorker* self) { + static Tensor GetTensorFromHost(DiscoWorker* self) { std::lock_guard lock(self->worker_zero_data->queue_mutex_); - NDArray array = self->worker_zero_data->host_arrays.front(); + Tensor array = self->worker_zero_data->host_arrays.front(); self->worker_zero_data->host_arrays.pop(); return array; } static void CopyFromWorker0(DiscoWorker* self, int reg_id) { if (self->worker_id == 0) { - NDArray tgt = GetNDArrayFromHost(self); - NDArray src = GetReg(self, reg_id).cast(); + Tensor tgt = GetTensorFromHost(self); + Tensor src = GetReg(self, reg_id).cast(); tgt.CopyFrom(src); } } static void CopyToWorker0(DiscoWorker* self, int reg_id) { if (self->worker_id == 0) { - NDArray src = GetNDArrayFromHost(self); - NDArray tgt = GetReg(self, reg_id).cast(); + Tensor src = GetTensorFromHost(self); + Tensor tgt = GetReg(self, reg_id).cast(); tgt.CopyFrom(src); } } diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index b4933aa303ef..99c54933bf3a 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -56,7 +56,7 @@ class DiscoSocketChannel : public DiscoChannel { class SocketSessionObj : public BcastSessionObj { public: explicit SocketSessionObj(int num_nodes, int num_workers_per_node, int num_groups, - const String& host, int port) + const ffi::String& host, int port) : num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) { const auto f_create_local_session = tvm::ffi::Function::GetGlobal("runtime.disco.create_socket_session_local_workers"); @@ -173,8 +173,8 @@ class SocketSessionObj : public BcastSessionObj { return remote_channels_[node_id - 1]->Recv(); } - void AppendHostNDArray(const NDArray& host_array) final { - local_session_->AppendHostNDArray(host_array); + void AppendHostTensor(const Tensor& host_array) final { + local_session_->AppendHostTensor(host_array); } void Shutdown() final { @@ -196,9 +196,8 @@ class SocketSessionObj : public BcastSessionObj { } ~SocketSessionObj() { Shutdown(); } - - static constexpr const char* _type_key = "runtime.disco.SocketSession"; - TVM_DECLARE_FINAL_OBJECT_INFO(SocketSessionObj, BcastSessionObj); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.SocketSession", SocketSessionObj, + BcastSessionObj); int num_nodes_; int num_workers_per_node_; TCPSocket socket_; @@ -209,7 +208,8 @@ class SocketSessionObj : public BcastSessionObj { class RemoteSocketSession { public: - explicit RemoteSocketSession(const String& server_host, int server_port, int num_local_workers) { + explicit RemoteSocketSession(const ffi::String& server_host, int server_port, + int num_local_workers) { socket_.Create(); socket_.SetKeepAlive(true); SockAddr server_addr{server_host.c_str(), server_port}; @@ -287,25 +287,27 @@ class RemoteSocketSession { int num_workers_per_node_{-1}; }; -void RemoteSocketSessionEntryPoint(const String& server_host, int server_port, +void RemoteSocketSessionEntryPoint(const ffi::String& server_host, int server_port, int num_local_workers) { RemoteSocketSession proxy(server_host, server_port, num_local_workers); proxy.MainLoop(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("runtime.disco.RemoteSocketSession", RemoteSocketSessionEntryPoint); -}); +} -Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, const String& host, - int port) { - auto n = make_object(num_nodes, num_workers_per_node, num_groups, host, port); +Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, + const ffi::String& host, int port) { + auto n = + ffi::make_object(num_nodes, num_workers_per_node, num_groups, host, port); return Session(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); refl::GlobalDef() .def("runtime.disco.SocketSession", SocketSession) .def("runtime.disco.socket_session_init_workers", @@ -318,7 +320,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ worker->worker_id = worker->worker_id + node_id * num_workers_per_node; worker->num_workers = num_nodes * num_workers_per_node; }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index 97af8bc9d3de..35fbf8abbb6f 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -39,9 +39,9 @@ namespace tvm { namespace runtime { -using vm::NDArrayCacheMetadata; -using FileRecord = NDArrayCacheMetadata::FileRecord; -using ParamRecord = NDArrayCacheMetadata::FileRecord::ParamRecord; +using vm::TensorCacheMetadata; +using FileRecord = TensorCacheMetadata::FileRecord; +using ParamRecord = TensorCacheMetadata::FileRecord::ParamRecord; struct ShardInfo { struct TensorInfo { @@ -78,7 +78,8 @@ ShardInfo::TensorInfo LoadTensorInfoFromJSON(const picojson::array& json_tensor_ shape.push_back(AsType(shape_json[i])); } std::string dtype = AsType(json_tensor_info[1]); - return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), DataType(StringToDLDataType(dtype))}; + return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), + DataType(ffi::StringToDLDataType(dtype))}; } ShardInfo::ShardFunc LoadShardFuncFromJSON(const picojson::array& json_shard_func) { @@ -117,28 +118,26 @@ class ShardLoaderObj : public Object { public: /*! \brief Create a shard loader. */ static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Optional mod); + std::string shard_info, ffi::Optional mod); /*! \brief Load the i-th parameter */ - NDArray Load(int weight_index) const; + Tensor Load(int weight_index) const; - NDArray LoadParamOnWorker0(int weight_index) const; + Tensor LoadParamOnWorker0(int weight_index) const; /*! \brief Load all the parameters */ - Array LoadAll() const; + ffi::Array LoadAll() const; - NDArray ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const NDArray& param) const; + Tensor ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const Tensor& param) const; /*! \brief Load all the pre-sharded parameters */ - Array LoadAllPresharded() const; + ffi::Array LoadAllPresharded() const; /*! \brief Load the i-th parameter from presharded binaries */ - NDArray LoadPresharded(int weight_index) const; + Tensor LoadPresharded(int weight_index) const; /*! \brief Slice the given tensor at a specific dimension */ - NDArray Shard(NDArray source, int dim, int num_slices) const; - - static constexpr const char* _type_key = "runtime.disco.ShardLoader"; - TVM_DECLARE_FINAL_OBJECT_INFO(ShardLoaderObj, Object); + Tensor Shard(Tensor source, int dim, int num_slices) const; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.ShardLoader", ShardLoaderObj, Object); public: /*! \brief Information of how each weight is stored and sharded */ @@ -149,8 +148,8 @@ class ShardLoaderObj : public Object { }; /*! \brief The ffi::Functions being used during sharding */ std::unordered_map shard_funcs_; - /*! \brief The metadata loaded from `ndarray-cache.json` */ - NDArrayCacheMetadata metadata_; + /*! \brief The metadata loaded from `tensor-cache.json` */ + TensorCacheMetadata metadata_; /*! \brief Sharding information for each weight */ std::vector param_info_; /*! \brief Maps the name of a shard to its index */ @@ -167,22 +166,22 @@ class ShardLoaderObj : public Object { * check for post-processing that may be required. Instead, the * public function `Load` or `LoadPresharded` should be called. * - * \param weight_index The index of NDArray tensor to load + * \param weight_index The index of Tensor tensor to load * * \returns The full tensor at the specified index */ - NDArray LoadDirect(int weight_index) const; + Tensor LoadDirect(int weight_index) const; }; ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Optional mod) { + std::string shard_info, ffi::Optional mod) { if (shard_info.empty() && mod.has_value()) { if (auto get_shard_info = (*mod)->GetFunction("get_shard_info")) { - shard_info = (*get_shard_info)().cast(); + shard_info = (*get_shard_info)().cast(); } } - ObjectPtr n = make_object(); - n->metadata_ = NDArrayCacheMetadata::LoadFromStr(metadata, path_to_metadata); + ObjectPtr n = ffi::make_object(); + n->metadata_ = TensorCacheMetadata::LoadFromStr(metadata, path_to_metadata); n->current_file_ = nullptr; n->param_info_.clear(); std::unordered_map shards = LoadShardInfoFromStr(shard_info); @@ -194,7 +193,7 @@ ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std: ShardInfo& shard_info = shards[name]; for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) { const std::string& name = shard_func.name; - if (Optional f = + if (ffi::Optional f = mod.has_value() ? (*mod)->GetFunction(name, true) : std::nullopt) { n->shard_funcs_[name] = *f; } else if (const auto f = tvm::ffi::Function::GetGlobal(name)) { @@ -209,10 +208,10 @@ ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std: return ObjectRef(std::move(n)); } -NDArray ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, - const NDArray& param) const { +Tensor ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, + const Tensor& param) const { Device device = param->device; - NDArray o = NDArray::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device); + Tensor o = Tensor::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device); ffi::Function f = this->shard_funcs_.at(shard_func.name); int n = static_cast(shard_func.params.size()); std::vector packed_args(n + 2); @@ -236,7 +235,7 @@ std::string GetSiblingPath(const std::string& path, const std::string& filename) LOG(FATAL) << "ValueError: Cannot find the parent directory: " << path; } -NDArray ShardLoaderObj::LoadParamOnWorker0(int weight_index) const { +Tensor ShardLoaderObj::LoadParamOnWorker0(int weight_index) const { DiscoWorker* worker = DiscoWorker::ThreadLocal(); int worker_id = worker->worker_id; Device device = worker->default_device; @@ -255,10 +254,10 @@ NDArray ShardLoaderObj::LoadParamOnWorker0(int weight_index) const { }; if (worker_id == 0) { - NDArray w = load(); + Tensor w = load(); return w; } else { - NDArray w = NDArray::Empty(param->shape, param->dtype, device); + Tensor w = Tensor::Empty(param->shape, param->dtype, device); return w; } } @@ -285,7 +284,7 @@ std::tuple ParseParamShardingInfo(const ParamRecord* param) { return {num_shards, worker_id}; } -NDArray ShardLoaderObj::LoadDirect(int weight_index) const { +Tensor ShardLoaderObj::LoadDirect(int weight_index) const { const ParamInfo& param_info = param_info_.at(weight_index); const ParamRecord* param = param_info.param; const FileRecord* file = param_info.file; @@ -301,7 +300,7 @@ NDArray ShardLoaderObj::LoadDirect(int weight_index) const { return param->Load(device, &this->current_file_stream_); } -NDArray ShardLoaderObj::Load(int weight_index) const { +Tensor ShardLoaderObj::Load(int weight_index) const { DiscoWorker* worker = DiscoWorker::ThreadLocal(); int worker_id = worker->worker_id; int num_shards = worker->num_workers; @@ -317,9 +316,9 @@ NDArray ShardLoaderObj::Load(int weight_index) const { << "ValueError: The first dimension of the " << "output shape must be equal to the " << "number of shards, but got: " << shape << " and num_shards = " << num_shards; - NDArray recv = NDArray::Empty(ffi::Shape(shape.begin() + 1, shape.end()), dtype, device); + Tensor recv = Tensor::Empty(ffi::Shape(shape.begin() + 1, shape.end()), dtype, device); if (worker_id == 0) { - NDArray w = LoadDirect(weight_index); + Tensor w = LoadDirect(weight_index); for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) { w = this->ApplyShardFunc(shard_func, w); } @@ -330,20 +329,20 @@ NDArray ShardLoaderObj::Load(int weight_index) const { return recv; } else { if (worker_id == 0) { - NDArray w = LoadDirect(weight_index); + Tensor w = LoadDirect(weight_index); BroadcastFromWorker0(w, /*in_group=*/false, w); return w; } else { - NDArray w = NDArray::Empty(param->shape, param->dtype, device); + Tensor w = Tensor::Empty(param->shape, param->dtype, device); BroadcastFromWorker0(w, /*in_group=*/false, w); return w; } } } -Array ShardLoaderObj::LoadAll() const { +ffi::Array ShardLoaderObj::LoadAll() const { int n = static_cast(param_info_.size()); - Array shards; + ffi::Array shards; shards.reserve(n); for (int i = 0; i < n; ++i) { std::string param_name = "param_" + std::to_string(i); @@ -354,7 +353,7 @@ Array ShardLoaderObj::LoadAll() const { return shards; } -NDArray ShardLoaderObj::LoadPresharded(int weight_index) const { +Tensor ShardLoaderObj::LoadPresharded(int weight_index) const { DiscoWorker* worker = DiscoWorker::ThreadLocal(); int worker_id = worker->worker_id; int num_shards = worker->num_workers; @@ -380,13 +379,13 @@ NDArray ShardLoaderObj::LoadPresharded(int weight_index) const { return LoadDirect(index); } -Array ShardLoaderObj::LoadAllPresharded() const { +ffi::Array ShardLoaderObj::LoadAllPresharded() const { DiscoWorker* worker = DiscoWorker::ThreadLocal(); size_t worker_id = static_cast(worker->worker_id); size_t num_workers = static_cast(worker->num_workers); size_t num_params = param_info_.size() / num_workers; - Array params; + ffi::Array params; params.reserve(num_params); for (size_t i_param = 0; i_param < num_params; ++i_param) { std::string param_name = static_cast( @@ -403,7 +402,7 @@ Array ShardLoaderObj::LoadAllPresharded() const { return params; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.disco.ShardLoader", ShardLoaderObj::Create) @@ -442,7 +441,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); return loader->LoadParamOnWorker0(param_index); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 32a194072653..fd4ad06c3fa8 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -116,7 +116,7 @@ void InitCCLPerWorker(ffi::Shape device_ids, std::string unique_id_bytes) { } } -void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { +void AllReduce(Tensor send, ReduceKind reduce_kind, bool in_group, Tensor recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ffi::Shape shape = send.Shape(); int64_t numel = shape->Product(); @@ -131,7 +131,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void AllGather(NDArray send, bool in_group, NDArray recv) { +void AllGather(Tensor send, bool in_group, Tensor recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ffi::Shape shape = send.Shape(); int64_t numel = shape->Product(); @@ -141,7 +141,7 @@ void AllGather(NDArray send, bool in_group, NDArray recv) { in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void BroadcastFromWorker0(Optional send, bool in_group, NDArray recv) { +void BroadcastFromWorker0(ffi::Optional send, bool in_group, Tensor recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int group_size = ctx->worker->num_workers / ctx->worker->num_groups; @@ -150,13 +150,13 @@ void BroadcastFromWorker0(Optional send, bool in_group, NDArray recv) { const void* send_data = [&]() -> const void* { if (is_sender) { CHECK(send.defined()); - CHECK(send.value().Shape()->Product() == recv.Shape()->Product()); + CHECK(send.value().Shape().Product() == recv.Shape().Product()); return send.value()->data; } else { return nullptr; } }(); - int64_t numel = recv.Shape()->Product(); + int64_t numel = recv.Shape().Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclBroadcast(send_data, recv->data, numel, @@ -164,7 +164,7 @@ void BroadcastFromWorker0(Optional send, bool in_group, NDArray recv) { /*root=*/0, in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { +void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) { CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; @@ -175,8 +175,8 @@ void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { deviceStream_t stream = ctx->GetDefaultStream(); if (is_sender) { CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0."; - NDArray buffer = send.value(); - int64_t numel = buffer.Shape()->Product(); + Tensor buffer = send.value(); + int64_t numel = buffer.Shape().Product(); CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly requires that the number " "of elements in the buffer to be " "divisible by the number of workers, but got numel = " @@ -184,11 +184,11 @@ void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { DataType dtype(buffer->dtype); int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); - CHECK_EQ(numel_per_shard, recv.Shape()->Product()) + CHECK_EQ(numel_per_shard, recv.Shape().Product()) << "ValueError: The number of elements in buffer `recv` must be the same as each shard " "of " "buffer `send`. `send.size` is " - << numel << ", but `recv.size` is " << recv.Shape()->Product() << "."; + << numel << ", but `recv.size` is " << recv.Shape().Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); for (int i = 0; i < num_receiver; ++i) { @@ -204,14 +204,14 @@ void ScatterFromWorker0(Optional send, bool in_group, NDArray recv) { } NCCL_CALL(ncclGroupStart()); } - int64_t numel = recv.Shape()->Product(); + int64_t numel = recv.Shape().Product(); DataType dtype(recv->dtype); NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } -void GatherToWorker0(NDArray send, bool in_group, Optional recv) { +void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { CHECK(send.defined()) << "ValueError: buffer `send` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; @@ -222,8 +222,8 @@ void GatherToWorker0(NDArray send, bool in_group, Optional recv) { deviceStream_t stream = ctx->GetDefaultStream(); if (is_sender) { CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0."; - NDArray buffer = recv.value(); - int64_t numel = buffer.Shape()->Product(); + Tensor buffer = recv.value(); + int64_t numel = buffer.Shape().Product(); CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly requires that the number " "of elements in the buffer to be " "divisible by the number of workers, but got numel = " @@ -231,11 +231,11 @@ void GatherToWorker0(NDArray send, bool in_group, Optional recv) { DataType dtype(buffer->dtype); int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); - CHECK_EQ(numel_per_shard, send.Shape()->Product()) + CHECK_EQ(numel_per_shard, send.Shape().Product()) << "ValueError: The number of elements in buffer `send` must be the same as each shard " "of " "buffer `recv`. `recv.size` is " - << numel << ", but `send.size` is " << send.Shape()->Product() << "."; + << numel << ", but `send.size` is " << send.Shape().Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); for (int i = 0; i < num_receiver; ++i) { @@ -251,25 +251,25 @@ void GatherToWorker0(NDArray send, bool in_group, Optional recv) { } NCCL_CALL(ncclGroupStart()); } - int64_t numel = send.Shape()->Product(); + int64_t numel = send.Shape().Product(); DataType dtype(send->dtype); NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } -void RecvFromWorker0(NDArray buffer) { +void RecvFromWorker0(Tensor buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); CHECK_NE(ctx->worker->worker_id, 0) << "ValueError: Worker 0 is not allowed to call RecvFromWorker0."; NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), 0, + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), 0, ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } -void SendToNextGroup(NDArray buffer) { +void SendToNextGroup(Tensor buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); int worker_id = ctx->worker->worker_id; @@ -278,12 +278,12 @@ void SendToNextGroup(NDArray buffer) { CHECK_LT(receiver_id, ctx->worker->num_workers) << "The current group is already the last group and there is no such a next group."; NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), receiver_id, ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } -void RecvFromPrevGroup(NDArray buffer) { +void RecvFromPrevGroup(Tensor buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); int worker_id = ctx->worker->worker_id; @@ -292,12 +292,12 @@ void RecvFromPrevGroup(NDArray buffer) { CHECK_GE(sender_id, 0) << "The current group is already the first group and there is no such a previous group."; NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), sender_id, ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } -void SendToWorker(NDArray buffer, int receiver_id) { +void SendToWorker(Tensor buffer, int receiver_id) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); int worker_id = ctx->worker->worker_id; @@ -305,18 +305,18 @@ void SendToWorker(NDArray buffer, int receiver_id) { << "Invalid receiver id " << receiver_id << ". The world size is " << ctx->worker->num_workers; CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself."; - NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), receiver_id, ctx->global_comm, stream)); } -void RecvFromWorker(NDArray buffer, int sender_id) { +void RecvFromWorker(Tensor buffer, int sender_id) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); int worker_id = ctx->worker->worker_id; CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers) << "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers; CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself."; - NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), sender_id, ctx->global_comm, stream)); } @@ -327,19 +327,19 @@ void SyncWorker() { StreamSynchronize(stream); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.disco.compiled_ccl", []() -> String { return TVM_DISCO_CCL_NAME; }) + .def("runtime.disco.compiled_ccl", []() -> ffi::String { return TVM_DISCO_CCL_NAME; }) .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl", InitCCL) .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker", InitCCLPerWorker) .def("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce", - [](NDArray send, int kind, bool in_group, NDArray recv) { + [](Tensor send, int kind, bool in_group, Tensor recv) { CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; nccl::AllReduce(send, static_cast(kind), in_group, recv); }) .def("runtime.disco." TVM_DISCO_CCL_NAME ".allgather", - [](NDArray send, bool in_group, NDArray recv) { nccl::AllGather(send, in_group, recv); }) + [](Tensor send, bool in_group, Tensor recv) { nccl::AllGather(send, in_group, recv); }) .def("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0", BroadcastFromWorker0) .def("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0", ScatterFromWorker0) .def("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0", GatherToWorker0) @@ -350,7 +350,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker", RecvFromWorker) .def("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker", SyncWorker) .def("runtime.disco." TVM_DISCO_CCL_NAME ".test_send_to_next_group_recv_from_prev_group", - [](NDArray buffer) { + [](Tensor buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; @@ -362,18 +362,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ tvm::runtime::nccl::RecvFromPrevGroup(buffer); } }) - .def("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0", - [](NDArray buffer) { - CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; - CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; - if (ctx->worker->worker_id == 2) { - tvm::runtime::nccl::SendToWorker(buffer, 0); - } else if (ctx->worker->worker_id == 0) { - tvm::runtime::nccl::RecvFromWorker(buffer, 2); - } - }); -}); + .def("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0", [](Tensor buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; + CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + if (ctx->worker->worker_id == 2) { + tvm::runtime::nccl::SendToWorker(buffer, 0); + } else if (ctx->worker->worker_id == 0) { + tvm::runtime::nccl::RecvFromWorker(buffer, 2); + } + }); +} } // namespace nccl } // namespace runtime diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index d901b3eae42c..c13cd9e60e9d 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -168,20 +168,18 @@ class ProcessSessionObj final : public BcastSessionObj { ffi::Function process_pool_; std::unique_ptr worker_0_; std::vector> workers_; - - static constexpr const char* _type_key = "runtime.disco.ProcessSession"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProcessSessionObj, SessionObj); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.ProcessSession", ProcessSessionObj, SessionObj); }; -Session Session::ProcessSession(int num_workers, int num_group, String process_pool_creator, - String entrypoint) { +Session Session::ProcessSession(int num_workers, int num_group, ffi::String process_pool_creator, + ffi::String entrypoint) { CHECK_EQ(num_workers % num_group, 0) << "The number of workers should be divisible by the number of worker group."; const auto pf = tvm::ffi::Function::GetGlobal(process_pool_creator); CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator << " in the registry. Please check if it is registered."; auto process_pool = (*pf)(num_workers, num_group, entrypoint).cast(); - auto n = make_object(num_workers, num_group, process_pool); + auto n = ffi::make_object(num_workers, num_group, process_pool); return Session(n); } @@ -194,12 +192,13 @@ void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_f worker.MainLoop(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); refl::GlobalDef() .def("runtime.disco.SessionProcess", Session::ProcessSession) .def("runtime.disco.WorkerProcess", WorkerProcess); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index ee6d5bf32ccc..067a4f0d4a67 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -49,7 +49,7 @@ struct DiscoProtocol { /*! \brief Recycle all the memory used in the arena */ inline void RecycleAll() { - this->object_arena_.clear(); + this->any_arena_.clear(); this->arena_.RecycleAll(); } @@ -81,27 +81,27 @@ struct DiscoProtocol { } support::Arena arena_; - std::vector object_arena_; + std::vector any_arena_; friend struct RPCReference; }; /*! * \brief The debug extension of the communication protocol that allows serialization and - * deserialization of NDArrays and reflection-capable TVM objects. + * deserialization of Tensors and reflection-capable TVM objects. */ struct DiscoDebugObject : public Object { public: /*! \brief The data to be serialized */ ffi::Any data; - /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug extension. */ + /*! \brief Wrap an Tensor or reflection-capable TVM object into the debug extension. */ static ObjectRef Wrap(const ffi::Any& data) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->data = data; return ObjectRef(n); } - /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug extension. */ + /*! \brief Wrap an Tensor or reflection-capable TVM object into the debug extension. */ static ObjectRef Wrap(const ffi::AnyView& data) { ffi::Any rv; rv = data; @@ -116,9 +116,7 @@ struct DiscoDebugObject : public Object { inline uint64_t GetFFIAnyProtocolBytes() const { return sizeof(uint64_t) + this->SaveToStr().size(); } - - static constexpr const char* _type_key = "runtime.disco.DiscoDebugObject"; - TVM_DECLARE_FINAL_OBJECT_INFO(DiscoDebugObject, SessionObj); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.DiscoDebugObject", DiscoDebugObject, SessionObj); }; template @@ -182,7 +180,7 @@ inline void DiscoProtocol::ReadFFIAny(TVMFFIAny* out) { uint32_t type_index; self->template Read(&type_index); if (type_index == TypeIndex::kRuntimeDiscoDRef) { - ObjectPtr dref = make_object(); + ObjectPtr dref = ffi::make_object(); self->template Read(&dref->reg_id); dref->session = Session{nullptr}; result = ObjectRef(std::move(dref)); @@ -191,7 +189,7 @@ inline void DiscoProtocol::ReadFFIAny(TVMFFIAny* out) { self->template Read(&size); std::string data(size, '\0'); self->template ReadArray(data.data(), size); - result = String(std::move(data)); + result = ffi::String(std::move(data)); } else if (type_index == ffi::TypeIndex::kTVMFFIBytes) { uint64_t size = 0; self->template Read(&size); @@ -215,12 +213,12 @@ inline void DiscoProtocol::ReadFFIAny(TVMFFIAny* out) { << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; } *reinterpret_cast(out) = result; - object_arena_.push_back(result); + any_arena_.push_back(result); } inline std::string DiscoDebugObject::SaveToStr() const { - if (auto opt_nd = this->data.as()) { - NDArray array = opt_nd.value(); + if (auto opt_nd = this->data.as()) { + Tensor array = opt_nd.value(); std::string result; { dmlc::MemoryStringStream mstrm(&result); @@ -247,7 +245,7 @@ inline ObjectPtr DiscoDebugObject::LoadFromStr(std::string jso ICHECK(!json_str.empty()); char control_bit = json_str.back(); json_str.pop_back(); - ObjectPtr result = make_object(); + ObjectPtr result = ffi::make_object(); if (control_bit == '0') { const auto f = tvm::ffi::Function::GetGlobal("node.LoadJSON"); CHECK(f.has_value()) << "ValueError: Cannot deserialize object in non-debugging mode"; @@ -256,7 +254,7 @@ inline ObjectPtr DiscoDebugObject::LoadFromStr(std::string jso dmlc::MemoryStringStream mstrm(&json_str); support::Base64InStream b64strm(&mstrm); b64strm.InitPosition(); - runtime::NDArray array; + runtime::Tensor array; ICHECK(array.Load(&b64strm)); result->data = std::move(array); } else { diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index 4f2ffb3d3f65..2bf132f362d5 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -30,8 +30,10 @@ struct SessionObj::FFI { } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + refl::ObjectDef(); refl::GlobalDef() .def("runtime.disco.SessionThreaded", Session::ThreadedSession) .def_method("runtime.disco.DRefDebugGetFromRemote", &DRefObj::DebugGetFromRemote) @@ -48,7 +50,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = SessionObj::FFI::CallWithPacked(self, args.Slice(1)); }) .def_method("runtime.disco.SessionShutdown", &SessionObj::Shutdown); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 7dba51e4900c..029038625faa 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -180,9 +181,8 @@ class ThreadedSessionObj final : public BcastSessionObj { ffi::PackedArgs RecvReplyPacked(int worker_id) final { return this->workers_.at(worker_id).channel->RecvReply(); } - - static constexpr const char* _type_key = "runtime.disco.ThreadedSession"; - TVM_DECLARE_FINAL_OBJECT_INFO(ThreadedSessionObj, SessionObj); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.ThreadedSession", ThreadedSessionObj, + SessionObj); std::vector workers_; }; @@ -190,9 +190,14 @@ class ThreadedSessionObj final : public BcastSessionObj { Session Session::ThreadedSession(int num_workers, int num_group) { CHECK_EQ(num_workers % num_group, 0) << "The number of workers should be divisible by the number of worker group."; - ObjectPtr n = make_object(num_workers, num_group); + ObjectPtr n = ffi::make_object(num_workers, num_group); return Session(std::move(n)); } +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/utils.h b/src/runtime/disco/utils.h index f0a10b6093d4..fb68335d8c5e 100644 --- a/src/runtime/disco/utils.h +++ b/src/runtime/disco/utils.h @@ -27,7 +27,7 @@ namespace tvm { namespace runtime { -inline Device UseDefaultDeviceIfNone(Optional device) { +inline Device UseDefaultDeviceIfNone(ffi::Optional device) { return device.value_or(DiscoWorker::ThreadLocal()->default_device); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 4564d72e5eed..b3733ee6fdff 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -196,15 +196,15 @@ void CopyFile(const std::string& src_file_name, const std::string& dest_file_nam << " dest='" << dest_file_name << "'"; } -Map LoadParams(const std::string& param_blob) { +ffi::Map LoadParams(const std::string& param_blob) { dmlc::MemoryStringStream strm(const_cast(¶m_blob)); return LoadParams(&strm); } -Map LoadParams(dmlc::Stream* strm) { - Map params; +ffi::Map LoadParams(dmlc::Stream* strm) { + ffi::Map params; uint64_t header, reserved; ICHECK(strm->Read(&header)) << "Invalid parameters file format"; - ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + ICHECK(header == kTVMTensorListMagic) << "Invalid parameters file format"; ICHECK(strm->Read(&reserved)) << "Invalid parameters file format"; std::vector names; @@ -214,15 +214,15 @@ Map LoadParams(dmlc::Stream* strm) { size_t size = static_cast(sz); ICHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { - // The data_entry is allocated on device, NDArray.load always load the array into CPU. - NDArray temp; + // The data_entry is allocated on device, Tensor.load always load the array into CPU. + Tensor temp; temp.Load(strm); params.Set(names[i], temp); } return params; } -void SaveParams(dmlc::Stream* strm, const Map& params) { +void SaveParams(dmlc::Stream* strm, const ffi::Map& params) { std::vector names; std::vector arrays; for (auto& p : params) { @@ -230,7 +230,7 @@ void SaveParams(dmlc::Stream* strm, const Map& params) { arrays.push_back(p.second.operator->()); } - uint64_t header = kTVMNDArrayListMagic, reserved = 0; + uint64_t header = kTVMTensorListMagic, reserved = 0; strm->Write(header); strm->Write(reserved); strm->Write(names); @@ -243,7 +243,7 @@ void SaveParams(dmlc::Stream* strm, const Map& params) { } } -std::string SaveParams(const Map& params) { +std::string SaveParams(const ffi::Map& params) { std::string bytes; dmlc::MemoryStringStream strm(&bytes); dmlc::Stream* fo = &strm; @@ -251,25 +251,25 @@ std::string SaveParams(const Map& params) { return bytes; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.SaveParams", - [](const Map& params) { + [](const ffi::Map& params) { std::string s = ::tvm::runtime::SaveParams(params); return ffi::Bytes(std::move(s)); }) .def("runtime.SaveParamsToFile", - [](const Map& params, const String& path) { + [](const ffi::Map& params, const ffi::String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); SaveParams(&strm, params); }) .def("runtime.LoadParams", [](const ffi::Bytes& s) { return ::tvm::runtime::LoadParams(s); }) - .def("runtime.LoadParamsFromFile", [](const String& path) { + .def("runtime.LoadParamsFromFile", [](const ffi::String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); return LoadParams(&strm); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index b4da7adea813..6f5487f7fab0 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -104,31 +104,31 @@ void CopyFile(const std::string& src_file_name, const std::string& dest_file_nam */ void RemoveFile(const std::string& file_name); -constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; +constexpr uint64_t kTVMTensorListMagic = 0xF7E58D4F05049CB7; /*! * \brief Load parameters from a string. * \param param_blob Serialized string of parameters. * \return Map of parameter name to parameter value. */ -Map LoadParams(const std::string& param_blob); +ffi::Map LoadParams(const std::string& param_blob); /*! * \brief Load parameters from a stream. * \param strm Stream to load parameters from. * \return Map of parameter name to parameter value. */ -Map LoadParams(dmlc::Stream* strm); +ffi::Map LoadParams(dmlc::Stream* strm); /*! * \brief Serialize parameters to a byte array. * \param params Parameters to save. - * \return String containing binary parameter data. + * \return ffi::String containing binary parameter data. */ -std::string SaveParams(const Map& params); +std::string SaveParams(const ffi::Map& params); /*! * \brief Serialize parameters to a stream. * \param strm Stream to write to. * \param params Parameters to save. */ -void SaveParams(dmlc::Stream* strm, const Map& params); +void SaveParams(dmlc::Stream* strm, const ffi::Map& params); /*! * \brief A dmlc stream which wraps standard file operations. diff --git a/src/runtime/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon_buffer.cc index 48afa5770afd..c6dd9421fe63 100644 --- a/src/runtime/hexagon/hexagon_buffer.cc +++ b/src/runtime/hexagon/hexagon_buffer.cc @@ -109,7 +109,7 @@ std::unique_ptr Allocator(size_t return std::make_unique(nbytes, alignment); } -HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, Optional scope) +HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, ffi::Optional scope) : ndim_(1), nbytes_per_allocation_(nbytes) { SetStorageScope(scope); @@ -125,7 +125,7 @@ HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, Optional s } HexagonBuffer::HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, - Optional scope) + ffi::Optional scope) : ndim_(2), nbytes_per_allocation_(nbytes) { SetStorageScope(scope); @@ -166,7 +166,7 @@ void* HexagonBuffer::GetPointer() { HexagonBuffer::StorageScope HexagonBuffer::GetStorageScope() const { return storage_scope_; } -void HexagonBuffer::SetStorageScope(Optional scope) { +void HexagonBuffer::SetStorageScope(ffi::Optional scope) { const std::string s = scope.value_or("global"); if (s == "global") { diff --git a/src/runtime/hexagon/hexagon_buffer.h b/src/runtime/hexagon/hexagon_buffer.h index 986d6b6e5ec6..2dd7c127e3ed 100644 --- a/src/runtime/hexagon/hexagon_buffer.h +++ b/src/runtime/hexagon/hexagon_buffer.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -49,7 +49,7 @@ class HexagonBuffer { * space in which to allocate. Defaults to global system * memory (DDR). */ - HexagonBuffer(size_t nbytes, size_t alignment, Optional scope); + HexagonBuffer(size_t nbytes, size_t alignment, ffi::Optional scope); /* \brief Allocate 2d (discontiguous) memory within Hexagon accessible * memory scopes. @@ -65,7 +65,7 @@ class HexagonBuffer { * space in which to allocate. Defaults to global system * memory (DDR). */ - HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, Optional scope); + HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, ffi::Optional scope); //! \brief Destruction deallocates the underlying allocations. ~HexagonBuffer(); @@ -140,7 +140,7 @@ class HexagonBuffer { size_t TotalBytes() const { return nbytes_per_allocation_ * allocations_.size(); } //! \brief Assign a storage scope to the buffer. - void SetStorageScope(Optional scope); + void SetStorageScope(ffi::Optional scope); /*! \brief Array of raw pointer allocations required by the buffer. * * For 1d (contiguous) storage a single allocation will result. diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index 491ded5730e6..61c7e4972ba0 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -46,19 +46,18 @@ class HexagonTimerNode : public TimerNode { virtual void Stop() { end = HAP_perf_get_time_us(); } virtual int64_t SyncAndGetElapsedNanos() { return (end - start) * 1e3; } virtual ~HexagonTimerNode() {} - - static constexpr const char* _type_key = "runtime.hexagon.HexagonTimerNode"; - TVM_DECLARE_FINAL_OBJECT_INFO(HexagonTimerNode, TimerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.hexagon.HexagonTimerNode", HexagonTimerNode, + TimerNode); private: uint64_t start, end; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.hexagon", - [](Device dev) { return Timer(make_object()); }); -}); + [](Device dev) { return Timer(ffi::make_object()); }); +} } // namespace hexagon namespace { @@ -89,14 +88,14 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } } // namespace detail -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "ffi.Module.load_from_file.hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { auto floader = tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); - *rv = floader(args[0].cast(), "so"); + *rv = floader(args[0].cast(), "so"); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index a26f113f1e9b..15ee1ed52a8b 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include @@ -52,7 +52,7 @@ void HexagonDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { // DataSpace: static allocations for Hexagon void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { + ffi::Optional mem_scope) { CHECK(shape || ndim == 0) << "shape array is null for a non-scalar tensor, ndim = " << ndim; CHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; @@ -122,7 +122,7 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignme CHECK(runtime_hexbuffs) << "Attempted to allocate Hexagon data with " << "HexagonDeviceAPI::AllocDataSpace before initializing resources. " << "Please call HexagonDeviceAPI::AcquireResources"; - return runtime_hexbuffs->AllocateHexagonBuffer(nbytes, alignment, String("global")); + return runtime_hexbuffs->AllocateHexagonBuffer(nbytes, alignment, ffi::String("global")); } void HexagonDeviceAPI::FreeDataSpace(Device dev, void* ptr) { @@ -191,7 +191,7 @@ void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("device_api.hexagon.dma_copy_dltensor", @@ -272,7 +272,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ type_hint.lanes = 1; HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); - *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, String(scope)); + *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, ffi::String(scope)); }) .def_packed("device_api.hexagon.free_nd", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -309,7 +309,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ DeviceAPI* ptr = HexagonDeviceAPI::Global(); *rv = static_cast(ptr); }); -}); +} } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/hexagon_device_api.h b/src/runtime/hexagon/hexagon_device_api.h index e77e681dd434..76439ef531ae 100644 --- a/src/runtime/hexagon/hexagon_device_api.h +++ b/src/runtime/hexagon/hexagon_device_api.h @@ -136,7 +136,7 @@ class HexagonDeviceAPI final : public DeviceAPI { * \return The allocated HexagonBuffer pointer. */ void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) final; + ffi::Optional mem_scope) final; /*! * \brief Copy data from one storage to another. diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 9db6a6680b06..5515c33e5f7d 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -42,11 +42,11 @@ HexagonModuleNode::HexagonModuleNode(std::string data, std::string fmt, std::string bc_str) : data_(data), fmt_(fmt), fmap_(fmap), asm_(asm_str), obj_(obj_str), ir_(ir_str), bc_(bc_str) {} -Optional HexagonModuleNode::GetFunction(const String& name) { +ffi::Optional HexagonModuleNode::GetFunction(const ffi::String& name) { LOG(FATAL) << "HexagonModuleNode::GetFunction is not implemented."; } -String HexagonModuleNode::InspectSource(const String& format) const { +ffi::String HexagonModuleNode::InspectSource(const ffi::String& format) const { if (format == "s" || format == "asm") { return asm_; } @@ -56,7 +56,7 @@ String HexagonModuleNode::InspectSource(const String& format) const { return ""; } -void HexagonModuleNode::WriteToFile(const String& file_name, const String& format) const { +void HexagonModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -93,7 +93,7 @@ ffi::Module HexagonModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str) { - auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); + auto n = ffi::make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); return ffi::Module(n); } diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index ae7174236622..1f99c278b28b 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -39,10 +39,10 @@ namespace runtime { * \param data The module data. * \param fmt The format of the data, can be "obj". * \param fmap The function information map of each function. - * \param asm_str String with the generated assembly source. - * \param obj_str String with the object file data. - * \param ir_str String with the disassembled LLVM IR source. - * \param bc_str String with the bitcode LLVM IR. + * \param asm_str ffi::String with the generated assembly source. + * \param obj_str ffi::String with the object file data. + * \param ir_str ffi::String with the disassembled LLVM IR source. + * \param bc_str ffi::String with the bitcode LLVM IR. */ ffi::Module HexagonModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, @@ -60,15 +60,15 @@ class HexagonModuleNode : public ffi::ModuleObj { HexagonModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str); - Optional GetFunction(const String& name) final; - String InspectSource(const String& format) const final; + ffi::Optional GetFunction(const ffi::String& name) final; + ffi::String InspectSource(const ffi::String& format) const final; const char* kind() const final { return "hexagon"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable | ffi::Module::kRunnable; } - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; protected: diff --git a/src/runtime/hexagon/hexagon_thread_manager.cc b/src/runtime/hexagon/hexagon_thread_manager.cc index 4f8ddd156b9f..a6ae62e39fa5 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.cc +++ b/src/runtime/hexagon/hexagon_thread_manager.cc @@ -140,11 +140,11 @@ void HexagonThreadManager::SpawnThreads(unsigned thread_stack_size_bytes, unsigned thread_pipe_size_words) { // allocate all stack space for threads stack_buffer_ = hexbuffs_.AllocateHexagonBuffer(thread_stack_size_bytes * nthreads_, - MEM_ALIGNMENT, String("global")); + MEM_ALIGNMENT, ffi::String("global")); // allocate space for pipe buffers (command queues) unsigned thread_pipe_size_bytes = thread_pipe_size_words * sizeof(qurt_pipe_data_t); pipe_buffer_ = hexbuffs_.AllocateHexagonBuffer(thread_pipe_size_bytes * nthreads_, MEM_ALIGNMENT, - String("global")); + ffi::String("global")); threads_.resize(nthreads_); pipes_.resize(nthreads_); diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.h b/src/runtime/hexagon/hexagon_vtcm_pool.h index ece8454b859a..d9918a873aa9 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.h +++ b/src/runtime/hexagon/hexagon_vtcm_pool.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index 31fed010a3de..55eee5df27f0 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -110,7 +110,7 @@ class HexagonTransportChannel : public RPCChannel { remote_handle64 _handle = AEE_EUNKNOWN; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -128,7 +128,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto sess = CreateClientSession(ep); *rv = CreateRPCSessionModule(sess); }); -}); +} } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index 96c45bfdf0d1..d9c2e647aea2 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -328,7 +328,7 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.hexagon.load_module", @@ -349,7 +349,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = false; } }); -}); +} void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); @@ -357,7 +357,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.rpc.server.upload", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { @@ -365,4 +365,4 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto data = args[1].cast(); SaveBinaryToFile(file_name, data); }); -}); +} diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index d511b0038f21..c3cec3039221 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -332,7 +332,7 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.hexagon.load_module", @@ -353,7 +353,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = false; } }); -}); +} void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); @@ -361,7 +361,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.rpc.server.upload", [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { @@ -369,4 +369,4 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto data = args[1].cast(); SaveBinaryToFile(file_name, data); }); -}); +} diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index 687ff6e79a16..d7a9ade7234a 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -1370,7 +1370,7 @@ std::optional SimulatorRPCChannel::to_nullptr(const detail::Mayb .Default(std::nullopt); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -1385,7 +1385,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ std::shared_ptr session = CreateClientSession(endpoint); *rv = CreateRPCSessionModule(session); }); -}); +} } // namespace hexagon } // namespace runtime diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index cef445ee91c0..db4d33be3789 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -36,7 +36,7 @@ namespace runtime { namespace memory { Storage::Storage(Buffer buffer, Allocator* allocator) { - auto n = make_object(); + auto n = ffi::make_object(); n->buffer = std::move(buffer); n->allocator = allocator; data_ = std::move(n); @@ -60,10 +60,10 @@ inline size_t GetDataAlignment(const DLDataType& dtype) { return align; } -NDArray StorageObj::AllocNDArrayScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, - String scope) { +Tensor StorageObj::AllocTensorScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, + ffi::String scope) { if (scope == "global" || scope.empty()) { - return AllocNDArray(offset, shape, dtype); + return AllocTensor(offset, shape, dtype); } VerifyDataType(dtype); @@ -71,7 +71,7 @@ NDArray StorageObj::AllocNDArrayScoped(int64_t offset, ffi::Shape shape, DLDataT public: explicit StorageScopedAlloc(Storage storage) : storage_(storage) {} - void AllocData(DLTensor* tensor, const ffi::Shape& shape, const String& scope, + void AllocData(DLTensor* tensor, const ffi::Shape& shape, const ffi::String& scope, int64_t byte_offset) { tensor->data = storage_->allocator->CreateView(storage_->buffer, shape, tensor->dtype, scope); tensor->byte_offset = byte_offset; @@ -87,11 +87,11 @@ NDArray StorageObj::AllocNDArrayScoped(int64_t offset, ffi::Shape shape, DLDataT << "storage allocation failure, attempted to allocate " << needed_size << " at offset " << offset << " in region that is " << this->buffer.size << "bytes"; - return NDArray::FromNDAlloc(StorageScopedAlloc(GetRef(this)), shape, dtype, - this->buffer.device, shape, scope, offset); + return Tensor::FromNDAlloc(StorageScopedAlloc(ffi::GetRef(this)), shape, dtype, + this->buffer.device, shape, scope, offset); } -NDArray StorageObj::AllocNDArray(int64_t offset, ffi::Shape shape, DLDataType dtype) { +Tensor StorageObj::AllocTensor(int64_t offset, ffi::Shape shape, DLDataType dtype) { VerifyDataType(dtype); size_t needed_size = ffi::GetDataSize(shape.Product(), dtype); @@ -120,8 +120,8 @@ NDArray StorageObj::AllocNDArray(int64_t offset, ffi::Shape shape, DLDataType dt Storage storage_; }; - return NDArray::FromNDAlloc(StorageAlloc(GetRef(this)), shape, dtype, - this->buffer.device, offset); + return Tensor::FromNDAlloc(StorageAlloc(ffi::GetRef(this)), shape, dtype, + this->buffer.device, offset); } MemoryManager* MemoryManager::Global() { @@ -213,8 +213,8 @@ void MemoryManager::Clear() { } } -NDArray Allocator::Empty(ffi::Shape shape, DLDataType dtype, DLDevice dev, - Optional mem_scope) { +Tensor Allocator::Empty(ffi::Shape shape, DLDataType dtype, DLDevice dev, + ffi::Optional mem_scope) { VerifyDataType(dtype); class BufferAlloc { @@ -239,7 +239,7 @@ NDArray Allocator::Empty(ffi::Shape shape, DLDataType dtype, DLDevice dev, } else { buffer = this->Alloc(dev, shape, dtype, *mem_scope); } - return NDArray::FromNDAlloc(BufferAlloc(buffer), shape, dtype, dev); + return Tensor::FromNDAlloc(BufferAlloc(buffer), shape, dtype, dev); } bool Allocator::AllowMemoryScope(const std::string& mem_scope) const { @@ -265,10 +265,10 @@ void Allocator::Clear() { // Pooled allocator will override this method. } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.memory_manager.clear", MemoryManager::Clear); -}); +} } // namespace memory } // namespace runtime diff --git a/src/runtime/memory/naive_allocator.h b/src/runtime/memory/naive_allocator.h index aed990d22c3b..6a968c86ef3b 100644 --- a/src/runtime/memory/naive_allocator.h +++ b/src/runtime/memory/naive_allocator.h @@ -67,7 +67,7 @@ class NaiveAllocator final : public Allocator { buf.size = nbytes; buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, - String(mem_scope)); + ffi::String(mem_scope)); used_memory_.fetch_add(nbytes, std::memory_order_relaxed); DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B"; buf.alloc_type = kNaive; diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index aa629aef50a7..126a7d9d90de 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -38,7 +38,7 @@ namespace tvm { namespace runtime { -inline String get_name_mangled(const String& module_name, const String& name) { +inline ffi::String get_name_mangled(const ffi::String& module_name, const ffi::String& name) { std::stringstream ss; ss << module_name << "_" << name; return ss.str(); @@ -48,6 +48,16 @@ namespace launch_param { /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +/*! \brief A tag to specify whether or not use programatic dependent launch */ +constexpr const char* kUseProgramaticDependentLaunch = "tir.use_programtic_dependent_launch"; +/*! \brief A tag to specify whether or not use cooperative launch */ +constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch"; +/*! \brief A tag to specify cluster dimension X for SM90+ cluster launch */ +constexpr const char* kClusterDimX = "tir.cluster_dim_x"; +/*! \brief A tag to specify cluster dimension Y for SM90+ cluster launch */ +constexpr const char* kClusterDimY = "tir.cluster_dim_y"; +/*! \brief A tag to specify cluster dimension Z for SM90+ cluster launch */ +constexpr const char* kClusterDimZ = "tir.cluster_dim_z"; } // namespace launch_param diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index f10489826a5a..38c8642ccabb 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -109,7 +109,7 @@ class Stream { public: explicit Stream(id device) { queue_ = [device newCommandQueue]; } ~Stream() { [queue_ release]; } - id GetCommandBuffer(std::string label = "", bool attach_error_callback = true) { + virtual id GetCommandBuffer(std::string label = "", bool attach_error_callback = true) { id cb = [queue_ commandBuffer]; if (!label.empty()) { cb.label = [NSString stringWithUTF8String:label.c_str()]; @@ -141,6 +141,19 @@ class Stream { std::string error_description_; }; +class MetalRawStream final : public Stream { +public: + explicit MetalRawStream(id commandBuffer): Stream(nullptr) { + buffer_ = commandBuffer; + } + id GetCommandBuffer(std::string label = "", bool attach_error_callback = true) override { + return buffer_; + } +private: + id buffer_; +}; + + /*! * \brief Process global Metal workspace. */ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 2a8544f6f17c..2a3cd558f2a5 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -352,7 +352,7 @@ int GetWarpSize(id dev) { MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("device_api.metal", @@ -362,7 +362,7 @@ int GetWarpSize(id dev) { }) .def("metal.ResetGlobalState", []() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); -}); +} class MetalTimerNode : public TimerNode { public: @@ -380,9 +380,7 @@ virtual void Stop() { [mtl_dev_ sampleTimestamps:&stop_cpu_time_ gpuTimestamp:&stop_gpu_time_]; } virtual int64_t SyncAndGetElapsedNanos() { return stop_gpu_time_ - start_gpu_time_; } - - static constexpr const char* _type_key = "runtime.metal.MetalTimerNode"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetalTimerNode, TimerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.metal.MetalTimerNode", MetalTimerNode, TimerNode); private: Device dev_; @@ -394,11 +392,11 @@ virtual void Stop() { MTLTimestamp stop_gpu_time_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.metal", - [](Device dev) { return Timer(make_object(dev)); }); -}); + [](Device dev) { return Timer(ffi::make_object(dev)); }); +} } // namespace metal } // namespace runtime diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 71c46504c4d4..ff0101ac9a92 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -33,6 +33,7 @@ #include "../pack_args.h" #include "../thread_storage_scope.h" #include "metal_common.h" +#include "tvm/runtime/device_api.h" namespace tvm { namespace runtime { @@ -58,9 +59,9 @@ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { LOG(FATAL) << "Do not support save to file, use save to binary and export instead"; } @@ -75,7 +76,7 @@ void WriteToFile(const String& file_name, const String& format) const final { stream->Write(fmt_); return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { // return text source if available. return source_; } @@ -200,6 +201,12 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) auto stream = metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id); + if (!(stream = dynamic_cast(metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id)))) { + // stream is not MetalRawStream + stream->SetError("Internal error: stream not from torch."); + return; + } + // skip launching so the error can be printed during sync if (stream->HasErrorHappened()) return; @@ -239,7 +246,8 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) stream->SetError(os.str()); } }]; - [cb commit]; + // When we reuse torch's command buffer, torch will sync + // [cb commit]; }; } @@ -263,14 +271,14 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) LaunchParamConfig launch_param_config_; }; -Optional MetalModuleNode::GetFunction(const String& name) { +ffi::Optional MetalModuleNode::GetFunction(const ffi::String& name) { ffi::Function ret; AUTORELEASEPOOL { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); if (it == fmap_.end()) { - return std::nullopt; + return; } const FunctionInfo& info = it->second; MetalWrappedFunc f; @@ -285,26 +293,26 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) ffi::Module MetalModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string fmt, std::string source) { - ObjectPtr n; - AUTORELEASEPOOL { n = make_object(smap, fmap, fmt, source); }; + ObjectPtr n; + AUTORELEASEPOOL { n = ffi::make_object(smap, fmap, fmt, source); }; return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "runtime.module.create_metal_module", - [](Map smap, std::string fmap_json, std::string fmt, std::string source) { - std::istringstream stream(fmap_json); - std::unordered_map fmap; - dmlc::JSONReader reader(&stream); - reader.Read(&fmap); + refl::GlobalDef().def("runtime.module.create_metal_module", + [](ffi::Map smap, std::string fmap_json, + std::string fmt, std::string source) { + std::istringstream stream(fmap_json); + std::unordered_map fmap; + dmlc::JSONReader reader(&stream); + reader.Read(&fmap); - return MetalModuleCreate( - std::unordered_map(smap.begin(), smap.end()), fmap, fmt, - source); - }); -}); + return MetalModuleCreate(std::unordered_map( + smap.begin(), smap.end()), + fmap, fmt, source); + }); +} ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes& bytes) { dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); @@ -324,9 +332,19 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) return MetalModuleCreate(smap, fmap, fmt, ""); } -TVM_FFI_STATIC_INIT_BLOCK({ +void SetMetalStream(TVMStreamHandle stream) { + metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); + auto s = new metal::MetalRawStream(static_cast>(stream)); + if (t->stream.size() <= t->device.device_id) { + t->stream.resize(t->device.device_id); + } + t->stream[t->device.device_id] = static_cast(s); +} + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes); -}); + refl::GlobalDef().def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes) + .def("metal.SetStream", SetMetalStream); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index b5f1e6995f83..8b21b2492716 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -24,6 +24,8 @@ #ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ #define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ +#include + namespace tvm { namespace ffi { // Forward declare TVM Object to use `Object*` in RPC protocol. @@ -72,7 +74,7 @@ enum class RPCCode : int { enum class RPCServerStatus : int { kSuccess = 0, kInvalidTypeCodeObject, - kInvalidTypeCodeNDArray, + kInvalidTypeCodeTensor, kInvalidDLTensorFieldStride, kInvalidDLTensorFieldByteOffset, kUnknownTypeIndex, @@ -144,8 +146,8 @@ inline const char* RPCServerStatusToString(RPCServerStatus status) { return "kSuccess"; case RPCServerStatus::kInvalidTypeCodeObject: return "kInvalidTypeCodeObject"; - case RPCServerStatus::kInvalidTypeCodeNDArray: - return "kInvalidTypeCodeNDArray"; + case RPCServerStatus::kInvalidTypeCodeTensor: + return "kInvalidTypeCodeTensor"; case RPCServerStatus::kInvalidDLTensorFieldStride: return "kInvalidDLTensorFieldStride"; case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { @@ -245,7 +247,7 @@ struct RPCReference { static void SendDLTensor(TChannelPtr channel, DLTensor* arr) { DLDevice dev; uint64_t data; - // When we return NDArray, we directly return + // When we return Tensor, we directly return // the space and the context // The client will be further wrapping dev = arr->device; @@ -255,7 +257,7 @@ struct RPCReference { channel->Write(arr->ndim); channel->Write(arr->dtype); channel->WriteArray(arr->shape, arr->ndim); - if (arr->strides != nullptr) { + if (!ffi::IsContiguous(*arr)) { channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); } channel->Write(arr->byte_offset); @@ -349,8 +351,8 @@ struct RPCReference { break; } - case ffi::TypeIndex::kTVMFFINDArray: { - channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray); + case ffi::TypeIndex::kTVMFFITensor: { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeTensor); break; } case ffi::TypeIndex::kTVMFFIDLTensorPtr: { @@ -470,7 +472,9 @@ struct RPCReference { break; } default: { - if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin || + type_index == ffi::TypeIndex::kTVMFFISmallStr || + type_index == ffi::TypeIndex::kTVMFFISmallBytes) { channel->ReadFFIAny(&(packed_args[i])); } else { channel->ThrowError(RPCServerStatus::kUnknownTypeIndex); diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 16c617ce3fcb..c782cb96c09f 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -35,7 +35,7 @@ namespace tvm { namespace runtime { -bool RuntimeEnabled(const String& target_str) { +bool RuntimeEnabled(const ffi::String& target_str) { std::string target = target_str; std::string f_name; if (target == "cpu") { @@ -72,7 +72,7 @@ bool RuntimeEnabled(const String& target_str) { TVM_FFI_CHECK_SAFE_CALL( \ TVMFFIEnvModRegisterContextSymbol("__" #FuncName, reinterpret_cast(FuncName))) -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; // Initialize the functions @@ -85,7 +85,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); refl::GlobalDef().def("runtime.RuntimeEnabled", RuntimeEnabled); -}); +} #undef TVM_INIT_CONTEXT_FUNC diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 3e0981146afc..933cd0b7a7cf 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -29,8 +29,8 @@ #include #include #include -#include #include +#include /* There are many OpenCL platforms that do not yet support OpenCL 2.0, * hence we use 1.2 APIs, some of which are now deprecated. In order @@ -341,7 +341,7 @@ class OpenCLWorkspace : public DeviceAPI { } void* AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, DLDataType dtype, - Optional mem_scope = std::nullopt); + ffi::Optional mem_scope = std::nullopt); void FreeDataSpaceView(Device dev, void* ptr); cl_device_id GetCLDeviceID(int device_id); @@ -350,22 +350,23 @@ class OpenCLWorkspace : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; void* AllocDataSpace(Device dev, size_t size, size_t alignment, DLDataType type_hint) final; void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope = std::nullopt) final; + ffi::Optional mem_scope = std::nullopt) final; void* AllocDataSpace(Device dev, size_t width, size_t height, DLDataType type_hint, - Optional mem_scope = std::nullopt); - void* GetNativePtr(const tvm::runtime::NDArray& narr); - void SetNativePtr(const tvm::runtime::NDArray& narr, void* host_ptr, size_t buf_size); + ffi::Optional mem_scope = std::nullopt); + void* GetNativePtr(const tvm::runtime::Tensor& narr); + void SetNativePtr(const tvm::runtime::Tensor& narr, void* host_ptr, size_t buf_size); void SetPerfHint(Device dev, cl_uint perf_hint); void FreeDataSpace(Device dev, void* ptr) final; void StreamSync(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; - size_t GetDataSize(const DLTensor& arr, Optional mem_scope = std::nullopt) final; + size_t GetDataSize(const DLTensor& arr, + ffi::Optional mem_scope = std::nullopt) final; // cl_mem alloc utils void* AllocCLBuffer(Device dev, size_t size, size_t alignment, DLDataType type_hint); void* AllocCLImage(Device dev, void* back_buffer, size_t width, size_t height, size_t row_pitch, - DLDataType type_hint, Optional mem_scope); + DLDataType type_hint, ffi::Optional mem_scope); /*! * \brief Get the thread local ThreadEntry @@ -436,9 +437,10 @@ struct BufferDescriptor { kImage2DNHWC, }; BufferDescriptor() = default; - explicit BufferDescriptor(Optional scope) : layout(MemoryLayoutFromScope(scope)) {} - static MemoryLayout MemoryLayoutFromScope(Optional mem_scope); - static String ScopeFromMemoryLayout(MemoryLayout mem_scope); + explicit BufferDescriptor(ffi::Optional scope) + : layout(MemoryLayoutFromScope(scope)) {} + static MemoryLayout MemoryLayoutFromScope(ffi::Optional mem_scope); + static ffi::String ScopeFromMemoryLayout(MemoryLayout mem_scope); /* clBuffer object */ // buffer should be the first element here @@ -479,7 +481,7 @@ class OpenCLModuleNodeBase : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) override; + ffi::Optional GetFunction(const ffi::String& name) override; // Initialize the programs virtual void Init() = 0; @@ -509,14 +511,14 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap, std::string source) : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; // Return true if OpenCL program for the requested function and device was created bool IsProgramCreated(const std::string& func_name, int device_id); - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; void SetPreCompiledPrograms(const std::string& bytes); std::string GetPreCompiledPrograms(); - String InspectSource(const String& format) const final; + ffi::String InspectSource(const ffi::String& format) const final; // Initialize the programs void Init() override; @@ -588,10 +590,9 @@ class OpenCLTimerNode : public TimerNode { OpenCLTimerNode() {} explicit OpenCLTimerNode(Device dev) : dev_(dev) {} - static constexpr const char* _type_key = "runtime.opencl.OpenCLTimerNode"; static size_t count_timer_execs; static std::vector event_start_idxs; - TVM_DECLARE_FINAL_OBJECT_INFO(OpenCLTimerNode, TimerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.opencl.OpenCLTimerNode", OpenCLTimerNode, TimerNode); private: int64_t duration; diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index afa4dd0b8403..8b6fba24988e 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -76,7 +76,7 @@ ImageInfo GetImageInfo(const cl::BufferDescriptor* desc, const DLTensor* tensor) } cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope( - Optional mem_scope) { + ffi::Optional mem_scope) { if (!mem_scope.has_value()) { return cl::BufferDescriptor::MemoryLayout::kBuffer1D; } else if (mem_scope.value() == "global.texture") { @@ -89,7 +89,7 @@ cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope( LOG(FATAL) << "No memory layout defined for memory of scope: " << mem_scope.value(); } -String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::MemoryLayout layout) { +ffi::String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::MemoryLayout layout) { switch (layout) { case cl::BufferDescriptor::MemoryLayout::kBuffer1D: return "global"; @@ -261,7 +261,7 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t size, size_t alignment, } void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t width, size_t height, DLDataType type_hint, - Optional mem_scope) { + ffi::Optional mem_scope) { // Texture allocation given width and height cl_uint row_align = GetImageAlignment(dev.device_id); size_t pixel_size = (type_hint.bits * type_hint.lanes + 7) / 8; @@ -278,13 +278,13 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t width, size_t height, D } if (!mem_scope.has_value()) { - mem_scope = String("global.texture"); + mem_scope = ffi::String("global.texture"); } return AllocCLImage(dev, back_buffer, width, height, row_pitch, type_hint, mem_scope); } void* OpenCLWorkspace::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { + ffi::Optional mem_scope) { this->Init(); if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { size_t size = GetMemObjectSize(dev, ndim, shape, dtype); @@ -321,7 +321,7 @@ void* OpenCLWorkspace::AllocCLBuffer(Device dev, size_t size, size_t alignment, void* OpenCLWorkspace::AllocCLImage(Device dev, void* back_buffer, size_t width, size_t height, size_t row_pitch, DLDataType type_hint, - Optional mem_scope) { + ffi::Optional mem_scope) { this->Init(); ICHECK(std::string(mem_scope.value()).find("texture") != std::string::npos) << "Expect texture scope while creating an Image object"; @@ -348,7 +348,7 @@ void* OpenCLWorkspace::AllocCLImage(Device dev, void* back_buffer, size_t width, return desc; } -size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, Optional mem_scope) { +size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, ffi::Optional mem_scope) { if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { return DeviceAPI::GetDataSize(arr); } @@ -360,7 +360,7 @@ size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, Optional mem_sc } void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, - DLDataType dtype, Optional mem_scope) { + DLDataType dtype, ffi::Optional mem_scope) { cl::BufferDescriptor* desc = static_cast(data); // Fall back for devices w/o "cl_khr_image2d_from_buffer" @@ -434,12 +434,12 @@ void OpenCLWorkspace::FreeDataSpaceView(Device dev, void* ptr) { } } -void* OpenCLWorkspace::GetNativePtr(const tvm::runtime::NDArray& narr) { +void* OpenCLWorkspace::GetNativePtr(const tvm::runtime::Tensor& narr) { cl::BufferDescriptor* desc = static_cast(narr.operator->()->data); return desc->host_ptr; } -void OpenCLWorkspace::SetNativePtr(const tvm::runtime::NDArray& narr, void* host_ptr, +void OpenCLWorkspace::SetNativePtr(const tvm::runtime::Tensor& narr, void* host_ptr, size_t buf_size) { cl::BufferDescriptor* desc = static_cast(narr.operator->()->data); @@ -630,7 +630,7 @@ std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name) { } std::string GetOpenCLVersion(cl_device_id pid) { - // String returned is "OpenCL $MAJOR.$MINOR $VENDOR_INFO". To + // ffi::String returned is "OpenCL $MAJOR.$MINOR $VENDOR_INFO". To // match other implementations, we want to return "$MAJOR.$MINOR" std::string ret = GetDeviceInfo(pid, CL_DEVICE_VERSION); @@ -761,7 +761,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic initialized_ = true; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("device_api.opencl.alloc_nd", @@ -789,7 +789,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = OpenCLWorkspace::Global()->AllocDataSpace( dev, static_cast(width), static_cast(height), type_hint, - String("global.texture")); + ffi::String("global.texture")); }) .def_packed("device_api.opencl.free_nd", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -809,13 +809,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ DeviceAPI* ptr = OpenCLWorkspace::Global(); *rv = static_cast(ptr); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.opencl", - [](Device dev) { return Timer(make_object(dev)); }); -}); + [](Device dev) { return Timer(ffi::make_object(dev)); }); +} class OpenCLPooledAllocator final : public memory::PooledAllocator { public: @@ -863,7 +863,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { buf.size = size; buf.alloc_type = AllocatorType::kPooled; buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, - String(mem_scope)); + ffi::String(mem_scope)); if (mem_scope.find("texture") == std::string::npos) { // All textures are backed by buffers - don't count in total memory used_memory_.fetch_add(size, std::memory_order_relaxed); @@ -887,7 +887,8 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { void* CreateView(const Buffer& buffer, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope) final { OpenCLWorkspace* ws_ = OpenCLWorkspace::Global(); - return ws_->AllocDataSpaceView(buffer.device, buffer.data, shape, type_hint, String(mem_scope)); + return ws_->AllocDataSpaceView(buffer.device, buffer.data, shape, type_hint, + ffi::String(mem_scope)); } void FreeView(Device dev, void* data) final { @@ -896,13 +897,13 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("DeviceAllocator.opencl", [](ffi::PackedArgs args, ffi::Any* rv) { Allocator* alloc = new OpenCLPooledAllocator(); *rv = static_cast(alloc); }); -}); +} } // namespace cl size_t OpenCLTimerNode::count_timer_execs = 0; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index a8e3b6fc20b6..3f9dadbb3af1 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -135,7 +135,7 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } -Optional OpenCLModuleNodeBase::GetFunction(const String& name) { +ffi::Optional OpenCLModuleNodeBase::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); @@ -160,7 +160,7 @@ Optional OpenCLModuleNodeBase::GetFunction(const String& name) { return PackFuncVoidAddr(f, info.arg_types); } -void OpenCLModuleNode::WriteToFile(const String& file_name, const String& format) const { +void OpenCLModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -178,7 +178,7 @@ ffi::Bytes OpenCLModuleNode::SaveToBytes() const { return ffi::Bytes(buffer); } -String OpenCLModuleNode::InspectSource(const String& format) const { +ffi::String OpenCLModuleNode::InspectSource(const ffi::String& format) const { if (format == fmt_) return data_; if (fmt_ == "cl") { return data_; @@ -349,7 +349,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { return data; } -Optional OpenCLModuleNode::GetFunction(const String& name) { +ffi::Optional OpenCLModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == "opencl.GetPreCompiledPrograms") { @@ -367,13 +367,13 @@ Optional OpenCLModuleNode::GetFunction(const String& name) { ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { - auto n = make_object(data, fmt, fmap, source); + auto n = ffi::make_object(data, fmt, fmap, source); n->Init(); return ffi::Module(n); } // Load module from module. -ffi::Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module OpenCLModuleLoadFile(const std::string& file_name, const ffi::String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -395,12 +395,12 @@ ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) { return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ffi.Module.load_from_file.cl", OpenCLModuleLoadFile) .def("ffi.Module.load_from_file.clbin", OpenCLModuleLoadFile) .def("ffi.Module.load_from_bytes.opencl", OpenCLModuleLoadFromBytes); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 5b90e0b566c7..096b05382379 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -39,9 +39,9 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap) : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; - String InspectSource(const String& format) const final { return spirv_text_; } + ffi::String InspectSource(const ffi::String& format) const final { return spirv_text_; } void Init() override; cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, @@ -52,7 +52,8 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::string spirv_text_; }; -void OpenCLSPIRVModuleNode::WriteToFile(const String& file_name, const String& format) const { +void OpenCLSPIRVModuleNode::WriteToFile(const ffi::String& file_name, + const ffi::String& format) const { // TODO(masahi): How SPIRV binaries should be save to a file? LOG(FATAL) << "Not implemented."; } @@ -132,7 +133,7 @@ cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, const std::string& spirv_text, std::unordered_map fmap) { - auto n = make_object(shaders, spirv_text, fmap); + auto n = ffi::make_object(shaders, spirv_text, fmap); n->Init(); return ffi::Module(n); } diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 8929f90b0f09..e1b1fec0a39a 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -39,6 +39,10 @@ namespace tvm { namespace runtime { + +/*! \brief TileLang Grid constant */ +constexpr unsigned int kDLGridConstant = 30U; + /*! * \brief argument union type of 32bit. */ @@ -134,7 +138,8 @@ enum ArgConvertCode { FLOAT64_TO_FLOAT32, FLOAT64_TO_FLOAT64, HANDLE_TO_HANDLE, - HANDLE_TO_TENSORMAP + HANDLE_TO_TENSORMAP, + HANDLE_TO_REFERENCE, }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { @@ -149,6 +154,8 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { if (t.bits == 32U) return FLOAT64_TO_FLOAT32; } else if (t.code == kDLOpaqueHandle) { return HANDLE_TO_HANDLE; + } else if (t.code == kDLGridConstant) { + return HANDLE_TO_REFERENCE; } LOG(FATAL) << "Cannot handle " << t << " as device function argument"; } @@ -191,6 +198,9 @@ inline ffi::Function PackFuncVoidAddr_(F f, const std::vector& c addr[i] = raw_args[i].v_ptr; break; } + case HANDLE_TO_REFERENCE: { + addr[i] = raw_args[i].v_obj; + } } } f(args, ret, addr); @@ -231,6 +241,7 @@ inline ffi::Function PackFuncNonBufferArg_(F f, int base, break; } case HANDLE_TO_HANDLE: + case HANDLE_TO_REFERENCE: case HANDLE_TO_TENSORMAP: { LOG(FATAL) << "not reached"; break; @@ -293,6 +304,7 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const std::vector(dev)); } +Timer DefaultTimer(Device dev) { return Timer(ffi::make_object(dev)); } class CPUTimerNode : public TimerNode { public: @@ -72,20 +71,18 @@ class CPUTimerNode : public TimerNode { virtual void Stop() { duration_ = std::chrono::high_resolution_clock::now() - start_; } virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } virtual ~CPUTimerNode() {} - - static constexpr const char* _type_key = "runtime.CPUTimerNode"; - TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.CPUTimerNode", CPUTimerNode, TimerNode); private: std::chrono::high_resolution_clock::time_point start_; std::chrono::duration duration_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.cpu", - [](Device dev) { return Timer(make_object()); }); -}); + [](Device dev) { return Timer(ffi::make_object()); }); +} // keep track of which timers are not defined but we have already warned about std::set seen_devices; @@ -114,20 +111,20 @@ Timer Timer::Start(Device dev) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.start_timer", Timer::Start); -}); +} namespace profiling { Profiler::Profiler(std::vector devs, std::vector metric_collectors, - std::unordered_map configuration) + std::unordered_map configuration) : devs_(devs), collectors_(metric_collectors), configuration_(configuration) { is_running_ = false; std::vector wrapped_devs; for (auto dev : devs) { - wrapped_devs.push_back(DeviceWrapper(make_object(dev))); + wrapped_devs.push_back(DeviceWrapper(ffi::make_object(dev))); } for (auto& x : collectors_) { x->Init(wrapped_devs); @@ -135,8 +132,8 @@ Profiler::Profiler(std::vector devs, std::vector metric // reset the thread pool so that PAPI eventset hooks are set in all threads. threading::ResetThreadPool(); - configuration_[String("Number of threads")] = - ObjectRef(make_object(threading::NumThreads())); + configuration_[ffi::String("Number of threads")] = + ObjectRef(ffi::make_object(threading::NumThreads())); } void Profiler::Start() { @@ -146,7 +143,7 @@ void Profiler::Start() { } } -void Profiler::StartCall(String name, Device dev, +void Profiler::StartCall(ffi::String name, Device dev, std::unordered_map extra_metrics) { std::vector> objs; for (auto& collector : collectors_) { @@ -182,7 +179,7 @@ void Profiler::Stop() { } } -std::vector ToShape(NDArray shape_tensor) { +std::vector ToShape(Tensor shape_tensor) { std::vector shape; auto rank = shape_tensor.Shape().size(); auto dtype = shape_tensor.DataType(); @@ -212,9 +209,11 @@ std::vector ToShape(NDArray shape_tensor) { return shape; } -String ShapeString(NDArray shape, DLDataType dtype) { return ShapeString(ToShape(shape), dtype); } +ffi::String ShapeString(Tensor shape, DLDataType dtype) { + return ShapeString(ToShape(shape), dtype); +} -String ShapeString(const std::vector& shape, DLDataType dtype) { +ffi::String ShapeString(const std::vector& shape, DLDataType dtype) { std::stringstream sizes; sizes << dtype << "["; for (size_t i = 0; i < shape.size(); i++) { @@ -224,12 +223,12 @@ String ShapeString(const std::vector& shape, DLDataType dtype) { sizes << shape[i]; } sizes << "]"; - return String(sizes.str()); + return ffi::String(sizes.str()); } -String ShapeString(const std::vector& shapes) { +ffi::String ShapeString(const std::vector& shapes) { std::stringstream sizes; - for (const NDArray& ary : shapes) { + for (const Tensor& ary : shapes) { if (sizes.tellp() > 0) { sizes << ", "; } @@ -243,10 +242,10 @@ String ShapeString(const std::vector& shapes) { } sizes << "]"; } - return String(sizes.str()); + return ffi::String(sizes.str()); } -String ReportNode::AsCSV() const { +ffi::String ReportNode::AsCSV() const { // get unique headers std::set unique_headers; @@ -300,7 +299,7 @@ String ReportNode::AsCSV() const { namespace { void metric_as_json(std::ostream& os, ffi::Any o) { - if (auto opt_str = o.as()) { + if (auto opt_str = o.as()) { os << "{\"string\":" << "\"" << *opt_str << "\"" << "}"; @@ -321,7 +320,7 @@ void metric_as_json(std::ostream& os, ffi::Any o) { } } // namespace -String ReportNode::AsJSON() const { +ffi::String ReportNode::AsJSON() const { std::ostringstream s; // DMLC's JSONWriter does not allow us to write a key value pair without // implementing Write for the value. We want a specific write for the value, @@ -395,29 +394,29 @@ Any AggregateMetric(const std::vector& metrics) { for (auto& metric : metrics) { sum += metric.as()->microseconds; } - return ObjectRef(make_object(sum)); + return ObjectRef(ffi::make_object(sum)); } else if (metrics[0].as()) { int64_t sum = 0; for (auto& metric : metrics) { sum += metric.as()->value; } - return ObjectRef(make_object(sum)); + return ObjectRef(ffi::make_object(sum)); } else if (metrics[0].as()) { double sum = 0; for (auto& metric : metrics) { sum += metric.as()->percent; } - return ObjectRef(make_object(sum)); + return ObjectRef(ffi::make_object(sum)); } else if (metrics[0].as()) { double sum = 0; for (auto& metric : metrics) { sum += metric.as()->ratio; } - return ObjectRef(make_object(sum / metrics.size())); + return ObjectRef(ffi::make_object(sum / metrics.size())); } else if (auto opt_str = metrics[0].as()) { for (auto& m : metrics) { if (*opt_str != m.as()) { - return String(""); + return ffi::String(""); } } // Assume all strings in metrics are the same. @@ -442,7 +441,7 @@ static void set_locale_for_separators(std::stringstream& s) { } } -static String print_metric(ffi::Any metric) { +static ffi::String print_metric(ffi::Any metric) { std::string val; if (metric.as()) { std::stringstream s; @@ -471,23 +470,23 @@ static String print_metric(ffi::Any metric) { return val; } -String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { +ffi::String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes - std::vector> aggregated_calls; + std::vector> aggregated_calls; if (aggregate) { std::unordered_map> aggregates; for (size_t i = 0; i < calls.size(); i++) { auto& frame = calls[i]; auto it = frame.find("Hash"); - std::string name = frame["Name"].cast(); + std::string name = frame["Name"].cast(); if (it != frame.end()) { - name = (*it).second.cast(); + name = (*it).second.cast(); } if (frame.find("Argument Shapes") != frame.end()) { - name += frame["Argument Shapes"].cast(); + name += frame["Argument Shapes"].cast(); } if (frame.find("Device") != frame.end()) { - name += frame["Device"].cast(); + name += frame["Device"].cast(); } if (aggregates.find(name) == aggregates.end()) { @@ -497,7 +496,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con } } for (const auto& p : aggregates) { - std::unordered_map aggregated; + std::unordered_map aggregated; std::unordered_set metrics; for (auto& call : calls) { for (auto& metric : call) { @@ -509,7 +508,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con for (auto i : p.second) { auto& call = calls[i]; auto it = std::find_if(call.begin(), call.end(), - [&metric](const std::pair& call_metric) { + [&metric](const std::pair& call_metric) { return std::string(call_metric.first) == metric; }); if (it != call.end()) { @@ -530,16 +529,17 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con // sort rows by duration if (sort) { - std::sort(aggregated_calls.begin(), aggregated_calls.end(), - [&](const Map& a, const Map& b) { - return a.at("Duration (us)").as()->microseconds > - b.at("Duration (us)").as()->microseconds; - }); + std::sort( + aggregated_calls.begin(), aggregated_calls.end(), + [&](const ffi::Map& a, const ffi::Map& b) { + return a.at("Duration (us)").as()->microseconds > + b.at("Duration (us)").as()->microseconds; + }); } // compute columnwise sums if (compute_col_sums) { - std::unordered_map col_sums; + std::unordered_map col_sums; for (auto call : aggregated_calls) { for (auto p : call) { if (p.second.as()) { @@ -548,35 +548,35 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con if (it != col_sums.end()) { val += it->second.as()->value; } - col_sums[p.first] = ObjectRef(make_object(val)); + col_sums[p.first] = ObjectRef(ffi::make_object(val)); } else if (p.second.as()) { double val = p.second.as()->microseconds; auto it = col_sums.find(p.first); if (it != col_sums.end()) { val += it->second.as()->microseconds; } - col_sums[p.first] = ObjectRef(make_object(val)); + col_sums[p.first] = ObjectRef(ffi::make_object(val)); } else if (p.second.as()) { double val = p.second.as()->percent; auto it = col_sums.find(p.first); if (it != col_sums.end()) { val += it->second.as()->percent; } - col_sums[p.first] = ObjectRef(make_object(val)); + col_sums[p.first] = ObjectRef(ffi::make_object(val)); } else if (p.second.as()) { // It does not make sense to sum ratios } } } - col_sums["Name"] = String("Sum"); - aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator + col_sums["Name"] = ffi::String("Sum"); + aggregated_calls.push_back({{ffi::String("Name"), ffi::String("----------")}}); // separator aggregated_calls.push_back(col_sums); } // per-device metrics for (auto p : device_metrics) { - Map metrics = p.second; - metrics.Set("Name", String("Total")); + ffi::Map metrics = p.second; + metrics.Set("Name", ffi::String("Total")); aggregated_calls.push_back(metrics); } @@ -660,14 +660,14 @@ std::string DeviceString(Device dev) { Report Profiler::Report() { // sync all timers and normalize rows - std::vector> rows; + std::vector> rows; for (auto& cf : calls_) { - std::unordered_map row; + std::unordered_map row; double us = cf.timer->SyncAndGetElapsedNanos() / 1e3; - row["Duration (us)"] = ObjectRef(make_object(us)); - row["Count"] = ObjectRef(make_object(1)); + row["Duration (us)"] = ObjectRef(ffi::make_object(us)); + row["Count"] = ObjectRef(ffi::make_object(1)); row["Name"] = cf.name; - row["Device"] = String(DeviceString(cf.dev)); + row["Device"] = ffi::String(DeviceString(cf.dev)); for (auto p : cf.extra_metrics) { row[p.first] = p.second; } @@ -676,23 +676,23 @@ Report Profiler::Report() { // the last frames are the overall times double overall_time_us = 0; - std::unordered_map> device_metrics; + std::unordered_map> device_metrics; for (size_t i = 0; i < devs_.size(); i++) { auto row = rows[rows.size() - 1]; rows.pop_back(); - device_metrics[row["Device"].cast()] = row; + device_metrics[row["Device"].cast()] = row; overall_time_us = std::max(overall_time_us, row["Duration (us)"].as()->microseconds); } // Calculate percentages for (auto& row : rows) { - row["Percent"] = ObjectRef(make_object( + row["Percent"] = ObjectRef(ffi::make_object( row["Duration (us)"].as()->microseconds / overall_time_us * 100)); } // convert to map - std::vector> converted_rows; + std::vector> converted_rows; for (const auto& row : rows) { converted_rows.push_back(row); } @@ -700,20 +700,20 @@ Report Profiler::Report() { return profiling::Report(converted_rows, device_metrics, configuration_); } -Report::Report(Array> calls, - Map> device_metrics, - Map configuration) { - auto node = make_object(); +Report::Report(ffi::Array> calls, + ffi::Map> device_metrics, + ffi::Map configuration) { + auto node = ffi::make_object(); node->calls = std::move(calls); node->device_metrics = std::move(device_metrics); node->configuration = std::move(configuration); data_ = std::move(node); } -Map parse_metrics(dmlc::JSONReader* reader) { +ffi::Map parse_metrics(dmlc::JSONReader* reader) { reader->BeginObject(); std::string metric_name, metric_value_name; - Map metrics; + ffi::Map metrics; while (reader->NextObjectItem(&metric_name)) { ffi::Any o; reader->BeginObject(); @@ -721,23 +721,23 @@ Map parse_metrics(dmlc::JSONReader* reader) { if (metric_value_name == "microseconds") { double microseconds; reader->Read(µseconds); - o = ObjectRef(make_object(microseconds)); + o = ObjectRef(ffi::make_object(microseconds)); } else if (metric_value_name == "percent") { double percent; reader->Read(&percent); - o = ObjectRef(make_object(percent)); + o = ObjectRef(ffi::make_object(percent)); } else if (metric_value_name == "count") { int64_t count; reader->Read(&count); - o = ObjectRef(make_object(count)); + o = ObjectRef(ffi::make_object(count)); } else if (metric_value_name == "ratio") { double ratio; reader->Read(&ratio); - o = ObjectRef(make_object(ratio)); + o = ObjectRef(ffi::make_object(ratio)); } else if (metric_value_name == "string") { std::string s; reader->Read(&s); - o = String(s); + o = ffi::String(s); } else { LOG(FATAL) << "Cannot parse metric of type " << metric_value_name << " valid types are microseconds, percent, count."; @@ -752,13 +752,13 @@ Map parse_metrics(dmlc::JSONReader* reader) { return metrics; } -Report Report::FromJSON(String json) { +Report Report::FromJSON(ffi::String json) { std::stringstream input(json.operator std::string()); dmlc::JSONReader reader(&input); std::string key; - Array> calls; - Map> device_metrics; - Map configuration; + ffi::Array> calls; + ffi::Map> device_metrics; + ffi::Map configuration; reader.BeginObject(); while (reader.NextObjectItem(&key)) { @@ -782,18 +782,22 @@ Report Report::FromJSON(String json) { return Report(calls, device_metrics, configuration); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + refl::ObjectDef(); + refl::GlobalDef() .def_method("runtime.profiling.AsTable", &ReportNode::AsTable) .def("runtime.profiling.AsCSV", [](Report n) { return n->AsCSV(); }) .def("runtime.profiling.AsJSON", [](Report n) { return n->AsJSON(); }) .def("runtime.profiling.FromJSON", Report::FromJSON) .def("runtime.profiling.DeviceWrapper", [](Device dev) { return DeviceWrapper(dev); }); -}); +} ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, - int device_id, int warmup_iters, Array collectors) { + int device_id, int warmup_iters, + ffi::Array collectors) { // Module::GetFunction is not const, so this lambda has to be mutable return ffi::Function::FromPacked([=](const ffi::AnyView* args, int32_t num_args, ffi::Any* ret) mutable { @@ -810,7 +814,7 @@ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device for (auto& collector : collectors) { collector->Init({DeviceWrapper(dev)}); } - std::vector> results; + std::vector> results; results.reserve(collectors.size()); std::vector> collector_data; collector_data.reserve(collectors.size()); @@ -828,7 +832,7 @@ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device for (auto& kv : collector_data) { results.push_back(kv.first->Stop(kv.second)); } - Map combined_results; + ffi::Map combined_results; for (auto m : results) { for (auto p : m) { // assume that there is no shared metric name between collectors @@ -839,12 +843,12 @@ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device }); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "runtime.profiling.ProfileFunction", - [](ffi::Module mod, String func_name, int device_type, int device_id, int warmup_iters, - Array collectors) { + [](ffi::Module mod, ffi::String func_name, int device_type, int device_id, int warmup_iters, + ffi::Array collectors) { if (mod->kind() == std::string("rpc")) { LOG(FATAL) << "Profiling a module over RPC is not yet supported"; // because we can't send @@ -854,7 +858,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return ProfileFunction(mod, func_name, device_type, device_id, warmup_iters, collectors); } }); -}); +} ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, @@ -871,10 +875,10 @@ ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int re pf.CallPacked(args, num_args, &temp); // allocate two large arrays to flush L2 cache - NDArray arr1, arr2; + Tensor arr1, arr2; if (cache_flush_bytes > 0) { - arr1 = NDArray::Empty({cache_flush_bytes / 4}, {kDLInt, 32, 1}, dev); - arr2 = NDArray::Empty({cache_flush_bytes / 4}, {kDLInt, 32, 1}, dev); + arr1 = Tensor::Empty({cache_flush_bytes / 4}, {kDLInt, 32, 1}, dev); + arr2 = Tensor::Empty({cache_flush_bytes / 4}, {kDLInt, 32, 1}, dev); } DeviceAPI::Get(dev)->StreamSync(dev, nullptr); @@ -921,23 +925,24 @@ ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int re return ffi::Function::FromPacked(ftimer); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.profiling.Report", - [](Array> calls, Map> device_metrics, - Map configuration) { + [](ffi::Array> calls, + ffi::Map> device_metrics, + ffi::Map configuration) { return Report(calls, device_metrics, configuration); }) .def("runtime.profiling.Count", - [](int64_t count) { return ObjectRef(make_object(count)); }) + [](int64_t count) { return ObjectRef(ffi::make_object(count)); }) .def("runtime.profiling.Percent", - [](double percent) { return ObjectRef(make_object(percent)); }) + [](double percent) { return ObjectRef(ffi::make_object(percent)); }) .def("runtime.profiling.Duration", - [](double duration) { return ObjectRef(make_object(duration)); }) + [](double duration) { return ObjectRef(ffi::make_object(duration)); }) .def("runtime.profiling.Ratio", - [](double ratio) { return ObjectRef(make_object(ratio)); }); -}); + [](double ratio) { return ObjectRef(ffi::make_object(ratio)); }); +} } // namespace profiling } // namespace runtime diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 9692b811a40c..016169653552 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -245,7 +245,7 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("device_api.rocm", @@ -257,14 +257,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ DeviceAPI* ptr = ROCMDeviceAPI::Global(); *rv = static_cast(ptr); }); -}); +} class ROCMTimerNode : public TimerNode { public: virtual void Start() { int device_id; ROCM_CALL(hipGetDevice(&device_id)); - stream_ = TVMFFIEnvGetCurrentStream(kDLROCM, device_id); + stream_ = TVMFFIEnvGetStream(kDLROCM, device_id); ROCM_CALL(hipEventRecord(start_, static_cast(stream_))); } virtual void Stop() { @@ -286,9 +286,7 @@ class ROCMTimerNode : public TimerNode { ROCM_CALL(hipEventCreate(&start_)); ROCM_CALL(hipEventCreate(&stop_)); } - - static constexpr const char* _type_key = "runtime.rocm.ROCMTimerNode"; - TVM_DECLARE_FINAL_OBJECT_INFO(ROCMTimerNode, TimerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.rocm.ROCMTimerNode", ROCMTimerNode, TimerNode); private: hipEvent_t start_; @@ -296,16 +294,17 @@ class ROCMTimerNode : public TimerNode { TVMStreamHandle stream_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("profiling.timer.rocm", [](Device dev) { return Timer(make_object()); }) + .def("profiling.timer.rocm", + [](Device dev) { return Timer(ffi::make_object()); }) .def("runtime.get_rocm_stream", []() { int device_id; ROCM_CALL(hipGetDevice(&device_id)); - return static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + return static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index f6beaca210bc..ca1a47400bc1 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -69,9 +69,9 @@ class ROCMModuleNode : public ffi::ModuleObj { int GetPropertyMask() const final { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them @@ -90,7 +90,7 @@ class ROCMModuleNode : public ffi::ModuleObj { return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (format == fmt_) { return data_; } @@ -172,7 +172,7 @@ class ROCMWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } - hipStream_t strm = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + hipStream_t strm = static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, @@ -198,7 +198,7 @@ class ROCMWrappedFunc { LaunchParamConfig launch_param_config_; }; -Optional ROCMModuleNode::GetFunction(const String& name) { +ffi::Optional ROCMModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); @@ -212,7 +212,7 @@ Optional ROCMModuleNode::GetFunction(const String& name) { ffi::Module ROCMModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string hip_source, std::string assembly) { - auto n = make_object(data, fmt, fmap, hip_source, assembly); + auto n = ffi::make_object(data, fmt, fmap, hip_source, assembly); return ffi::Module(n); } @@ -238,13 +238,13 @@ ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes& bytes) { return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ffi.Module.load_from_bytes.hsaco", ROCMModuleLoadFromBytes) .def("ffi.Module.load_from_bytes.hip", ROCMModuleLoadFromBytes) .def("ffi.Module.load_from_file.hsaco", ROCMModuleLoadFile) .def("ffi.Module.load_from_file.hip", ROCMModuleLoadFile); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index a02acd9611e3..88e01255d82a 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -45,7 +45,7 @@ class RPCDeviceAPI final : public DeviceAPI { } void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) final { + ffi::Optional mem_scope) final { auto sess = GetSess(dev); auto remote_dev = RemoveRPCSessionMask(dev); void* data = @@ -151,13 +151,13 @@ class RPCDeviceAPI final : public DeviceAPI { } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("device_api.rpc", [](ffi::PackedArgs args, ffi::Any* rv) { static RPCDeviceAPI inst; DeviceAPI* ptr = &inst; *rv = static_cast(ptr); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index e1282c17878a..0778b5539474 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -171,6 +171,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { for (int i = 0; i < args.size(); ++i) { if (args[i] == nullptr) continue; if (args[i].type_index() == ffi::TypeIndex::kTVMFFIModule) continue; + if (args[i].type_index() == ffi::TypeIndex::kTVMFFISmallStr || + args[i].type_index() == ffi::TypeIndex::kTVMFFISmallBytes) + continue; + if (args[i].type_index() == ffi::TypeIndex::kTVMFFIStr || + args[i].type_index() == ffi::TypeIndex::kTVMFFIBytes) + continue; if (const Object* obj = args[i].as()) { if (!obj->IsInstance()) { LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " << obj->GetTypeKey() @@ -221,14 +227,20 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { void WriteFFIAny(const TVMFFIAny* in) { // NOTE: for now all remote object are encoded as RPCObjectRef // follow the same disco protocol in case we would like to upgrade later - // - // Rationale note: Only handle remote object allows the same mechanism to work for minRPC - // which is needed for wasm and other env that goes through C API + // TODO(tqchen): consider merge with disco protocol const AnyView* any_view_ptr = reinterpret_cast(in); if (const auto* ref = any_view_ptr->as()) { this->template Write(runtime::TypeIndex::kRuntimeRPCObjectRef); uint64_t handle = reinterpret_cast(ref->object_handle()); this->template Write(handle); + } else if (auto opt_str = any_view_ptr->as()) { + this->template Write(ffi::TypeIndex::kTVMFFIStr); + this->template Write((*opt_str).size()); + this->template WriteArray((*opt_str).data(), (*opt_str).size()); + } else if (auto opt_bytes = any_view_ptr->as()) { + this->template Write(ffi::TypeIndex::kTVMFFIBytes); + this->template Write((*opt_bytes).size()); + this->template WriteArray((*opt_bytes).data(), (*opt_bytes).size()); } else { LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() @@ -239,6 +251,10 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { const AnyView* any_view_ptr = reinterpret_cast(in); if (any_view_ptr->as()) { return sizeof(uint32_t) + sizeof(int64_t); + } else if (auto opt_str = any_view_ptr->as()) { + return sizeof(uint32_t) + sizeof(uint64_t) + (*opt_str).size(); + } else if (auto opt_bytes = any_view_ptr->as()) { + return sizeof(uint32_t) + sizeof(uint64_t) + (*opt_bytes).size(); } else { LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() @@ -261,11 +277,28 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // Always wrap things back in RPCObjectRef // this is because we want to enable multi-hop RPC // and next hop would also need to check the object index - RPCObjectRef rpc_obj(make_object(reinterpret_cast(handle), nullptr)); + RPCObjectRef rpc_obj( + ffi::make_object(reinterpret_cast(handle), nullptr)); // Legacy ABI translation // TODO(tqchen): remove this once we have upgraded to new ABI *reinterpret_cast(out) = rpc_obj; - object_arena_.push_back(rpc_obj); + any_arena_.emplace_back(rpc_obj); + } else if (type_index == ffi::TypeIndex::kTVMFFIStr) { + uint64_t size; + this->template Read(&size); + std::string data(size, '\0'); + this->template ReadArray(data.data(), size); + ffi::String ret(std::move(data)); + *reinterpret_cast(out) = ret; + any_arena_.emplace_back(ret); + } else if (type_index == ffi::TypeIndex::kTVMFFIBytes) { + uint64_t size; + this->template Read(&size); + std::string data(size, '\0'); + this->template ReadArray(data.data(), size); + ffi::Bytes ret(std::move(data)); + *reinterpret_cast(out) = ret; + any_arena_.emplace_back(ret); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; @@ -284,7 +317,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { /*! \brief Recycle all the memory used in the arena */ void RecycleAll() { - this->object_arena_.clear(); + this->any_arena_.clear(); this->arena_.RecycleAll(); } @@ -309,7 +342,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // Internal arena support::Arena arena_; // internal arena for temp objects - std::vector object_arena_; + std::vector any_arena_; // State switcher void SwitchToState(State state) { @@ -433,7 +466,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { if (code == RPCCode::kException) { // switch to the state before sending exception. this->SwitchToState(kRecvPacketNumBytes); - String msg = args[0].cast(); + ffi::String msg = args[0].cast(); if (!support::StartsWith(msg, "RPCSessionTimeoutError: ")) { msg = "RPCError: Error caught from RPC call:\n" + msg; } @@ -962,7 +995,7 @@ void RPCDevAllocDataWithScope(RPCSession* handler, ffi::PackedArgs args, ffi::An int ndim = arr->ndim; int64_t* shape = arr->shape; DLDataType dtype = arr->dtype; - auto mem_scope = args[1].cast>(); + auto mem_scope = args[1].cast>(); void* data = handler->GetDeviceAPI(dev)->AllocDataSpace(dev, ndim, shape, dtype, mem_scope); *rv = data; } @@ -1154,7 +1187,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { } void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) final { + ffi::Optional mem_scope) final { DLTensor temp; temp.data = nullptr; temp.device = dev; diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 195adef053bd..9438470cb215 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -78,8 +78,8 @@ class RPCEndpoint { * Shutdown has no effect if the connection has already been shut down. * Shutdown will wait for all output currently queued from the RPC connection (i.e. The user * doesn't need to wait for completion before calling Shutdown.) Any further use of objects that - * depended on the endpoint (e.g. A tvm.nd.array allocated on the remote RPC session) may throw an - * exception when used. + * depended on the endpoint (e.g. A tvm.runtime.tensor allocated on the remote RPC session) may + * throw an exception when used. */ void Shutdown(); diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index abf635020afe..4eefb2b2b978 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -45,9 +45,9 @@ ffi::Function CreateEventDrivenServer(ffi::Function fsend, std::string name, }); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("rpc.CreateEventDrivenServer", CreateEventDrivenServer); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 3d4928f8b43a..2cfeacfcd71f 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include @@ -54,13 +54,13 @@ void LocalSession::EncodeReturn(ffi::Any rv, const FEncodeReturn& encode_return) if (rv == nullptr) { packed_args[1] = rv; encode_return(ffi::PackedArgs(packed_args, 2)); - } else if (rv.as()) { - // We follow a special protocol to return NDArray to client side - // The first pack value is the NDArray handle as DLTensor - // The second pack value is a customized deleter that deletes the NDArray. + } else if (rv.as()) { + // We follow a special protocol to return Tensor to client side + // The first pack value is the Tensor handle as DLTensor + // The second pack value is a customized deleter that deletes the Tensor. TVMFFIAny ret_any = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv)); void* opaque_handle = ret_any.v_obj; - packed_args[1] = TVMFFINDArrayGetDLTensorPtr(opaque_handle); + packed_args[1] = TVMFFITensorGetDLTensorPtr(opaque_handle); packed_args[2] = opaque_handle; encode_return(ffi::PackedArgs(packed_args, 3)); } else if (const auto opt_bytes = rv.as()) { @@ -149,11 +149,11 @@ DeviceAPI* LocalSession::GetDeviceAPI(Device dev, bool allow_missing) { return DeviceAPI::Get(dev, allow_missing); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("rpc.LocalSession", []() { return CreateRPCSessionModule(std::make_shared()); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index bcf661960f06..a90c69c63c8b 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -41,18 +41,18 @@ namespace tvm { namespace runtime { /*! - * \brief Build a local NDArray with remote backing storage. + * \brief Build a local Tensor with remote backing storage. * \param sess the RPCSession which owns the given handle. * \param handle A pointer valid on the remote end which should form the `data` field of the * underlying DLTensor. * \param template_tensor An empty DLTensor whose shape and dtype fields are used to fill the newly * created array. Needed because it's difficult to pass a shape vector as a ffi::Function arg. * \param dev Remote device used with this tensor. Must have non-zero RPCSessMask. - * \param remote_ndarray_handle The handle returned by RPC server to identify the NDArray. + * \param remote_tensor_handle The handle returned by RPC server to identify the Tensor. */ -NDArray NDArrayFromRemoteOpaqueHandle(std::shared_ptr sess, void* handle, - DLTensor* template_tensor, Device dev, - void* remote_ndarray_handle) { +Tensor TensorFromRemoteOpaqueHandle(std::shared_ptr sess, void* handle, + DLTensor* template_tensor, Device dev, + void* remote_tensor_handle) { ICHECK_EQ(sess->table_index(), GetRPCSessionIndex(dev)) << "The Device given does not belong to the given session"; class RemoteSpaceAlloc { @@ -71,7 +71,7 @@ NDArray NDArrayFromRemoteOpaqueHandle(std::shared_ptr sess, void* ha space.sess = sess; space.data = handle; ffi::Shape shape(template_tensor->shape, template_tensor->shape + template_tensor->ndim); - return NDArray::FromNDAlloc(RemoteSpaceAlloc(space), shape, template_tensor->dtype, dev); + return Tensor::FromNDAlloc(RemoteSpaceAlloc(space), shape, template_tensor->dtype, dev); } /*! @@ -104,9 +104,9 @@ class RPCWrappedFunc : public Object { // run a remote translation to translate RPC related objects to // their remote counterparts. switch (args[i].type_index()) { - case ffi::TypeIndex::kTVMFFINDArray: { - // Pass NDArray as DLTensor - auto dptr = std::make_unique(*args[i].cast().operator->()); + case ffi::TypeIndex::kTVMFFITensor: { + // Pass Tensor as DLTensor + auto dptr = std::make_unique(*args[i].cast().operator->()); dptr->device = RemoveSessMask(dptr->device); dptr->data = static_cast(dptr->data)->data; packed_args[i] = dptr.get(); @@ -190,7 +190,7 @@ class RPCModuleNode final : public ffi::ModuleObj { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::ModulePropertyMask::kRunnable; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { if (name == "CloseRPCConnection") { return ffi::Function([this](ffi::PackedArgs, ffi::Any*) { sess_->Shutdown(); }); } @@ -199,7 +199,7 @@ class RPCModuleNode final : public ffi::ModuleObj { return WrapRemoteFunc(sess_->GetFunction(name)); } else { InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); - return remote_mod_get_function_(GetRef(this), name, true); + return remote_mod_get_function_(ffi::GetRef(this), name, true); } } @@ -215,12 +215,12 @@ class RPCModuleNode final : public ffi::ModuleObj { if (module_handle_ != nullptr) { return remote_get_time_evaluator_( - GetRef(this), name, static_cast(dev.device_type), dev.device_id, number, - repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, + ffi::GetRef(this), name, static_cast(dev.device_type), dev.device_id, + number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } else { return remote_get_time_evaluator_( - Optional(std::nullopt), name, static_cast(dev.device_type), + ffi::Optional(std::nullopt), name, static_cast(dev.device_type), dev.device_id, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } @@ -231,9 +231,9 @@ class RPCModuleNode final : public ffi::ModuleObj { return remote_load_module_(name); } - void ImportModule(ffi::Module other) { + void ImportModule(const ffi::Module& other) final { InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); - remote_import_module_(GetRef(this), other); + remote_import_module_(ffi::GetRef(this), other); } const std::shared_ptr& sess() { return sess_; } @@ -261,8 +261,8 @@ class RPCModuleNode final : public ffi::ModuleObj { // The local channel std::shared_ptr sess_; // remote function to get time evaluator - ffi::TypedFunction, std::string, int, int, int, int, int, int, - int, int, int, std::string)> + ffi::TypedFunction, std::string, int, int, int, int, int, + int, int, int, int, std::string)> remote_get_time_evaluator_; // remote function getter for modules. ffi::TypedFunction remote_mod_get_function_; @@ -303,16 +303,16 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } else if (type_index == ffi::TypeIndex::kTVMFFIModule) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); - auto n = make_object(handle, sess_); + auto n = ffi::make_object(handle, sess_); *rv = ffi::Module(n); - } else if (type_index == ffi::TypeIndex::kTVMFFINDArray || + } else if (type_index == ffi::TypeIndex::kTVMFFITensor || type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr) { ICHECK_EQ(args.size(), 3); auto tensor = args[1].cast(); void* nd_handle = args[2].cast(); - *rv = NDArrayFromRemoteOpaqueHandle(sess_, tensor->data, tensor, - AddRPCSessionMask(tensor->device, sess_->table_index()), - nd_handle); + *rv = TensorFromRemoteOpaqueHandle(sess_, tensor->data, tensor, + AddRPCSessionMask(tensor->device, sess_->table_index()), + nd_handle); } else if (type_index == ffi::TypeIndex::kTVMFFIBytes || type_index == ffi::TypeIndex::kTVMFFIStr || type_index == ffi::TypeIndex::kTVMFFISmallStr || @@ -322,7 +322,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); - auto n = make_object(handle, sess_); + auto n = ffi::make_object(handle, sess_); *rv = ObjectRef(n); } else { ICHECK_EQ(args.size(), 2); @@ -331,7 +331,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } ffi::Module CreateRPCSessionModule(std::shared_ptr sess) { - auto n = make_object(nullptr, sess); + auto n = ffi::make_object(nullptr, sess); RPCSession::InsertToSessionTable(sess); return ffi::Module(n); } @@ -393,11 +393,11 @@ inline void CPUCacheFlush(int begin_index, const ffi::PackedArgs& args) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.RPCTimeEvaluator", - [](Optional opt_mod, std::string name, int device_type, int device_id, + [](ffi::Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, std::string f_preproc_name) { @@ -420,7 +420,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - Optional pf = m->GetFunction(name); + ffi::Optional pf = m->GetFunction(name); CHECK(pf.has_value()) << "Cannot find " << name << "` in the global registry"; return profiling::WrapTimeEvaluator( *pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, @@ -443,10 +443,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("cache_flush_cpu_non_first_arg", [](ffi::PackedArgs args, ffi::Any* rv) { CPUCacheFlush(1, args); }); -}); +} // server function registration. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tvm.rpc.server.ImportModule", @@ -455,10 +455,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ffi::Module parent, std::string name, bool query_imports) { return parent->GetFunction(name, query_imports); }); -}); +} // functions to access an RPC module. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("rpc.LoadRemoteModule", @@ -480,13 +480,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK_EQ(tkey, "rpc"); *rv = static_cast(m.operator->())->sess()->table_index(); }) - .def("tvm.rpc.NDArrayFromRemoteOpaqueHandle", + .def("tvm.rpc.TensorFromRemoteOpaqueHandle", [](ffi::Module mod, void* remote_array, DLTensor* template_tensor, Device dev, - void* ndarray_handle) -> NDArray { - return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, - template_tensor, dev, ndarray_handle); + void* tensor_handle) -> Tensor { + return TensorFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, + template_tensor, dev, tensor_handle); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index 22619289d053..0bc608ccc253 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -113,7 +113,7 @@ ffi::Module CreatePipeClient(std::vector cmd) { return CreateRPCSessionModule(CreateClientSession(endpt)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("rpc.CreatePipeClient", [](ffi::PackedArgs args, ffi::Any* rv) { std::vector cmd; @@ -122,7 +122,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } *rv = CreatePipeClient(cmd); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index 52d04a72631f..c8e7a4ee81c9 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -36,7 +36,7 @@ std::string RPCGetPath(const std::string& name) { return (*f)(name).cast(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvm.rpc.server.upload", @@ -57,7 +57,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ std::string file_name = RPCGetPath(args[0].cast()); RemoveFile(file_name); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index ace9cf9b9485..1fee1424ea22 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -24,6 +24,7 @@ #include "rpc_session.h" #include +#include #include #include @@ -127,5 +128,7 @@ void RPCSession::InsertToSessionTable(std::shared_ptr sess) { sess->table_index_ = RPCSessTable::Global()->Insert(sess); } +TVM_FFI_STATIC_INIT_BLOCK() { tvm::ffi::reflection::ObjectDef(); } + } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index c0e09ec004ba..d7f629a0254f 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -55,8 +55,8 @@ class RPCSession { /*! \brief Module handle in the remote. */ using ModuleHandle = void*; - /*! \brief NDArray handle in the remote. */ - using NDArrayHandle = void*; + /*! \brief Tensor handle in the remote. */ + using TensorHandle = void*; /*! * \brief Callback to send an encoded return values via encode_args. @@ -66,7 +66,7 @@ class RPCSession { * Encoding convention (as list of arguments): * - str/float/int/byte: [tcode: int, value: TVMValue] value follows ffi::Function convention. * - ffi::Function/Module: [tcode: int, handle: void*] - * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] + * - Tensor: [tcode: int, meta: DLTensor*, nd_handle: void*] * DLTensor* contains the meta-data as well as handle into the remote data. * nd_handle can be used for deletion. */ @@ -98,7 +98,7 @@ class RPCSession { * - type_code is follows the ffi::Function convention. * - int/float/string/bytes follows the ffi::Function convention, all data are local. * - ffi::Function/Module and future remote objects: pass remote handle instead. - * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor + * - Tensor/DLTensor: pass a DLTensor pointer, the data field of DLTensor * points to a remote data handle returned by the Device API. * The meta-data of the DLTensor sits on local. * @@ -109,8 +109,8 @@ class RPCSession { * * The callee need to store the return value into ret_value. * - ffi::Function/Module are stored as void* - * - NDArray is stored as local NDArray, whose data field is a remote handle. - * Notably the NDArray's deleter won't delete remote handle. + * - Tensor is stored as local Tensor, whose data field is a remote handle. + * Notably the Tensor's deleter won't delete remote handle. * It is up to the user of the RPCSession to such wrapping. * - In short, remote handles are "moved" as return values * and the callee needs to explicitly manage them by calling @@ -315,9 +315,8 @@ class RPCObjectRefObj : public Object { void* object_handle() const { return object_handle_; } static constexpr const uint32_t _type_index = TypeIndex::kRuntimeRPCObjectRef; - static constexpr const char* _type_key = "runtime.RPCObjectRef"; static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(RPCObjectRefObj, Object); + TVM_FFI_DECLARE_OBJECT_INFO_STATIC("runtime.RPCObjectRef", RPCObjectRefObj, Object); private: // The object handle @@ -333,7 +332,10 @@ class RPCObjectRefObj : public Object { */ class RPCObjectRef : public ObjectRef { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, RPCObjectRefObj); + explicit RPCObjectRef(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RPCObjectRef, ObjectRef, RPCObjectRefObj); }; /*! diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index d2f141ee21e0..c19b91801e77 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -122,7 +122,7 @@ void RPCServerLoop(ffi::Function fsend, ffi::Function frecv) { ->ServerLoop(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("rpc.Connect", @@ -140,7 +140,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ RPCServerLoop(args[0].cast(), args[1].cast()); } }); -}); +} class SimpleSockHandler : public dmlc::Stream { // Things that will interface with user directly. @@ -167,14 +167,14 @@ class SimpleSockHandler : public dmlc::Stream { support::TCPSocket sock_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("rpc.ReturnException", [](int sockfd, String msg) { + refl::GlobalDef().def("rpc.ReturnException", [](int sockfd, ffi::String msg) { auto handler = SimpleSockHandler(sockfd); RPCReference::ReturnException(msg.c_str(), &handler); return; }); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index b816fb600e1e..2cf7d3394599 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -47,7 +47,7 @@ class StaticLibraryNode final : public ffi::ModuleObj { public: const char* kind() const final { return "static_library"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { const ObjectPtr& sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_func_names") { return ffi::Function( @@ -65,13 +65,13 @@ class StaticLibraryNode final : public ffi::ModuleObj { std::vector func_names; for (const auto func_name : func_names_) func_names.push_back(func_name); stream->Write(func_names); - return Bytes(buffer); + return ffi::Bytes(buffer); } static ffi::Module LoadFromBytes(ffi::Bytes bytes) { dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); dmlc::Stream* stream = &ms; - auto n = make_object(); + auto n = ffi::make_object(); // load data std::string data; ICHECK(stream->Read(&data)) << "Loading data failed"; @@ -80,12 +80,12 @@ class StaticLibraryNode final : public ffi::ModuleObj { // load func names std::vector func_names; ICHECK(stream->Read(&func_names)) << "Loading func names failed"; - for (auto func_name : func_names) n->func_names_.push_back(String(func_name)); + for (auto func_name : func_names) n->func_names_.push_back(ffi::String(func_name)); return ffi::Module(n); } - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames() << " to '" << file_name << "'"; SaveBinaryToFile(file_name, data_); @@ -96,7 +96,7 @@ class StaticLibraryNode final : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name) final { + bool ImplementsFunction(const ffi::String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } @@ -119,25 +119,25 @@ class StaticLibraryNode final : public ffi::ModuleObj { /*! \brief Contents of the object file. */ std::string data_; /*! \brief Function names exported by the above. */ - Array func_names_; + ffi::Array func_names_; }; } // namespace -ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names) { - auto node = make_object(); +ffi::Module LoadStaticLibrary(const std::string& filename, ffi::Array func_names) { + auto node = ffi::make_object(); LoadBinaryFromFile(filename, &node->data_); node->func_names_ = std::move(func_names); VLOG(0) << "Loaded static library from '" << filename << "' implementing " << node->FuncNames(); return ffi::Module(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.ModuleLoadStaticLibrary", LoadStaticLibrary) .def("ffi.Module.load_from_bytes.static_library", StaticLibraryNode::LoadFromBytes); -}); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h index 8a5600fc0588..2ebca2edd277 100644 --- a/src/runtime/static_library.h +++ b/src/runtime/static_library.h @@ -43,7 +43,7 @@ namespace runtime { * \brief Returns a static library with the contents loaded from filename which exports * func_names with the usual packed-func calling convention. */ -ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names); +ffi::Module LoadStaticLibrary(const std::string& filename, ffi::Array func_names); } // namespace runtime } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/tensor.cc similarity index 67% rename from src/runtime/ndarray.cc rename to src/runtime/tensor.cc index 115d55c8f4e7..4ef744452c3c 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/tensor.cc @@ -18,15 +18,15 @@ */ /*! - * \file ndarray.cc - * \brief NDArray container infratructure. + * \file tensor.cc + * \brief Tensor container infratructure. */ #include #include #include #include #include -#include +#include #include "tvm/runtime/data_type.h" @@ -59,10 +59,10 @@ inline void VerifyDataType(DLDataType dtype) { ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); } -void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { +void TensorCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { size_t arr_size = GetDataSize(*handle); - ICHECK_EQ(arr_size, nbytes) << "ArrayCopyFromBytes: size mismatch"; - ICHECK(IsContiguous(*handle)) << "ArrayCopyFromBytes only support contiguous array for now"; + ICHECK_EQ(arr_size, nbytes) << "TensorCopyFromBytes: size mismatch"; + ICHECK(IsContiguous(*handle)) << "TensorCopyFromBytes only support contiguous array for now"; DLTensor from; from.data = const_cast(data); @@ -77,8 +77,8 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { DeviceAPI::Get(handle->device)->StreamSync(handle->device, nullptr); } -void NDArray::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, - TVMStreamHandle stream) { +void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, + TVMStreamHandle stream) { size_t arr_size = GetDataSize(*handle); ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; @@ -97,7 +97,28 @@ void NDArray::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } -NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, Optional mem_scope) { +void Tensor::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, + TVMStreamHandle stream) { + size_t arr_size = GetDataSize(*handle); + ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; + + DLTensor from; + from.data = const_cast(data); + from.device = Device{kDLCPU, 0}; + from.ndim = handle->ndim; + from.dtype = handle->dtype; + from.shape = handle->shape; + from.strides = nullptr; + from.byte_offset = 0; + + DeviceAPI::Get(handle->device)->CopyDataFromTo(&from, const_cast(handle), stream); + // Synchronize in case data become unavailable later. + DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); +} + +Tensor Tensor::Empty(ffi::Shape shape, DLDataType dtype, Device dev, + ffi::Optional mem_scope) { struct DeviceAPIAlloc { void AllocData(DLTensor* tensor, ffi::Optional mem_scope) { tensor->data = DeviceAPI::Get(tensor->device) @@ -108,11 +129,10 @@ NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, Optional< DeviceAPI::Get(tensor->device)->FreeDataSpace(tensor->device, tensor->data); } }; - return ffi::NDArray::FromNDAlloc(DeviceAPIAlloc(), shape, dtype, dev, mem_scope); + return ffi::Tensor::FromNDAlloc(DeviceAPIAlloc(), shape, dtype, dev, mem_scope); } -NDArray NDArray::CreateView(ffi::Shape shape, DLDataType dtype, - uint64_t relative_byte_offset) const { +Tensor Tensor::CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_byte_offset) const { ICHECK(data_ != nullptr); const DLTensor& orig = *get_mutable(); @@ -145,14 +165,14 @@ NDArray NDArray::CreateView(ffi::Shape shape, DLDataType dtype, << view_size << " bytes. " << "This would occupy bytes " << relative_byte_offset << " <= i_byte < " << (relative_byte_offset + view_size) << " within the backing array. " - << "However, the NDArray being viewed only contains " << curr_size << " bytes (shape = " + << "However, the Tensor being viewed only contains " << curr_size << " bytes (shape = " << ffi::Shape(curr_dl_tensor.shape, curr_dl_tensor.shape + curr_dl_tensor.ndim) << ", dtype= " << curr_dl_tensor.dtype << ")."; - // helper allocator class that retains ref count of original NDArray + // helper allocator class that retains ref count of original Tensor class ViewBasedAlloc { public: - explicit ViewBasedAlloc(NDArray source) : source_(source) {} + explicit ViewBasedAlloc(Tensor source) : source_(source) {} void AllocData(DLTensor* tensor, int64_t byte_offset) { tensor->data = source_.get_mutable()->data; tensor->byte_offset = byte_offset; @@ -161,30 +181,30 @@ NDArray NDArray::CreateView(ffi::Shape shape, DLDataType dtype, void FreeData(DLTensor* tensor) {} private: - NDArray source_; + Tensor source_; }; - NDArray ret = NDArray::FromNDAlloc(ViewBasedAlloc(NDArray(*this)), shape, dtype, (*this)->device, - curr_dl_tensor.byte_offset + relative_byte_offset); + Tensor ret = Tensor::FromNDAlloc(ViewBasedAlloc(Tensor(*this)), shape, dtype, (*this)->device, + curr_dl_tensor.byte_offset + relative_byte_offset); return ret; } -void NDArray::CopyToBytes(void* data, size_t nbytes) const { +void Tensor::CopyToBytes(void* data, size_t nbytes) const { ICHECK(data != nullptr); ICHECK(data_ != nullptr); - NDArray::CopyToBytes(get_mutable(), data, nbytes); + Tensor::CopyToBytes(get_mutable(), data, nbytes); } -void NDArray::CopyFromBytes(const void* data, size_t nbytes) { +void Tensor::CopyFromBytes(const void* data, size_t nbytes) { ICHECK(data != nullptr); ICHECK(data_ != nullptr); - ArrayCopyFromBytes(get_mutable(), data, nbytes); + TensorCopyFromBytes(get_mutable(), data, nbytes); } -NDArray NDArray::CopyTo(const Device& dev, Optional mem_scope) const { +Tensor Tensor::CopyTo(const Device& dev, ffi::Optional mem_scope) const { ICHECK(data_ != nullptr); const DLTensor* dptr = operator->(); - NDArray ret = + Tensor ret = Empty(ffi::Shape(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev, mem_scope); this->CopyTo(ret); Device copy_gpu_dev = dptr->device.device_type != kDLCPU ? dptr->device : dev; @@ -192,10 +212,10 @@ NDArray NDArray::CopyTo(const Device& dev, Optional mem_scope) const { return ret; } -void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { +void Tensor::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { size_t from_size = GetDataSize(*from); size_t to_size = GetDataSize(*to); - ICHECK_EQ(from_size, to_size) << "TVMArrayCopyFromTo: The size in bytes must exactly match."; + ICHECK_EQ(from_size, to_size) << "TVMTensorCopyFromTo: The size in bytes must exactly match."; ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU || to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost || @@ -216,16 +236,15 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str using namespace tvm::runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.TVMArrayAllocWithScope", NDArray::Empty) - .def_method("runtime.TVMArrayCreateView", &NDArray::CreateView) - .def("runtime.TVMArrayCopyFromBytes", - [](DLTensor* arr, void* data, size_t nbytes) { ArrayCopyFromBytes(arr, data, nbytes); }) - .def( - "runtime.TVMArrayCopyToBytes", - [](DLTensor* arr, void* data, size_t nbytes) { NDArray::CopyToBytes(arr, data, nbytes); }) - .def("runtime.TVMArrayCopyFromTo", - [](DLTensor* from, DLTensor* to) { NDArray::CopyFromTo(from, to); }); -}); + .def("runtime.TVMTensorAllocWithScope", Tensor::Empty) + .def_method("runtime.TVMTensorCreateView", &Tensor::CreateView) + .def("runtime.TVMTensorCopyFromBytes", + [](DLTensor* arr, void* data, size_t nbytes) { TensorCopyFromBytes(arr, data, nbytes); }) + .def("runtime.TVMTensorCopyToBytes", + [](DLTensor* arr, void* data, size_t nbytes) { Tensor::CopyToBytes(arr, data, nbytes); }) + .def("runtime.TVMTensorCopyFromTo", + [](DLTensor* from, DLTensor* to) { Tensor::CopyFromTo(from, to); }); +} diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index deaeec6ad3a0..132369e9b427 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -141,7 +141,7 @@ class ParallelLauncher { // The counter page. std::atomic* sync_counter_{nullptr}; // The error message - std::vector> par_errors_; + std::vector> par_errors_; }; /*! \brief Lock-free single-producer-single-consumer queue for each thread */ @@ -379,7 +379,7 @@ class ThreadPool { * \brief args[0] is the AffinityMode, args[1] is the number of threads. * args2 is a list of CPUs which is used to set the CPU affinity. */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("runtime.config_threadpool", @@ -389,7 +389,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int nthreads = args[1].cast(); std::vector cpus; if (args.size() >= 3) { - auto cpu_array = args[2].cast>(); + auto cpu_array = args[2].cast>(); for (auto cpu : cpu_array) { ICHECK(IsNumber(cpu)) << "The CPU core information '" << cpu << "' is not a number."; @@ -399,7 +399,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ threading::Configure(mode, nthreads, cpus); }) .def("runtime.NumThreads", []() -> int32_t { return threading::NumThreads(); }); -}); +} namespace threading { diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 914fe67819de..d085ed40613f 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -224,6 +224,8 @@ struct ThreadWorkLoad { size_t work_size[6]; // Dynamic shared memory allocation size in bytes. size_t dyn_shmem_size{0}; + // Cluster dimensions for SM90+ cluster launch (x, y, z) + size_t cluster_dim[3] = {1, 1, 1}; /*! * \param i The block dimension. * \return i-th block dim @@ -234,6 +236,12 @@ struct ThreadWorkLoad { * \return i-th grid dim */ inline size_t grid_dim(size_t i) const { return work_size[i]; } + /*! + * \return whether cluster launch is enabled + */ + inline bool use_cluster_launch() const { + return cluster_dim[0] > 1 || cluster_dim[1] > 1 || cluster_dim[2] > 1; + } }; /*! \brief Launch parameters configuration */ class LaunchParamConfig { @@ -247,6 +255,19 @@ class LaunchParamConfig { ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; + } else if (tag == launch_param::kUseProgramaticDependentLaunch) { + use_programmatic_dependent_launch_ = true; + } else if (tag == launch_param::kUseCooperativeLaunch) { + use_cooperative_launch_ = true; + } else if (tag == launch_param::kClusterDimX) { + cluster_dim_x_arg_index_ = arg_index_map_.size(); + arg_index_map_.push_back(100); // Special marker for cluster dim x + } else if (tag == launch_param::kClusterDimY) { + cluster_dim_y_arg_index_ = arg_index_map_.size(); + arg_index_map_.push_back(101); // Special marker for cluster dim y + } else if (tag == launch_param::kClusterDimZ) { + cluster_dim_z_arg_index_ = arg_index_map_.size(); + arg_index_map_.push_back(102); // Special marker for cluster dim z } else { ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); @@ -267,10 +288,19 @@ class LaunchParamConfig { const TVMFFIAny* raw_args = reinterpret_cast(args.data()); for (size_t i = 0; i < arg_index_map_.size(); ++i) { - // Dynamic shapes can result in 0 dim size. Guard to ensure that the dim size is at least 1. + uint32_t idx = arg_index_map_[i]; size_t size = static_cast(raw_args[base_ + i].v_int64); - if (size > 0) { - w.work_size[arg_index_map_[i]] = size; + if (idx == 100) { + // Cluster dim X + w.cluster_dim[0] = size > 0 ? size : 1; + } else if (idx == 101) { + // Cluster dim Y + w.cluster_dim[1] = size > 0 ? size : 1; + } else if (idx == 102) { + // Cluster dim Z + w.cluster_dim[2] = size > 0 ? size : 1; + } else { + w.work_size[idx] = size; } } if (use_dyn_shared_memory_) { @@ -281,6 +311,15 @@ class LaunchParamConfig { // return the work dim size_t work_dim() const { return work_dim_; } + bool use_programtic_dependent_launch() const { return use_programmatic_dependent_launch_; } + + bool use_cooperative_launch() const { return use_cooperative_launch_; } + + bool use_cluster_launch() const { + return cluster_dim_x_arg_index_ >= 0 || cluster_dim_y_arg_index_ >= 0 || + cluster_dim_z_arg_index_ >= 0; + } + private: /*! \brief base axis */ size_t base_; @@ -290,6 +329,14 @@ class LaunchParamConfig { std::vector arg_index_map_; /*! \brief Whether or not use dynamic shared memory. */ bool use_dyn_shared_memory_{false}; + /*! \brief Whether or not use programmatic dependent launch. */ + bool use_programmatic_dependent_launch_{false}; + /*! \brief Whether or not use cooperative launch. */ + bool use_cooperative_launch_{false}; + /*! \brief Cluster dimension argument indices (-1 if not used) */ + int cluster_dim_x_arg_index_{-1}; + int cluster_dim_y_arg_index_{-1}; + int cluster_dim_z_arg_index_{-1}; }; } // namespace runtime diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index cb56ed181243..c4f6b3e17777 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -438,14 +438,14 @@ int MaxConcurrency() { // This global function can be used by disco runtime to bind processes // to CPUs. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tvm.runtime.threading.set_current_thread_affinity", [](ffi::Shape cpu_ids) { SetThreadAffinity(CURRENT_THREAD_HANDLE, std::vector{cpu_ids.begin(), cpu_ids.end()}); }); -}); +} } // namespace threading } // namespace runtime diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index c8fbd9082103..13e151ecd202 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -25,12 +25,12 @@ namespace tvm { namespace runtime { namespace vm { -std::unique_ptr ConvertPagedPrefillFunc(Array args, +std::unique_ptr ConvertPagedPrefillFunc(ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); @@ -47,33 +47,41 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, throw; } -std::unique_ptr ConvertRaggedPrefillFunc(Array args, +std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { - CHECK_EQ(args.size(), 3); + CHECK(args.size() == 3 || args.size() == 5); ffi::Function attn_func = args[1].cast(); ffi::Function plan_func = args[2].cast(); + int64_t qk_head_dim_override = -1; + int64_t v_head_dim_override = -1; + if (args.size() == 5) { + qk_head_dim_override = args[3].cast(); + v_head_dim_override = args[4].cast(); + } return std::make_unique(std::move(attn_func), std::move(plan_func), - attn_kind); + attn_kind, qk_head_dim_override, + v_head_dim_override); } LOG(FATAL) << "Cannot reach here"; throw; } -std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind) { +std::unique_ptr ConvertPagedDecodeFunc(ffi::Array args, + AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); @@ -90,12 +98,12 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, At throw; } -std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); @@ -105,12 +113,12 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< throw; } -std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, - AttnKind attn_kind) { +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc( + ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 449a1def0a38..31f1ce9f4ad2 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -57,6 +58,22 @@ class AttnBackendFunc { virtual ~AttnBackendFunc() = default; protected: + // helper allocator class for creating strided view of a Tensor + // that applies byte offset to the original data pointer + class ViewBasedAlloc { + public: + explicit ViewBasedAlloc(Tensor source) : source_(source) {} + void AllocData(DLTensor* tensor, int64_t* strides, int64_t extra_byte_offset) { + tensor->data = static_cast(source_->data) + extra_byte_offset; + tensor->strides = strides; + } + + void FreeData(DLTensor* tensor) {} + + private: + Tensor source_; + }; + ffi::Function attn_func_; public: @@ -71,22 +88,22 @@ class PagedPrefillFunc : public AttnBackendFunc { AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} - virtual void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, NDArray q_rope_position, - NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, - double rotary_theta, double sm_scale, NDArray attn_output, NDArray attn_lse, + virtual void MHA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, Tensor q_rope_position, + Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MHA computation is not supported by the current backend"; } - virtual void MLA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, bool causal, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + virtual void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, bool causal, double sm_scale, + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MLA computation is not supported by the current backend"; } - virtual void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, - NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + virtual void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, + Tensor page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, int64_t batch_size, int64_t total_qo_len, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, @@ -101,10 +118,10 @@ class TIRPagedPrefillFunc : public PagedPrefillFunc { explicit TIRPagedPrefillFunc(ffi::Function attn_func, AttnKind attn_kind) : PagedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} - void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, NDArray q_rope_position, - NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, - double rotary_theta, double sm_scale, NDArray attn_output, NDArray attn_lse, + void MHA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, Tensor q_rope_position, + Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { attn_func_(q, qo_indptr, pages, page_indptr, page_indices, length_info, k_rope_pos_offset, q_rope_position, attn_output, attn_lse, static_cast(causal), @@ -112,9 +129,9 @@ class TIRPagedPrefillFunc : public PagedPrefillFunc { rotary_theta, sm_scale); } - void MLA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, bool causal, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, bool causal, double sm_scale, + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { attn_func_(q, qo_indptr, pages, page_indptr, page_indices, length_info, attn_output, attn_lse, static_cast(causal), sm_scale); } @@ -128,64 +145,123 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { : PagedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), plan_func_(std::move(plan_func)) {} - void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, NDArray q_rope_position, - NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, - double rotary_theta, double sm_scale, NDArray attn_output, NDArray attn_lse, + void MHA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, Tensor q_rope_position, + Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, plan_info_vec] = cached_buffers_[depth]; double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; - attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, qo_indptr, - page_indptr, page_indices, length_info, q_rope_position, k_rope_pos_offset, - attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), - /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), - /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, - /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + + ICHECK_EQ(pages.ndim(), 5); + int H = pages->shape[2]; + int N = pages->shape[3]; + int D = pages->shape[4]; + CHECK(pages.IsContiguous()); + std::vector pages_k_v_shape = {pages->shape[0], H, N, D}; + std::vector pages_k_v_strides = {2 * H * N * D, N * D, D, 1}; + Tensor pages_k = + Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, + pages->device, pages_k_v_strides.data(), pages->byte_offset); + Tensor pages_v = Tensor::FromNDAlloc( + ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, pages->device, + pages_k_v_strides.data(), pages->byte_offset + (H * N * D) * pages.DataType().bytes()); + + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages_k, pages_v, + qo_indptr, page_indptr, page_indices, length_info, attn_output, attn_lse, + /*mask_mode_code=*/static_cast(causal), /*layout(HND)=*/1, + /*window_left=*/-1, /*enable_pdl=*/false, sm_scale, + /*rope_rcp_scale=*/rope_rcp_scale, /*rope_rcp_theta=*/rope_rcp_theta); + DeviceAPI::Get(device)->SetStream(device, original_stream); } - void MLA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, bool causal, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, bool causal, double sm_scale, + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, plan_info_vec] = cached_buffers_[depth]; - attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, page_indices, - attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), - /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1], sm_scale, compute_stream); + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); + ICHECK_NE(qk_head_dim_, -1); + ICHECK_NE(v_head_dim_, -1); + int64_t H = q->shape[1]; + int64_t page_size = pages->shape[1]; + int64_t rope_head_dim = qk_head_dim_ - v_head_dim_; + int64_t nope_head_dim = q->shape[2] - rope_head_dim; + + // Split q into q_nope and q_pe + CHECK(q.IsContiguous()); + std::vector q_nope_shape = {q->shape[0], H, nope_head_dim}; + std::vector q_pe_shape = {q->shape[0], H, rope_head_dim}; + std::vector q_strides = {H * q->shape[2], q->shape[2], 1}; + Tensor q_nope = Tensor::FromNDAlloc(ViewBasedAlloc(q), ffi::Shape(q_nope_shape), q->dtype, + q->device, q_strides.data(), q->byte_offset); + Tensor q_pe = Tensor::FromNDAlloc(ViewBasedAlloc(q), ffi::Shape(q_pe_shape), q->dtype, + q->device, q_strides.data(), + q->byte_offset + nope_head_dim * q.DataType().bytes()); + // Split pages into kv_nope and kv_pe + CHECK(pages.IsContiguous()); + std::vector kv_nope_shape = {pages->shape[0], page_size, nope_head_dim}; + std::vector kv_pe_shape = {pages->shape[0], page_size, rope_head_dim}; + std::vector kv_strides = {page_size * pages->shape[2], pages->shape[2], 1}; + Tensor kv_nope = + Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(kv_nope_shape), pages->dtype, + pages->device, kv_strides.data(), pages->byte_offset); + Tensor kv_pe = Tensor::FromNDAlloc( + ViewBasedAlloc(pages), ffi::Shape(kv_pe_shape), pages->dtype, pages->device, + kv_strides.data(), pages->byte_offset + nope_head_dim * pages.DataType().bytes()); + + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q_nope, q_pe, kv_nope, + kv_pe, page_indices, attn_output, attn_lse, + /*mask_mode_code=*/static_cast(causal), + /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1], sm_scale); + DeviceAPI::Get(device)->SetStream(device, original_stream); } - void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, - NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, + Tensor page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, int64_t batch_size, int64_t total_qo_len, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) final { - std::vector kv_len; - kv_len.reserve(batch_size); + Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32), Device{kDLCPU, 0}); + int32_t* kv_len_arr_data = static_cast(kv_len_arr.data_ptr()); for (int i = 0; i < static_cast(batch_size); ++i) { - kv_len.push_back((*page_indptr)[i + 1] != (*page_indptr)[i] - ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size + - (*last_page_len)[i] - : 0); + kv_len_arr_data[i] = + (*page_indptr)[i + 1] != (*page_indptr)[i] + ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size + (*last_page_len)[i] + : 0; } - IntTuple plan_info_vec; + qk_head_dim_ = qk_head_dim; + v_head_dim_ = v_head_dim; + ffi::Array plan_info_vec; + Device device = float_workspace_buffer->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, copy_stream); if (attn_kind == AttnKind::kMHA) { // Todo(tvm-team): enable cuda graph plan_info_vec = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, - qo_indptr->as_ndarray(), page_indptr->as_ndarray(), - IntTuple(std::move(kv_len)), total_qo_len, batch_size, num_qo_heads, - num_kv_heads, page_size, - /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream) - .cast(); + qo_indptr->as_tensor(), page_indptr->as_tensor(), kv_len_arr, total_qo_len, + batch_size, num_qo_heads, num_kv_heads, page_size, + /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, + /*window_left=*/-1, /*fixed_split_size=*/-1, /*disable_split_kv=*/false, + /*num_colocated_ctas=*/0) + .cast>(); } else if (attn_kind == AttnKind::kMLA) { plan_info_vec = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, - qo_indptr->as_ndarray(), page_indptr->as_ndarray(), - IntTuple(std::move(kv_len)), num_qo_heads, v_head_dim, causal, copy_stream) - .cast(); + qo_indptr->as_tensor(), page_indptr->as_tensor(), kv_len_arr, num_qo_heads, + v_head_dim, causal) + .cast>(); } + DeviceAPI::Get(device)->SetStream(device, original_stream); if (cached_buffers_.size() <= static_cast(depth)) { cached_buffers_.resize(depth + 1); @@ -196,8 +272,10 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { } private: + int64_t qk_head_dim_ = -1; + int64_t v_head_dim_ = -1; ffi::Function plan_func_; - std::vector> cached_buffers_; + std::vector>> cached_buffers_; }; /*! \brief The ragged prefill attention function base class. */ @@ -207,15 +285,15 @@ class RaggedPrefillFunc : public AttnBackendFunc { AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} - virtual void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, - NDArray q_rope_position, NDArray k_rope_pos_offset, bool causal, + virtual void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, + Tensor q_rope_position, Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MHA computation is not supported by the current backend"; } - virtual void BeginForward(NDArray float_workspace_buffer, NDArray int_workspace_buffer, - NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + virtual void BeginForward(Tensor float_workspace_buffer, Tensor int_workspace_buffer, + Tensor page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, HostMemoryVector* kv_indptr, int64_t batch_size, int64_t total_qo_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) { @@ -229,10 +307,10 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc { explicit TIRRaggedPrefillFunc(ffi::Function attn_func, AttnKind attn_kind) : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} - void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, - NDArray q_rope_position, NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, - double rotary_scale, double rotary_theta, double sm_scale, NDArray attn_output, - NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position, + Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, + TVMStreamHandle compute_stream) final { attn_func_(q, qo_indptr, k, v, kv_indptr, q_rope_position, k_rope_pos_offset, attn_output, attn_lse, static_cast(causal), /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, @@ -244,53 +322,74 @@ class TIRRaggedPrefillFunc : public RaggedPrefillFunc { class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { public: explicit FlashInferRaggedPrefillFunc(ffi::Function attn_func, ffi::Function plan_func, - AttnKind attn_kind) + AttnKind attn_kind, int64_t qk_head_dim_override, + int64_t v_head_dim_override) : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), + qk_head_dim_override_(qk_head_dim_override), + v_head_dim_override_(v_head_dim_override), plan_func_(std::move(plan_func)) {} - void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, - NDArray q_rope_position, NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, - double rotary_scale, double rotary_theta, double sm_scale, NDArray attn_output, - NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position, + Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, + TVMStreamHandle compute_stream) final { + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; attn_func_(float_workspace_buffer_, int_workspace_buffer_, plan_info_vec_, q, k, v, qo_indptr, - kv_indptr, q_rope_position, k_rope_pos_offset, attn_output, attn_lse, + kv_indptr, attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), - /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), - /*layout(NHD)=*/0, /*window_left=*/-1, sm_scale, + /*layout(NHD)=*/0, /*window_left=*/-1, + /*enable_pdl=*/false, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, - /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + /*rope_rcp_theta=*/rope_rcp_theta); + DeviceAPI::Get(device)->SetStream(device, original_stream); } - void BeginForward(NDArray float_workspace_buffer, NDArray int_workspace_buffer, - NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + void BeginForward(Tensor float_workspace_buffer, Tensor int_workspace_buffer, + Tensor page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, HostMemoryVector* kv_indptr, int64_t batch_size, int64_t total_qo_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) final { - std::vector kv_len; - kv_len.reserve(batch_size); + Tensor kv_len_arr = Tensor::Empty({batch_size}, DataType::Int(32), Device{kDLCPU, 0}); + int32_t* kv_len_arr_data = static_cast(kv_len_arr.data_ptr()); for (int i = 0; i < static_cast(batch_size); ++i) { - kv_len.push_back((*kv_indptr)[i + 1] - (*kv_indptr)[i]); + kv_len_arr_data[i] = (*kv_indptr)[i + 1] - (*kv_indptr)[i]; + } + if (qk_head_dim_override_ != -1) { + qk_head_dim = qk_head_dim_override_; + } + if (v_head_dim_override_ != -1) { + v_head_dim = v_head_dim_override_; } // Todo(tvm-team): enable cuda graph float_workspace_buffer_ = float_workspace_buffer; int_workspace_buffer_ = int_workspace_buffer; page_locked_int_workspace_buffer_ = page_locked_int_workspace_buffer; + Device device = float_workspace_buffer->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, copy_stream); plan_info_vec_ = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, - qo_indptr->as_ndarray(), kv_indptr->as_ndarray(), IntTuple(std::move(kv_len)), - total_qo_len, batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1, - /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream) - .cast(); + qo_indptr->as_tensor(), kv_indptr->as_tensor(), kv_len_arr, total_qo_len, + batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1, + /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, + /*window_left=*/-1, /*fixed_split_size=*/-1, /*disable_split_kv=*/false, + /*num_colocated_ctas=*/0) + .cast>(); + DeviceAPI::Get(device)->SetStream(device, original_stream); } private: + int64_t qk_head_dim_override_; + int64_t v_head_dim_override_; ffi::Function plan_func_; - NDArray float_workspace_buffer_; - NDArray int_workspace_buffer_; - NDArray page_locked_int_workspace_buffer_; - IntTuple plan_info_vec_; + Tensor float_workspace_buffer_; + Tensor int_workspace_buffer_; + Tensor page_locked_int_workspace_buffer_; + ffi::Array plan_info_vec_; }; /*! \brief The paged decode attention function base class. */ @@ -300,21 +399,21 @@ class PagedDecodeFunc : public AttnBackendFunc { AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} - virtual void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, - NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, + virtual void MHA(int depth, Tensor q, Tensor pages, Tensor page_indptr, Tensor page_indices, + Tensor length_info, Tensor k_rope_pos_offset, Tensor q_rope_position, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MHA computation is not supported by the current backend"; } - virtual void MLA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, - NDArray length_info, double sm_scale, NDArray attn_output, NDArray attn_lse, + virtual void MLA(int depth, Tensor q, Tensor pages, Tensor page_indptr, Tensor page_indices, + Tensor length_info, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MLA computation is not supported by the current backend"; } - virtual void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, - NDArray page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, + virtual void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, + Tensor page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, @@ -329,18 +428,18 @@ class TIRPagedDecodeFunc : public PagedDecodeFunc { explicit TIRPagedDecodeFunc(ffi::Function attn_func, AttnKind attn_kind) : PagedDecodeFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} - void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, - NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, - RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MHA(int depth, Tensor q, Tensor pages, Tensor page_indptr, Tensor page_indices, + Tensor length_info, Tensor k_rope_pos_offset, Tensor q_rope_position, RoPEMode rope_mode, + double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, + Tensor attn_lse, TVMStreamHandle compute_stream) final { attn_func_(q, pages, page_indptr, page_indices, length_info, k_rope_pos_offset, q_rope_position, attn_output, attn_lse, /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, rotary_theta, sm_scale); } - void MLA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, - NDArray length_info, double sm_scale, NDArray attn_output, NDArray attn_lse, + void MLA(int depth, Tensor q, Tensor pages, Tensor page_indptr, Tensor page_indices, + Tensor length_info, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) final { attn_func_(q, pages, page_indptr, page_indices, length_info, attn_output, attn_lse, sm_scale); } @@ -354,35 +453,58 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { : PagedDecodeFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), plan_func_(std::move(plan_func)) {} - void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, - NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, - RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MHA(int depth, Tensor q, Tensor pages, Tensor page_indptr, Tensor page_indices, + Tensor length_info, Tensor k_rope_pos_offset, Tensor q_rope_position, RoPEMode rope_mode, + double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, + Tensor attn_lse, TVMStreamHandle compute_stream) final { + Device device = q->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, compute_stream); auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, plan_info_vec] = cached_buffers_[depth]; double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; - attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, page_indptr, - page_indices, length_info, q_rope_position, k_rope_pos_offset, attn_output, attn_lse, - /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), - /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, - /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + + ICHECK_EQ(pages.ndim(), 5); + int H = pages->shape[2]; + int N = pages->shape[3]; + int D = pages->shape[4]; + CHECK(pages.IsContiguous()); + std::vector pages_k_v_shape = {pages->shape[0], H, N, D}; + std::vector pages_k_v_strides = {2 * H * N * D, N * D, D, 1}; + Tensor pages_k = + Tensor::FromNDAlloc(ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, + pages->device, pages_k_v_strides.data(), pages->byte_offset); + Tensor pages_v = Tensor::FromNDAlloc( + ViewBasedAlloc(pages), ffi::Shape(pages_k_v_shape), pages->dtype, pages->device, + pages_k_v_strides.data(), pages->byte_offset + (H * N * D) * pages.DataType().bytes()); + + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages_k, pages_v, + page_indptr, page_indices, length_info, attn_output, attn_lse, + /*layout(HND)=*/1, /*window_left=*/-1, /*enable_pdl=*/false, sm_scale, + /*rope_rcp_scale=*/rope_rcp_scale, /*rope_rcp_theta=*/rope_rcp_theta); + DeviceAPI::Get(device)->SetStream(device, original_stream); } - void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, - NDArray page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, + void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, + Tensor page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, TVMStreamHandle copy_stream) final { // Todo(tvm-team): enable cuda graph - IntTuple plan_info_vec = + Tensor empty_qkv_data = Tensor::Empty({1}, q_dtype, Device{kDLCPU, 0}); + Device device = float_workspace_buffer->device; + TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); + DeviceAPI::Get(device)->SetStream(device, copy_stream); + ffi::Array plan_info_vec = plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, - page_indptr->as_ndarray(), batch_size, num_qo_heads, num_kv_heads, page_size, + page_indptr->as_tensor(), batch_size, num_qo_heads, num_kv_heads, page_size, /*enable_cuda_graph=*/false, - static_cast(rope_mode == RoPEMode::kInline), - /*window_left=*/-1, qk_head_dim, v_head_dim, q_dtype, kv_dtype, copy_stream) - .cast(); + /*window_left=*/-1, /*logits_soft_cap=*/0.0, qk_head_dim, v_head_dim, + empty_qkv_data, empty_qkv_data) + .cast>(); + DeviceAPI::Get(device)->SetStream(device, original_stream); if (cached_buffers_.size() <= static_cast(depth)) { cached_buffers_.resize(depth + 1); @@ -394,7 +516,7 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { private: ffi::Function plan_func_; - std::vector> cached_buffers_; + std::vector>> cached_buffers_; }; /*! \brief The paged prefill with tree mask attention function base class. */ @@ -404,22 +526,22 @@ class PagedPrefillTreeMaskFunc : public AttnBackendFunc { AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} - virtual void MHA(NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, NDArray k_rope_pos_offset, - NDArray q_rope_position, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, + virtual void MHA(Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, Tensor k_rope_pos_offset, + Tensor q_rope_position, Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MHA computation is not supported by the current backend"; } - virtual void MLA(NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, - NDArray page_indices, NDArray length_info, NDArray tree_attn_mn_indptr, - NDArray tree_attn_mask, double sm_scale, NDArray attn_output, NDArray attn_lse, + virtual void MLA(Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, + Tensor page_indices, Tensor length_info, Tensor tree_attn_mn_indptr, + Tensor tree_attn_mask, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MLA computation is not supported by the current backend"; } - virtual void BeginForward(NDArray temp_float_attn_workspace, NDArray temp_int_attn_workspace, + virtual void BeginForward(Tensor temp_float_attn_workspace, Tensor temp_int_attn_workspace, HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, HostMemoryVector* qo_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, @@ -434,11 +556,11 @@ class TIRPagedPrefillTreeMaskFunc : public PagedPrefillTreeMaskFunc { explicit TIRPagedPrefillTreeMaskFunc(ffi::Function attn_func, AttnKind attn_kind) : PagedPrefillTreeMaskFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} - void MHA(NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, NDArray page_indices, - NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, - NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, RoPEMode rope_mode, - double rotary_scale, double rotary_theta, double sm_scale, NDArray attn_output, - NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MHA(Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, Tensor page_indices, + Tensor length_info, Tensor k_rope_pos_offset, Tensor q_rope_position, + Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, RoPEMode rope_mode, + double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, + Tensor attn_lse, TVMStreamHandle compute_stream) final { attn_func_(q, qo_indptr, pages, page_indptr, page_indices, length_info, k_rope_pos_offset, q_rope_position, attn_output, attn_lse, /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, @@ -453,21 +575,20 @@ class RaggedPrefillTreeMaskFunc : public AttnBackendFunc { AttnBackendKind backend_kind) : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} - virtual void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, - NDArray q_rope_position, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, + virtual void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, + Tensor q_rope_position, Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MHA computation is not supported by the current backend"; } - virtual void MLA(NDArray q, NDArray compressed_kv, NDArray k_pe, NDArray qo_indptr, - NDArray kv_indptr, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, - double sm_scale, NDArray attn_output, NDArray attn_lse, - TVMStreamHandle compute_stream) { + virtual void MLA(Tensor q, Tensor compressed_kv, Tensor k_pe, Tensor qo_indptr, Tensor kv_indptr, + Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, double sm_scale, + Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { LOG(FATAL) << "MLA computation is not supported by the current backend"; } - virtual void BeginForward(NDArray temp_float_attn_workspace, NDArray temp_int_attn_workspace, + virtual void BeginForward(Tensor temp_float_attn_workspace, Tensor temp_int_attn_workspace, HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, HostMemoryVector* qo_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, @@ -482,10 +603,10 @@ class TIRRaggedPrefillTreeMaskFunc : public RaggedPrefillTreeMaskFunc { explicit TIRRaggedPrefillTreeMaskFunc(ffi::Function attn_func, AttnKind attn_kind) : RaggedPrefillTreeMaskFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} - void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, - NDArray q_rope_position, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, - RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, - NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + void MHA(Tensor q, Tensor k, Tensor v, Tensor qo_indptr, Tensor kv_indptr, Tensor q_rope_position, + Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, RoPEMode rope_mode, + double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, + Tensor attn_lse, TVMStreamHandle compute_stream) final { attn_func_(q, qo_indptr, k, v, kv_indptr, q_rope_position, tree_attn_mn_indptr, tree_attn_mask, attn_output, attn_lse, /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, @@ -499,7 +620,8 @@ class TIRRaggedPrefillTreeMaskFunc : public RaggedPrefillTreeMaskFunc { * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedPrefillFunc pointer. */ -std::unique_ptr ConvertPagedPrefillFunc(Array args, AttnKind attn_kind); +std::unique_ptr ConvertPagedPrefillFunc(ffi::Array args, + AttnKind attn_kind); /*! * \brief Create a PagedDecodeFunc from the given arguments and the attention kind. @@ -507,7 +629,8 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedDecodeFunc pointer. */ -std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind); +std::unique_ptr ConvertPagedDecodeFunc(ffi::Array args, + AttnKind attn_kind); /*! * \brief Create a RaggedPrefillFunc from the given arguments and the attention kind. @@ -515,7 +638,7 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, At * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * RaggedPrefillFunc pointer. */ -std::unique_ptr ConvertRaggedPrefillFunc(Array args, +std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array args, AttnKind attn_kind); /*! @@ -524,7 +647,7 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array args * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedPrefillTreeMaskFunc pointer. */ -std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(ffi::Array args, AttnKind attn_kind); /*! @@ -533,8 +656,8 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * RaggedPrefillTreeMaskFunc pointer. */ -std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, - AttnKind attn_kind); +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc( + ffi::Array args, AttnKind attn_kind); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index 290ca02653d2..1c695a10e25d 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_VM_ATTN_UTILS_H_ #define TVM_RUNTIME_VM_ATTN_UTILS_H_ -#include +#include #include #include @@ -355,14 +355,14 @@ class HostMemoryVector { explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) : reserved_size_(reserved_size) { ICHECK(DataType(dtype) == DataType::Int(32)); - data_ = NDArray::Empty({reserved_size}, dtype, device); + data_ = Tensor::Empty({reserved_size}, dtype, device); } void push_back(int32_t value) { ICHECK_LE(current_size_, reserved_size_); if (current_size_ == reserved_size_) { reserved_size_ *= 2; - NDArray new_data = NDArray::Empty({reserved_size_}, data_->dtype, data_->device); + Tensor new_data = Tensor::Empty({reserved_size_}, data_->dtype, data_->device); std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes()); data_ = new_data; } @@ -386,8 +386,8 @@ class HostMemoryVector { void clear() { current_size_ = 0; } - /*! \brief Return the vector as an NDArray. */ - NDArray as_ndarray() { return data_.CreateView({current_size_}, data_->dtype); } + /*! \brief Return the vector as an Tensor. */ + Tensor as_tensor() { return data_.CreateView({current_size_}, data_->dtype); } IntTuple as_int_tuple() const { std::vector values; @@ -401,7 +401,7 @@ class HostMemoryVector { private: int64_t reserved_size_ = 0; int64_t current_size_ = 0; - NDArray data_{nullptr}; + Tensor data_{nullptr}; }; /*! @@ -411,12 +411,12 @@ class HostMemoryVector { * * The core functions of this class is `CopyXXXAsync` and `CommitAttnAuxDataCopy`. * `CopyXXXAsync` takes the input data on CPU host, and copy the input data - * to GPU in an asynchronous way, and returns the NDArray view of the data + * to GPU in an asynchronous way, and returns the Tensor view of the data * on GPU device. * * Being asynchronous here means the `CopyXXXAsync` function may not perform * data copy from CPU to GPU at the time of being called. Therefore, the - * returned NDArray view may have wrong result, until `CommitAttnAuxDataCopy` is + * returned Tensor view may have wrong result, until `CommitAttnAuxDataCopy` is * explicitly invoked and the data copy stream is synchronized. * * We design this manager class in order to reduce the data copy overhead. @@ -436,16 +436,16 @@ class PagedKVCacheAuxDataManager { /*! \brief Reset the attention auxiliary data status of copy manager. */ virtual void ResetAttnAuxDataCopy() = 0; /*! \brief Copy the indptr array of append lengths after coalescing. (see GetChunkedBlockIds) */ - virtual NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; + virtual Tensor CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the indptr array of page table. */ - virtual NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; + virtual Tensor CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the indices array of page table. */ - virtual NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) = 0; + virtual Tensor CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the array of KV slot number used in the last page of the seq. */ - virtual NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) = 0; + virtual Tensor CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! * \brief Copy the length information of the sequences. - * Each NDArray is in shape `(3, n)`. "n" is the number of sequences. + * Each Tensor is in shape `(3, n)`. "n" is the number of sequences. * For a sequence "i", location * - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), * - "(1, i)" is the starting offset of the sliding window in the seq, @@ -453,51 +453,51 @@ class PagedKVCacheAuxDataManager { * \note When sliding window is not enabled, only the * "last_page_len" (a.k.a., the first "n" elements) will be effectively used. */ - virtual NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, - HostMemoryVector* sliding_window_offset, - HostMemoryVector* sink_size, int depth) = 0; + virtual Tensor CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) = 0; /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) = 0; + virtual Tensor CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! * \brief Copy the append length indptr array on device. * \note Since the Q/K/V data may have raggedness in terms of lengths, * we represent the append lengths in CSR format. */ - virtual NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) = 0; /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) = 0; /*! \brief Copy the q position mapping of applying RoPE for each sequence. */ - virtual NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyQRoPEPosMapAsync(HostMemoryVector* data) = 0; /*! * \brief Copy the corresponding position in global KV cache (pages) * for each position along the length dimension of K/V data when * appending new K/V data. */ - virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Copy the remote position map for KV transfer. */ - virtual NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Copy the receiver id for KV transfer. */ - virtual NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyKVTransferRecverIDAsync(HostMemoryVector* data) = 0; /*! \brief Copy the local position map for KV page-to-page transfer. */ - virtual NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Copy the remote position map for KV page-to-page transfer. */ - virtual NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Copy the receiver id for KV page-to-page transfer. */ - virtual NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) = 0; /*! \brief Copy the tree attention mask. */ - virtual NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) = 0; + virtual Tensor CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the mn indptr of the tree attention mask. */ - virtual NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; + virtual Tensor CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Commit all the attention auxiliary data copy operations since the last commit. */ virtual void CommitAttnAuxDataCopy() = 0; /*! \brief Reset the compact KV auxiliary data status of copy manager. */ virtual void ResetCompactKVAuxDataCopy() = 0; /*! \brief Copy the length indptr array of KV data copy for each sequence. */ - virtual NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) = 0; + virtual Tensor CopyCommitLengthIndptrAsync(HostMemoryVector* data) = 0; /*! \brief Copy the src/dst position arrays for each sequence. */ - virtual NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, - HostMemoryVector* dst_data) = 0; + virtual Tensor CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) = 0; /*! \brief Commit all the compact KV auxiliary data copy operations since the last commit. */ virtual void CommitCompactKVAuxDataCopy() = 0; @@ -525,144 +525,144 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream) { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { qo_indptr_on_depths_device_.push_back( - NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); + Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); page_indptr_on_depths_device_.push_back( - NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); + Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); page_indices_on_depths_device_.push_back( - NDArray::Empty({num_total_pages}, dtype_aux_, device)); + Tensor::Empty({num_total_pages}, dtype_aux_, device)); length_info_on_depths_device_.push_back( - NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device)); + Tensor::Empty({3, reserved_num_seqs}, dtype_aux_, device)); k_rope_pos_offset_on_depths_device_.push_back( - NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); - tree_attn_mask_device_.push_back(NDArray::Empty( + Tensor::Empty({reserved_num_seqs}, dtype_aux_, device)); + tree_attn_mask_device_.push_back(Tensor::Empty( {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device)); tree_attn_mn_indptr_device_.push_back( - NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); + Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); } - cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); - k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); - q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + cur_append_length_indptr_device_ = Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + k_ragged_rope_pos_offset_device_ = Tensor::Empty({reserved_num_seqs}, dtype_aux_, device); + q_rope_position_map_device_ = Tensor::Empty({prefill_chunk_size}, dtype_aux_, device); + append_position_map_device_ = Tensor::Empty({prefill_chunk_size}, dtype_aux_, device); kv_transfer_remote_position_map_device = - NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - kv_transfer_recver_id_device = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + Tensor::Empty({prefill_chunk_size}, dtype_aux_, device); + kv_transfer_recver_id_device = Tensor::Empty({prefill_chunk_size}, dtype_aux_, device); kv_transfer_page_to_page_local_position_map_device = kv_transfer_page_to_page_remote_position_map_device = - NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + Tensor::Empty({prefill_chunk_size}, dtype_aux_, device); kv_transfer_page_to_page_recver_id_device = - NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + Tensor::Empty({prefill_chunk_size}, dtype_aux_, device); + commit_copy_length_indptr_device_ = Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device); commit_copy_src_dst_pos_in_page_table_device_ = - NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)}, - dtype_aux_, device); + Tensor::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)}, + dtype_aux_, device); } // The reset of the plain auxiliary data manager is no-op. void ResetAttnAuxDataCopy() final {} - NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = qo_indptr_on_depths_device_[depth].CreateView( + Tensor CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor view = qo_indptr_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = page_indptr_on_depths_device_[depth].CreateView( + Tensor CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor view = page_indptr_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = page_indices_on_depths_device_[depth].CreateView( + Tensor CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor view = page_indices_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = length_info_on_depths_device_[depth].CreateView( + Tensor CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor view = length_info_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = k_rope_pos_offset_on_depths_device_[depth].CreateView( + Tensor CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor view = k_rope_pos_offset_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { - NDArray view = cur_append_length_indptr_device_.CreateView({static_cast(data->size())}, - dtype_aux_); + Tensor CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { + Tensor view = cur_append_length_indptr_device_.CreateView({static_cast(data->size())}, + dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { - NDArray view = k_ragged_rope_pos_offset_device_.CreateView({static_cast(data->size())}, - dtype_aux_); + Tensor CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { + Tensor view = k_ragged_rope_pos_offset_device_.CreateView({static_cast(data->size())}, + dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { - NDArray view = + Tensor CopyQRoPEPosMapAsync(HostMemoryVector* data) final { + Tensor view = q_rope_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { - NDArray view = + Tensor CopyAppendPositionMapAsync(HostMemoryVector* data) final { + Tensor view = append_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_remote_position_map_device.CreateView( + Tensor CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { + Tensor view = kv_transfer_remote_position_map_device.CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { - NDArray view = + Tensor CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { + Tensor view = kv_transfer_recver_id_device.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_page_to_page_local_position_map_device.CreateView( + Tensor CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { + Tensor view = kv_transfer_page_to_page_local_position_map_device.CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_page_to_page_remote_position_map_device.CreateView( + Tensor CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { + Tensor view = kv_transfer_page_to_page_remote_position_map_device.CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_page_to_page_recver_id_device.CreateView( + Tensor CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { + Tensor view = kv_transfer_page_to_page_recver_id_device.CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = + Tensor CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor view = tree_attn_mask_device_[depth].CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = tree_attn_mn_indptr_device_[depth].CreateView( + Tensor CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor view = tree_attn_mn_indptr_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, - HostMemoryVector* sliding_window_offset, - HostMemoryVector* sink_size, int depth) final { + Tensor CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { int n_elem = last_page_len->size(); ICHECK_GT(n_elem, 0); - NDArray view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); + Tensor view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); ffi::Shape copy_shape{n_elem}; CopyVecDataToArray(view, last_page_len->data(), copy_shape); CopyVecDataToArray(view, sliding_window_offset->data(), copy_shape, @@ -678,18 +678,17 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { // The reset of the plain auxiliary data manager is no-op. void ResetCompactKVAuxDataCopy() final {} - NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { - NDArray view = commit_copy_length_indptr_device_.CreateView( - {static_cast(data->size())}, dtype_aux_); + Tensor CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { + Tensor view = commit_copy_length_indptr_device_.CreateView({static_cast(data->size())}, + dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, - HostMemoryVector* dst_data) final { + Tensor CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) final { int n_elem = src_data->size(); ICHECK_GT(n_elem, 0); - NDArray view = - commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_); + Tensor view = commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_); ffi::Shape copy_shape{n_elem}; CopyVecDataToArray(view, src_data->data(), copy_shape); CopyVecDataToArray(view, dst_data->data(), copy_shape, @@ -702,12 +701,12 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { private: /*! - * \brief Copy a vector of data to the input NDArray. + * \brief Copy a vector of data to the input Tensor. * It optionally supports specifying the shape of copy and the element - * offset to the destination NDArray. + * offset to the destination Tensor. */ - void CopyVecDataToArray(NDArray array, int32_t* vec_data, - Optional shape = std::nullopt, int dst_elem_offset = 0) { + void CopyVecDataToArray(Tensor array, int32_t* vec_data, + ffi::Optional shape = std::nullopt, int dst_elem_offset = 0) { if (array->shape[0] == 0) { return; } @@ -743,27 +742,27 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { copy_src.shape = copy_dst.shape; copy_src.strides = nullptr; copy_src.byte_offset = 0; - NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); + Tensor::CopyFromTo(©_src, ©_dst, copy_stream_); } - std::vector qo_indptr_on_depths_device_; - std::vector page_indptr_on_depths_device_; - std::vector page_indices_on_depths_device_; - std::vector length_info_on_depths_device_; - std::vector k_rope_pos_offset_on_depths_device_; - std::vector tree_attn_mask_device_; - std::vector tree_attn_mn_indptr_device_; - NDArray cur_append_length_indptr_device_; - NDArray k_ragged_rope_pos_offset_device_; - NDArray q_rope_position_map_device_; - NDArray append_position_map_device_; - NDArray kv_transfer_remote_position_map_device; - NDArray kv_transfer_recver_id_device; - NDArray kv_transfer_page_to_page_local_position_map_device; - NDArray kv_transfer_page_to_page_remote_position_map_device; - NDArray kv_transfer_page_to_page_recver_id_device; - NDArray commit_copy_length_indptr_device_; - NDArray commit_copy_src_dst_pos_in_page_table_device_; + std::vector qo_indptr_on_depths_device_; + std::vector page_indptr_on_depths_device_; + std::vector page_indices_on_depths_device_; + std::vector length_info_on_depths_device_; + std::vector k_rope_pos_offset_on_depths_device_; + std::vector tree_attn_mask_device_; + std::vector tree_attn_mn_indptr_device_; + Tensor cur_append_length_indptr_device_; + Tensor k_ragged_rope_pos_offset_device_; + Tensor q_rope_position_map_device_; + Tensor append_position_map_device_; + Tensor kv_transfer_remote_position_map_device; + Tensor kv_transfer_recver_id_device; + Tensor kv_transfer_page_to_page_local_position_map_device; + Tensor kv_transfer_page_to_page_remote_position_map_device; + Tensor kv_transfer_page_to_page_recver_id_device; + Tensor commit_copy_length_indptr_device_; + Tensor commit_copy_src_dst_pos_in_page_table_device_; }; /*! @@ -790,7 +789,7 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { merged_attn_aux_data_host_ = HostMemoryVector(attn_aux_data_cache_size, dtype_aux, preferred_host_device); // - Initialize the device auxiliary data buffer. - merged_attn_aux_data_device_ = NDArray::Empty({attn_aux_data_cache_size}, dtype_aux, device); + merged_attn_aux_data_device_ = Tensor::Empty({attn_aux_data_cache_size}, dtype_aux, device); // - Calculate cache size of all the compact KV auxiliary arrays in // local cache and the large on-device array. @@ -800,60 +799,60 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { merged_compact_kv_aux_data_host_ = HostMemoryVector(compact_kv_aux_data_cache_size, dtype_aux, preferred_host_device); merged_compact_kv_aux_data_device_ = - NDArray::Empty({compact_kv_aux_data_cache_size}, dtype_aux, device); + Tensor::Empty({compact_kv_aux_data_cache_size}, dtype_aux, device); } void ResetAttnAuxDataCopy() final { attn_aux_data_copy_offset_ = 0; } - NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { + Tensor CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { + Tensor CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { + Tensor CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } + Tensor CopyAppendPositionMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { + Tensor CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { + Tensor CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { + Tensor CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { + Tensor CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { + Tensor CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray mask_1d = CopyAttnAuxVecToCache(data); + Tensor CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor mask_1d = CopyAttnAuxVecToCache(data); return mask_1d.CreateView({static_cast(data->size() / 2), 2}, mask_1d->dtype); } - NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + Tensor CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, - HostMemoryVector* sliding_window_offset, - HostMemoryVector* sink_size, int depth) final { + Tensor CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { int64_t n_elem = last_page_len->size(); std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, last_page_len->data(), n_elem * elem_byte_size_); @@ -861,8 +860,9 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { sliding_window_offset->data(), n_elem * elem_byte_size_); std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + 2 * n_elem, sink_size->data(), n_elem * elem_byte_size_); - NDArray view = merged_attn_aux_data_device_.CreateView( - {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = + Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), ffi::Shape({3, n_elem}), + dtype_aux_, device_, attn_aux_data_copy_offset_ * elem_byte_size_); attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem); return view; } @@ -881,23 +881,24 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { DLTensor copy_src = copy_dst; copy_src.data = merged_attn_aux_data_host_.data(); copy_src.device = Device{kDLCPU, 0}; - NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); + Tensor::CopyFromTo(©_src, ©_dst, copy_stream_); } void ResetCompactKVAuxDataCopy() final { compact_kv_aux_data_copy_offset_ = 0; } - NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { + Tensor CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { return CopyCompactKVAuxVecToCache(data); } - NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, - HostMemoryVector* dst_data) final { + Tensor CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) final { int64_t n_elem = src_data->size(); std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, src_data->data(), n_elem * elem_byte_size_); std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_ + n_elem, dst_data->data(), n_elem * elem_byte_size_); - NDArray view = merged_compact_kv_aux_data_device_.CreateView( - {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_), + ffi::Shape({2, n_elem}), dtype_aux_, device_, + compact_kv_aux_data_copy_offset_ * elem_byte_size_); compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem); return view; } @@ -916,10 +917,24 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { DLTensor copy_src = copy_dst; copy_src.data = merged_compact_kv_aux_data_host_.data(); copy_src.device = Device{kDLCPU, 0}; - NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); + Tensor::CopyFromTo(©_src, ©_dst, copy_stream_); } private: + // helper allocator class that applies byte offset to the original data pointer + class ViewHelper { + public: + explicit ViewHelper(Tensor source) : source_(source) {} + void AllocData(DLTensor* tensor, int64_t extra_byte_offset) { + tensor->data = static_cast(source_->data) + extra_byte_offset; + } + + void FreeData(DLTensor* tensor) {} + + private: + Tensor source_; + }; + /*! * \brief Calculate the start element offsets of the auxiliary arrays in the local cache. * \return Return the local cache size (total number of elements in the local cache). @@ -985,24 +1000,26 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { /*! * \brief Copy the input data to the cache at the given offset. - * And return the NDArray view of the cache starting at the offset. + * And return the Tensor view of the cache starting at the offset. */ - NDArray CopyAttnAuxVecToCache(HostMemoryVector* data) { + Tensor CopyAttnAuxVecToCache(HostMemoryVector* data) { int64_t n_elem = data->size(); std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, data->data(), n_elem * elem_byte_size_); - NDArray view = merged_attn_aux_data_device_.CreateView( - {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = + Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), ffi::Shape({n_elem}), + dtype_aux_, device_, attn_aux_data_copy_offset_ * elem_byte_size_); attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); return view; } - NDArray CopyCompactKVAuxVecToCache(HostMemoryVector* data) { + Tensor CopyCompactKVAuxVecToCache(HostMemoryVector* data) { int64_t n_elem = data->size(); std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, data->data(), n_elem * elem_byte_size_); - NDArray view = merged_compact_kv_aux_data_device_.CreateView( - {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + Tensor view = Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_), + ffi::Shape({n_elem}), dtype_aux_, device_, + compact_kv_aux_data_copy_offset_ * elem_byte_size_); compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); return view; } @@ -1020,8 +1037,8 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { int64_t compact_kv_aux_data_copy_offset_ = 0; HostMemoryVector merged_attn_aux_data_host_; HostMemoryVector merged_compact_kv_aux_data_host_; - NDArray merged_attn_aux_data_device_; - NDArray merged_compact_kv_aux_data_device_; + Tensor merged_attn_aux_data_device_; + Tensor merged_compact_kv_aux_data_device_; }; } // namespace vm diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 90e3b4c54922..1bd3084c210b 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -29,16 +29,18 @@ #include #include #include -#include +#include #include #include #include +#include + namespace tvm { namespace runtime { namespace vm { -using tvm::runtime::NDArray; +using tvm::runtime::Tensor; //------------------------------------------------- // Shape/StructInfo handling. @@ -47,9 +49,9 @@ using tvm::runtime::NDArray; * \brief Builtin function to allocate shape heap. * \param ctx_ptr The context module pointer. * \param size the size of the heap. - * \return An allocate NDArray as shape heap. + * \return An allocate Tensor as shape heap. */ -NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { +Tensor AllocShapeHeap(void* ctx_ptr, int64_t size) { VirtualMachine* vm = static_cast(ctx_ptr); // use host allocator, which is always last element. size_t host_device_index = vm->devices.size() - 1; @@ -64,10 +66,10 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.alloc_shape_heap", AllocShapeHeap); -}); +} /*! * \brief Builtin match R.Prim function. @@ -88,7 +90,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \sa MatchShape */ void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t reg, - Optional err_ctx) { + ffi::Optional err_ctx) { int64_t* heap_data = heap == nullptr ? nullptr : static_cast(heap->data); MatchShapeCode code = static_cast(code_value); @@ -107,10 +109,10 @@ void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.match_prim_value", MatchPrimValue); -}); +} /*! * \brief Builtin match shape function. @@ -122,7 +124,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { // input shape the first argument can take in tensor or shape. ffi::Shape input_shape; - if (auto opt_nd = args[0].as()) { + if (auto opt_nd = args[0].as()) { input_shape = opt_nd.value().Shape(); } else { input_shape = args[0].cast(); @@ -134,7 +136,7 @@ void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { ICHECK_LE(kBeginCode + size * 2, args.size()); // a function that lazily get context for error reporting const int64_t kErrorContextOffset = kBeginCode + size * 2; - Optional err_ctx = args[kErrorContextOffset].cast(); + ffi::Optional err_ctx = args[kErrorContextOffset].cast(); CHECK_EQ(input_shape.size(), size) << "RuntimeError: " << err_ctx.value_or("") << " match_cast shape size mismatch."; @@ -161,10 +163,10 @@ void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("vm.builtin.match_shape", MatchShape); -}); +} /*! * \brief Builtin make prim value function. @@ -188,10 +190,10 @@ int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.make_prim_value", MakePrimValue); -}); +} /*! * \brief Builtin make shape function. @@ -222,10 +224,10 @@ void MakeShape(ffi::PackedArgs args, ffi::Any* rv) { *rv = ffi::Shape(std::move(shape)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("vm.builtin.make_shape", MakeShape); -}); +} /*! * \brief Builtin function to check if arg is Tensor(dtype, ndim) @@ -238,14 +240,14 @@ void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { ffi::AnyView arg = args[0]; int ndim = args[1].cast(); DataType dtype; - Optional err_ctx; + ffi::Optional err_ctx; if (args.size() == 3) { dtype = DataType::Void(); - err_ctx = args[2].cast>(); + err_ctx = args[2].cast>(); } else { dtype = args[2].cast(); - err_ctx = args[3].cast>(); + err_ctx = args[3].cast>(); } auto opt_ptr = arg.try_cast(); @@ -265,10 +267,10 @@ void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("vm.builtin.check_tensor_info", CheckTensorInfo); -}); +} /*! * \brief Builtin function to check if arg is Shape(ndim) @@ -276,7 +278,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param ndim Expected size of the shape, can be -1 (indicate unknown). * \param err_ctx Additional context if error occurs. */ -void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { +void CheckShapeInfo(ObjectRef arg, int ndim, ffi::Optional err_ctx) { // a function that lazily get context for error reporting auto* ptr = arg.as(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Shape but get " @@ -288,10 +290,10 @@ void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.check_shape_info", CheckShapeInfo); -}); +} /*! * \brief Builtin function to check if arg is PrimValue(dtype) @@ -299,7 +301,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for unknown dtype. * \param err_ctx Additional context if error occurs. */ -void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, Optional err_ctx) { +void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, ffi::Optional err_ctx) { if (auto opt_obj = arg.as()) { LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype << ", but received ObjectRef of type " << opt_obj.value()->GetTypeKey(); @@ -318,10 +320,10 @@ void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, Optional err_c } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.check_prim_value_info", CheckPrimValueInfo); -}); +} /*! * \brief Builtin function to check if arg is Tuple with size elements. @@ -329,7 +331,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param size The expected size of the tuple. * \param err_ctx Additional context if error occurs. */ -void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { +void CheckTupleInfo(ObjectRef arg, int64_t size, ffi::Optional err_ctx) { // a function that lazily get context for error reporting auto* ptr = arg.as(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " @@ -339,33 +341,33 @@ void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { << " but get a Tuple with " << ptr->size() << " elements."; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.check_tuple_info", CheckTupleInfo); -}); +} /*! * \brief Builtin function to check if arg is a callable function. * \param arg The input argument. * \param err_ctx Additional context if error occurs. */ -void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { +void CheckFuncInfo(ObjectRef arg, ffi::Optional err_ctx) { // a function that lazily get context for error reporting bool is_func = arg.as() || arg.as(); CHECK(is_func) << "TypeError: " << err_ctx.value_or("") << " expect a Function but get " << arg->GetTypeKey(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.check_func_info", CheckFuncInfo); -}); +} //------------------------------------------------- // Storage management. //------------------------------------------------- Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_index, - DLDataType dtype_hint, String mem_scope) { + DLDataType dtype_hint, ffi::String mem_scope) { VirtualMachine* vm = static_cast(ctx_ptr); ICHECK_LT(device_index, vm->devices.size()) @@ -384,17 +386,17 @@ Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_inde return Storage(buffer, alloc); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("vm.builtin.alloc_storage", VMAllocStorage) - .def_method("vm.builtin.alloc_tensor", &StorageObj::AllocNDArray); -}); + .def_method("vm.builtin.alloc_tensor", &StorageObj::AllocTensor); +} //------------------------------------------------- // Closure function handling, calling convention //------------------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("vm.builtin.make_closure", @@ -428,26 +430,97 @@ TVM_FFI_STATIC_INIT_BLOCK({ } func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); }); -}); +} + +//------------------------------------- +// Python function call support +//------------------------------------- + +// Global registry for Python functions +static std::unordered_map py_func_registry; + +/*! + * \brief Clear the Python function registry on shutdown + */ +void ClearPyFuncRegistry() { py_func_registry.clear(); } + +/*! + * \brief Register a Python function for call_py_func + * \param name The function name + * \param func The Python function wrapped as ffi::Function + */ +void RegisterPyFunc(const std::string& name, ffi::Function func) { py_func_registry[name] = func; } + +/*! + * \brief Get a registered Python function + * \param name The function name + * \return The Python function + */ +ffi::Function GetPyFunc(const std::string& name) { + auto it = py_func_registry.find(name); + if (it == py_func_registry.end()) { + LOG(FATAL) << "Python function '" << name << "' not found in registry"; + } + return it->second; +} + +/*! + * \brief Call a Python function from VM + * \param args The packed function arguments (tuple containing function name and arguments) + * \param rv The return value + */ +void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) { + // args[0] should be a tuple containing (func_name, args_tuple) + if (args.size() != 1) { + LOG(FATAL) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)"; + } + + auto tuple_arg = args[0].cast>(); + if (tuple_arg.size() != 2) { + LOG(FATAL) << "vm.builtin.call_py_func tuple should contain (func_name, args)"; + } + + // Get function name + std::string func_name = tuple_arg[0].cast(); + + // Get arguments tuple + auto func_args = tuple_arg[1].cast>(); + + // Look up Python function in registry + ffi::Function py_func = GetPyFunc(func_name); + + // Call the Python function with the arguments + std::vector py_args_vec(func_args.begin(), func_args.end()); + ffi::PackedArgs py_args(py_args_vec.data(), py_args_vec.size()); + py_func.CallPacked(py_args, rv); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("vm.builtin.call_py_func", CallPyFunc) + .def("vm.builtin.register_py_func", RegisterPyFunc) + .def("vm.builtin.get_py_func", GetPyFunc) + .def("vm.builtin.clear_py_func_registry", ClearPyFuncRegistry); +} //------------------------------------- // Builtin runtime operators. //------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("vm.builtin.shape_of", &NDArray::Shape) + .def_method("vm.builtin.shape_of", [](Tensor data) -> ffi::Shape { return data.Shape(); }) .def("vm.builtin.copy", [](ffi::Any a) -> ffi::Any { return a; }) - .def("vm.builtin.reshape", - [](NDArray data, ffi::Shape new_shape) { - return data.CreateView(new_shape, data->dtype); - }) + .def( + "vm.builtin.reshape", + [](Tensor data, ffi::Shape new_shape) { return data.CreateView(new_shape, data->dtype); }) .def("vm.builtin.null_value", []() -> std::nullptr_t { return nullptr; }) - .def("vm.builtin.to_device", [](NDArray data, int dev_type, int dev_id) { + .def("vm.builtin.to_device", [](Tensor data, int dev_type, int dev_id) { Device dst_device = {(DLDeviceType)dev_type, dev_id}; return data.CopyTo(dst_device); }); -}); +} /*! * \brief Load the scalar value in cond and return the result value. @@ -458,11 +531,11 @@ bool ReadIfCond(ffi::AnyView cond) { if (auto opt_int = cond.try_cast()) { return opt_int.value(); } - NDArray arr = cond.cast(); + Tensor arr = cond.cast(); if (arr->device.device_type != kDLCPU) { arr = arr.CopyTo(DLDevice{kDLCPU, 0}); } - ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt); + ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || arr->dtype.code == kDLBool); int64_t result; switch (arr->dtype.bits) { case 1: { @@ -492,16 +565,16 @@ bool ReadIfCond(ffi::AnyView cond) { return result != 0; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.read_if_cond", ReadIfCond); -}); +} //------------------------------------- // Debugging API //------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "vm.builtin.invoke_debug_func", [](ffi::PackedArgs args, ffi::Any* rv) -> void { @@ -509,12 +582,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ int num_args = args.size() - 3; ObjectRef io_effect = args[0].cast(); ICHECK(!io_effect.defined()) << "ValueError: IOEffect is expected to be lowered to None."; - String debug_func_name = args[1].cast(); + ffi::String debug_func_name = args[1].cast(); const auto debug_func = tvm::ffi::Function::GetGlobal(debug_func_name); CHECK(debug_func.has_value()) << "ValueError: " << debug_func_name << " is not found. " - << "Use the decorator `@tvm.register_func(\"" + << "Use the decorator `@tvm.register_global_func(\"" << debug_func_name << "\")` to register it."; - String line_info = args[2].cast(); + ffi::String line_info = args[2].cast(); std::vector call_args(num_args + 1); { call_args[0] = line_info; @@ -525,31 +598,31 @@ TVM_FFI_STATIC_INIT_BLOCK({ debug_func->CallPacked(ffi::PackedArgs(call_args.data(), call_args.size()), rv); *rv = io_effect; }); -}); +} //------------------------------------- // Data structure API //------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("vm.builtin.tuple_getitem", - [](Array arr, int64_t index) { return arr[index]; }) + [](ffi::Array arr, int64_t index) { return arr[index]; }) .def("vm.builtin.tuple_reset_item", [](const ffi::ArrayObj* arr, int64_t index) { const_cast(arr)->SetItem(index, nullptr); }) .def_packed("vm.builtin.make_tuple", [](ffi::PackedArgs args, ffi::Any* rv) { - Array arr; + ffi::Array arr; for (int i = 0; i < args.size(); ++i) { arr.push_back(args[i]); } *rv = arr; }) .def("vm.builtin.tensor_to_shape", - [](NDArray data) { - NDArray arr = data; + [](Tensor data) { + Tensor arr = data; if (data->device.device_type != kDLCPU) { arr = data.CopyTo(DLDevice{kDLCPU, 0}); } @@ -581,7 +654,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } return ffi::Shape(out_shape); }) - .def("vm.builtin.ensure_zero_offset", [](NDArray data) { + .def("vm.builtin.ensure_zero_offset", [](Tensor data) { if (data->byte_offset == 0) { return data; } @@ -592,14 +665,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ dl_tensor->dl_tensor.data = reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; dl_tensor->dl_tensor.byte_offset = 0; - return NDArray::FromDLPack(dl_tensor); + return Tensor::FromDLPack(dl_tensor); } else { - auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device); + auto new_array = Tensor::Empty(data.Shape(), data->dtype, data->device); new_array.CopyFrom(data); return new_array; } }); -}); +} } // namespace vm } // namespace runtime @@ -662,7 +735,7 @@ int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMFFIAny* a using namespace tvm::runtime; TVM_FFI_SAFE_CALL_BEGIN(); auto* list = static_cast(anylist); - list[index] = tvm::ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(args[ret_offset])); + list[index] = tvm::ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(&args[ret_offset]); TVM_FFI_SAFE_CALL_END(); } } // extern "C" diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index d7ccff66a046..9523fd3f4b30 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -44,7 +44,7 @@ struct CUDAGraphCaptureKey { // identified by this shape tuple. This is default constructed as an empty tuple. ffi::Shape shape_expr; - CUDAGraphCaptureKey(int64_t index, const Optional& shape_expr) : index(index) { + CUDAGraphCaptureKey(int64_t index, const ffi::Optional& shape_expr) : index(index) { if (shape_expr) { this->shape_expr = shape_expr.value(); } @@ -140,8 +140,6 @@ class CUDACaptureStream { /*! \brief The VM extension of CUDA graph. */ class CUDAGraphExtensionNode : public VMExtensionNode { public: - TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphExtensionNode, VMExtensionNode); - /*! * \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode. * \param vm The virtual machine. @@ -153,20 +151,20 @@ class CUDAGraphExtensionNode : public VMExtensionNode { * \return The return value of the capture function. */ ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, Any args, - int64_t entry_index, Optional shape_expr) { + int64_t entry_index, ffi::Optional shape_expr) { CUDAGraphCaptureKey entry_key{entry_index, shape_expr}; if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) { // Launch CUDA graph const auto& [states, exec] = it->second; int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - CUDA_CALL(cudaGraphLaunch( - exec, static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)))); + CUDA_CALL( + cudaGraphLaunch(exec, static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)))); return states; } // Set up arguments for the graph execution - Array tuple_args = args.cast>(); + ffi::Array tuple_args = args.cast>(); int nargs = static_cast(tuple_args.size()); std::vector packed_args(nargs); @@ -220,7 +218,9 @@ class CUDAGraphExtensionNode : public VMExtensionNode { return alloc_result; } - static constexpr const char* _type_key = "vm.CUDAGraphExtension"; + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("vm.CUDAGraphExtension", CUDAGraphExtensionNode, + VMExtensionNode); private: /*! @@ -240,14 +240,15 @@ class CUDAGraphExtensionNode : public VMExtensionNode { /*! Managed reference to CUDAGraphExtensionNode */ class CUDAGraphExtension : public VMExtension { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAGraphExtension, VMExtension, CUDAGraphExtensionNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CUDAGraphExtension, VMExtension, + CUDAGraphExtensionNode); static CUDAGraphExtension Create() { - auto data_ = make_object(); + auto data_ = ffi::make_object(); return CUDAGraphExtension(std::move(data_)); } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("vm.builtin.cuda_graph.run_or_capture", @@ -258,7 +259,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto capture_func = args[1].cast(); Any func_args = args[2]; int64_t entry_index = args[3].cast(); - Optional shape_expr = std::nullopt; + ffi::Optional shape_expr = std::nullopt; if (args.size() == 5) { shape_expr = args[4].cast(); } @@ -273,7 +274,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int64_t entry_index = args[2].cast(); *rv = extension->GetCachedAllocation(vm, alloc_func, entry_index); }); -}); +} } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index ef6fbe6373af..40edbc14c433 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -48,11 +48,11 @@ std::string VMExecutable::Stats() const { oss << "Relax VM executable statistics:" << std::endl; // Get the number of constants. - // If the constant is an NDArray, get the shape of each of them. + // If the constant is an Tensor, get the shape of each of them. // If the constant is an DLDataType, get the data type of each of them. oss << " Constant pool (# " << constants.size() << "): ["; for (const auto& it : constants) { - if (auto opt_nd = it.as()) { + if (auto opt_nd = it.as()) { const auto ndarray = opt_nd.value(); const auto& shape = ndarray.Shape(); // Scalar @@ -74,7 +74,7 @@ std::string VMExecutable::Stats() const { } oss.seekp(-2, oss.cur); oss << "], "; - } else if (auto opt_str = it.as()) { + } else if (auto opt_str = it.as()) { std::string f = opt_str.value(); oss << "\""; oss << f; @@ -181,7 +181,7 @@ ffi::Bytes VMExecutable::SaveToBytes() const { return ffi::Bytes(code); } -void VMExecutable::WriteToFile(const String& file_name, const String& format) const { +void VMExecutable::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { runtime::SaveBinaryToFile(file_name, VMExecutable::SaveToBytes()); } @@ -189,7 +189,7 @@ ffi::Module VMExecutable::LoadFromBytes(const ffi::Bytes& bytes) { std::string code; dmlc::MemoryFixedSizeStream strm(const_cast(bytes.data()), bytes.size()); - ObjectPtr exec = make_object(); + ObjectPtr exec = ffi::make_object(); // Load header. LoadHeader(&strm); @@ -206,18 +206,18 @@ ffi::Module VMExecutable::LoadFromBytes(const ffi::Bytes& bytes) { return ffi::Module(exec); } -ffi::Module VMExecutable::LoadFromFile(const String& file_name) { +ffi::Module VMExecutable::LoadFromFile(const ffi::String& file_name) { std::string data; runtime::LoadBinaryFromFile(file_name, &data); return VMExecutable::LoadFromBytes(ffi::Bytes(data)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ffi.Module.load_from_file.relax.VMExecutable", VMExecutable::LoadFromFile) .def("ffi.Module.load_from_bytes.relax.VMExecutable", VMExecutable::LoadFromBytes); -}); +} void VMFuncInfo::Save(dmlc::Stream* strm) const { int32_t temp_kind = static_cast(kind); @@ -248,8 +248,8 @@ void VMExecutable::SaveGlobalSection(dmlc::Stream* strm) const { strm->Write(fun void VMExecutable::SaveConstantSection(dmlc::Stream* strm) const { strm->Write(static_cast(this->constants.size())); for (const auto& it : this->constants) { - if (auto opt_nd = it.as()) { - strm->Write(ffi::TypeIndex::kTVMFFINDArray); + if (auto opt_nd = it.as()) { + strm->Write(ffi::TypeIndex::kTVMFFITensor); runtime::SaveDLTensor(strm, opt_nd.value().operator->()); } else if (auto opt_shape = it.as()) { ffi::Shape shape = opt_shape.value(); @@ -258,8 +258,8 @@ void VMExecutable::SaveConstantSection(dmlc::Stream* strm) const { for (size_t i = 0; i < shape.size(); ++i) { strm->Write(shape.at(i)); } - } else if (auto opt_str = it.as()) { - String str = opt_str.value(); + } else if (auto opt_str = it.as()) { + ffi::String str = opt_str.value(); strm->Write(ffi::TypeIndex::kTVMFFIStr); strm->Write(str.size()); for (size_t i = 0; i < str.size(); ++i) { @@ -299,13 +299,13 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); size_t size = static_cast(sz); - runtime::NDArray ndarray; + runtime::Tensor ndarray; DLDataType dtype; // Load each of the constants. for (size_t i = 0; i < size; i++) { int constant_type; STREAM_CHECK(strm->Read(&constant_type, sizeof(constant_type)), "constant"); - if (constant_type == ffi::TypeIndex::kTVMFFINDArray) { + if (constant_type == ffi::TypeIndex::kTVMFFITensor) { ndarray.Load(strm); ffi::Any cell; cell = ndarray; @@ -333,7 +333,7 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { strm->Read(&(data[i])); } ffi::Any cell; - cell = String(std::string(data.begin(), data.end())); + cell = ffi::String(std::string(data.begin(), data.end())); this->constants.push_back(cell); } else if (constant_type == ffi::TypeIndex::kTVMFFIInt) { int64_t value; @@ -348,7 +348,7 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { cell = value; this->constants.push_back(cell); } else { - LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + LOG(FATAL) << "Constant pool can only contain Tensor and DLDataType, but got " << ffi::TypeIndexToTypeKey(constant_type) << " when loading the VM constant pool."; } } @@ -395,9 +395,9 @@ ffi::Module VMExecutable::VMProfilerLoadExecutable() const { return ffi::Module(vm); } -bool VMExecutable::HasFunction(const String& name) const { return func_map.count(name); } +bool VMExecutable::HasFunction(const ffi::String& name) const { return func_map.count(name); } -String VMExecutable::AsText() const { +ffi::String VMExecutable::AsText() const { auto get_func_name = [&](Index index) -> std::string { if (static_cast(index) < func_table.size()) { return func_table[index].name; @@ -471,10 +471,10 @@ String VMExecutable::AsText() const { } os << "\n"; } - return String(os.str()); + return ffi::String(os.str()); } -String VMExecutable::AsPython() const { +ffi::String VMExecutable::AsPython() const { auto get_func_name = [&](Index index) -> std::string { if (static_cast(index) < func_table.size()) { return "\"" + func_table[index].name + "\""; @@ -549,13 +549,13 @@ String VMExecutable::AsPython() const { } } } - return String(os.str()); + return ffi::String(os.str()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.ExecutableLoadFromFile", VMExecutable::LoadFromFile); -}); +} } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/hexagon/builtin.cc b/src/runtime/vm/hexagon/builtin.cc index be5c7f5fd6f9..72929dd3d8f2 100644 --- a/src/runtime/vm/hexagon/builtin.cc +++ b/src/runtime/vm/hexagon/builtin.cc @@ -31,12 +31,13 @@ namespace tvm { namespace runtime { namespace vm { +// clang-format off -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("vm.builtin.hexagon.dma_copy", - [](ffi::AnyView vm_ptr, NDArray src_arr, NDArray dst_arr, int queue_id, + [](ffi::AnyView vm_ptr, Tensor src_arr, Tensor dst_arr, int queue_id, bool bypass_cache) { const DLTensor* dptr = dst_arr.operator->(); const DLTensor* sptr = src_arr.operator->(); @@ -57,8 +58,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ CHECK(ret == DMA_SUCCESS); }) .def("vm.builtin.hexagon.dma_wait", [](ffi::AnyView vm_ptr, int queue_id, int inflight_dma, - bool bypass_cache, [[maybe_unused]] NDArray src_arr, - [[maybe_unused]] NDArray dst_arr) { + bool bypass_cache, [[maybe_unused]] Tensor src_arr, + [[maybe_unused]] Tensor dst_arr) { ICHECK(inflight_dma >= 0); tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight_dma); if (bypass_cache) { @@ -69,7 +70,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ QURT_MEM_DCACHE); } }); -}); +} + +// clang-format on } // namespace vm } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/kv_state.cc b/src/runtime/vm/kv_state.cc index 5d13be7ef519..5d04139a32c8 100644 --- a/src/runtime/vm/kv_state.cc +++ b/src/runtime/vm/kv_state.cc @@ -30,7 +30,7 @@ namespace vm { // Register Object Type // KV State base methods -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("vm.builtin.kv_state_clear", &KVStateObj::Clear) @@ -45,17 +45,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ KVState kv_state = args[0].cast(); ffi::Shape seq_ids = args[1].cast(); ffi::Shape append_lengths = args[2].cast(); - Optional token_tree_parent_ptr; + ffi::Optional token_tree_parent_ptr; if (args.size() == 4) { - token_tree_parent_ptr = args[3].cast>(); + token_tree_parent_ptr = args[3].cast>(); } kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr); }) .def_method("vm.builtin.kv_state_end_forward", &KVStateObj::EndForward); -}); +} // Attention KV Cache methods -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("vm.builtin.kv_cache_disagg_prepare_recv", @@ -76,50 +76,50 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("vm.builtin.attention_kv_cache_debug_get_kv_mla", &AttentionKVCacheObj::DebugGetKVMLA) .def("vm.builtin.attention_kv_cache_attention_with_fused_qkv", - [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray qkv_data, - NDArray o_data) { + [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, Tensor qkv_data, + Tensor o_data) { kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), std::nullopt, std::move(o_data), sm_scale); }) .def("vm.builtin.attention_kv_cache_self_attention", - [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, - NDArray k_data, NDArray v_data, NDArray o_data, NDArray lse_data) { + [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, Tensor q_data, + Tensor k_data, Tensor v_data, Tensor o_data, Tensor lse_data) { kv_cache->SelfAttention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), std::move(o_data), std::move(lse_data), sm_scale); }) .def("vm.builtin.attention_kv_cache_cross_attention", - [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, - NDArray o_data, NDArray lse_data) { + [](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, Tensor q_data, + Tensor o_data, Tensor lse_data) { kv_cache->CrossAttention(layer_id, std::move(q_data), std::move(o_data), std::move(lse_data), sm_scale); }) .def("vm.builtin.attention_kv_cache_append_mla_kv", - [](AttentionKVCache kv_cache, int64_t layer_id, NDArray kv_data) { + [](AttentionKVCache kv_cache, int64_t layer_id, Tensor kv_data) { kv_cache->AppendMLAKV(layer_id, std::move(kv_data)); return kv_cache; }) .def("vm.builtin.attention_kv_cache_merge_attn_output_inplace", - [](AttentionKVCache kv_cache, NDArray o_self_attn, NDArray lse_self_attn, - NDArray o_cross_attn, NDArray lse_cross_attn) { + [](AttentionKVCache kv_cache, Tensor o_self_attn, Tensor lse_self_attn, + Tensor o_cross_attn, Tensor lse_cross_attn) { return kv_cache->MergeAttnOutputInplace( std::move(o_self_attn), std::move(lse_self_attn), std::move(o_cross_attn), std::move(lse_cross_attn)); }); -}); +} // RNN State methods -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("vm.builtin.rnn_state_get", &RNNStateObj::Get) .def("vm.builtin.rnn_state_set", - [](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) { + [](RNNState state, int64_t layer_id, int64_t state_id, Tensor data) { state->Set(layer_id, state_id, data); return state; }) .def_method("vm.builtin.rnn_state_debug_get", &RNNStateObj::DebugGet); -}); +} } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/kv_state.h b/src/runtime/vm/kv_state.h index 46d8f4f59603..33c669f18ab2 100644 --- a/src/runtime/vm/kv_state.h +++ b/src/runtime/vm/kv_state.h @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace runtime { @@ -94,8 +94,9 @@ class KVStateObj : public Object { * is the sum of "append_lengths". Nullptr means the token tree of each sequence * is a chain. */ - virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, - const Optional& token_tree_parent_ptr = std::nullopt) = 0; + virtual void BeginForward( + const IntTuple& seq_ids, const IntTuple& append_lengths, + const ffi::Optional& token_tree_parent_ptr = std::nullopt) = 0; /*! * \brief Mark the start of the forward function. @@ -104,13 +105,13 @@ class KVStateObj : public Object { */ virtual void EndForward() = 0; - static constexpr const char* _type_key = "relax.vm.KVState"; - TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("relax.vm.KVState", KVStateObj, Object); }; class KVState : public ObjectRef { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(KVState, ObjectRef, KVStateObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(KVState, ObjectRef, KVStateObj); }; /*! @@ -178,8 +179,8 @@ class AttentionKVCacheObj : public KVStateObj { * \param sm_scale The additional attention scaling factor. * \sa AttentionKVCache::Attention */ - virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double sm_scale) = 0; + virtual void AttentionWithFusedQKV(int64_t layer_id, Tensor qkv_data, ffi::Optional mask, + Tensor o_data, double sm_scale) = 0; /*! * \brief Fine-grained API that computes ragged self attention with Q/K/V data. @@ -191,8 +192,8 @@ class AttentionKVCacheObj : public KVStateObj { * \param lse_data The output attention LSE data, in layout `(total_length, num_qo_heads)`. * \param sm_scale The additional attention scaling factor. */ - virtual void SelfAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - NDArray o_data, NDArray lse_data, double sm_scale) = 0; + virtual void SelfAttention(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data, + Tensor o_data, Tensor lse_data, double sm_scale) = 0; /*! * \brief Fine-grained API that computes paged cross attention with Q and in-cache KV data. @@ -202,7 +203,7 @@ class AttentionKVCacheObj : public KVStateObj { * \param lse_data The output attention LSE data, in layout `(total_length, num_qo_heads)`. * \param sm_scale The additional attention scaling factor. */ - virtual void CrossAttention(int64_t layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, + virtual void CrossAttention(int64_t layer_id, Tensor q_data, Tensor o_data, Tensor lse_data, double sm_scale) = 0; /*! @@ -210,7 +211,7 @@ class AttentionKVCacheObj : public KVStateObj { * \param layer_id The model layer where the attention compute happens. * \param kv_data The input KV data to append, in layout `(total_length, qk_head_dim)`. */ - virtual void AppendMLAKV(int64_t layer_id, NDArray kv_data) = 0; + virtual void AppendMLAKV(int64_t layer_id, Tensor kv_data) = 0; /*! * \brief Fine-grained API that merges the attention output from two sources. @@ -220,8 +221,8 @@ class AttentionKVCacheObj : public KVStateObj { * \param lse2_data The second source LSE data. * \return The merged O and LSE data. */ - virtual Array MergeAttnOutputInplace(NDArray o_self_attn, NDArray lse_self_attn, - NDArray o_cross_attn, NDArray lse_cross_attn) = 0; + virtual ffi::Array MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn, + Tensor o_cross_attn, Tensor lse_cross_attn) = 0; /*! * \brief Compute linear attention with Q/K/V data. @@ -233,7 +234,7 @@ class AttentionKVCacheObj : public KVStateObj { * \param sm_scale The additional attention scaling factor. * \sa AttentionKVCache::Attention */ - virtual void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + virtual void LinearAttention(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data, double sm_scale) = 0; /************** Positions **************/ @@ -243,7 +244,7 @@ class AttentionKVCacheObj : public KVStateObj { * This function is supposed to be invoked after calling BeginForward. * \return The in-sequence query positions, in shape `(total_length,)`. */ - virtual NDArray GetQueryPositions() = 0; + virtual Tensor GetQueryPositions() = 0; /************** Debug Helpers **************/ @@ -265,7 +266,7 @@ class AttentionKVCacheObj : public KVStateObj { * \param V_data The output V data of the given sequence in layout elaborated above. */ virtual void DebugGetKV(int64_t seq_id, // - int64_t start_pos, int64_t end_pos, NDArray k_data, NDArray v_data) = 0; + int64_t start_pos, int64_t end_pos, Tensor k_data, Tensor v_data) = 0; /*! * \brief Fetch the compact K/V data of the given sequence for MLA cache. @@ -275,7 +276,7 @@ class AttentionKVCacheObj : public KVStateObj { * \param kv_data The output KV data of the given sequence in layout elaborated above. */ virtual void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos, - NDArray kv_data) = 0; + Tensor kv_data) = 0; /*! * \brief Set the K/V data of the given sequence from input K/V data. @@ -291,15 +292,15 @@ class AttentionKVCacheObj : public KVStateObj { * \param k_data The K data to set in layout elaborated above. * \param v_data The V data to set in layout elaborated above. */ - virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) = 0; + virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, Tensor k_data, Tensor v_data) = 0; - static constexpr const char* _type_key = "relax.vm.AttentionKVCache"; - TVM_DECLARE_BASE_OBJECT_INFO(AttentionKVCacheObj, KVStateObj); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("relax.vm.AttentionKVCache", AttentionKVCacheObj, KVStateObj); }; class AttentionKVCache : public KVState { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, KVState, AttentionKVCacheObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttentionKVCache, KVState, AttentionKVCacheObj); }; /*! @@ -317,7 +318,7 @@ class RNNStateObj : public KVStateObj { * \return The array of State data, each element corresponds to a state. * \throws Error if the given sequence id is not valid. */ - virtual void Get(int64_t layer_id, int64_t state_id, NDArray o_data) = 0; + virtual void Get(int64_t layer_id, int64_t state_id, Tensor o_data) = 0; /*! * \brief Set the State data for the specified sequence. @@ -326,7 +327,7 @@ class RNNStateObj : public KVStateObj { * \param data The data to be set. * \throws Error if the given sequence id is not valid. */ - virtual void Set(int64_t layer_id, int64_t state_id, NDArray data) = 0; + virtual void Set(int64_t layer_id, int64_t state_id, Tensor data) = 0; /*! * \brief Fetch the compact rnn state data of the given sequence. @@ -334,15 +335,15 @@ class RNNStateObj : public KVStateObj { * \param state_id The state id within the layer. * \param seq_id The sequence whose state data is to be fetched. */ - virtual NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) = 0; + virtual Tensor DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) = 0; - static constexpr const char* _type_key = "relax.vm.RNNState"; - TVM_DECLARE_BASE_OBJECT_INFO(RNNStateObj, KVStateObj); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("relax.vm.RNNState", RNNStateObj, KVStateObj); }; class RNNState : public KVState { public: - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RNNState, KVState, RNNStateObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RNNState, KVState, RNNStateObj); }; } // namespace vm diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index 599978579f67..e4bdb7e86607 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -42,7 +42,7 @@ #include #include #include -#include +#include #include #include @@ -66,7 +66,7 @@ class AttentionKVCacheLegacyObj : public Object { /*! * \brief Underlying support data. */ - NDArray data; + Tensor data; /*! * \brief number of slots already filled. @@ -82,7 +82,7 @@ class AttentionKVCacheLegacyObj : public Object { * \brief View all current cached values as one array. * \param shape The cached values. */ - NDArray View(const ffi::Shape& shape) { + Tensor View(const ffi::Shape& shape) { CHECK_EQ(shape[0], fill_count) << "Requested shape do not match the filled count"; for (int i = 1; i < this->data->ndim; ++i) { CHECK_EQ(shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; @@ -102,7 +102,7 @@ class AttentionKVCacheLegacyObj : public Object { this->fill_count -= n; } - void Update(NDArray value) { + void Update(Tensor value) { CHECK(data.DataType() == value.DataType()) << "dtype mismatch"; CHECK_EQ(value->shape[0], fill_count) << "Requested shape do not match the filled count"; ICHECK(data.IsContiguous()); @@ -111,7 +111,7 @@ class AttentionKVCacheLegacyObj : public Object { DLTensor copy_dst = *(data.operator->()); copy_dst.byte_offset = 0; copy_dst.shape = value->shape; - NDArray::CopyFromTo(value.operator->(), ©_dst); + Tensor::CopyFromTo(value.operator->(), ©_dst); this->fill_count = value->shape[0]; } @@ -121,7 +121,7 @@ class AttentionKVCacheLegacyObj : public Object { * \param max_cache_size max size of the cache. * \param num_attention_sinks number of sinks to store (https://arxiv.org/abs/2309.17453). */ - void WindowOverride(NDArray value, int64_t max_cache_size, int64_t num_attention_sinks = 0) { + void WindowOverride(Tensor value, int64_t max_cache_size, int64_t num_attention_sinks = 0) { CHECK(data.DataType() == value.DataType()) << "dtype mismatch"; CHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) << "dim 0 of value too large"; // reallocate cache @@ -133,7 +133,7 @@ class AttentionKVCacheLegacyObj : public Object { if (reserved_slots != data->shape[0]) { std::vector new_shape(data->shape, data->shape + data->ndim); new_shape[0] = reserved_slots; - NDArray new_data = NDArray::Empty(new_shape, data->dtype, data->device); + Tensor new_data = Tensor::Empty(new_shape, data->dtype, data->device); new_data.CreateView(data.Shape(), data->dtype).CopyFrom(data); this->data = new_data; } @@ -165,7 +165,7 @@ class AttentionKVCacheLegacyObj : public Object { copy_src.byte_offset = 0; copy_src.shape = &shape[0]; - NDArray::CopyFromTo(©_src, ©_dst); + Tensor::CopyFromTo(©_src, ©_dst); } // copy the remainder to the beginning of the cache @@ -186,7 +186,7 @@ class AttentionKVCacheLegacyObj : public Object { num_filled_elements * ((value->dtype.bits * value->dtype.lanes + 7) / 8); copy_src.shape = &shape[0]; - NDArray::CopyFromTo(©_src, ©_dst); + Tensor::CopyFromTo(©_src, ©_dst); this->window_attention_current_pos = value->shape[0] - num_elements_to_copy + num_attention_sinks; } @@ -196,7 +196,7 @@ class AttentionKVCacheLegacyObj : public Object { * \brief Append value to the cache. * \param value The value to be appended. */ - void Append(NDArray value) { + void Append(Tensor value) { CHECK(data.DataType() == value.DataType()) << "dtype mismatch"; // reallocate cache int64_t reserved_slots = data->shape[0]; @@ -206,7 +206,7 @@ class AttentionKVCacheLegacyObj : public Object { if (reserved_slots != data->shape[0]) { std::vector new_shape(data->shape, data->shape + data->ndim); new_shape[0] = reserved_slots; - NDArray new_data = NDArray::Empty(new_shape, data->dtype, data->device); + Tensor new_data = Tensor::Empty(new_shape, data->dtype, data->device); new_data.CreateView(data.Shape(), data->dtype).CopyFrom(data); this->data = new_data; } @@ -223,12 +223,13 @@ class AttentionKVCacheLegacyObj : public Object { DLTensor copy_dst = *(data.operator->()); copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8); copy_dst.shape = value->shape; - NDArray::CopyFromTo(value.operator->(), ©_dst); + Tensor::CopyFromTo(value.operator->(), ©_dst); this->fill_count += value->shape[0]; } - static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheLegacyObj, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.vm.AttentionKVCacheLegacy", AttentionKVCacheLegacyObj, + Object); }; /*! \brief reference to closure. */ @@ -238,10 +239,10 @@ class AttentionKVCacheLegacy : public ObjectRef { * \brief Create the attention kv cache. * \param init_data The initial reserved. */ - static AttentionKVCacheLegacy Create(NDArray init_data, ffi::Shape reserve_shape, + static AttentionKVCacheLegacy Create(Tensor init_data, ffi::Shape reserve_shape, int init_fill_count) { - auto n = make_object(); - n->data = NDArray::Empty(reserve_shape, init_data->dtype, init_data->device); + auto n = ffi::make_object(); + n->data = Tensor::Empty(reserve_shape, init_data->dtype, init_data->device); n->fill_count = 0; n->Append(init_data); if (init_fill_count >= 0) { @@ -251,69 +252,68 @@ class AttentionKVCacheLegacy : public ObjectRef { return AttentionKVCacheLegacy(n); } - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef, - AttentionKVCacheLegacyObj); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttentionKVCacheLegacy, ObjectRef, + AttentionKVCacheLegacyObj); }; //------------------------------------------------- // Register runtime functions //------------------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.attention_kv_cache_create", AttentionKVCacheLegacy::Create); -}); +} -AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, NDArray value) { +AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, Tensor value) { cache->Update(value); return cache; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.attention_kv_cache_update", AttentionKVCacheUpdate); -}); +} -AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache, NDArray value) { +AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache, Tensor value) { cache->Append(value); return cache; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.attention_kv_cache_append", AttentionKVCacheAppend); -}); +} -AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cache, NDArray value, +AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cache, Tensor value, int64_t max_cache_size) { cache->WindowOverride(value, max_cache_size); return cache; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.attention_kv_cache_window_override", AttentionKVCacheWindowOverride); -}); +} AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheLegacy cache, - NDArray value, - int64_t max_cache_size, + Tensor value, int64_t max_cache_size, int64_t num_attention_sinks) { cache->WindowOverride(value, max_cache_size, num_attention_sinks); return cache; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.attention_kv_cache_window_override_with_sinks", AttentionKVCacheWindowOverrideWithSinks); -}); +} -NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ffi::Shape shape) { +Tensor AttentionKVCacheView(AttentionKVCacheLegacy cache, ffi::Shape shape) { return cache->View(shape); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "vm.builtin.attention_kv_cache_view", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -333,32 +333,32 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = cache->View(ffi::Shape(shape)); } }); -}); +} -void AttentionKVCacheArrayPopN(Array caches, int64_t n) { +void AttentionKVCacheArrayPopN(ffi::Array caches, int64_t n) { for (AttentionKVCacheLegacy cache : caches) { cache->PopN(static_cast(n)); } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.attention_kv_cache_array_popn", AttentionKVCacheArrayPopN); -}); +} -void AttentionKVCacheArrayClear(Array caches) { +void AttentionKVCacheArrayClear(ffi::Array caches) { for (AttentionKVCacheLegacy cache : caches) { cache->Clear(); } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.attention_kv_cache_array_clear", AttentionKVCacheArrayClear); -}); +} // NOTE this is a built-in highly related to LM so we put it here. -int SampleTopPFromLogits(NDArray logits, double temperature, double top_p, double uniform_sample) { +int SampleTopPFromLogits(Tensor logits, double temperature, double top_p, double uniform_sample) { ICHECK(logits.IsContiguous()); ICHECK(logits.DataType() == DataType::Float(32)); @@ -419,12 +419,12 @@ int SampleTopPFromLogits(NDArray logits, double temperature, double top_p, doubl return data[0].second; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.sample_top_p_from_logits", SampleTopPFromLogits); -}); +} -int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { +int SampleTopPFromProb(Tensor prob, double top_p, double uniform_sample) { ICHECK(prob.IsContiguous()); ICHECK(prob.DataType() == DataType::Float(32)); @@ -517,12 +517,12 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { return sampled_index; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.sample_top_p_from_prob", SampleTopPFromProb); -}); +} -NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { +Tensor MultinomialFromUniform(Tensor prob, Tensor uniform_sample) { ICHECK(prob.IsContiguous()); ICHECK(uniform_sample.IsContiguous()); @@ -540,7 +540,7 @@ NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { int64_t vocab_size = prob->shape[prob->ndim - 1]; const float* pprob = static_cast(prob->data); const float* psample = static_cast(uniform_sample->data); - NDArray new_array = NDArray::Empty({batch_size, 1}, DataType::Int(64), uniform_sample->device); + Tensor new_array = Tensor::Empty({batch_size, 1}, DataType::Int(64), uniform_sample->device); int64_t* parray = static_cast(new_array->data); for (int64_t i = 0; i < batch_size; ++i) { float cum_sum_prob = 0.0f; @@ -557,13 +557,13 @@ NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { return new_array; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.multinomial_from_uniform", MultinomialFromUniform); -}); +} // This is an inplace operation. -void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { +void ApplyRepetitionPenalty(Tensor logits, Tensor token_ids, double penalty) { ICHECK(logits.IsContiguous()); ICHECK(token_ids.IsContiguous()); ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; @@ -583,10 +583,10 @@ void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.apply_repetition_penalty", ApplyRepetitionPenalty); -}); +} /*! * \brief Apply presence and frequency penalty. This is an inplace operation. @@ -597,7 +597,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param presence_penalty The penalty factor, applied if a token appeared in an one-off manner. * \param frequency_penalty The penalty factor, contributes more the more frequent a token appears. */ -void ApplyPresenceAndFrequencyPenalty(NDArray logits, NDArray token_ids, NDArray token_freqs, +void ApplyPresenceAndFrequencyPenalty(Tensor logits, Tensor token_ids, Tensor token_freqs, double presence_penalty, double frequency_penalty) { // See https://platform.openai.com/docs/guides/text-generation/frequency-and-presence-penalties ICHECK(logits.IsContiguous()); @@ -621,14 +621,14 @@ void ApplyPresenceAndFrequencyPenalty(NDArray logits, NDArray token_ids, NDArray } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.apply_presence_and_frequency_penalty", ApplyPresenceAndFrequencyPenalty); -}); +} // This is an inplace operation. -void ApplySoftmaxWithTemperature(NDArray logits, double temperature) { +void ApplySoftmaxWithTemperature(Tensor logits, double temperature) { ICHECK(logits.IsContiguous()); ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; @@ -649,10 +649,10 @@ void ApplySoftmaxWithTemperature(NDArray logits, double temperature) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("vm.builtin.apply_softmax_with_temperature", ApplySoftmaxWithTemperature); -}); +} } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 405f2f482a01..4fb3cd69d60f 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include @@ -111,7 +111,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief The RoPE theta. */ const double rotary_theta_; /*! \brief The optional RoPE extension factors for RoPE scaling. */ - const Optional rope_ext_factors_; + const ffi::Optional rope_ext_factors_; /*! \brief The KV cache dtype. */ const DataType kv_dtype_; @@ -122,15 +122,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! * \brief The KV data managed by the KV cache. - * If KV transfer function is specifed, pages_ will be allocated by NVSHMEM as a whole NDArray. + * If KV transfer function is specifed, pages_ will be allocated by NVSHMEM as a whole Tensor. * pages_ will contain tensor view of each layer. - * Otherwise, pages_ has `num_layers` NDArrays, each of them + * Otherwise, pages_ has `num_layers` Tensors, each of them * has layout (num_pages, 2, num_heads, page_size, qk_head_dim). * Along on the "2" dimension, index 0 stands for K and 1 stands for V. */ - std::vector pages_; + std::vector pages_; /*! \brief The whole KV cache allocated by NVSHMEM*/ - NDArray nvshmem_pages_; + Tensor nvshmem_pages_; /*! \brief The list of ids of released pages for page reuse. */ std::vector free_page_ids_; /*! \brief The mapping from sequence ids to sequences. */ @@ -181,15 +181,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::unique_ptr aux_data_manager_; // Temporary arrays to store intermediate attention results. - NDArray temp_attn_q_device_; - NDArray temp_attn_k_device_; - NDArray temp_attn_v_device_; - NDArray temp_attn_output_device_; - NDArray temp_attn_lse_device_; - NDArray merged_attn_lse_device_; - std::vector temp_int_attn_workspace_; - std::vector temp_int_pinned_attn_workspace_; - NDArray temp_float_attn_workspace_; + Tensor temp_attn_q_device_; + Tensor temp_attn_k_device_; + Tensor temp_attn_v_device_; + Tensor temp_attn_output_device_; + Tensor temp_attn_lse_device_; + Tensor merged_attn_lse_device_; + std::vector temp_int_attn_workspace_; + std::vector temp_int_pinned_attn_workspace_; + Tensor temp_float_attn_workspace_; //------------------------------------------- // Below are the auxiliary data structure on CPU. @@ -227,34 +227,34 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // after each synchronization and pass these views as input for // attention/append. //------------------------------------------- - NDArray cur_append_length_indptr_view_; - NDArray k_ragged_rope_pos_offset_view_; - NDArray q_rope_position_map_view_; - NDArray append_position_map_view_; - NDArray kv_transfer_remote_position_map_view_; - NDArray kv_transfer_recver_id_view_; - NDArray kv_transfer_page_to_page_local_position_map_view_; - NDArray kv_transfer_page_to_page_remote_position_map_view_; - NDArray kv_transfer_page_to_page_recver_id_view_; - NDArray temp_attn_output_view_; - NDArray temp_attn_lse_view_; - NDArray merged_attn_lse_view_; - std::vector qo_indptr_on_depths_view_; - std::vector page_indptr_on_depths_view_; - std::vector page_indices_on_depths_view_; - std::vector page_indptr_sliding_window_on_depths_view_; - std::vector page_indices_sliding_window_on_depths_view_; - std::vector length_info_on_depths_view_; - std::vector layer_sliding_window_length_info_on_depths_view_; - std::vector k_rope_pos_offset_view_; - std::vector k_rope_pos_offset_sliding_window_view_; - std::vector tree_attn_mask_view_; - std::vector tree_attn_mn_indptr_view_; - - Optional f_transpose_append_mha_; - Optional f_transpose_append_mla_; - Optional f_transfer_kv_; - Optional f_transfer_kv_page_to_page_ = std::nullopt; + Tensor cur_append_length_indptr_view_; + Tensor k_ragged_rope_pos_offset_view_; + Tensor q_rope_position_map_view_; + Tensor append_position_map_view_; + Tensor kv_transfer_remote_position_map_view_; + Tensor kv_transfer_recver_id_view_; + Tensor kv_transfer_page_to_page_local_position_map_view_; + Tensor kv_transfer_page_to_page_remote_position_map_view_; + Tensor kv_transfer_page_to_page_recver_id_view_; + Tensor temp_attn_output_view_; + Tensor temp_attn_lse_view_; + Tensor merged_attn_lse_view_; + std::vector qo_indptr_on_depths_view_; + std::vector page_indptr_on_depths_view_; + std::vector page_indices_on_depths_view_; + std::vector page_indptr_sliding_window_on_depths_view_; + std::vector page_indices_sliding_window_on_depths_view_; + std::vector length_info_on_depths_view_; + std::vector layer_sliding_window_length_info_on_depths_view_; + std::vector k_rope_pos_offset_view_; + std::vector k_rope_pos_offset_sliding_window_view_; + std::vector tree_attn_mask_view_; + std::vector tree_attn_mn_indptr_view_; + + ffi::Optional f_transpose_append_mha_; + ffi::Optional f_transpose_append_mla_; + ffi::Optional f_transfer_kv_; + ffi::Optional f_transfer_kv_page_to_page_ = std::nullopt; ffi::Function f_compact_copy_; std::unique_ptr f_attention_prefill_ragged_; std::unique_ptr f_attention_prefill_; @@ -264,10 +264,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv_; std::unique_ptr f_attention_prefill_with_tree_mask_; std::unique_ptr f_mla_prefill_; - Array f_merge_inplace_; + ffi::Array f_merge_inplace_; ffi::Function f_split_rotary_; ffi::Function f_copy_single_page_; - Optional f_debug_get_kv_; + ffi::Optional f_debug_get_kv_; /*! \brief The device this PagedKVCache runs on. */ Device device_; @@ -279,16 +279,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { TVMStreamHandle kv_transfer_stream_ = nullptr; public: - /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ + /*! \brief Constructor. Take the cache configuration and initialize the Tensors. */ explicit PagedAttentionKVCacheObj( int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, int64_t layer_id_end_offset, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, std::vector attn_kinds, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, - Optional rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, Device device, - Optional f_transpose_append_mha, - Optional f_transpose_append_mla, ffi::Function f_compact_copy, + ffi::Optional rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, + Device device, ffi::Optional f_transpose_append_mha, + ffi::Optional f_transpose_append_mla, ffi::Function f_compact_copy, std::unique_ptr f_attention_prefill_ragged, std::unique_ptr f_attention_prefill, std::unique_ptr f_attention_decode, @@ -296,7 +296,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::unique_ptr f_attention_decode_sliding_window, std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv, std::unique_ptr f_attention_prefill_with_tree_mask, - std::unique_ptr f_mla_prefill, Array f_merge_inplace, + std::unique_ptr f_mla_prefill, ffi::Array f_merge_inplace, ffi::Function f_split_rotary, ffi::Function f_copy_single_page, ffi::Function f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), @@ -360,7 +360,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { (*f_nvshmem_empty)( ffi::Shape({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}), dtype, device) - .cast(); + .cast(); for (int i = 0; i < num_layers; ++i) { pages_.push_back(nvshmem_pages_.CreateView( {num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, nvshmem_pages_->dtype, @@ -380,7 +380,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ffi::Shape kv_cache_shape = GetKVCacheShape(attn_kinds_[layer_id_begin_offset_ + i], num_total_pages, reserved_num_seqs, num_kv_heads, page_size, qk_head_dim, v_head_dim); - pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device)); + pages_.push_back(Tensor::Empty(kv_cache_shape, dtype, device)); } } @@ -442,47 +442,47 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); - temp_int_pinned_attn_workspace_.push_back(NDArray::Empty( + Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + temp_int_pinned_attn_workspace_.push_back(Tensor::Empty( {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); } - qo_indptr_on_depths_view_.push_back(NDArray()); - page_indptr_on_depths_view_.push_back(NDArray()); - page_indices_on_depths_view_.push_back(NDArray()); - page_indptr_sliding_window_on_depths_view_.push_back(NDArray()); - page_indices_sliding_window_on_depths_view_.push_back(NDArray()); - length_info_on_depths_view_.push_back(NDArray()); - layer_sliding_window_length_info_on_depths_view_.push_back(NDArray()); - k_rope_pos_offset_view_.push_back(NDArray()); - k_rope_pos_offset_sliding_window_view_.push_back(NDArray()); - tree_attn_mask_view_.push_back(NDArray()); - tree_attn_mn_indptr_view_.push_back(NDArray()); + qo_indptr_on_depths_view_.push_back(Tensor()); + page_indptr_on_depths_view_.push_back(Tensor()); + page_indices_on_depths_view_.push_back(Tensor()); + page_indptr_sliding_window_on_depths_view_.push_back(Tensor()); + page_indices_sliding_window_on_depths_view_.push_back(Tensor()); + length_info_on_depths_view_.push_back(Tensor()); + layer_sliding_window_length_info_on_depths_view_.push_back(Tensor()); + k_rope_pos_offset_view_.push_back(Tensor()); + k_rope_pos_offset_sliding_window_view_.push_back(Tensor()); + tree_attn_mask_view_.push_back(Tensor()); + tree_attn_mn_indptr_view_.push_back(Tensor()); is_chain_on_depths_.push_back(true); } // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); - temp_int_pinned_attn_workspace_.push_back(NDArray::Empty( + Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + temp_int_pinned_attn_workspace_.push_back(Tensor::Empty( {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); temp_float_attn_workspace_ = - NDArray::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device); + Tensor::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device); } if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) { temp_attn_q_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device); temp_attn_k_device_ = - NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device); + Tensor::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device); temp_attn_v_device_ = - NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device); + Tensor::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device); } temp_attn_output_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device); temp_attn_lse_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); merged_attn_lse_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) { free_page_ids_.push_back(page_id); } @@ -694,7 +694,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { DeviceAPI::Get(device_)->SetStream(device_, copy_stream_); } for (int layer = 0; layer < num_layers_; ++layer) { - NDArray page_layer_view = pages_[layer]; + Tensor page_layer_view = pages_[layer]; f_copy_single_page_(page_layer_view, src_page_id, tgt_page_id, copy_length); } if (copy_stream_ != compute_stream_) { @@ -712,9 +712,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Copy indptr/src/dst arrays to GPU. aux_data_manager_->ResetCompactKVAuxDataCopy(); - NDArray commit_copy_length_indptr_view = + Tensor commit_copy_length_indptr_view = aux_data_manager_->CopyCommitLengthIndptrAsync(&commit_copy_length_indptr_host_); - NDArray commit_copy_src_dst_pos_in_page_table_view = + Tensor commit_copy_src_dst_pos_in_page_table_view = aux_data_manager_->CopyCommitSrcDstPosInPageTableAsync( &commit_copy_src_pos_in_page_table_host_, &commit_copy_dst_pos_in_page_table_host_); aux_data_manager_->CommitCompactKVAuxDataCopy(); @@ -849,7 +849,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /************** Attention **************/ void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths, - const Optional& opt_token_tree_parent_ptr) final { + const ffi::Optional& opt_token_tree_parent_ptr) final { // Note: MLA does not supported tree attention for now. if (attn_kinds_[0] == AttnKind::kMLA) { CHECK(!opt_token_tree_parent_ptr.defined()) << "Tree attention is not supported yet for MLA"; @@ -1271,13 +1271,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sequence->kv_transfer_metadata.local_position_map.end()); } - void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double sm_scale) final { + void AttentionWithFusedQKV(int64_t layer_id, Tensor qkv_data, ffi::Optional mask, + Tensor o_data, double sm_scale) final { // Part 1. Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); - NDArray pages = pages_[local_layer_id]; + Tensor pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); CHECK(attn_kinds_[layer_id] == AttnKind::kMHA || @@ -1308,15 +1308,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // The auxiliary data structure on device must have been synchronized. ICHECK(!dirty_aux_data_device_); - NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_}, - qkv_data->dtype); - NDArray k_data = temp_attn_k_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_}, - qkv_data->dtype); - NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_}, - qkv_data->dtype); + Tensor q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_}, + qkv_data->dtype); + Tensor k_data = temp_attn_k_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_}, + qkv_data->dtype); + Tensor v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_}, + qkv_data->dtype); - NDArray qkv_data_view = qkv_data; - NDArray o_data_view = o_data; + Tensor qkv_data_view = qkv_data; + Tensor o_data_view = o_data; if (total_seq_length != qkv_data->shape[0]) { qkv_data_view = qkv_data.CreateView( {total_seq_length, qkv_data->shape[1], qkv_data->shape[2]}, qkv_data->dtype); @@ -1372,13 +1372,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void SelfAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - NDArray o_data, NDArray lse_data, double sm_scale) final { + void SelfAttention(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, + Tensor lse_data, double sm_scale) final { // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); - NDArray pages = pages_[local_layer_id]; + Tensor pages = pages_[local_layer_id]; CHECK(q_data.DataType() == pages.DataType()); CHECK(k_data.DataType() == pages.DataType()); CHECK(v_data.DataType() == pages.DataType()); @@ -1415,13 +1415,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void CrossAttention(int64_t layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, + void CrossAttention(int64_t layer_id, Tensor q_data, Tensor o_data, Tensor lse_data, double sm_scale) final { // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); - NDArray pages = pages_[local_layer_id]; + Tensor pages = pages_[local_layer_id]; CHECK(q_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); AttnKind attn_kind = attn_kinds_[layer_id]; @@ -1455,12 +1455,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void AppendMLAKV(int64_t layer_id, NDArray kv_data) final { + void AppendMLAKV(int64_t layer_id, Tensor kv_data) final { // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); - NDArray pages = pages_[local_layer_id]; + Tensor pages = pages_[local_layer_id]; CHECK(kv_data.DataType() == pages.DataType()); CHECK(attn_kinds_[layer_id] == AttnKind::kMLA); @@ -1481,14 +1481,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transpose_append_mla_.value()(pages_[local_layer_id], kv_data, append_position_map_view_); } - Array MergeAttnOutputInplace(NDArray o_self_attn, NDArray lse_self_attn, - NDArray o_cross_attn, NDArray lse_cross_attn) final { + ffi::Array MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn, + Tensor o_cross_attn, Tensor lse_cross_attn) final { CHECK_GE(f_merge_inplace_.size(), 2) << "The general attention merge function is not defined."; f_merge_inplace_[1](o_self_attn, lse_self_attn, o_cross_attn, lse_cross_attn); return {o_self_attn, lse_self_attn}; } - void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + void LinearAttention(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data, double sm_scale) { // Todo(ruihang): implement it } @@ -1586,7 +1586,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - NDArray GetQueryPositions() final { + Tensor GetQueryPositions() final { // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. @@ -1594,8 +1594,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return q_rope_position_map_view_; }; - void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, NDArray k_data, - NDArray v_data) final { + void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, Tensor k_data, + Tensor v_data) final { CHECK(f_debug_get_kv_.defined()) << "PageAttentionKVCache requires the `f_debug_get_kv` to be explicitly passed in when " "initialization. Please construct the KV cache with `f_debug_get_kv`."; @@ -1609,8 +1609,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { static constexpr const char* error_msg = "DebugGetKV expects the k_data in layout (num_layers, seq_length, num_kv_heads, " "qk_head_dim)."; - std::vector vec_kv_data = {&k_data, &v_data}; - for (const NDArray* data_ptr : vec_kv_data) { + std::vector vec_kv_data = {&k_data, &v_data}; + for (const Tensor* data_ptr : vec_kv_data) { CHECK_EQ((*data_ptr)->ndim, 4) << error_msg; CHECK_EQ((*data_ptr)->shape[0], num_layers_) << error_msg << " The number of layers mismatches."; @@ -1635,7 +1635,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map.push_back(page_id * page_size_ + page_offset); } } - NDArray position_map_device = NDArray::Empty({end_pos - start_pos}, dtype_aux_, device_); + Tensor position_map_device = Tensor::Empty({end_pos - start_pos}, dtype_aux_, device_); position_map_device.CopyFromBytes( append_position_map.data() + start_pos, (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8)); @@ -1645,7 +1645,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos, NDArray kv_data) final { + void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos, Tensor kv_data) final { CHECK(f_debug_get_kv_.defined()) << "PageAttentionKVCache requires the `f_debug_get_kv` to be explicitly passed in when " "initialization. Please construct the KV cache with `f_debug_get_kv`."; @@ -1678,7 +1678,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map.push_back(page_id * page_size_ + page_offset); } } - NDArray position_map_device = NDArray::Empty({end_pos - start_pos}, dtype_aux_, device_); + Tensor position_map_device = Tensor::Empty({end_pos - start_pos}, dtype_aux_, device_); position_map_device.CopyFromBytes( append_position_map.data() + start_pos, (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8)); @@ -1688,12 +1688,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) final { + void DebugSetKV(int64_t seq_id, int64_t start_pos, Tensor k_data, Tensor v_data) final { ICHECK(false) << "DebugSetKV for PageAttentionKVCache not implemented yet."; } - - static constexpr const char* _type_key = "relax.vm.PagedAttentionKVCache"; - TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, AttentionKVCacheObj); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.vm.PagedAttentionKVCache", PagedAttentionKVCacheObj, + AttentionKVCacheObj); private: /*! \brief Get a new free page and return its id. */ @@ -2053,7 +2052,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { temp_float_attn_workspace_, temp_int_attn_workspace_[0], temp_int_pinned_attn_workspace_[0], &cur_append_lengths_indptr_host_, &cur_append_lengths_indptr_host_, cur_batch_size_, - cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_kv_heads_, qk_head_dim_, + cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_qo_heads_, qk_head_dim_, v_head_dim_, /*causal=*/true, copy_stream_); } } @@ -2080,8 +2079,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { * \brief Compute attention for between the input q data and the * input k/v data and the k/v data in cache on the given layer. */ - void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - NDArray output, double sm_scale) { + void AttentionInternal(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data, + Tensor output, double sm_scale) { int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); @@ -2099,8 +2098,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "Both self-attention and cross-attention are not computed."; } - void MHASelfAttnInternal(NDArray q_data, NDArray k_data, NDArray v_data, NDArray o_data, - NDArray lse_data, double sm_scale) { + void MHASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, + Tensor lse_data, double sm_scale) { if (is_chain_on_depths_[0]) { // If the batch does not form a tree, use raggedness prefill kernel. ICHECK_NOTNULL(f_attention_prefill_ragged_); @@ -2121,8 +2120,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void MLASelfAttnInternal(NDArray q_data, NDArray k_data, NDArray v_data, NDArray o_data, - NDArray lse_data, double sm_scale) { + void MLASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, + Tensor lse_data, double sm_scale) { CHECK(is_chain_on_depths_[0]) << "Tree attn not able for MLA for now."; // If the batch does not form a tree, use raggedness prefill kernel. ICHECK_NOTNULL(f_attention_prefill_ragged_); @@ -2133,8 +2132,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } /*! \brief Compute cross-attention for MHA. Return if there is effective computation. */ - bool MHACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, - NDArray lse_data, double sm_scale, bool is_first_kernel) { + bool MHACrossAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor o_data, Tensor lse_data, + double sm_scale, bool is_first_kernel) { std::unique_ptr& f_prefill = (!support_sliding_window_ && attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) @@ -2152,8 +2151,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - NDArray attn_output; - NDArray attn_lse; + Tensor attn_output; + Tensor attn_lse; if (is_first_kernel) { attn_output = o_data; attn_lse = lse_data; @@ -2162,10 +2161,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_lse = temp_attn_lse_view_; } // If layer is sliding window, use sliding window index pointer/indices - NDArray page_indptr; - NDArray page_indices; - NDArray length_info; - NDArray k_rope_pos; + Tensor page_indptr; + Tensor page_indices; + Tensor length_info; + Tensor k_rope_pos; double rotary_theta; double rotary_scale; @@ -2219,8 +2218,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } /*! \brief Compute cross-attention for MLA. Return if there is effective computation. */ - bool MLACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, - NDArray lse_data, double sm_scale) { + bool MLACrossAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor o_data, Tensor lse_data, + double sm_scale) { CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; bool is_first_kernel = true; @@ -2228,8 +2227,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - NDArray attn_output; - NDArray attn_lse; + Tensor attn_output; + Tensor attn_lse; if (is_first_kernel) { attn_output = o_data; attn_lse = lse_data; @@ -2259,7 +2258,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // If the auxiliary data is already synced, return and no need to sync again. return; } - // - Sync NDArrays to GPU. + // - Sync Tensors to GPU. SyncAuxArrayToDevice(); KernelBeginForward(); // - Clear the dirty flag. @@ -2434,7 +2433,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Register runtime functions //------------------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "vm.builtin.paged_attention_kv_cache_create", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -2463,36 +2462,36 @@ TVM_FFI_STATIC_INIT_BLOCK({ int rope_mode = args[8].cast(); double rotary_scale = args[9].cast(); double rotary_theta = args[10].cast(); - Optional rope_ext_factors = std::nullopt; // args[11] - NDArray init = args[12].cast(); - Optional f_transpose_append_mha = std::nullopt; // args[13] - Optional f_transpose_append_mla = std::nullopt; // args[14] + ffi::Optional rope_ext_factors = std::nullopt; // args[11] + Tensor init = args[12].cast(); + ffi::Optional f_transpose_append_mha = std::nullopt; // args[13] + ffi::Optional f_transpose_append_mla = std::nullopt; // args[14] std::unique_ptr f_attention_prefill_ragged = - ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); + ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill = - ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); + ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_decode = - ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); + ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_sliding_window = - ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); + ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_decode_sliding_window = - ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); + ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv = - ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); + ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_with_tree_mask = - ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); + ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); std::unique_ptr f_mla_prefill = - ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); - Array f_merge_inplace = args[23].cast>(); + ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); + ffi::Array f_merge_inplace = args[23].cast>(); ffi::Function f_split_rotary = args[24].cast(); ffi::Function f_copy_single_page = args[25].cast(); ffi::Function f_debug_get_kv = args[26].cast(); ffi::Function f_compact_copy = args[27].cast(); - if (auto opt_nd = args[11].as()) { + if (auto opt_nd = args[11].as()) { rope_ext_factors = opt_nd.value(); } - auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { + auto f_convert_optional_packed_func = [&args](int arg_idx) -> ffi::Optional { if (auto opt_func = args[arg_idx].as()) { return opt_func.value(); } @@ -2521,7 +2520,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } // NOTE: We will remove this legacy construction after finishing the transition phase. // Some `ffi::Function()` here are placeholders that will be filled. - ObjectPtr n = make_object( + ObjectPtr n = ffi::make_object( page_size, num_layers, layer_id_begin_offset, layer_id_end_offset, num_qo_heads, num_kv_heads, qk_head_dim, v_head_dim, attn_kinds_vec, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), @@ -2538,7 +2537,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); }); -}); +} } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index 8963df065258..61194b5dade2 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -78,9 +78,9 @@ class RNNStateImpObj : public RNNStateObj { const int64_t max_history_ = 1; /*! * \brief The init value for ALL layer in the storage. - * The array has `num_states_per_layer_` NDArrays + * The array has `num_states_per_layer_` Tensors */ - const Array init_layer_value_; + const ffi::Array init_layer_value_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); @@ -89,12 +89,12 @@ class RNNStateImpObj : public RNNStateObj { /*! * \brief The storages of space state models. - * The array has `num_layers * num_states_per_layer_` NDArrays, + * The array has `num_layers * num_states_per_layer_` Tensors, * each of them has layout `(num_seq, max_history, state_size)`. * \note As `num_states_per_layer_` may vary for different dtype and shape, - * we use a 2D array to store the NDArrays for each layer. + * we use a 2D array to store the Tensors for each layer. */ - Array> storages_; + ffi::Array> storages_; /*! \brief The list of ids of released seq slot for reuse. */ std::vector free_slot_ids_; /*! \brief The mapping from sequence ids to sequences. */ @@ -117,19 +117,19 @@ class RNNStateImpObj : public RNNStateObj { */ bool dirty_aux_data_device_ = false; /*! \brief The device array of the sequence ids. */ - NDArray seq_slot_ids_device_; + Tensor seq_slot_ids_device_; /*! * \brief The view of the device array of the sequence ids. * The view is used to reuse the memory but with different shape. */ - NDArray seq_slot_ids_view_; + Tensor seq_slot_ids_view_; /*! \brief The device array of the history slot ids. */ - NDArray history_slot_ids_device_; + Tensor history_slot_ids_device_; /*! * \brief The view of the device array of the history slot ids. * The view is used to reuse the memory but with different shape. */ - NDArray history_slot_ids_view_; + Tensor history_slot_ids_view_; /******************* Interaction Functions *******************/ @@ -140,28 +140,28 @@ class RNNStateImpObj : public RNNStateObj { * \note Each state data per layer may have different dtype and shape, so we use a * different function for each state data. */ - Array f_gets_; + ffi::Array f_gets_; /*! * \brief The function to set the state data to the storage. * The function signature is `f_set_(state, seq_slot_ids, history_slot_ids, data, max_history)`. - * where `state` is the storage NDArray, `seq_slot_ids` and `history_slot_ids` are + * where `state` is the storage Tensor, `seq_slot_ids` and `history_slot_ids` are * 1-D int32 arrays of the same length as the batch size, and `data` is the input data. * \note The `history_slot_ids` is the slot of this round, but we need to write to the * slot of the next round. * \note Each state data per layer may have different dtype and shape, so we use a * different function for each state data. */ - Array f_sets_; + ffi::Array f_sets_; public: - /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ - explicit RNNStateImpObj(int64_t num_layers, // - int64_t reserved_num_seqs, // - int64_t max_history, // - DLDevice device, // - Array f_gets, // - Array f_sets, // - Array init_layer_value) + /*! \brief Constructor. Take the cache configuration and initialize the Tensors. */ + explicit RNNStateImpObj(int64_t num_layers, // + int64_t reserved_num_seqs, // + int64_t max_history, // + DLDevice device, // + ffi::Array f_gets, // + ffi::Array f_sets, // + ffi::Array init_layer_value) : num_layers_(num_layers), reserved_num_seqs_(reserved_num_seqs), num_states_per_layer_(init_layer_value.size()), @@ -172,14 +172,14 @@ class RNNStateImpObj : public RNNStateObj { // Allocate the storage for the space state models. storages_.reserve(num_layers_); for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) { - Array layer_storages; + ffi::Array layer_storages; layer_storages.reserve(num_states_per_layer_); for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) { ffi::Shape state_shape = init_layer_value[state_id].Shape(); std::vector storage_shape = {reserved_num_seqs, max_history}; storage_shape.insert(storage_shape.end(), state_shape.begin(), state_shape.end()); - NDArray state_storage = - NDArray::Empty(storage_shape, init_layer_value[state_id].DataType(), device); + Tensor state_storage = + Tensor::Empty(storage_shape, init_layer_value[state_id].DataType(), device); layer_storages.push_back(state_storage); } storages_.push_back(layer_storages); @@ -188,8 +188,8 @@ class RNNStateImpObj : public RNNStateObj { CHECK_GT(max_history_, 0) << "At least 1 history slot to store the current state"; // Allocate the auxiliary arrays on device. - seq_slot_ids_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); - history_slot_ids_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); + seq_slot_ids_device_ = Tensor::Empty({reserved_num_seqs}, dtype_aux_, device); + history_slot_ids_device_ = Tensor::Empty({reserved_num_seqs}, dtype_aux_, device); Clear(); } @@ -208,7 +208,7 @@ class RNNStateImpObj : public RNNStateObj { /************** Interaction **************/ void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths, - const Optional& opt_token_tree_parent_ptr) final { + const ffi::Optional& opt_token_tree_parent_ptr) final { CHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; @@ -259,7 +259,7 @@ class RNNStateImpObj : public RNNStateObj { dirty_aux_data_device_ = true; } - void Get(int64_t layer_id, int64_t state_id, NDArray o_data) final { + void Get(int64_t layer_id, int64_t state_id, Tensor o_data) final { // The auxiliary data structure on device must have been synchronized. CHECK(!dirty_aux_data_device_) << "The auxiliary arrays are not synchronized to device. Please call " @@ -269,11 +269,11 @@ class RNNStateImpObj : public RNNStateObj { CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater than 0."; // TODO(siyuan): support zero-copy when seq_len is one // Copy the state data to the return array. - NDArray state = storages_[layer_id][state_id]; + Tensor state = storages_[layer_id][state_id]; f_gets_[state_id](state, seq_slot_ids_view_, history_slot_ids_view_, o_data); } - void Set(int64_t layer_id, int64_t state_id, NDArray data) final { + void Set(int64_t layer_id, int64_t state_id, Tensor data) final { // The auxiliary data structure on device must have been synchronized. CHECK(!dirty_aux_data_device_) << "The auxiliary arrays are not synchronized to device. Please call " @@ -282,24 +282,24 @@ class RNNStateImpObj : public RNNStateObj { << "The batch size is not consistent with the number of sequence ids."; CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater than 0."; - NDArray state = storages_[layer_id][state_id]; + Tensor state = storages_[layer_id][state_id]; f_sets_[state_id](state, seq_slot_ids_view_, history_slot_ids_view_, data); } - NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) { + Tensor DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) { auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in the space state storage."; - NDArray state = storages_[layer_id][state_id]; + Tensor state = storages_[layer_id][state_id]; int64_t seq_slot_id = it->second.seq_slot_id; int64_t history_slot_id = it->second.history_slot_id; std::vector shape{state.Shape().begin() + 2, state.Shape().end()}; - NDArray result = NDArray::Empty(shape, state->dtype, state->device); + Tensor result = Tensor::Empty(shape, state->dtype, state->device); DLTensor copy_src = GetStatePtrBySeqHistory(layer_id, state_id, seq_slot_id, history_slot_id); DLTensor copy_dst = *result.operator->(); - NDArray::CopyFromTo(©_src, ©_dst); + Tensor::CopyFromTo(©_src, ©_dst); return result; } @@ -316,8 +316,8 @@ class RNNStateImpObj : public RNNStateObj { for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) { DLTensor dst = GetStatePtrBySeqHistory(layer_id, state_id, seq_slot_id, /*history_slot_id=*/0); - NDArray init = init_layer_value_[state_id]; - NDArray::CopyFromTo(init.operator->(), &dst); + Tensor init = init_layer_value_[state_id]; + Tensor::CopyFromTo(init.operator->(), &dst); } } @@ -352,7 +352,7 @@ class RNNStateImpObj : public RNNStateObj { for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) { DLTensor copy_src = GetStatePtrBySeq(layer_id, state_id, parent_slot_id); DLTensor copy_dst = GetStatePtrBySeq(layer_id, state_id, child_slot_id); - NDArray::CopyFromTo(©_src, ©_dst); + Tensor::CopyFromTo(©_src, ©_dst); } } dirty_aux_data_device_ = true; @@ -385,7 +385,7 @@ class RNNStateImpObj : public RNNStateObj { DLTensor GetStatePtrBySeqHistory(int64_t layer_id, int64_t state_id, int64_t seq_slot_id, int64_t history_slot_id) { - NDArray state = storages_[layer_id][state_id]; + Tensor state = storages_[layer_id][state_id]; int64_t state_size = 1; for (int64_t i = 2; i < state->ndim; ++i) { state_size *= state->shape[i]; @@ -396,11 +396,12 @@ class RNNStateImpObj : public RNNStateObj { _state.byte_offset = elem_offset * state->dtype.bits / 8; _state.ndim = state->ndim - 2; _state.shape = const_cast(_state.shape + 2); + _state.strides = const_cast(_state.strides + 2); return _state; } DLTensor GetStatePtrBySeq(int64_t layer_id, int64_t state_id, int64_t seq_slot_id) { - NDArray state = storages_[layer_id][state_id]; + Tensor state = storages_[layer_id][state_id]; int64_t state_size = 1; for (int64_t i = 1; i < state->ndim; ++i) { state_size *= state->shape[i]; @@ -411,6 +412,7 @@ class RNNStateImpObj : public RNNStateObj { _state.byte_offset = elem_offset * state->dtype.bits / 8; _state.ndim = state->ndim - 1; _state.shape = const_cast(_state.shape + 1); + _state.strides = const_cast(_state.strides + 1); return _state; } @@ -420,7 +422,7 @@ class RNNStateImpObj : public RNNStateObj { * invoked before running attention computation on device. */ void SyncAuxArrayToDevice() { - auto fcopy_from_vec = [](NDArray array, std::vector vec_data) { + auto fcopy_from_vec = [](Tensor array, std::vector vec_data) { DLTensor copy_dst = *array.operator->(); DLTensor copy_src; copy_src.data = vec_data.data(); @@ -428,9 +430,9 @@ class RNNStateImpObj : public RNNStateObj { copy_src.ndim = 1; copy_src.dtype = array->dtype; copy_src.shape = array->shape; - copy_src.strides = nullptr; + copy_src.strides = array->strides; copy_src.byte_offset = 0; - NDArray::CopyFromTo(©_src, ©_dst); + Tensor::CopyFromTo(©_src, ©_dst); }; std::vector seq_slot_ids; @@ -456,29 +458,28 @@ class RNNStateImpObj : public RNNStateObj { } public: - static constexpr const char* _type_key = "relax.vm.RNNStateImp"; - TVM_DECLARE_FINAL_OBJECT_INFO(RNNStateImpObj, RNNStateObj); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.vm.RNNStateImp", RNNStateImpObj, RNNStateObj); }; //------------------------------------------------- // Register runtime functions //------------------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("vm.builtin.rnn_state_create", [](int64_t num_layers, // - int64_t reserved_num_seqs, // - int64_t max_history, // - Array f_gets, // - Array f_sets, // - Array init_layer_value) { + refl::GlobalDef().def("vm.builtin.rnn_state_create", [](int64_t num_layers, // + int64_t reserved_num_seqs, // + int64_t max_history, // + ffi::Array f_gets, // + ffi::Array f_sets, // + ffi::Array init_layer_value) { CHECK_GT(num_layers, 0) << "The number of layers should be greater than 0."; CHECK_GT(reserved_num_seqs, 0) << "The number of reserved sequences should be greater than 0."; CHECK_GE(max_history, 0) << "The maximum history length should be greater or equal than 0."; CHECK_GT(init_layer_value.size(), 0) << "The number of states per layer should be greater than 0."; Device device = init_layer_value[0]->device; - for (const NDArray& state : init_layer_value) { + for (const Tensor& state : init_layer_value) { CHECK(state->device.device_type == device.device_type && state->device.device_id == device.device_id) << "The device type of all states should be the same."; @@ -490,11 +491,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ << "The number of state setters should be the same as the number of states per layer, " << "but got " << f_sets.size() << " and " << init_layer_value.size() << " respectively."; ObjectPtr n = - make_object(num_layers, reserved_num_seqs, max_history, device, - std::move(f_gets), std::move(f_sets), init_layer_value); + ffi::make_object(num_layers, reserved_num_seqs, max_history, device, + std::move(f_gets), std::move(f_sets), init_layer_value); return RNNState(std::move(n)); }); -}); +} } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/ndarray_cache_support.cc b/src/runtime/vm/tensor_cache_support.cc similarity index 69% rename from src/runtime/vm/ndarray_cache_support.cc rename to src/runtime/vm/tensor_cache_support.cc index cfd979cc6f24..1f727241cd25 100644 --- a/src/runtime/vm/ndarray_cache_support.cc +++ b/src/runtime/vm/tensor_cache_support.cc @@ -17,17 +17,17 @@ * under the License. */ /*! - * \file src/runtime/vm/ndarray_cache_support.cc - * \brief Runtime to support ndarray cache file loading. + * \file src/runtime/vm/tensor_cache_support.cc + * \brief Runtime to support tensor cache file loading. * - * This file provides a minimum support for ndarray cache file loading. + * This file provides a minimum support for tensor cache file loading. * * The main focus of this implementation is to enable loading * with minimum set of intermediate files while also being * compatible to some of the multi-shard files that are more * friendly in some of the environments. * - * NDArray cache also provides a way to do system-wide + * Tensor cache also provides a way to do system-wide * parameter sharing across multiple VMs. * * There are likely other ways to load the parameters ndarray-ache. @@ -41,8 +41,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -65,7 +65,7 @@ inline ValueType GetValue(const picojson::object& json, const std::string& key) return AsType(json.at(key)); } -NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::object& json) { +TensorCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::object& json) { std::vector shape; { picojson::array shape_json = GetValue(json, "shape"); @@ -74,10 +74,10 @@ NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson:: shape.push_back(AsType(d)); } } - NDArrayCacheMetadata::FileRecord::ParamRecord result; + TensorCacheMetadata::FileRecord::ParamRecord result; std::string dtype = GetValue(json, "dtype"); result.name = GetValue(json, "name"); - result.dtype = DataType(StringToDLDataType(dtype)); + result.dtype = DataType(ffi::StringToDLDataType(dtype)); result.format = GetValue(json, "format"); result.nbytes = GetValue(json, "nbytes"); result.byte_offset = GetValue(json, "byteOffset"); @@ -85,9 +85,9 @@ NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson:: return result; } -NDArrayCacheMetadata::FileRecord JSONAsFileRecord(const picojson::object& json) { +TensorCacheMetadata::FileRecord JSONAsFileRecord(const picojson::object& json) { picojson::array records = GetValue(json, "records"); - NDArrayCacheMetadata::FileRecord result; + TensorCacheMetadata::FileRecord result; result.data_path = GetValue(json, "dataPath"); result.format = GetValue(json, "format"); result.nbytes = GetValue(json, "nbytes"); @@ -98,9 +98,9 @@ NDArrayCacheMetadata::FileRecord JSONAsFileRecord(const picojson::object& json) return result; } -NDArrayCacheMetadata JSONAsNDArrayCacheMetadata(const picojson::object& json) { +TensorCacheMetadata JSONAsTensorCacheMetadata(const picojson::object& json) { picojson::array records = GetValue(json, "records"); - NDArrayCacheMetadata result; + TensorCacheMetadata result; result.records.reserve(records.size()); for (const picojson::value& item : records) { result.records.push_back(JSONAsFileRecord(AsType(item))); @@ -108,8 +108,8 @@ NDArrayCacheMetadata JSONAsNDArrayCacheMetadata(const picojson::object& json) { return result; } -NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromStr(const std::string& json_str, - const std::string& path) { +TensorCacheMetadata TensorCacheMetadata::LoadFromStr(const std::string& json_str, + const std::string& path) { picojson::value json_info; { std::string err = picojson::parse(json_info, json_str); @@ -119,16 +119,16 @@ NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromStr(const std::string& json_s CHECK(json_info.is()) << "ValueError: The given string is not a JSON object: " << json_str; } - NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(AsType(json_info)); + TensorCacheMetadata result = JSONAsTensorCacheMetadata(AsType(json_info)); result.path = path; return result; } -TVM_DLL NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) { +TVM_DLL TensorCacheMetadata TensorCacheMetadata::Load(const std::string& path) { picojson::value json_info; { std::string json_str; - LoadBinaryFromFile(path + "/ndarray-cache.json", &json_str); + LoadBinaryFromFile(path + "/tensor-cache.json", &json_str); std::string err = picojson::parse(json_info, json_str); if (!err.empty()) { LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << json_str; @@ -136,13 +136,13 @@ TVM_DLL NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) CHECK(json_info.is()) << "ValueError: The given string is not a JSON object: " << json_str; } - NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(AsType(json_info)); + TensorCacheMetadata result = JSONAsTensorCacheMetadata(AsType(json_info)); result.path = path; return result; } -void CopyNDArrayFromBytes(NDArray param, const void* data, size_t nbytes, - Optional* staging_buffer) { +void CopyTensorFromBytes(Tensor param, const void* data, size_t nbytes, + ffi::Optional* staging_buffer) { Device device = param->device; if (device.device_type != kDLOpenCL || staging_buffer == nullptr) { param.CopyFromBytes(data, nbytes); @@ -158,17 +158,17 @@ void CopyNDArrayFromBytes(NDArray param, const void* data, size_t nbytes, } } if (!staging_buffer->defined()) { - *staging_buffer = NDArray::Empty(param.Shape(), param->dtype, param->device); + *staging_buffer = Tensor::Empty(param.Shape(), param->dtype, param->device); } - NDArray staging_view = staging_buffer->value().CreateView(param.Shape(), param->dtype); + Tensor staging_view = staging_buffer->value().CreateView(param.Shape(), param->dtype); staging_view.CopyFromBytes(data, nbytes); param.CopyFrom(staging_view); DeviceAPI::Get(device)->StreamSync(device, nullptr); } -NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load( - Device device, const std::string* raw_data, Optional* staging_buffer) const { - NDArray arr = NDArray::Empty(shape, dtype, device); +Tensor TensorCacheMetadata::FileRecord::ParamRecord::Load( + Device device, const std::string* raw_data, ffi::Optional* staging_buffer) const { + Tensor arr = Tensor::Empty(shape, dtype, device); if (dtype == DataType::Float(32) && format == "f32-to-bf16") { // decode bf16 to f32 std::vector buffer(nbytes / 2); @@ -177,24 +177,24 @@ NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load( for (size_t i = 0; i < buffer.size(); ++i) { decoded[i] = static_cast(buffer[i]) << 16; } - CopyNDArrayFromBytes(arr, decoded.data(), decoded.size() * sizeof(uint32_t), staging_buffer); + CopyTensorFromBytes(arr, decoded.data(), decoded.size() * sizeof(uint32_t), staging_buffer); } else { - CopyNDArrayFromBytes(arr, raw_data->data() + byte_offset, nbytes, staging_buffer); + CopyTensorFromBytes(arr, raw_data->data() + byte_offset, nbytes, staging_buffer); } return arr; } -TVM_DLL Array NDArrayCacheMetadata::FileRecord::Load( +TVM_DLL ffi::Array TensorCacheMetadata::FileRecord::Load( Device device, const std::string& path_prefix, // std::string* raw_data_buffer, // - Optional* staging_buffer) const { + ffi::Optional* staging_buffer) const { LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer); CHECK_EQ(this->format, "raw-shard") << "ValueError: Only `raw-shard` format is supported"; CHECK_EQ(this->nbytes, raw_data_buffer->length()) << "ValueError: Encountered an corrupted parameter shard. It means it is not downloaded " "completely or downloading is interrupted. Please try to download again."; - Array result; + ffi::Array result; result.reserve(this->records.size()); for (const ParamRecord& nd_rec : this->records) { result.push_back(nd_rec.Load(device, raw_data_buffer, staging_buffer)); @@ -203,25 +203,25 @@ TVM_DLL Array NDArrayCacheMetadata::FileRecord::Load( } /*! - * A NDArray cache to store pre-loaded arrays in the system. + * A Tensor cache to store pre-loaded arrays in the system. */ -class NDArrayCache { +class TensorCache { public: - static NDArrayCache* Global() { - static NDArrayCache* inst = new NDArrayCache(); + static TensorCache* Global() { + static TensorCache* inst = new TensorCache(); return inst; } - static void Update(String name, NDArray arr, bool override) { - NDArrayCache* pool = Global(); + static void Update(ffi::String name, Tensor arr, bool override) { + TensorCache* pool = Global(); if (!override) { ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already exists in the cache"; } pool->pool_.Set(name, arr); } - static Optional Get(String name) { - NDArrayCache* pool = Global(); + static ffi::Optional Get(ffi::String name) { + TensorCache* pool = Global(); auto it = pool->pool_.find(name); if (it != pool->pool_.end()) { return (*it).second; @@ -230,8 +230,8 @@ class NDArrayCache { } } - static void Remove(String name) { - NDArrayCache* pool = Global(); + static void Remove(ffi::String name) { + TensorCache* pool = Global(); pool->pool_.erase(name); } @@ -245,11 +245,11 @@ class NDArrayCache { */ static void Load(const std::string& cache_path, int device_type, int device_id) { DLDevice device{static_cast(device_type), device_id}; - NDArrayCacheMetadata metadata = NDArrayCacheMetadata::Load(cache_path); - Optional staging_buffer; + TensorCacheMetadata metadata = TensorCacheMetadata::Load(cache_path); + ffi::Optional staging_buffer; std::string raw_data; - Array params; - for (const NDArrayCacheMetadata::FileRecord& shard_rec : metadata.records) { + ffi::Array params; + for (const TensorCacheMetadata::FileRecord& shard_rec : metadata.records) { try { params = shard_rec.Load(device, cache_path, &raw_data, &staging_buffer); } catch (const dmlc::Error& e) { @@ -264,41 +264,41 @@ class NDArrayCache { } private: - Map pool_; + ffi::Map pool_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("vm.builtin.ndarray_cache.get", NDArrayCache::Get) - .def_packed("vm.builtin.ndarray_cache.update", + .def("vm.builtin.tensor_cache.get", TensorCache::Get) + .def_packed("vm.builtin.tensor_cache.update", [](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 2 || args.size() == 3); - String name = args[0].cast(); + ffi::String name = args[0].cast(); bool is_override = args.size() == 2 ? false : args[2].cast(); - NDArray arr; - if (auto opt_nd = args[1].as()) { + Tensor arr; + if (auto opt_nd = args[1].as()) { arr = opt_nd.value(); } else { - // We support converting DLTensors to NDArrays as RPC references are always + // We support converting DLTensors to Tensors as RPC references are always // DLTensors auto tensor = args[1].cast(); std::vector shape; for (int64_t i = 0; i < tensor->ndim; i++) { shape.push_back(tensor->shape[i]); } - arr = NDArray::Empty(shape, tensor->dtype, tensor->device); + arr = Tensor::Empty(shape, tensor->dtype, tensor->device); arr.CopyFrom(tensor); DeviceAPI::Get(arr->device)->StreamSync(arr->device, nullptr); } - NDArrayCache::Update(name, arr, is_override); + TensorCache::Update(name, arr, is_override); }) - .def("vm.builtin.ndarray_cache.remove", NDArrayCache::Remove) - .def("vm.builtin.ndarray_cache.clear", NDArrayCache::Clear) - .def("vm.builtin.ndarray_cache.load", NDArrayCache::Load); -}); + .def("vm.builtin.tensor_cache.remove", TensorCache::Remove) + .def("vm.builtin.tensor_cache.clear", TensorCache::Clear) + .def("vm.builtin.tensor_cache.load", TensorCache::Load); +} // This param module node can be useful to get param dict in RPC mode // when the remote already have loaded parameters from file. @@ -306,7 +306,7 @@ class ParamModuleNode : public ffi::ModuleObj { public: const char* kind() const final { return "param_module"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { if (name == "get_params") { auto params = params_; return ffi::Function([params](ffi::PackedArgs args, ffi::Any* rv) { *rv = params; }); @@ -315,11 +315,11 @@ class ParamModuleNode : public ffi::ModuleObj { } } - static Array GetParams(const String& prefix, int num_params) { - Array params; + static ffi::Array GetParams(const ffi::String& prefix, int num_params) { + ffi::Array params; for (int i = 0; i < num_params || num_params == -1; ++i) { std::string name = prefix + "_" + std::to_string(i); - auto opt = NDArrayCache::Get(name); + auto opt = TensorCache::Get(name); if (opt) { params.push_back(opt.value()); } else { @@ -330,11 +330,11 @@ class ParamModuleNode : public ffi::ModuleObj { return params; } - static Array GetParamByName(const Array& names) { - Array result; + static ffi::Array GetParamByName(const ffi::Array& names) { + ffi::Array result; result.reserve(names.size()); - for (const String& name : names) { - if (Optional opt = NDArrayCache::Get(name)) { + for (const ffi::String& name : names) { + if (ffi::Optional opt = TensorCache::Get(name)) { result.push_back(opt.value()); } else { LOG(FATAL) << "ValueError: Cannot find parameter in cache: " << name; @@ -344,22 +344,22 @@ class ParamModuleNode : public ffi::ModuleObj { } static ffi::Module Create(const std::string& prefix, int num_params) { - auto n = make_object(); + auto n = ffi::make_object(); n->params_ = GetParams(prefix, num_params); return ffi::Module(n); } - static ffi::Module CreateByName(const Array& names) { - auto n = make_object(); + static ffi::Module CreateByName(const ffi::Array& names) { + auto n = ffi::make_object(); n->params_ = GetParamByName(names); return ffi::Module(n); } private: - Array params_; + ffi::Array params_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("vm.builtin.param_module_from_cache", ParamModuleNode::Create) @@ -368,18 +368,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("vm.builtin.param_array_from_cache_by_name", ParamModuleNode::GetParamByName) .def_packed("vm.builtin.param_array_from_cache_by_name_unpacked", [](ffi::PackedArgs args, ffi::Any* rv) { - Array names; + ffi::Array names; names.reserve(args.size()); for (int i = 0; i < args.size(); ++i) { - if (!args[i].try_cast()) { + if (!args[i].try_cast()) { LOG(FATAL) << "ValueError: Expect string as input, but get " << args[i].GetTypeKey() << " at " << i; } - names.push_back(args[i].cast()); + names.push_back(args[i].cast()); } *rv = ParamModuleNode::GetParamByName(names); }); -}); +} } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c4fdedd815a9..be981b205cbb 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -38,8 +38,8 @@ namespace vm { // VM Closure object //--------------------------------------------- -VMClosure::VMClosure(String func_name, ffi::Function impl) { - auto ptr = make_object(); +VMClosure::VMClosure(ffi::String func_name, ffi::Function impl) { + auto ptr = ffi::make_object(); ptr->func_name = func_name; ptr->impl = std::move(impl); data_ = std::move(ptr); @@ -84,7 +84,7 @@ ffi::Any IndexIntoNestedObject(ffi::Any obj, ffi::PackedArgs args, int starting_ return obj; } -NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator* alloc) { +Tensor ConvertTensorToDevice(Tensor src, const DLDevice& dev, Allocator* alloc) { if (src->device.device_type == dev.device_type && src->device.device_id == dev.device_id) { return src; } else { @@ -95,15 +95,15 @@ NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator* allo } Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { - if (src.as()) { - return ConvertNDArrayToDevice(src.cast(), dev, alloc); + if (src.as()) { + return ConvertTensorToDevice(src.cast(), dev, alloc); } else if (src.as()) { std::vector ret; auto arr = src.cast>(); for (size_t i = 0; i < arr.size(); i++) { ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc)); } - return Array(ret.begin(), ret.end()); + return ffi::Array(ret.begin(), ret.end()); } else { return src; } @@ -112,8 +112,8 @@ Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { ffi::Any ConvertArgToDevice(ffi::AnyView input, Device dev, Allocator* alloc) { // in terms of memory-behavior. // To be extra careful, we copy DLTensor. - // The developer can still explicitly allocate NDArray - // in TVM Native API or NDArray::FromDLPack to regain zero copy behavior. + // The developer can still explicitly allocate Tensor + // in TVM Native API or Tensor::FromDLPack to regain zero copy behavior. ffi::Any ret; if (auto opt_obj = input.as()) { ret = ConvertObjectToDevice(opt_obj.value(), dev, alloc); @@ -189,7 +189,7 @@ class VirtualMachineImpl : public VirtualMachine { void LoadExecutable(ObjectPtr exec) final; void Init(const std::vector& devices, const std::vector& alloc_types) final; - VMClosure GetClosure(const String& func_name) final { + VMClosure GetClosure(const ffi::String& func_name) final { return this->GetClosureInternal(func_name, false).value(); } void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, ffi::PackedArgs args, @@ -210,7 +210,7 @@ class VirtualMachineImpl : public VirtualMachine { void _SetInputWithParamModule(ffi::PackedArgs args, ffi::Any* rv); int _GetFunctionArity(std::string func_name); std::string _GetFunctionParamName(std::string func_name, int index); - ffi::Function _LookupFunction(const String& name); + ffi::Function _LookupFunction(const ffi::String& name); TVM_MODULE_VTABLE_BEGIN("relax.VirtualMachine"); TVM_MODULE_VTABLE_ENTRY_PACKED("vm_initialization", &VirtualMachineImpl::_Init); @@ -236,7 +236,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param allow_missing Whether none is allowed. * \return The result */ - Optional GetClosureInternal(const String& func_name, bool allow_missing); + ffi::Optional GetClosureInternal(const ffi::String& func_name, bool allow_missing); /*! * \brief Set inputs to a function. @@ -245,7 +245,7 @@ class VirtualMachineImpl : public VirtualMachine { * correct device for the function, they will be copied to the device. * \param with_param_module If set to true, the last argument will be a module and can be invoked * to get the argument, this is mainly used for debugging purposes and setting composite - * objects. \note This interface works when using VM over RPC by internally converting NDArray in + * objects. \note This interface works when using VM over RPC by internally converting Tensor in * the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C * runtime. */ @@ -276,7 +276,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param args The arguments to bound to the function. * \note This function is used by RPC server to help benchmarking. */ - void SaveClosure(const String& func_name, const String& save_name, bool include_return, + void SaveClosure(const ffi::String& func_name, const ffi::String& save_name, bool include_return, ffi::PackedArgs args); /*! * \brief Internal function to invoke a closure. @@ -300,7 +300,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param name The name of the function. * \return The result function, can return ffi::Function(nullptr) if nothing is found. */ - Optional GetFuncFromImports(const String& name) { + ffi::Optional GetFuncFromImports(const ffi::String& name) { for (auto& lib : this->imports_) { if (auto opt_func = lib.cast()->GetFunction(name, true)) { return *opt_func; @@ -470,7 +470,7 @@ void VirtualMachineImpl::Init(const std::vector& devices, // Setup constant sections. this->const_pool_.reserve(exec_->constants.size()); for (const auto& constant : exec_->constants) { - if (auto opt_nd = constant.as()) { + if (auto opt_nd = constant.as()) { this->const_pool_.push_back(ConvertRegToDevice(opt_nd.value(), devices[0], allocators[0])); } else { this->const_pool_.push_back(constant); @@ -572,7 +572,7 @@ RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_pa return ret; } -void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save_name, +void VirtualMachineImpl::SaveClosure(const ffi::String& func_name, const ffi::String& save_name, bool include_return, ffi::PackedArgs args) { VMClosure clo = this->GetClosure(func_name); std::vector inputs(args.size()); @@ -589,8 +589,8 @@ void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save saved_closures_[save_name] = VMClosure(save_name, impl); } -Optional VirtualMachineImpl::GetClosureInternal(const String& func_name, - bool allow_missing) { +ffi::Optional VirtualMachineImpl::GetClosureInternal(const ffi::String& func_name, + bool allow_missing) { // look up saved closures. auto saved_it = saved_closures_.find(func_name); if (saved_it != saved_closures_.end()) { @@ -621,7 +621,7 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na } else { ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) << "Cannot support closure with function kind " << static_cast(finfo.kind); - Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); + ffi::Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); ICHECK(tir_func.has_value()) << "Cannot find underlying compiled tir function of VMTIRFunc " << finfo.name; auto impl = ffi::Function([this, finfo, tir_func](ffi::PackedArgs args, ffi::Any* rv) { @@ -697,7 +697,7 @@ void VirtualMachineImpl::InitFuncPool() { const VMFuncInfo& info = exec_->func_table[func_index]; if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { // only look through imports first - Optional func = GetFuncFromImports(info.name); + ffi::Optional func = GetFuncFromImports(info.name); if (!func.has_value()) { const auto p_func = tvm::ffi::Function::GetGlobal(info.name); if (p_func.has_value()) func = *p_func; @@ -846,7 +846,9 @@ void VirtualMachineImpl::RunLoop() { } } -ObjectPtr VirtualMachine::Create() { return make_object(); } +ObjectPtr VirtualMachine::Create() { + return ffi::make_object(); +} //-------------------------------------------------------------------- // FFI related code @@ -869,7 +871,7 @@ void VirtualMachineImpl::_Init(ffi::PackedArgs args, ffi::Any* rv) { void VirtualMachineImpl::_SaveClosure(ffi::PackedArgs args, ffi::Any* rv) { ICHECK_GE(args.size(), 3); std::string func_name = args[0].cast(); - this->SaveClosure(func_name, args[1].cast(), args[2].cast(), args.Slice(3)); + this->SaveClosure(func_name, args[1].cast(), args[2].cast(), args.Slice(3)); } void VirtualMachineImpl::_InvokeClosure(ffi::PackedArgs args, ffi::Any* rv) { @@ -894,7 +896,7 @@ void VirtualMachineImpl::_SetInstrument(ffi::PackedArgs args, ffi::Any* rv) { if (args[0].as()) { this->SetInstrument(args[0].cast()); } else { - String func_name = args[0].cast(); + ffi::String func_name = args[0].cast(); const auto factory = tvm::ffi::Function::GetGlobal(func_name); CHECK(factory.has_value()) << "Cannot find factory " << func_name; ffi::Any rv; @@ -950,9 +952,9 @@ std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, int return vm_func.param_names[index]; } -ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { - if (Optional opt = this->GetClosureInternal(name, true)) { - return ffi::Function([clo = opt.value(), _self = GetRef(this)]( +ffi::Function VirtualMachineImpl::_LookupFunction(const ffi::String& name) { + if (ffi::Optional opt = this->GetClosureInternal(name, true)) { + return ffi::Function([clo = opt.value(), _self = ffi::GetRef(this)]( ffi::PackedArgs args, ffi::Any* rv) -> void { auto* self = const_cast(_self.as()); ICHECK(self); @@ -973,7 +975,7 @@ ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { */ class VirtualMachineProfiler : public VirtualMachineImpl { public: - Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "profile") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -987,7 +989,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } } - prof_ = profiling::Profiler(devices, {}, {{String("Executor"), String("VM")}}); + prof_ = profiling::Profiler(devices, {}, {{ffi::String("Executor"), ffi::String("VM")}}); auto inputs = GetInputsFor(f_name); @@ -1029,11 +1031,11 @@ class VirtualMachineProfiler : public VirtualMachineImpl { if (prof_ && prof_->IsRunning()) { auto f_name = GetFuncName(inst.func_idx); std::optional dev; - std::vector arrs; + std::vector arrs; - auto f_check_ndarray_arg = [&dev, &arrs](const RegType& arg) { - if (auto opt_nd = arg.as()) { - NDArray arr = opt_nd.value(); + auto f_check_tensor_arg = [&dev, &arrs](const RegType& arg) { + if (auto opt_nd = arg.as()) { + Tensor arr = opt_nd.value(); if (arr.defined()) { dev = arr->device; arrs.push_back(arr); @@ -1045,10 +1047,10 @@ class VirtualMachineProfiler : public VirtualMachineImpl { Instruction::Arg arg = inst.args[i]; if (arg.kind() == Instruction::ArgKind::kRegister) { auto reg = ReadRegister(curr_frame, arg.value()); - f_check_ndarray_arg(reg); + f_check_tensor_arg(reg); } else if (arg.kind() == Instruction::ArgKind::kConstIdx) { const auto& const_val = this->const_pool_[arg.value()]; - f_check_ndarray_arg(const_val); + f_check_tensor_arg(const_val); } } @@ -1074,7 +1076,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { }; ObjectPtr VirtualMachine::CreateProfiler() { - return make_object(); + return ffi::make_object(); } #else diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 023d34e68bda..a2ff8bb7ce0e 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -451,7 +451,7 @@ VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { return const_cast(const_cast(this)->device(device_id)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("device_api.vulkan", @@ -464,7 +464,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); return rv; }); -}); +} } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index a5fb6c2293fa..dbf2d9fff76c 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -33,11 +33,11 @@ namespace vulkan { ffi::Module VulkanModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string source) { - auto n = make_object(smap, fmap, source); + auto n = ffi::make_object(smap, fmap, source); return ffi::Module(n); } -ffi::Module VulkanModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module VulkanModuleLoadFile(const std::string& file_name, const ffi::String& format) { std::string data; std::unordered_map smap; std::unordered_map fmap; @@ -67,12 +67,12 @@ ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) { return VulkanModuleCreate(smap, fmap, ""); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ffi.Module.load_from_file.vulkan", VulkanModuleLoadFile) .def("ffi.Module.load_from_bytes.vulkan", VulkanModuleLoadFromBytes); -}); +} } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 2f50a0154658..007d6abdbadb 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -205,7 +205,7 @@ VulkanModuleNode::~VulkanModuleNode() { } } -Optional VulkanModuleNode::GetFunction(const String& name) { +ffi::Optional VulkanModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); @@ -403,7 +403,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, return pe; } -void VulkanModuleNode::WriteToFile(const String& file_name, const String& format) const { +void VulkanModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; std::string meta_file = GetMetaFilePath(file_name); @@ -427,7 +427,7 @@ ffi::Bytes VulkanModuleNode::SaveToBytes() const { return ffi::Bytes(buffer); } -String VulkanModuleNode::InspectSource(const String& format) const { +ffi::String VulkanModuleNode::InspectSource(const ffi::String& format) const { // can only return disassembly code. return source_; } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index 2ff90568de9d..53ae3ac4ba82 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -94,15 +94,15 @@ class VulkanModuleNode final : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args); - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; - String InspectSource(const String& format) const final; + ffi::String InspectSource(const ffi::String& format) const final; private: // function information table. diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 1b02e7dfb8c0..658e76be466c 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -25,13 +25,13 @@ namespace tvm { namespace script { namespace ir_builder { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { IRBuilderFrameNode::RegisterReflection(); IRBuilderNode::RegisterReflection(); -}); +} void IRBuilderFrameNode::EnterWithScope() { - IRBuilder::Current()->frames.push_back(GetRef(this)); + IRBuilder::Current()->frames.push_back(ffi::GetRef(this)); } void IRBuilderFrameNode::ExitWithScope() { @@ -50,7 +50,7 @@ void IRBuilderFrameNode::AddCallback(ffi::TypedFunction callback) { } IRBuilder::IRBuilder() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->frames.clear(); n->result = std::nullopt; data_ = n; @@ -95,7 +95,7 @@ Namer::FType& Namer::vtable() { return inst; } -void Namer::Name(ObjectRef node, String name) { +void Namer::Name(ObjectRef node, ffi::String name) { static const FType& f = vtable(); CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name; CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \"" @@ -105,7 +105,7 @@ void Namer::Name(ObjectRef node, String name) { } // namespace details -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("script.ir_builder.IRBuilderFrameEnter", &IRBuilderFrameNode::EnterWithScope) @@ -118,7 +118,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.ir_builder.IRBuilderIsInScope", IRBuilder::IsInScope) .def_method("script.ir_builder.IRBuilderGet", &IRBuilderNode::Get) .def("script.ir_builder.IRBuilderName", IRBuilder::Name); -}); +} } // namespace ir_builder } // namespace script diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 9a1e5cdd109c..fae4ba41bfda 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -25,10 +25,10 @@ namespace script { namespace ir_builder { namespace ir { -TVM_FFI_STATIC_INIT_BLOCK({ IRModuleFrameNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { IRModuleFrameNode::RegisterReflection(); } void IRModuleFrameNode::ExitWithScope() { - Map func_map; + ffi::Map func_map; CHECK_EQ(functions.size(), global_var_map.size()) << "All functions must be defined in the IRModule. Got " << global_var_map.size() << "declared function(s), but only " << functions.size() << "defined function(s)."; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 26af0e55c76d..e609f1b8efd2 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -32,7 +32,7 @@ namespace ir_builder { namespace ir { IRModuleFrame IRModule() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); @@ -49,14 +49,15 @@ inline relax::StructInfo GetGlobalVarStructInfo(const BaseFunc& func) { } } -GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { +GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signature) { IRModuleFrame frame = FindModuleFrame(); CHECK(!frame->global_var_map.count(func_name)) << "ValueError: function " << func_name << " already exists"; auto gvar_type = [&]() -> Type { if (auto prim_func = func_signature.as()) { - Array arg_types = prim_func->params.Map([](const auto& var) { return GetType(var); }); + ffi::Array arg_types = + prim_func->params.Map([](const auto& var) { return GetType(var); }); return FuncType(arg_types, prim_func->ret_type); } @@ -72,7 +73,7 @@ GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) return gv; } -void DefFunction(const String& func_name, const BaseFunc& func) { +void DefFunction(const ffi::String& func_name, const BaseFunc& func) { IRModuleFrame frame = FindModuleFrame(); auto it = frame->global_var_map.find(func_name); CHECK(it != frame->global_var_map.end()) @@ -82,7 +83,7 @@ void DefFunction(const String& func_name, const BaseFunc& func) { gv->struct_info_ = GetGlobalVarStructInfo(func); } -void ModuleAttrs(Map attrs, bool allow_overwrite) { +void ModuleAttrs(ffi::Map attrs, bool allow_overwrite) { if (IRBuilder::IsInScope()) { // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); @@ -93,7 +94,7 @@ void ModuleAttrs(Map attrs, bool allow_overwrite) { } } -Optional ModuleGetAttr(const String& key) { +ffi::Optional ModuleGetAttr(const ffi::String& key) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (frame->attrs.find(key) != frame->attrs.end()) { @@ -103,7 +104,8 @@ Optional ModuleGetAttr(const String& key) { return std::nullopt; } -void ModuleSetAttr(const String& key, const Optional& value, bool allow_override) { +void ModuleSetAttr(const ffi::String& key, const ffi::Optional& value, + bool allow_override) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (!allow_override && frame->attrs.find(key) != frame->attrs.end() && value.defined()) { @@ -119,7 +121,7 @@ void ModuleSetAttr(const String& key, const Optional& value, bool all } } -void ModuleGlobalInfos(Map> global_infos) { +void ModuleGlobalInfos(ffi::Map> global_infos) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); if (!frame->global_infos.empty()) { @@ -130,13 +132,13 @@ void ModuleGlobalInfos(Map> global_infos) { } } -VDevice LookupVDevice(String target_kind, int device_index) { +VDevice LookupVDevice(ffi::String target_kind, int device_index) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (frame->global_infos.empty()) { LOG(FATAL) << "ValueError: The GlobalInfos in the IRModule is not defined."; } - Array vdevices = frame->global_infos["vdevice"]; + ffi::Array vdevices = frame->global_infos["vdevice"]; if (vdevices.empty() || device_index < 0 || static_cast(device_index) >= vdevices.size()) { LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not found."; @@ -159,7 +161,7 @@ VDevice LookupVDevice(String target_kind, int device_index) { return VDevice(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.ir.IRModule", IRModule) @@ -170,7 +172,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.ir_builder.ir.ModuleSetAttr", ModuleSetAttr) .def("script.ir_builder.ir.ModuleGlobalInfos", ModuleGlobalInfos) .def("script.ir_builder.ir.LookupVDevice", LookupVDevice); -}); +} } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h index b12e5e270d89..54ea6ce6ad92 100644 --- a/src/script/ir_builder/ir/utils.h +++ b/src/script/ir_builder/ir/utils.h @@ -26,10 +26,10 @@ namespace script { namespace ir_builder { namespace ir { -inline IRModuleFrame FindModuleFrame(const String& method) { +inline IRModuleFrame FindModuleFrame(const ffi::String& method) { IRBuilder builder = IRBuilder::Current(); - if (Optional frame = builder->FindFrame()) { - const Optional& last_module_frame = builder->GetLastFrame(); + if (ffi::Optional frame = builder->FindFrame()) { + const ffi::Optional& last_module_frame = builder->GetLastFrame(); if (last_module_frame.defined() && last_module_frame.value() == frame) { return frame.value(); } @@ -43,7 +43,7 @@ inline IRModuleFrame FindModuleFrame(const String& method) { inline IRModuleFrame FindModuleFrame() { IRBuilder builder = IRBuilder::Current(); - if (Optional frame = builder->FindFrame()) { + if (ffi::Optional frame = builder->FindFrame()) { return frame.value(); } else { LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure it" diff --git a/src/script/ir_builder/relax/distributed.cc b/src/script/ir_builder/relax/distributed.cc index 424d20980ad2..3efb38d44bf5 100644 --- a/src/script/ir_builder/relax/distributed.cc +++ b/src/script/ir_builder/relax/distributed.cc @@ -28,8 +28,9 @@ namespace tvm { namespace relax { -Expr MakeCallTIRDist(Expr func, Tuple args, Array out_sinfo_list, - Optional packed_ints) { +Expr MakeCallTIRDist(Expr func, Tuple args, + ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->tensor_sinfo->shape.as(); CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " @@ -55,10 +56,10 @@ Expr MakeCallTIRDist(Expr func, Tuple args, Array block_frame = IRBuilder::Current()->GetLastFrame()) { + if (ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame()) { block_frame.value()->ExitWithScope(); ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) << "ValueError: There is some remaining BlockFrame that is not properly popped out."; @@ -87,12 +87,12 @@ void FunctionFrameNode::ExitWithScope() { // Case 0. No outer frame, return function directly ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; - } else if (Optional opt_frame = builder->FindFrame()) { + } else if (ffi::Optional opt_frame = builder->FindFrame()) { // Case 1. A global function of an IRModule CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " "function scope, if it's defined in a Module"; const IRModuleFrame& frame = opt_frame.value(); - const String& func_name = name.value_or(""); + const ffi::String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { // First time visiting the function. ir::DeclFunction(func_name, func); @@ -108,7 +108,7 @@ void FunctionFrameNode::ExitWithScope() { void BlockFrameNode::EnterWithScope() { // Step 1. If the last frame is a block frame. The start of a new block frame marks the end of the // last block frame. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); if (block_frame.defined()) { block_frame.value()->ExitWithScope(); // Block frames cannot appear consecutively. @@ -116,7 +116,7 @@ void BlockFrameNode::EnterWithScope() { } // Step 2. Deal with the new block frame. RelaxFrameNode::EnterWithScope(); - Optional func_frame = IRBuilder::Current()->FindFrame(); + ffi::Optional func_frame = IRBuilder::Current()->FindFrame(); CHECK(func_frame.defined()) << "ValueError: Cannot find FunctionFrame when creating BindingBlocks, Please ensure " "creating the block under Relax function scope."; @@ -162,7 +162,7 @@ void BlockFrameNode::ExitWithScope() { // Step 3. Rewrite the dataflow block. if (is_dataflow) { // Step 3.0. Define a map to replace variables - Array new_output_vars; + ffi::Array new_output_vars; std::unordered_map var_remap; for (const auto& output_var : output_vars) { tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var)); @@ -185,7 +185,7 @@ void BlockFrameNode::ExitWithScope() { } // Step 3. Get the last frame from the IRBuilder frame stack. - Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); ICHECK(opt_last_frame.defined()); RelaxFrame last_frame = opt_last_frame.value(); @@ -195,7 +195,7 @@ void BlockFrameNode::ExitWithScope() { // Step 5. Push the block frame into the corresponding field of the last frame. if (const auto* seq_frame = last_frame.as()) { - auto frame = GetRef(seq_frame); + auto frame = ffi::GetRef(seq_frame); frame->binding_blocks.push_back(block); } else { LOG(FATAL) << "ValueError: Currently the last frame is supposed to be either a function frame " @@ -210,7 +210,7 @@ void BlockFrameNode::ExitWithScope() { } void IfFrameNode::EnterWithScope() { - const Array& frames = IRBuilder::Current()->frames; + const ffi::Array& frames = IRBuilder::Current()->frames; for (const IRBuilderFrame& frame : frames) { const auto* block_frame = frame.as(); if (block_frame && block_frame->is_dataflow) { @@ -241,8 +241,8 @@ void ThenFrameNode::EnterWithScope() { void ThenFrameNode::ExitWithScope() { SeqExprFrameNode::ExitWithScope(); - String var_name; - output = GetSeqExprForBranch(GetRef(this), &var_name); + ffi::String var_name; + output = GetSeqExprForBranch(ffi::GetRef(this), &var_name); IfFrame frame = FindIfFrame("R.Then"); frame->then_expr = output; frame->var_name = var_name; @@ -259,8 +259,8 @@ void ElseFrameNode::EnterWithScope() { void ElseFrameNode::ExitWithScope() { SeqExprFrameNode::ExitWithScope(); - String var_name; - output = GetSeqExprForBranch(GetRef(this), &var_name); + ffi::String var_name; + output = GetSeqExprForBranch(ffi::GetRef(this), &var_name); IfFrame frame = FindIfFrame("R.Else"); frame->else_expr = output; CHECK(frame->var_name == var_name) diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index b845434e917b..db77d4db5b26 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -34,7 +34,7 @@ namespace relax { using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using tvm::relax::VarNode; using tvm::relax::IdNode; const VarNode* var = node.as(); @@ -43,7 +43,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using tvm::relax::DataflowVarNode; using tvm::relax::IdNode; const DataflowVarNode* var = node.as(); @@ -54,10 +54,11 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) /////////////////////////////// Function //////////////////////////////// FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); const IRBuilder& ir_builder = IRBuilder::Current(); - Optional mod = std::nullopt; - if (const Optional mod_frame = ir_builder->GetLastFrame()) { + ffi::Optional mod = std::nullopt; + if (const ffi::Optional mod_frame = + ir_builder->GetLastFrame()) { mod = tvm::IRModule(mod_frame.value()->functions); } n->block_builder = tvm::relax::BlockBuilder::Create( @@ -67,7 +68,7 @@ FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { return FunctionFrame(n); } -tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info) { +tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struct_info) { FunctionFrame frame = FindFunctionFrame("R.Arg"); tvm::relax::Var var(name, struct_info); frame->params.push_back(var); @@ -76,7 +77,7 @@ tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_inf return var; } -void FuncName(const String& name) { +void FuncName(const ffi::String& name) { FunctionFrame frame = FindFunctionFrame("R.func_name"); if (frame->name.has_value()) { LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() @@ -85,7 +86,7 @@ void FuncName(const String& name) { frame->name = name; } -void FuncAttrs(Map attrs) { +void FuncAttrs(ffi::Map attrs) { FunctionFrame frame = FindFunctionFrame("R.func_attr"); for (const auto& [key, value] : attrs) { if (key == tvm::attr::kGlobalSymbol && frame->is_private.value_or(Bool(false))->value) { @@ -145,7 +146,7 @@ void FuncRetValue(const tvm::relax::Expr& value) { frame->output = std::move(normalized_value); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.relax.Function", Function) @@ -154,27 +155,27 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.ir_builder.relax.FuncAttrs", FuncAttrs) .def("script.ir_builder.relax.FuncRetStructInfo", FuncRetStructInfo) .def("script.ir_builder.relax.FuncRetValue", FuncRetValue); -}); +} ///////////////////////////// BindingBlock ////////////////////////////// BlockFrame Dataflow() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->is_dataflow = true; n->block_ended = false; return BlockFrame(n); } BlockFrame BindingBlock() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->is_dataflow = false; n->block_ended = false; return BlockFrame(n); } -void DataflowBlockOutput(const Array& vars) { +void DataflowBlockOutput(const ffi::Array& vars) { // Step 1. Check that we're in a Dataflow block that is not ended. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); CHECK(block_frame.defined() && block_frame.value()->is_dataflow) << "ValueError: `R.output` should appear inside a dataflow block. However, the current " "innermost block is not a dataflow block."; @@ -187,7 +188,7 @@ void DataflowBlockOutput(const Array& vars) { // Step 3. All the output variables must be global variables and must be emitted by this dataflow // block. - const Array& emitted_vars = block_frame.value()->emitted_vars; + const ffi::Array& emitted_vars = block_frame.value()->emitted_vars; for (const tvm::relax::Var& var : vars) { CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " @@ -196,18 +197,18 @@ void DataflowBlockOutput(const Array& vars) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.relax.Dataflow", Dataflow) .def("script.ir_builder.relax.BindingBlock", BindingBlock) .def("script.ir_builder.relax.DataflowBlockOutput", DataflowBlockOutput); -}); +} /////////////////////////////// Bindings /////////////////////////////// tvm::relax::Var Emit(const tvm::relax::Expr& expr, - const Optional& annotate_struct_info) { + const ffi::Optional& annotate_struct_info) { using tvm::relax::GetStructInfo; BlockFrame block_frame = CheckBlockFrameExistAndUnended(); const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); @@ -244,30 +245,30 @@ tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { return binding->var; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.relax.Emit", Emit) .def("script.ir_builder.relax.EmitMatchCast", EmitMatchCast) .def("script.ir_builder.relax.EmitVarBinding", EmitVarBinding); -}); +} /////////////////////////////// SeqExpr /////////////////////////////// SeqExprFrame SeqExpr() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return SeqExprFrame(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.ir_builder.relax.SeqExpr", SeqExpr); -}); +} ///////////////////////////// If Then Else ///////////////////////////// IfFrame If(tvm::relax::Expr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->condition = condition; n->then_expr = std::nullopt; n->else_expr = std::nullopt; @@ -275,22 +276,22 @@ IfFrame If(tvm::relax::Expr condition) { } ThenFrame Then() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ThenFrame(n); } ElseFrame Else() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ElseFrame(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.relax.If", If) .def("script.ir_builder.relax.Then", Then) .def("script.ir_builder.relax.Else", Else); -}); +} } // namespace relax } // namespace ir_builder diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index 7fd7e21a6739..e24b4a27593d 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -31,8 +31,8 @@ namespace script { namespace ir_builder { namespace relax { -inline FunctionFrame FindFunctionFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->FindFrame()) { +inline FunctionFrame FindFunctionFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method @@ -40,8 +40,8 @@ inline FunctionFrame FindFunctionFrame(const String& method) { throw; } -inline IfFrame FindIfFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { +inline IfFrame FindIfFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); } else { LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method @@ -51,7 +51,7 @@ inline IfFrame FindIfFrame(const String& method) { } inline tvm::relax::BlockBuilder GetBlockBuilder() { - Optional frame = IRBuilder::Current()->FindFrame(); + ffi::Optional frame = IRBuilder::Current()->FindFrame(); CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " "assignment is called under R.function()"; return frame.value()->block_builder; @@ -61,14 +61,14 @@ inline BlockFrame CheckBlockFrameExistAndUnended() { // We check if the current block is "ended" - if a block is ended, it is not allowed to emit new // bindings into this block, and we should throw exceptions. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); CHECK(block_frame.defined()) << "ValueError: Block frame not find"; CHECK(!block_frame.value()->block_ended) << "ValueError: New binding is not allowed after dataflow block output."; return block_frame.value(); } -inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { +inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, ffi::String* var_name) { // Step 0. Check frame type std::string method; std::string output_var_suffix; @@ -101,10 +101,10 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String *var_name = last_binding->var->name_hint(); // Step 3. Re-collect binding blocks to replace the last binding. - Array new_blocks(frame->binding_blocks.begin(), - frame->binding_blocks.end() - 1); - Array last_block_bindings(last_block->bindings.begin(), - last_block->bindings.end() - 1); + ffi::Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + ffi::Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); tvm::relax::Var new_var = tvm::relax::Var(last_binding->var->name_hint() + output_var_suffix, GetStructInfo(last_binding->var)); diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index b0d5bb337f35..7c10b6cdc8d1 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -28,7 +28,7 @@ namespace script { namespace ir_builder { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { TIRFrameNode::RegisterReflection(); PrimFuncFrameNode::RegisterReflection(); BlockFrameNode::RegisterReflection(); @@ -46,7 +46,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ThenFrameNode::RegisterReflection(); ElseFrameNode::RegisterReflection(); DeclBufferFrameNode::RegisterReflection(); -}); +} void PrimFuncFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); @@ -67,11 +67,11 @@ void PrimFuncFrameNode::ExitWithScope() { if (builder->frames.empty()) { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; - } else if (Optional opt_frame = builder->FindFrame()) { + } else if (ffi::Optional opt_frame = builder->FindFrame()) { CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " "function scope, if it's defined in a Module"; const ir::IRModuleFrame& frame = opt_frame.value(); - const String& func_name = name.value_or(""); + const ffi::String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { // Case. First time visiting the function. ir::DeclFunction(func_name, func); @@ -86,17 +86,17 @@ void PrimFuncFrameNode::ExitWithScope() { void BlockFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - Array tir_alloc_buffers; + ffi::Array tir_alloc_buffers; for (const tvm::tir::Buffer& buffer : alloc_buffers) { tir_alloc_buffers.push_back(buffer); } - Map attrs = annotations.value_or({}); + ffi::Map attrs = annotations.value_or({}); if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) { attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access)); } - tvm::tir::Block block(iter_vars, reads.value_or(Array()), - writes.value_or(Array()), name, AsStmt(stmts), init, - tir_alloc_buffers, match_buffers, attrs); + tvm::tir::Block block(iter_vars, reads.value_or(ffi::Array()), + writes.value_or(ffi::Array()), name, AsStmt(stmts), + init, tir_alloc_buffers, match_buffers, attrs); if (no_realize) { CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`"; @@ -123,7 +123,7 @@ void BlockInitFrameNode::ExitWithScope() { void ForFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts))); + AddToParent(this->f_make_for_loop(vars, doms, steps, AsStmt(stmts))); } void AssertFrameNode::ExitWithScope() { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 33a687f54bc4..6639d73dafc3 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -22,6 +22,7 @@ #include #include "./utils.h" +#include "tvm/ffi/string.h" namespace tvm { namespace script { @@ -30,10 +31,11 @@ namespace tir { using tvm::tir::IterVar; -Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, - Optional> strides, Optional elem_offset, - String storage_scope, int align, int offset_factor, String buffer_type, - Optional> axis_separators) { +Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, int align, + int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators) { CHECK(buffer_type == "auto" || buffer_type == "default" || buffer_type.empty()) << "ValueError: `buffer_type` must be `auto` or `default` or empty"; Var buffer_data; @@ -50,14 +52,14 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype; elem_offset = tvm::tir::Var("elem_offset", shape_dtype); } - return Buffer(buffer_data, dtype, shape, strides.value_or(Array()), + return Buffer(buffer_data, dtype, shape, strides.value_or(ffi::Array()), elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor, (buffer_type == "auto" ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault), - axis_separators.value_or(Array())); + axis_separators.value_or(ffi::Array())); } PrimFuncFrame PrimFunc(bool is_private) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->name = std::nullopt; n->is_private = is_private; n->args.clear(); @@ -69,14 +71,14 @@ PrimFuncFrame PrimFunc(bool is_private) { return PrimFuncFrame(n); } -Var Arg(String name, Var var) { +Var Arg(ffi::String name, Var var) { PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); details::Namer::Name(var, name); frame->args.push_back(var); return var; } -Buffer Arg(String name, Buffer buffer) { +Buffer Arg(ffi::String name, Buffer buffer) { PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); details::Namer::Name(buffer, name); Var handle(buffer->name + "_handle", DataType::Handle()); @@ -85,7 +87,7 @@ Buffer Arg(String name, Buffer buffer) { return buffer; } -void FuncName(String name) { +void FuncName(ffi::String name) { PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); if (frame->name.has_value()) { LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); @@ -93,7 +95,7 @@ void FuncName(String name) { frame->name = name; } -void FuncAttrs(Map new_attrs) { +void FuncAttrs(ffi::Map new_attrs) { using namespace tvm::tir; PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); for (const auto& [key, value] : new_attrs) { @@ -124,15 +126,15 @@ tvm::Type FuncRet(tvm::Type ret_type) { return ret_type; } -Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optional data, - Array strides, PrimExpr elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type_str, - Optional> axis_separators) { +Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, + ffi::Optional data, ffi::Array strides, PrimExpr elem_offset, + ffi::String storage_scope, int align, int offset_factor, + ffi::String buffer_type_str, ffi::Optional> axis_separators) { Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); if (const auto* var = param.as()) { - PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); - Var v = GetRef(var); + PrimFuncFrame frame = FindPrimFuncFrameRelaxed("T.match_buffer"); + Var v = ffi::GetRef(var); for (auto const& arg : frame->args) { if (arg.same_as(v)) { frame->buffer_map.Set(v, buffer); @@ -143,19 +145,19 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optio } else if (const auto* buffer_load = param.as()) { BlockFrame frame = FindBlockFrame("T.match_buffer"); frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( - buffer, BufferRegionFromLoad(GetRef(buffer_load)))); + buffer, BufferRegionFromLoad(ffi::GetRef(buffer_load)))); } else if (const auto* buffer_region = param.as()) { BlockFrame frame = FindBlockFrame("T.match_buffer"); frame->match_buffers.push_back( - tvm::tir::MatchBufferRegion(buffer, GetRef(buffer_region))); + tvm::tir::MatchBufferRegion(buffer, ffi::GetRef(buffer_region))); } else { LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; } return buffer; } -BlockFrame Block(String name, bool no_realize) { - ObjectPtr n = make_object(); +BlockFrame Block(ffi::String name, bool no_realize) { + ObjectPtr n = ffi::make_object(); n->name = name; n->iter_vars.clear(); n->reads = std::nullopt; @@ -170,7 +172,7 @@ BlockFrame Block(String name, bool no_realize) { return BlockFrame(n); } -BlockInitFrame Init() { return BlockInitFrame(make_object()); } +BlockInitFrame Init() { return BlockInitFrame(ffi::make_object()); } void Where(PrimExpr predicate) { BlockFrame frame = FindBlockFrame("T.where"); @@ -181,13 +183,13 @@ void Where(PrimExpr predicate) { frame->predicate = predicate; } -void Reads(Array buffer_slices) { +void Reads(ffi::Array buffer_slices) { using namespace tvm::tir; BlockFrame frame = FindBlockFrame("T.reads"); if (frame->reads.defined()) { LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads; } - Array reads; + ffi::Array reads; for (const ObjectRef& obj : buffer_slices) { if (auto buffer_region = obj.as()) { reads.push_back(buffer_region.value()); @@ -200,14 +202,14 @@ void Reads(Array buffer_slices) { frame->reads = reads; } -void Writes(Array buffer_slices) { +void Writes(ffi::Array buffer_slices) { using namespace tvm::tir; BlockFrame frame = FindBlockFrame("T.writes"); if (frame->writes.defined()) { LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is " << frame->writes; } - Array writes; + ffi::Array writes; for (const ObjectRef& obj : buffer_slices) { if (auto buffer_region = obj.as()) { writes.push_back(buffer_region.value()); @@ -221,9 +223,9 @@ void Writes(Array buffer_slices) { } /*! \brief Recursively merge two annotations, the new attrs will override the old ones */ -Map MergeAnnotations(const Map& new_attrs, - const Map& old_attrs) { - Map result = old_attrs; +ffi::Map MergeAnnotations(const ffi::Map& new_attrs, + const ffi::Map& old_attrs) { + ffi::Map result = old_attrs; for (const auto& [key, value] : new_attrs) { auto old_value = old_attrs.Get(key); // Case 1: the key is not in the old annotations, set the key to the new value @@ -234,15 +236,15 @@ Map MergeAnnotations(const Map& new_attrs, // Case 2: the key is in the old annotations // Case 2.1: both are dicts - auto old_dict = old_value->try_cast>(); - auto new_dict = value.try_cast>(); + auto old_dict = old_value->try_cast>(); + auto new_dict = value.try_cast>(); if (old_dict && new_dict) { // Recursively merge the two dicts auto merged_dict = MergeAnnotations(*old_dict, *new_dict); result.Set(key, merged_dict); continue; } - // Case 2.2: the values are not both dicts, check if the keys are the same + // Case 2.3: the values are not both dicts, check if the keys are the same if (!ffi::AnyEqual()(old_value.value(), value)) { LOG(FATAL) << "ValueError: Try to merge two annotations with different values for key `" << key << "`, previous one is " << old_value.value() << ", new one is " << value; @@ -251,27 +253,31 @@ Map MergeAnnotations(const Map& new_attrs, return result; } -void BlockAttrs(Map attrs) { +void BlockAttrs(ffi::Map attrs) { BlockFrame frame = FindBlockFrame("T.block_attr"); // Case 1: the block has no annotations, set the new annotations if (!frame->annotations.defined()) { frame->annotations = attrs; } else { // Case 2: the block has annotations, merge the new annotations with the old ones - frame->annotations = MergeAnnotations(attrs, frame->annotations.value()); + frame->annotations = Downcast>(MergeAnnotations(Downcast>(attrs), Downcast>(frame->annotations.value()))); } } -Buffer AllocBuffer(Array shape, DataType dtype, Optional data, - Array strides, PrimExpr elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type_str, - Optional> axis_separators) { +Buffer AllocBuffer(ffi::Array shape, DataType dtype, ffi::Optional data, + ffi::Array strides, PrimExpr elem_offset, ffi::String storage_scope, + int align, int offset_factor, ffi::String buffer_type_str, + ffi::Optional> axis_separators) { Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); IRBuilder builder = IRBuilder::Current(); - if (Optional frame = builder->FindFrame()) { + if (ffi::Optional frame = builder->GetLastFrame()) { frame.value()->alloc_buffers.push_back(buffer); - } else if (Optional frame = builder->GetLastFrame()) { + } else if (ffi::Optional frame = builder->FindFrame()) { + frame.value()->alloc_buffers.push_back(buffer); + } else if (ffi::Optional frame = builder->GetLastFrame()) { + frame.value()->root_alloc_buffers.push_back(buffer); + } else if (ffi::Optional frame = builder->FindFrame()) { frame.value()->root_alloc_buffers.push_back(buffer); } else { LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure " @@ -282,7 +288,7 @@ Buffer AllocBuffer(Array shape, DataType dtype, Optional data, namespace axis { IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { - if (Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { + if (ffi::Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { BlockFrame frame = opt_frame.value(); frame->iter_vars.push_back(iter_var); frame->iter_values.push_back(binding); @@ -307,9 +313,9 @@ TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan"); TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque"); #undef TVM_TIR_IR_BUILDER_AXIS -Array Remap(String kinds, Array bindings, DataType dtype) { +ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType dtype) { using namespace tvm::tir; - Array results; + ffi::Array results; ICHECK_EQ(kinds.size(), bindings.size()); int n = bindings.size(); results.reserve(n); @@ -334,7 +340,7 @@ Array Remap(String kinds, Array bindings, DataType dtype) { } } } - ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef(v); + ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << ffi::GetRef(v); DataType dtype = v->dtype; if (c == 'S') { results.push_back(PushBlockVar(IterVar(/*dom=*/dom, @@ -359,21 +365,27 @@ Array Remap(String kinds, Array bindings, DataType dtype) { } // namespace axis -#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ - ForFrame Method(PrimExpr start, PrimExpr stop, Optional> annotations) { \ - PrimExpr min = start; \ - PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ - ObjectPtr n = make_object(); \ - int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ - n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ - n->doms = {Range::FromMinExtent(min, extent)}; \ - n->f_make_for_loop = [annotations](Array vars, Array doms, tvm::tir::Stmt body) { \ - ICHECK_EQ(vars.size(), 1); \ - ICHECK_EQ(doms.size(), 1); \ - return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ - annotations.value_or(Map())); \ - }; \ - return ForFrame(n); \ +#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, \ + ffi::Optional> annotations, \ + ffi::Optional step) { \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = ffi::make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->steps = {step}; \ + n->f_make_for_loop = [annotations](ffi::Array vars, ffi::Array doms, \ + ffi::Array> steps, \ + tvm::tir::Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + ICHECK_EQ(steps.size(), 1); \ + return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ + annotations.value_or(ffi::Map()), steps[0]); \ + }; \ + return ForFrame(n); \ } TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial); @@ -383,60 +395,66 @@ TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled); #undef TVM_TIR_IR_BUILDER_FOR_FRAME -ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, - Optional> annotations) { +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, + ffi::Optional> annotations) { using namespace tvm::tir; PrimExpr min = start; PrimExpr extent = arith::Analyzer().Simplify(stop - start); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); int bits = std::max(min.dtype().bits(), extent.dtype().bits()); DataType dtype = DataType(min.dtype().code(), bits, 1); n->vars = {Var("v", dtype)}; n->doms = {Range::FromMinExtent(min, extent)}; - n->f_make_for_loop = [annotations, thread, dtype](Array vars, Array doms, + n->steps = {std::nullopt}; + n->f_make_for_loop = [annotations, thread, dtype](ffi::Array vars, ffi::Array doms, + ffi::Array> steps, Stmt body) -> For { ICHECK_EQ(vars.size(), 1); ICHECK_EQ(doms.size(), 1); + ICHECK(steps.size() == 1 && (!steps[0].has_value() || is_one(*steps[0]))); IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, - annotations.value_or(Map())); + annotations.value_or(ffi::Map()), std::nullopt); }; return ForFrame(n); } -ForFrame Grid(Array extents) { +ForFrame Grid(ffi::Array extents) { using namespace tvm::tir; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->vars.reserve(extents.size()); n->doms.reserve(extents.size()); + n->steps.resize(extents.size()); for (const auto& extent : extents) { DataType dtype = extent.dtype(); n->vars.push_back(Var("v", extent.dtype())); n->doms.push_back(Range(make_const(dtype, 0), extent)); } - n->f_make_for_loop = [](Array vars, Array doms, Stmt body) -> Stmt { + n->f_make_for_loop = [](ffi::Array vars, ffi::Array doms, + ffi::Array> steps, Stmt body) -> Stmt { ICHECK_EQ(vars.size(), doms.size()); + ICHECK_EQ(vars.size(), steps.size()); int n = vars.size(); for (int i = n - 1; i >= 0; --i) { Range dom = doms[i]; Var var = vars[i]; body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), - /*thread_binding=*/std::nullopt, /*annotations=*/{}); + /*thread_binding=*/std::nullopt, /*annotations=*/{}, /*step=*/steps[i]); } return body; }; return ForFrame(n); } -AssertFrame Assert(PrimExpr condition, String message) { - ObjectPtr n = make_object(); +AssertFrame Assert(PrimExpr condition, ffi::String message) { + ObjectPtr n = ffi::make_object(); n->condition = condition; n->message = tvm::tir::StringImm(message); return AssertFrame(n); } -LetFrame LetStmt(PrimExpr value, Optional type_annotation, Optional var) { - ObjectPtr n = make_object(); +LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { + ObjectPtr n = ffi::make_object(); if (var.defined()) { n->var = var.value(); } else if (type_annotation.defined()) { @@ -449,7 +467,7 @@ LetFrame LetStmt(PrimExpr value, Optional type_annotation, Optional v } LetFrame LegacyLetStmt(Var var, PrimExpr value) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->var = var; n->value = value; return LetFrame(n); @@ -458,8 +476,8 @@ LetFrame LegacyLetStmt(Var var, PrimExpr value) { LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { IterVar iter_var{nullptr}; - if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { - if (Optional opt_iter_var = opt_frame.value()->env_threads.Get(var)) { + if (ffi::Optional opt_frame = IRBuilder::Current()->FindFrame()) { + if (ffi::Optional opt_iter_var = opt_frame.value()->env_threads.Get(var)) { iter_var = opt_iter_var.value(); } else { LOG(FATAL) << "ValueError: " << var->name_hint @@ -468,7 +486,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { } else { LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc"; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); if (!iter_var->dom.defined()) { const_cast(iter_var.get())->dom = Range(tvm::tir::make_zero(extent.dtype()), extent); @@ -482,48 +500,50 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { return LaunchThreadFrame(n); } -LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) { +LaunchThreadFrame LaunchThread(ffi::String thread_tag, PrimExpr extent) { return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent); } -RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String storage_scope, PrimExpr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->buffer_slice = buffer_slice; n->storage_scope = storage_scope; n->condition = condition; return RealizeFrame(n); } -AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope, - Optional condition, Optional> annotations) { - ObjectPtr n = make_object(); +AllocateFrame Allocate(ffi::Array extents, DataType dtype, ffi::String storage_scope, + ffi::Optional condition, + ffi::Optional> annotations) { + ObjectPtr n = ffi::make_object(); n->extents = extents; n->dtype = dtype; n->storage_scope = storage_scope; n->condition = condition.value_or(tvm::Bool(true)); - n->annotations = annotations.value_or(Map()); + n->annotations = annotations.value_or(ffi::Map()); n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype), storage_scope)); return AllocateFrame(n); } -AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, - Array extents, Optional> annotations) { - ObjectPtr n = make_object(); +AllocateConstFrame AllocateConst(tvm::runtime::Tensor data, DataType dtype, + ffi::Array extents, + ffi::Optional> annotations) { + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->extents = extents; n->data = data; - n->annotations = annotations.value_or(Map()); + n->annotations = annotations.value_or(ffi::Map()); n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype))); return AllocateConstFrame(n); } -AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) { +AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value) { // convert POD value to PrimExpr if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { node = node.cast(); } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->node = std::move(node); n->attr_key = attr_key; n->value = value; @@ -531,13 +551,13 @@ AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) { } WhileFrame While(PrimExpr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->condition = condition; return WhileFrame(n); } IfFrame If(PrimExpr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->condition = condition; n->then_stmts = std::nullopt; n->else_stmts = std::nullopt; @@ -545,19 +565,19 @@ IfFrame If(PrimExpr condition) { } ThenFrame Then() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ThenFrame(n); } ElseFrame Else() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ElseFrame(n); } -Var EnvThread(String thread_tag, DataType dtype) { +Var EnvThread(ffi::String thread_tag, DataType dtype) { IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag); Var var = iter_var->var; - if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { + if (ffi::Optional opt_frame = IRBuilder::Current()->FindFrame()) { opt_frame.value()->env_threads.Set(var, iter_var); } else { LOG(FATAL) << "EnvThread can only be used inside a PrimFunc"; @@ -565,8 +585,8 @@ Var EnvThread(String thread_tag, DataType dtype) { return var; } -void BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate = std::nullopt) { +void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate = std::nullopt) { runtime::DataType buffer_dtype = buffer->dtype; bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); @@ -631,12 +651,12 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices, AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } -DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, - Optional data, Optional> strides, - Optional elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type, - Optional> axis_separators) { - ObjectPtr n = make_object(); +DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, + int align, int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators) { + ObjectPtr n = ffi::make_object(); n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type, axis_separators); n->allocated = data.defined(); @@ -645,7 +665,8 @@ DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_ void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } -PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global", bool is_size_var = false) { +PrimExpr Ptr(runtime::DataType dtype, ffi::String storage_scope = "global", + bool is_size_var = false) { PointerType type_annotation(PrimType(dtype), storage_scope); return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); } @@ -653,7 +674,7 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global", bool is_s using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { tvm::tir::BufferNode* buffer = const_cast(node.as()); buffer->name = name; @@ -661,40 +682,41 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) int n = buffer->strides.size(); for (int i = 0; i < n; ++i) { PrimExpr e = buffer->strides[i]; - if (auto v = e.as()) { - Namer::Name(v.value(), name + "_s" + std::to_string(i)); + if (const auto* v = e.as()) { + ffi::String new_name = !v->name_hint.empty() ? v->name_hint : (name + "_s" + std::to_string(i)); + Namer::Name(ffi::GetRef(v), ffi::String(new_name)); } } }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using namespace tvm::tir; SizeVarNode* var = const_cast(node.as()); - var->name_hint = name; + var->name_hint = ffi::String(name); }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using namespace tvm::tir; VarNode* var = const_cast(node.as()); var->name_hint = name; }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using namespace tvm::tir; IterVarNode* var = const_cast(node.as()); Namer::Name(var->var, name); }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Buffer", BufferDecl) .def("script.ir_builder.tir.PrimFunc", PrimFunc) .def("script.ir_builder.tir.Arg", - [](String name, ObjectRef obj) -> ObjectRef { + [](ffi::String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; if (auto var = obj.as()) { return Arg(name, var.value()); @@ -740,10 +762,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.ir_builder.tir.Else", Else) .def("script.ir_builder.tir.DeclBuffer", DeclBuffer) .def("script.ir_builder.tir.LaunchThread", - [](ffi::Variant thread_tag_or_var, PrimExpr extent) { + [](ffi::Variant thread_tag_or_var, PrimExpr extent) { if (auto var = thread_tag_or_var.as()) { return LaunchThread(var.value(), extent); - } else if (auto str = thread_tag_or_var.as()) { + } else if (auto str = thread_tag_or_var.as()) { return LaunchThread(str.value(), extent); } else { LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " @@ -755,7 +777,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.ir_builder.tir.BufferStore", BufferStore) .def("script.ir_builder.tir.Evaluate", Evaluate) .def("script.ir_builder.tir.Ptr", Ptr); -}); +} #define TVM_TMP_STR(x) #x @@ -766,7 +788,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def(Prefix TVM_TMP_STR(64), DType##64) #define TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix, Func) \ - def(Prefix TVM_TMP_STR(x4), Func##x4) \ + def(Prefix TVM_TMP_STR(x2), Func##x2) \ + .def(Prefix TVM_TMP_STR(x4), Func##x4) \ .def(Prefix TVM_TMP_STR(x8), Func##x8) \ .def(Prefix TVM_TMP_STR(x16), Func##x16) \ .def(Prefix TVM_TMP_STR(x32), Func##x32) \ @@ -778,7 +801,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32) \ .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64) -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.BFloat16", BFloat16) @@ -789,89 +812,96 @@ TVM_FFI_STATIC_INIT_BLOCK({ .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt) .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); -}); +} // Float8 variants -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E3M4", Float8E3M4) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4", Float8E3M4); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E4M3", Float8E4M3) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3", Float8E4M3); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E5M2", Float8E5M2) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); -}); +} // Float6 variants -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); -}); +} // Float4 variant -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN) .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("script.ir_builder.tir.TensorFloat32", TensorFloat32) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.TensorFloat32", TensorFloat32); +} + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.ir_builder.tir.Boolean", Boolean) @@ -882,7 +912,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }) .def("script.ir_builder.tir.max", [](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); -}); +} } // namespace tir } // namespace ir_builder } // namespace script diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 9703a2adc323..655dea5fbda3 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -39,7 +39,7 @@ inline void AddToParent(tvm::tir::Stmt stmt) { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = stmt; } else if (const auto* tir_frame = builder->frames.back().as()) { - GetRef(tir_frame)->stmts.push_back(stmt); + ffi::GetRef(tir_frame)->stmts.push_back(stmt); } else { LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back(); } @@ -50,7 +50,7 @@ inline void AddToParent(tvm::tir::Stmt stmt) { * \param stmt The array of Stmt. * \return The SeqStmt. */ -inline tvm::tir::Stmt AsStmt(const Array& stmt) { +inline tvm::tir::Stmt AsStmt(const ffi::Array& stmt) { return tvm::tir::SeqStmt::Flatten(stmt); } @@ -59,10 +59,11 @@ inline tvm::tir::Stmt AsStmt(const Array& stmt) { * \param method The method name to be printed when throwing exception. * \return The top frame of PrimFuncFrame. */ -inline PrimFuncFrame FindPrimFuncFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { +inline PrimFuncFrame FindPrimFuncFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); - } else if (Optional frame = IRBuilder::Current()->FindFrame()) { + } else if (ffi::Optional frame = + IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a PrimFunc. " << "While " << method << " did occur within the PrimFunc \"" << frame.value()->name << "\", other frames (e.g. block/if/else/let) had been introduced since the " @@ -74,15 +75,33 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { throw; } +/*! + * \brief Find a PrimFuncFrame anywhere in the current builder stack (not necessarily the top). + * This relaxed variant enables certain APIs (e.g., T.match_buffer on a PrimFunc param) + * to be invoked after non-top-level frames (let/if/for) have been introduced, while + * still being inside a PrimFunc scope. + * \param method The method name to be printed when throwing exception. + * \return The PrimFuncFrame found in the builder stack. + */ +inline PrimFuncFrame FindPrimFuncFrameRelaxed(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: " << method << " must be called under a T.prim_func(), " + << "but it occurred outside of any T.prim_func() frame"; + } + throw; +} + /*! * \brief Check whether the top frame in IRBuilder frame stack is BlockFrame. * \param method The method name to be printed when throwing exception. * \return The top frame of BlockFrame. */ -inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->FindFrame()) { +inline BlockFrame FindBlockFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); - } else if (Optional frame = IRBuilder::Current()->FindFrame()) { + } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). " << "While " << method << " did occur within the block \"" << frame.value()->name << "\", other frames (e.g. if/else/let) had been introduced since the T.block(\"" @@ -99,10 +118,10 @@ inline BlockFrame FindBlockFrame(const String& method) { * \param method The method name to be printed when throwing exception. * \return The top frame of IfFrame. */ -inline IfFrame FindIfFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { +inline IfFrame FindIfFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); - } else if (Optional frame = IRBuilder::Current()->FindFrame()) { + } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.if_(). " << "While " << method << " did occur within the conditional based on (" << frame.value()->condition @@ -121,7 +140,7 @@ inline IfFrame FindIfFrame(const String& method) { * \return The converted BufferRegion. */ inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) { - Array ranges; + ffi::Array ranges; for (const PrimExpr& index : buffer_load->indices) { ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1))); } diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index aa7e0473488b..e5d72c002da0 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -26,7 +26,7 @@ namespace tvm { namespace script { namespace printer { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { DocNode::RegisterReflection(); ExprDocNode::RegisterReflection(); StmtDocNode::RegisterReflection(); @@ -54,33 +54,36 @@ TVM_FFI_STATIC_INIT_BLOCK({ ClassDocNode::RegisterReflection(); CommentDocNode::RegisterReflection(); DocStringDocNode::RegisterReflection(); -}); +} -ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } +ExprDoc ExprDocNode::Attr(ffi::String attr) const { + return AttrAccessDoc(ffi::GetRef(this), attr); +} -ExprDoc ExprDocNode::operator[](Array indices) const { - return IndexDoc(GetRef(this), indices); +ExprDoc ExprDocNode::operator[](ffi::Array indices) const { + return IndexDoc(ffi::GetRef(this), indices); } -ExprDoc ExprDocNode::Call(Array args) const { - return CallDoc(GetRef(this), args, Array(), Array()); +ExprDoc ExprDocNode::Call(ffi::Array args) const { + return CallDoc(ffi::GetRef(this), args, ffi::Array(), + ffi::Array()); } -ExprDoc ExprDocNode::Call(Array args, Array kwargs_keys, - Array kwargs_values) const { - return CallDoc(GetRef(this), args, kwargs_keys, kwargs_values); +ExprDoc ExprDocNode::Call(ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values) const { + return CallDoc(ffi::GetRef(this), args, kwargs_keys, kwargs_values); } -ExprDoc ExprDoc::operator[](Array indices) const { return (*get())[indices]; } +ExprDoc ExprDoc::operator[](ffi::Array indices) const { return (*get())[indices]; } -StmtBlockDoc::StmtBlockDoc(Array stmts) { - ObjectPtr n = make_object(); +StmtBlockDoc::StmtBlockDoc(ffi::Array stmts) { + ObjectPtr n = ffi::make_object(); n->stmts = stmts; this->data_ = std::move(n); } -LiteralDoc::LiteralDoc(ffi::Any value, const Optional& object_path) { - ObjectPtr n = make_object(); +LiteralDoc::LiteralDoc(ffi::Any value, const ffi::Optional& object_path) { + ObjectPtr n = ffi::make_object(); n->value = value; if (object_path.defined()) { n->source_paths.push_back(object_path.value()); @@ -88,29 +91,29 @@ LiteralDoc::LiteralDoc(ffi::Any value, const Optional& object_path) this->data_ = std::move(n); } -IdDoc::IdDoc(String name) { - ObjectPtr n = make_object(); +IdDoc::IdDoc(ffi::String name) { + ObjectPtr n = ffi::make_object(); n->name = name; this->data_ = std::move(n); } -AttrAccessDoc::AttrAccessDoc(ExprDoc value, String name) { - ObjectPtr n = make_object(); +AttrAccessDoc::AttrAccessDoc(ExprDoc value, ffi::String name) { + ObjectPtr n = ffi::make_object(); n->value = value; n->name = name; this->data_ = std::move(n); } -IndexDoc::IndexDoc(ExprDoc value, Array indices) { - ObjectPtr n = make_object(); +IndexDoc::IndexDoc(ExprDoc value, ffi::Array indices) { + ObjectPtr n = ffi::make_object(); n->value = value; n->indices = indices; this->data_ = std::move(n); } -CallDoc::CallDoc(ExprDoc callee, Array args, Array kwargs_keys, - Array kwargs_values) { - ObjectPtr n = make_object(); +CallDoc::CallDoc(ExprDoc callee, ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values) { + ObjectPtr n = ffi::make_object(); n->callee = callee; n->args = args; n->kwargs_keys = kwargs_keys; @@ -118,96 +121,97 @@ CallDoc::CallDoc(ExprDoc callee, Array args, Array kwargs_keys, this->data_ = std::move(n); } -OperationDoc::OperationDoc(OperationDocNode::Kind kind, Array operands) { - ObjectPtr n = make_object(); +OperationDoc::OperationDoc(OperationDocNode::Kind kind, ffi::Array operands) { + ObjectPtr n = ffi::make_object(); n->kind = kind; n->operands = operands; this->data_ = std::move(n); } -LambdaDoc::LambdaDoc(Array args, ExprDoc body) { - ObjectPtr n = make_object(); +LambdaDoc::LambdaDoc(ffi::Array args, ExprDoc body) { + ObjectPtr n = ffi::make_object(); n->args = args; n->body = body; this->data_ = std::move(n); } -TupleDoc::TupleDoc(Array elements) { - ObjectPtr n = make_object(); +TupleDoc::TupleDoc(ffi::Array elements) { + ObjectPtr n = ffi::make_object(); n->elements = elements; this->data_ = std::move(n); } -ListDoc::ListDoc(Array elements) { - ObjectPtr n = make_object(); +ListDoc::ListDoc(ffi::Array elements) { + ObjectPtr n = ffi::make_object(); n->elements = elements; this->data_ = std::move(n); } -DictDoc::DictDoc(Array keys, Array values) { - ObjectPtr n = make_object(); +DictDoc::DictDoc(ffi::Array keys, ffi::Array values) { + ObjectPtr n = ffi::make_object(); n->keys = keys; n->values = values; this->data_ = std::move(n); } -SliceDoc::SliceDoc(Optional start, Optional stop, Optional step) { - ObjectPtr n = make_object(); +SliceDoc::SliceDoc(ffi::Optional start, ffi::Optional stop, + ffi::Optional step) { + ObjectPtr n = ffi::make_object(); n->start = start; n->stop = stop; n->step = step; this->data_ = std::move(n); } -AssignDoc::AssignDoc(ExprDoc lhs, Optional rhs, Optional annotation) { +AssignDoc::AssignDoc(ExprDoc lhs, ffi::Optional rhs, ffi::Optional annotation) { CHECK(rhs.defined() || annotation.defined()) << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc."; CHECK(lhs->IsInstance() || annotation == nullptr) << "ValueError: annotation can only be nonnull if lhs is an identifier."; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->lhs = lhs; n->rhs = rhs; n->annotation = annotation; this->data_ = std::move(n); } -IfDoc::IfDoc(ExprDoc predicate, Array then_branch, Array else_branch) { +IfDoc::IfDoc(ExprDoc predicate, ffi::Array then_branch, ffi::Array else_branch) { CHECK(!then_branch.empty() || !else_branch.empty()) << "ValueError: At least one of the then branch or else branch needs to be non-empty."; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->predicate = predicate; n->then_branch = then_branch; n->else_branch = else_branch; this->data_ = std::move(n); } -WhileDoc::WhileDoc(ExprDoc predicate, Array body) { - ObjectPtr n = make_object(); +WhileDoc::WhileDoc(ExprDoc predicate, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->predicate = predicate; n->body = body; this->data_ = std::move(n); } -ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, Array body) { - ObjectPtr n = make_object(); +ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->lhs = lhs; n->rhs = rhs; n->body = body; this->data_ = std::move(n); } -ScopeDoc::ScopeDoc(Optional lhs, ExprDoc rhs, Array body) { - ObjectPtr n = make_object(); +ScopeDoc::ScopeDoc(ffi::Optional lhs, ExprDoc rhs, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->lhs = lhs; n->rhs = rhs; n->body = body; this->data_ = std::move(n); } -ScopeDoc::ScopeDoc(ExprDoc rhs, Array body) { - ObjectPtr n = make_object(); +ScopeDoc::ScopeDoc(ExprDoc rhs, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->lhs = std::nullopt; n->rhs = rhs; n->body = body; @@ -215,27 +219,27 @@ ScopeDoc::ScopeDoc(ExprDoc rhs, Array body) { } ExprStmtDoc::ExprStmtDoc(ExprDoc expr) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->expr = expr; this->data_ = std::move(n); } -AssertDoc::AssertDoc(ExprDoc test, Optional msg) { - ObjectPtr n = make_object(); +AssertDoc::AssertDoc(ExprDoc test, ffi::Optional msg) { + ObjectPtr n = ffi::make_object(); n->test = test; n->msg = msg; this->data_ = std::move(n); } ReturnDoc::ReturnDoc(ExprDoc value) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->value = value; this->data_ = std::move(n); } -FunctionDoc::FunctionDoc(IdDoc name, Array args, Array decorators, - Optional return_type, Array body) { - ObjectPtr n = make_object(); +FunctionDoc::FunctionDoc(IdDoc name, ffi::Array args, ffi::Array decorators, + ffi::Optional return_type, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->args = args; n->decorators = decorators; @@ -244,57 +248,59 @@ FunctionDoc::FunctionDoc(IdDoc name, Array args, Array decor this->data_ = std::move(n); } -ClassDoc::ClassDoc(IdDoc name, Array decorators, Array body) { - ObjectPtr n = make_object(); +ClassDoc::ClassDoc(IdDoc name, ffi::Array decorators, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->decorators = decorators; n->body = body; this->data_ = std::move(n); } -CommentDoc::CommentDoc(String comment) { - ObjectPtr n = make_object(); +CommentDoc::CommentDoc(ffi::String comment) { + ObjectPtr n = ffi::make_object(); n->comment = comment; this->data_ = std::move(n); } -DocStringDoc::DocStringDoc(String docs) { - ObjectPtr n = make_object(); +DocStringDoc::DocStringDoc(ffi::String docs) { + ObjectPtr n = ffi::make_object(); n->comment = docs; this->data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "script.printer.DocSetSourcePaths", - [](Doc doc, Array source_paths) { doc->source_paths = source_paths; }); -}); + [](Doc doc, ffi::Array source_paths) { doc->source_paths = source_paths; }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("script.printer.ExprDocAttr", &ExprDocNode::Attr) .def_method("script.printer.ExprDocIndex", &ExprDocNode::operator[]) - .def_method( - "script.printer.ExprDocCall", - [](ExprDoc doc, Array args, Array kwargs_keys, - Array kwargs_values) { return doc->Call(args, kwargs_keys, kwargs_values); }); -}); + .def_method("script.printer.ExprDocCall", + [](ExprDoc doc, ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values) { + return doc->Call(args, kwargs_keys, kwargs_values); + }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.StmtDocSetComment", - [](StmtDoc doc, Optional comment) { doc->comment = comment; }); -}); + refl::GlobalDef().def( + "script.printer.StmtDocSetComment", + [](StmtDoc doc, ffi::Optional comment) { doc->comment = comment; }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.StmtBlockDoc", - [](Array stmts) { return StmtBlockDoc(stmts); }); -}); + [](ffi::Array stmts) { return StmtBlockDoc(stmts); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("script.printer.LiteralDocNone", LiteralDoc::None) @@ -302,158 +308,161 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.printer.LiteralDocBoolean", LiteralDoc::Boolean) .def("script.printer.LiteralDocFloat", LiteralDoc::Float) .def("script.printer.LiteralDocStr", LiteralDoc::Str); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.IdDoc", [](String name) { return IdDoc(name); }); -}); + refl::GlobalDef().def("script.printer.IdDoc", [](ffi::String name) { return IdDoc(name); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.AttrAccessDoc", - [](ExprDoc value, String attr) { return AttrAccessDoc(value, attr); }); -}); + [](ExprDoc value, ffi::String attr) { return AttrAccessDoc(value, attr); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.IndexDoc", - [](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); -}); + refl::GlobalDef().def("script.printer.IndexDoc", [](ExprDoc value, ffi::Array indices) { + return IndexDoc(value, indices); + }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.CallDoc", [](ExprDoc callee, // - Array args, // - Array kwargs_keys, // - Array kwargs_values) { + refl::GlobalDef().def("script.printer.CallDoc", [](ExprDoc callee, // + ffi::Array args, // + ffi::Array kwargs_keys, // + ffi::Array kwargs_values) { return CallDoc(callee, args, kwargs_keys, kwargs_values); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.OperationDoc", [](int32_t kind, Array operands) { - return OperationDoc(OperationDocNode::Kind(kind), operands); - }); -}); + refl::GlobalDef().def("script.printer.OperationDoc", + [](int32_t kind, ffi::Array operands) { + return OperationDoc(OperationDocNode::Kind(kind), operands); + }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.LambdaDoc", - [](Array args, ExprDoc body) { return LambdaDoc(args, body); }); -}); + [](ffi::Array args, ExprDoc body) { return LambdaDoc(args, body); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.TupleDoc", - [](Array elements) { return TupleDoc(elements); }); -}); + [](ffi::Array elements) { return TupleDoc(elements); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ListDoc", - [](Array elements) { return ListDoc(elements); }); -}); + [](ffi::Array elements) { return ListDoc(elements); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.DictDoc", [](Array keys, Array values) { - return DictDoc(keys, values); - }); -}); + refl::GlobalDef().def( + "script.printer.DictDoc", + [](ffi::Array keys, ffi::Array values) { return DictDoc(keys, values); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.SliceDoc", - [](Optional start, Optional stop, - Optional step) { return SliceDoc(start, stop, step); }); -}); + [](ffi::Optional start, ffi::Optional stop, + ffi::Optional step) { return SliceDoc(start, stop, step); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.AssignDoc", - [](ExprDoc lhs, Optional rhs, Optional annotation) { - return AssignDoc(lhs, rhs, annotation); - }); -}); + refl::GlobalDef().def("script.printer.AssignDoc", [](ExprDoc lhs, ffi::Optional rhs, + ffi::Optional annotation) { + return AssignDoc(lhs, rhs, annotation); + }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.IfDoc", [](ExprDoc predicate, Array then_branch, - Array else_branch) { - return IfDoc(predicate, then_branch, else_branch); - }); -}); + refl::GlobalDef().def( + "script.printer.IfDoc", + [](ExprDoc predicate, ffi::Array then_branch, ffi::Array else_branch) { + return IfDoc(predicate, then_branch, else_branch); + }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.WhileDoc", [](ExprDoc predicate, Array body) { + refl::GlobalDef().def("script.printer.WhileDoc", [](ExprDoc predicate, ffi::Array body) { return WhileDoc(predicate, body); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.ForDoc", [](ExprDoc lhs, ExprDoc rhs, Array body) { - return ForDoc(lhs, rhs, body); - }); -}); + refl::GlobalDef().def( + "script.printer.ForDoc", + [](ExprDoc lhs, ExprDoc rhs, ffi::Array body) { return ForDoc(lhs, rhs, body); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ScopeDoc", - [](Optional lhs, ExprDoc rhs, Array body) { + [](ffi::Optional lhs, ExprDoc rhs, ffi::Array body) { return ScopeDoc(lhs, rhs, body); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ExprStmtDoc", [](ExprDoc expr) { return ExprStmtDoc(expr); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "script.printer.AssertDoc", - [](ExprDoc test, Optional msg = std::nullopt) { return AssertDoc(test, msg); }); -}); + [](ExprDoc test, ffi::Optional msg = std::nullopt) { return AssertDoc(test, msg); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ReturnDoc", [](ExprDoc value) { return ReturnDoc(value); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.FunctionDoc", - [](IdDoc name, Array args, Array decorators, - Optional return_type, Array body) { + [](IdDoc name, ffi::Array args, ffi::Array decorators, + ffi::Optional return_type, ffi::Array body) { return FunctionDoc(name, args, decorators, return_type, body); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ClassDoc", - [](IdDoc name, Array decorators, Array body) { + [](IdDoc name, ffi::Array decorators, ffi::Array body) { return ClassDoc(name, decorators, body); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.CommentDoc", - [](String comment) { return CommentDoc(comment); }); -}); + [](ffi::String comment) { return CommentDoc(comment); }); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.DocStringDoc", - [](String docs) { return DocStringDoc(docs); }); -}); + [](ffi::String docs) { return DocStringDoc(docs); }); +} } // namespace printer } // namespace script diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 7e6d76c4bf9a..77990c8048c5 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -275,7 +275,7 @@ void DocPrinter::Append(const Doc& doc, const PrinterConfig& cfg) { } } -String DocPrinter::GetString() const { +ffi::String DocPrinter::GetString() const { std::string text = output_.str(); // Remove any trailing indentation diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index b92c9dbe7aa2..53c388f84a5b 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -81,7 +81,7 @@ class DocPrinter { * * \sa Append */ - String GetString() const; + ffi::String GetString() const; protected: /*! @@ -267,7 +267,7 @@ class DocPrinter { std::vector line_starts_; /*! \brief Path of the object that we would like to underline */ - Array path_to_underline_; + ffi::Array path_to_underline_; /*! * \brief Candidate spans to be underlined, until we find a better match. diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 21f5e3301568..1a79806d1621 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -182,7 +182,7 @@ class PythonDocPrinter : public DocPrinter { } template - void PrintJoinedDocs(const Array& docs, const std::string& separator) { + void PrintJoinedDocs(const ffi::Array& docs, const std::string& separator) { bool is_first = true; for (auto& doc : docs) { if (is_first) { @@ -194,7 +194,7 @@ class PythonDocPrinter : public DocPrinter { } } - void PrintIndentedBlock(const Array& docs) { + void PrintIndentedBlock(const ffi::Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { NewLine(); @@ -207,7 +207,7 @@ class PythonDocPrinter : public DocPrinter { DecreaseIndent(); } - void PrintDecorators(const Array& decorators) { + void PrintDecorators(const ffi::Array& decorators) { for (const ExprDoc& decorator : decorators) { output_ << "@"; PrintDoc(decorator); @@ -285,7 +285,7 @@ class PythonDocPrinter : public DocPrinter { } } - void PrintDocString(const String& comment) { + void PrintDocString(const ffi::String& comment) { size_t start_pos = output_.tellp(); output_ << "\"\"\""; @@ -304,7 +304,7 @@ class PythonDocPrinter : public DocPrinter { underlines_exempted_.push_back({start_pos, end_pos}); } - void PrintBlockComment(const String& comment) { + void PrintBlockComment(const ffi::String& comment) { IncreaseIndent(); NewLine(); PrintDocString(comment); @@ -484,7 +484,7 @@ void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { } else { output_ << ", "; } - const String& keyword = doc->kwargs_keys[i]; + const ffi::String& keyword = doc->kwargs_keys[i]; output_ << keyword; output_ << "="; PrintDoc(doc->kwargs_values[i]); @@ -714,7 +714,7 @@ void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) { } } -String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { +ffi::String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { if (cfg->num_context_lines < 0) { cfg->num_context_lines = std::numeric_limits::max(); } @@ -728,10 +728,10 @@ String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { return result.substr(0, last_space); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.DocToPythonScript", DocToPythonScript); -}); +} } // namespace printer } // namespace script diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index fd478768bf32..62d4c3ad6132 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -28,7 +28,7 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](ffi::Shape n, AccessPath n_p, IRDocsifier d) -> Doc { int s = n.size(); - Array results; + ffi::Array results; results.reserve(s); for (int i = 0; i < s; ++i) { results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayItem(i))); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 70be98f4c425..aac5656f9146 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -24,7 +24,7 @@ namespace tvm { namespace script { namespace printer { -TVM_FFI_STATIC_INIT_BLOCK({ IRFrameNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { IRFrameNode::RegisterReflection(); } struct SortableFunction { int priority; @@ -130,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](VDevice vdev, AccessPath p, IRDocsifier d) -> Doc { d->AddGlobalInfo("vdevice", vdev); - Map config = vdev->target->Export(); + ffi::Map config = vdev->target->Export(); return IR(d, "vdevice") ->Call({d->AsDoc(config, p), LiteralDoc::Int(vdev->vdevice_id, p->Attr("vdevice_id")), diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index 5643ab4de43a..f33170577154 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -23,10 +23,10 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch>( // - "", [](Array array, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch>( // + "", [](ffi::Array array, AccessPath p, IRDocsifier d) -> Doc { int n = array.size(); - Array results; + ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { results.push_back(d->AsDoc(array[i], p->ArrayItem(i))); @@ -35,8 +35,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch>( // - "", [](Map dict, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch>( // + "", [](ffi::Map dict, AccessPath p, IRDocsifier d) -> Doc { using POO = std::pair; std::vector items{dict.begin(), dict.end()}; bool is_str_map = true; @@ -48,12 +48,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } if (is_str_map) { std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) { - return Downcast(lhs.first) < Downcast(rhs.first); + return Downcast(lhs.first) < Downcast(rhs.first); }); } int n = dict.size(); - Array ks; - Array vs; + ffi::Array ks; + ffi::Array vs; ks.reserve(n); vs.reserve(n); for (int i = 0; i < n; ++i) { diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index d79e5cd4565d..588e6066d9c0 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -37,28 +37,26 @@ namespace printer { class IRFrameNode : public FrameNode { public: - Map>* global_infos = nullptr; + ffi::Map>* global_infos = nullptr; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; // global infos is not exposed } - - static constexpr const char* _type_key = "script.printer.IRFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(IRFrameNode, FrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.IRFrame", IRFrameNode, FrameNode); }; class IRFrame : public Frame { public: explicit IRFrame(const IRDocsifier& d) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmts.clear(); n->d = d.get(); n->global_infos = nullptr; data_ = std::move(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRFrame, Frame, IRFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IRFrame, Frame, IRFrameNode); }; /*! \brief Redirected method for the ReprPrinter */ diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index efe7bc2f937a..8ebbedfef78d 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -30,12 +30,13 @@ namespace tvm { namespace script { namespace printer { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { FrameNode::RegisterReflection(); IRDocsifierNode::RegisterReflection(); -}); +} -IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { +IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, + const ffi::String& name_hint) { if (auto it = obj2info.find(obj); it != obj2info.end()) { // TVM's IR dialects do not allow multiple definitions of the same // variable within an IRModule. This branch can only be reached @@ -51,7 +52,7 @@ IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const St return IdDoc(it->second.name.value()); } - String name = name_hint; + ffi::String name = name_hint; if (cfg->show_object_address) { std::stringstream stream; stream << name << "_" << obj.get(); @@ -72,7 +73,7 @@ void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreato frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); } -Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { +ffi::Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { auto it = obj2info.find(obj); if (it == obj2info.end()) { return std::nullopt; @@ -82,8 +83,8 @@ Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) { ICHECK(obj != nullptr) << "TypeError: Cannot add nullptr to metadata"; - String key = obj.GetTypeKey(); - Array& array = metadata[key]; + ffi::String key = obj.GetTypeKey(); + ffi::Array& array = metadata[key]; int index = std::find_if(array.begin(), array.end(), [&](const ffi::Any& a) { return ffi::AnyEqual()(a, obj); }) - array.begin(); @@ -94,9 +95,9 @@ ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) { "metadata")[{LiteralDoc::Str(key, std::nullopt)}][{LiteralDoc::Int(index, std::nullopt)}]; } -void IRDocsifierNode::AddGlobalInfo(const String& name, const GlobalInfo& ginfo) { +void IRDocsifierNode::AddGlobalInfo(const ffi::String& name, const GlobalInfo& ginfo) { ICHECK(ginfo.defined()) << "TypeError: Cannot add nullptr to global_infos"; - Array& array = global_infos[name]; + ffi::Array& array = global_infos[name]; array.push_back(ginfo); } @@ -191,11 +192,11 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, } IRDocsifier::IRDocsifier(const PrinterConfig& cfg) { - auto n = make_object(); + auto n = ffi::make_object(); n->cfg = cfg; n->dispatch_tokens.push_back(""); // Define builtin keywords according to cfg. - for (const String& keyword : cfg->GetBuiltinKeywords()) { + for (const ffi::String& keyword : cfg->GetBuiltinKeywords()) { n->defined_names.insert(keyword); } data_ = std::move(n); diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index d4580af96891..19da2cd508aa 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -23,15 +23,15 @@ namespace script { namespace printer { IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, const IRDocsifier& d, // - const Optional& var, const Optional& ann) { + const ffi::Optional& var, const ffi::Optional& ann) { using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); - std::vector> branches{ + std::vector> branches{ PrintSeqExpr(n->true_branch, n_p->Attr("true_branch"), d, false), PrintSeqExpr(n->false_branch, n_p->Attr("false_branch"), d, false), }; if (var.defined()) { - for (Array& stmts : branches) { + for (ffi::Array& stmts : branches) { ExprDoc ret = Downcast(stmts.back())->expr; stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); } @@ -44,7 +44,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](relax::MatchCast n, AccessPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; - Optional ann = std::nullopt; + ffi::Optional ann = std::nullopt; if (d->cfg->show_all_struct_info) { ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); } @@ -59,9 +59,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::VarBinding n, AccessPath n_p, IRDocsifier d) -> Doc { if (const auto if_ = n->value.as()) { - Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); - return PrintIfExpr(GetRef(if_), n_p->Attr("value"), d, lhs, ann); + return PrintIfExpr(ffi::GetRef(if_), n_p->Attr("value"), d, lhs, ann); } else if (n->value->IsInstance() && !n->value->IsInstance()) { IdDoc lhs = DefineVar(n->var, d->frames.back(), d); @@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ExprStmtDoc(rhs); } else { ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); - Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); return AssignDoc(lhs, rhs, ann); } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index e7e7e21380e4..6d96327e2db4 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -29,8 +29,8 @@ namespace printer { class AttrPrinter { public: - explicit AttrPrinter(AccessPath p, const IRDocsifier& d, Array* keys, - Array* values) + explicit AttrPrinter(AccessPath p, const IRDocsifier& d, ffi::Array* keys, + ffi::Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} void operator()(const tvm::Attrs& attrs) { @@ -46,7 +46,7 @@ class AttrPrinter { << "` misses reflection registration and do not support serialization"; // new printing mechanism using the new reflection ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { - String field_name = String(field_info->name); + ffi::String field_name = ffi::String(field_info->name); Any field_value = ffi::reflection::FieldGetter(field_info)(attrs); keys->push_back(field_name); values->push_back(d->AsDoc(field_value, p->Attr(field_name))); @@ -56,8 +56,8 @@ class AttrPrinter { AccessPath p; const IRDocsifier& d; - Array* keys; - Array* values; + ffi::Array* keys; + ffi::Array* values; }; ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifier& d) { @@ -69,8 +69,8 @@ ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifi } } -Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); @@ -83,9 +83,9 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& } ICHECK(n->args.size() == 2 || n->args.size() == 3); ICHECK(n->sinfo_args.size() == 1); - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; // Step 1. Print n->args[0], the callee args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); // Step 2. Print n->args[1], the input arguments @@ -96,7 +96,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& bool is_dtensor = false; kwargs_keys.push_back("out_sinfo"); if (const auto* o = o_sinfo.as()) { - Array fields; + ffi::Array fields; AccessPath fields_p = o_sinfo_p->Attr("fields"); for (int i = 0, l = o->fields.size(); i < l; ++i) { if (o->fields[i].as()) { @@ -115,7 +115,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& // for call_tir_inplace, we also need to include the inplace args if (n->op.same_as(call_tir_inplace_op)) { kwargs_keys.push_back("inplace_indices"); - Array index_fields; + ffi::Array index_fields; if (auto* call_tir_inplace_attrs = n->attrs.as()) { for (auto inplace_index : call_tir_inplace_attrs->inplace_indices) { index_fields.push_back( @@ -160,7 +160,8 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& } } -Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { +ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& assert_op = Op::Get("relax.assert_op"); if (!n->op.same_as(assert_op)) { return std::nullopt; @@ -170,7 +171,7 @@ Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, con // is the _format_ string, or else roundtripping will fail // (the format string will be interpreted as an argument and there will be a new default format // string given) - Array args; + ffi::Array args; args.push_back(d->AsDoc(n->args[0], n_p->Attr("args")->ArrayItem(0))); ExprDoc second_arg = d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1)); for (size_t i = 2; i < n->args.size(); i++) { @@ -179,36 +180,40 @@ Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, con return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg}); } -Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& hint_on_device_op = Op::Get("relax.hint_on_device"); if (!n->op.same_as(hint_on_device_op)) { return std::nullopt; } - Array args; + ffi::Array args; args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); - Array kwargs_keys; - Array kwargs_values; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; ICHECK(n->attrs.defined()); if (n->attrs.as()) { AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs); + ExprDoc scope_val = kwargs_values.back(); + kwargs_keys.pop_back(); + kwargs_values.pop_back(); args.push_back(Relax(d, "device")->Call({}, kwargs_keys, kwargs_values)); + args.push_back(scope_val); } return Relax(d, "hint_on_device")->Call(args); } -Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& to_vdevice_op = Op::Get("relax.to_vdevice"); if (!n->op.same_as(to_vdevice_op)) { return std::nullopt; } - Array args; + ffi::Array args; args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); - Array kwargs_keys; - Array kwargs_values; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; ICHECK(n->attrs.defined()); if (const auto* attrs = n->attrs.as()) { VDevice vdev = attrs->dst_vdevice; @@ -216,13 +221,14 @@ Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, int dev_index = FindVDeviceIndexByTargetKind(vdev, d); kwargs_keys.push_back("dst_vdevice"); kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("dst_vdevice"))); + LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index) + ":" + vdev->memory_scope, + n_p->Attr("dst_vdevice"))); } return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } -Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& print_op = Op::Get("relax.print"); if (!n->op.same_as(print_op)) { return std::nullopt; @@ -233,7 +239,7 @@ Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, // (the format string will be interpreted as an argument and there will be a new default format // string given) ExprDoc first_arg = d->AsDoc(n->args[0], n_p->Attr("args")->ArrayItem(0)); - Array args; + ffi::Array args; for (size_t i = 1; i < n->args.size(); i++) { args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayItem(i))); } @@ -244,29 +250,29 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, AccessPath n_p, IRDocsifier d) -> Doc { // Special case: call_tir, call_dps_packed, call_tir_with_grad - if (Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { + if (ffi::Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); } // Special case: assert_op - if (Optional doc = PrintAssertOp(n, n_p, d)) { + if (ffi::Optional doc = PrintAssertOp(n, n_p, d)) { return doc.value(); } // Special case: hint_on_device - if (Optional doc = PrintHintOnDevice(n, n_p, d)) { + if (ffi::Optional doc = PrintHintOnDevice(n, n_p, d)) { return doc.value(); } // Special case: to_vdevice - if (Optional doc = PrintToVDevice(n, n_p, d)) { + if (ffi::Optional doc = PrintToVDevice(n, n_p, d)) { return doc.value(); } // Special case: print - if (Optional doc = PrintRelaxPrint(n, n_p, d)) { + if (ffi::Optional doc = PrintRelaxPrint(n, n_p, d)) { return doc.value(); } - ExprDoc prefix{nullptr}; - Array args; - Array kwargs_keys; - Array kwargs_values; + ExprDoc prefix{ffi::UnsafeInit()}; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; // Step 1. Print op if (const auto* op = n->op.as()) { prefix = Relax(d, "call_packed"); @@ -299,7 +305,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_values.push_back(LiteralDoc::Str(n->attrs->GetTypeKey(), n_p->Attr("attrs"))); } if (const auto* attrs = n->attrs.as()) { - std::vector> sorted; + std::vector> sorted; for (const auto& kv : attrs->dict) { sorted.push_back(kv); } @@ -317,7 +323,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 4. Print type_args if (n->sinfo_args.size() > 0) { AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); - Array sinfo_args; + ffi::Array sinfo_args; for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { sinfo_args.push_back(d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayItem(i))); } diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index d8b3871b35bc..d1a29be24f5e 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -37,16 +37,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::distributed::DTensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; bool require_kwargs = false; if (n->tensor_sinfo->shape.defined()) { // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->tensor_sinfo->shape.value().as()) { - auto shape_expr = GetRef(shape); + auto shape_expr = ffi::GetRef(shape); AccessPath shape_p = n_p->Attr("shape")->Attr("values"); - Array shape_docs; + ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( PrintShapeVar(shape_expr->values[i], shape_p->ArrayItem(i), d)); @@ -102,7 +102,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } if (!has_relax_frame || !f) { - Array args; + ffi::Array args; args.push_back(d->AsDoc(n->shape, n_p->Attr("shape"))); if (n->device_range.defined()) { args.push_back(d->AsDoc(n->device_range, n_p->Attr("device_range"))); @@ -116,7 +116,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (kv.second[i].same_as(n)) { std::stringstream ss; ss << kv.first << "[" << i << "]"; - return d->AsDoc(String(ss.str()), n_p); + return d->AsDoc(ffi::String(ss.str()), n_p); } } } diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index c411622e6409..0c8cd3c12371 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -53,7 +53,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (n->fields.empty()) { return Relax(d, "tuple")->Call({}); } - Array fields_doc; + ffi::Array fields_doc; AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::ShapeExpr n, AccessPath n_p, IRDocsifier d) -> Doc { - Array values_doc; + ffi::Array values_doc; AccessPath values_p = n_p->Attr("values"); for (int i = 0, l = n->values.size(); i < l; ++i) { values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayItem(i), d)); @@ -79,7 +79,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); -Optional SpecialScalar(const runtime::NDArray& n, const AccessPath& p) { +ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { DataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { @@ -135,7 +135,7 @@ Optional SpecialScalar(const runtime::NDArray& n, const AccessPath& p) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Constant n, AccessPath n_p, IRDocsifier d) -> Doc { - if (Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { + if (ffi::Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { if (n->struct_info_.as()) { ExprDoc ann = d->AsDoc(n->struct_info_, n_p->Attr("struct_info_")); return Relax(d, "dist.const")->Call({s.value(), ann}); diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index aa6182f189fe..978c4a8243da 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -37,7 +37,7 @@ bool AtTopLevelFunction(const IRDocsifier& d) { return d->frames.size() == 3; } -TVM_FFI_STATIC_INIT_BLOCK({ RelaxFrameNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RelaxFrameNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](relax::Function n, AccessPath n_p, IRDocsifier d) -> Doc { @@ -47,7 +47,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) IdDoc func_name(""); // if we are binding a local definition, then calling d->Define // will result in a repeated definition and an incorrect displayed name - if (Optional name = GetBindingName(d)) { + if (ffi::Optional name = GetBindingName(d)) { func_name = IdDoc(name.value()); } else { func_name = IdDoc(FindFunctionName(d, n).value_or("main")); @@ -56,13 +56,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->is_func = true; (*f)->func_vars = &func_vars; // Step 1. Print the return type - Optional ret_type = std::nullopt; + ffi::Optional ret_type = std::nullopt; if (const auto& func_sinfo = relax::MatchStructInfo(n)) { ret_type = d->AsDoc(func_sinfo.value()->ret, // n_p->Attr("struct_info_")->Attr("ret")); } // Step 2. Print params - Array params; + ffi::Array params; { AccessPath params_p = n_p->Attr("params"); for (int i = 0, l = n->params.size(); i < l; ++i) { @@ -81,8 +81,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // For a function without an IR module whose global symbol // doesn't match the function name, we should still print the global symbol attribute. if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && - Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { - Map new_attrs; + Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { + ffi::Map new_attrs; for (auto kv : n->attrs->dict) { if (kv.first != tvm::attr::kGlobalSymbol) { new_attrs.Set(kv.first, kv.second); @@ -101,26 +101,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 5. Prepare the decorator (include purity if it's impure) ExprDoc decorator = Relax(d, "function"); - Array pos_args = {}; - Array dec_keys; - Array dec_values; + ffi::Array pos_args = {}; + ffi::Array dec_keys; + ffi::Array dec_values; if (!n->is_pure) { dec_keys.push_back("pure"); - dec_values.push_back(LiteralDoc::Boolean(false, Optional())); + dec_values.push_back(LiteralDoc::Boolean(false, ffi::Optional())); } // if the function is global or is not in a module and does not have a global symbol, // indicate that it's private if (AtTopLevelFunction(d) && (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { dec_keys.push_back("private"); - dec_values.push_back(LiteralDoc::Boolean(true, Optional())); + dec_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); } if (dec_keys.size()) { decorator = decorator->Call(pos_args, dec_keys, dec_values); } // Step 6. Print body - Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); + ffi::Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); }); diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc index 7cedc63c271c..a28967cb4194 100644 --- a/src/script/printer/relax/region.cc +++ b/src/script/printer/relax/region.cc @@ -22,18 +22,18 @@ namespace tvm { namespace script { namespace printer { -Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, - bool use_ret) { +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, + const IRDocsifier& d, bool use_ret) { With f(d); - const Array& blocks = n->blocks; + const ffi::Array& blocks = n->blocks; AccessPath blocks_p = n_p->Attr("blocks"); - Array* stmts = &(*f)->stmts; + ffi::Array* stmts = &(*f)->stmts; for (int i = 0, l = blocks.size(); i < l; ++i) { Doc block = d->AsDoc(blocks[i], blocks_p->ArrayItem(i)); if (const auto* stmt_block = block.as()) { stmts->insert(stmts->end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); } else if (const auto* stmt = block.as()) { - stmts->push_back(GetRef(stmt)); + stmts->push_back(ffi::GetRef(stmt)); } else { LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey(); } @@ -52,18 +52,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); }); -Array PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p, - const IRDocsifier& d, Array* non_dataflow_vars) { - const Array& bindings = n->bindings; +ffi::Array PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p, + const IRDocsifier& d, + ffi::Array* non_dataflow_vars) { + const ffi::Array& bindings = n->bindings; AccessPath bindings_p = n_p->Attr("bindings"); - Array stmts; + ffi::Array stmts; for (int i = 0, l = bindings.size(); i < l; ++i) { const relax::Binding& binding = bindings[i]; AccessPath binding_p = bindings_p->ArrayItem(i); ICHECK(binding->var.defined()); Doc binding_doc = d->AsDoc(binding, binding_p); if (const auto* stmt = binding_doc.as()) { - stmts.push_back(GetRef(stmt)); + stmts.push_back(ffi::GetRef(stmt)); } else if (const auto* stmt_block = binding_doc.as()) { stmts.insert(stmts.end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); } else { @@ -85,8 +86,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::DataflowBlock n, AccessPath n_p, IRDocsifier d) -> Doc { - Array non_dataflow_vars; - Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); + ffi::Array non_dataflow_vars; + ffi::Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); return ScopeDoc(std::nullopt, Relax(d, "dataflow")->Call({}), stmts); }); diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 87de6a8335f5..e597df64501d 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -63,9 +63,9 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifie TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::PrimStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (n->value.defined()) { kwargs_keys.push_back("value"); @@ -81,9 +81,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::ShapeStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->values.defined()) { - Array shape = n->values.value(); + ffi::Array shape = n->values.value(); AccessPath shape_p = n_p->Attr("values"); - Array shape_docs; + ffi::Array shape_docs; for (int i = 0, ndim = shape.size(); i < ndim; ++i) { shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayItem(i), d)); } @@ -96,15 +96,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::TensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (n->shape.defined()) { // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->shape.value().as()) { - auto shape_expr = GetRef(shape); + auto shape_expr = ffi::GetRef(shape); AccessPath shape_p = n_p->Attr("shape")->Attr("values"); - Array shape_docs; + ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( PrintShapeVar(shape_expr->values[i], shape_p->ArrayItem(i), d)); @@ -126,8 +126,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_keys.push_back("vdevice"); std::string dev_kind = n->vdevice.value()->target->kind->name; int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(), d); - kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("vdevice"))); + kwargs_values.push_back(LiteralDoc::Str( + dev_kind + ":" + std::to_string(dev_index) + ":" + n->vdevice.value()->memory_scope, + n_p->Attr("vdevice"))); } if (args.empty() && kwargs_keys.empty()) { return Relax(d, "Tensor"); @@ -141,7 +142,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (n->fields.empty()) { return Relax(d, "Tuple"); } - Array fields_doc; + ffi::Array fields_doc; AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); @@ -156,8 +157,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity")); if (n->IsOpaque()) { - Array keys; - Array values; + ffi::Array keys; + ffi::Array values; if (!n->ret->IsInstance()) { keys.push_back("ret"); @@ -175,8 +176,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // TODO(@junrushao): track symbolic shape relation - Array params_doc; - Array params = n->params.value(); + ffi::Array params_doc; + ffi::Array params = n->params.value(); AccessPath params_p = n_p->Attr("params"); for (int i = 0, n_params = params.size(); i < n_params; ++i) { params_doc.push_back(d->AsDoc(params[i], params_p->ArrayItem(i))); diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 67f39a6f6c45..0c1a2cd26035 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -58,11 +58,11 @@ Doc PrintTIRVar(tir::Var n, AccessPath n_p, IRDocsifier d) { ICHECK(f->is_func); f->func_vars->insert(n.get()); } - IdDoc var = d->Define(n, GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); + IdDoc var = d->Define(n, ffi::GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); var->source_paths.push_back(n_p); f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), std::nullopt)); } - if (Optional doc = d->GetVarDoc(n)) { + if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << n; @@ -86,7 +86,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // - if (Optional doc = d->GetVarDoc(n)) { + if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { IdDoc ret(n->name_hint); @@ -98,7 +98,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // - Optional doc = d->GetVarDoc(mod); + ffi::Optional doc = d->GetVarDoc(mod); ICHECK(doc) << "Unable to print IRModule before definition in Relax."; if (d->cfg->module_alias.empty()) { // Use Module Name directly diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc index d4ad35a13ee5..032205244347 100644 --- a/src/script/printer/relax/type.cc +++ b/src/script/printer/relax/type.cc @@ -58,7 +58,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (n->fields.empty()) { return Relax(d, "Tuple"); } - Array fields_doc; + ffi::Array fields_doc; AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); @@ -69,8 +69,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "relax", [](tvm::FuncType n, AccessPath n_p, IRDocsifier d) -> Doc { - Array arg_types_doc; - Array arg_types = n->arg_types; + ffi::Array arg_types_doc; + ffi::Array arg_types = n->arg_types; AccessPath arg_types_p = n_p->Attr("arg_types"); for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayItem(i))); @@ -84,10 +84,10 @@ TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::TensorTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ReprPrintRelax", ReprPrintRelax); -}); +} } // namespace printer } // namespace script diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 37ae86220051..7dddfaecbbe7 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -50,15 +50,13 @@ class RelaxFrameNode : public FrameNode { .def_ro("is_func", &RelaxFrameNode::is_func) .def_ro("module_alias_printed", &RelaxFrameNode::module_alias_printed); } - - static constexpr const char* _type_key = "script.printer.RelaxFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(RelaxFrameNode, FrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.RelaxFrame", RelaxFrameNode, FrameNode); }; class RelaxFrame : public Frame { public: explicit RelaxFrame(const IRDocsifier& d) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmts.clear(); n->d = d.get(); n->is_func = false; @@ -66,7 +64,7 @@ class RelaxFrame : public Frame { data_ = std::move(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, Frame, RelaxFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RelaxFrame, Frame, RelaxFrameNode); }; /*! \brief Redirected method for the ReprPrinter */ @@ -81,8 +79,9 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); } -inline Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, - const IRDocsifier& d, const Optional& rhs) { +inline ffi::Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, + const IRDocsifier& d, + const ffi::Optional& rhs) { if (!v->struct_info_.defined()) { return std::nullopt; } @@ -96,7 +95,7 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& } } if (attempt_to_hide_struct_info) { - Optional inferred_sinfo = std::nullopt; + ffi::Optional inferred_sinfo = std::nullopt; if (auto opt = rhs.as()) { auto call = opt.value(); if (auto opt = call->op.as()) { @@ -133,13 +132,13 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); } -Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, - bool use_ret); +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, + const IRDocsifier& d, bool use_ret); ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d); inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifier& d) { - Array vdevices = d->global_infos["vdevice"]; + ffi::Array vdevices = d->global_infos["vdevice"]; int kind_index = 0; for (size_t i = 0; i < vdevices.size(); ++i) { auto vdev = Downcast(vdevices[i]); diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index fb4f8a9d772b..1a33d760a9d5 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -23,7 +23,8 @@ namespace script { namespace printer { Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // - Optional opt_realize, Optional opt_realize_p) { + ffi::Optional opt_realize, + ffi::Optional opt_realize_p) { With frame(d, block); ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); const tir::BlockRealizeNode* realize = @@ -35,7 +36,8 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // for (Frame f : d->frames) { if (const auto* tir_f = f.as()) { if (auto for_loop = tir_f->tir.as()) { - for (Optional loop = for_loop; loop; loop = loop.value()->body.as()) { + for (ffi::Optional loop = for_loop; loop; + loop = loop.value()->body.as()) { loop_vars.insert(std::make_pair(loop.value()->loop_var.get(), loop.value())); } } @@ -81,7 +83,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: " << tir::IterVarType2String(iter_var->iter_type); } - ExprDoc dom{nullptr}; + ExprDoc dom{ffi::UnsafeInit()}; if (tir::is_zero(iter_var->dom->min)) { ExprDoc extent = d->AsDoc(iter_var->dom->extent, // iter_var_p->Attr("dom")->Attr("extent")); @@ -113,12 +115,12 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // remap_vars_indices.clear(); return; } - Array lhs; - Array loop_var_doc; + ffi::Array lhs; + ffi::Array loop_var_doc; lhs.reserve(m); loop_var_doc.reserve(m); std::string binding_type = ""; - Array binding_paths; + ffi::Array binding_paths; for (int i : remap_vars_indices) { tir::IterVar iter_var = block->iter_vars[i]; AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); @@ -158,12 +160,12 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // } // Step 3. Handle block read/write regions { - Array reads; + ffi::Array reads; for (int i = 0, n = block->reads.size(); i < n; ++i) { reads.push_back(d->AsDoc(block->reads[i], block_p->Attr("reads")->ArrayItem(i))); } (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads")->Call(reads))); - Array writes; + ffi::Array writes; for (int i = 0, n = block->writes.size(); i < n; ++i) { writes.push_back(d->AsDoc(block->writes[i], block_p->Attr("writes")->ArrayItem(i))); } @@ -201,8 +203,8 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // } // Step 8. Handle block body AsDocBody(block->body, block_p->Attr("body"), frame->get(), d); - Array kwargs_keys; - Array kwargs_values; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (!realize) { kwargs_keys.push_back("no_realize"); kwargs_values.push_back(LiteralDoc::Boolean(true, std::nullopt)); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 0e7ae3a843cf..4057b1d09bfc 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -24,13 +24,14 @@ namespace tvm { namespace script { namespace printer { -Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, const Frame& frame, - const IRDocsifier& d, BufferVarDefinition var_definitions) { +ffi::Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, + const Frame& frame, const IRDocsifier& d, + BufferVarDefinition var_definitions) { using tvm::tir::Var; using tvm::tir::VarNode; - Map kwargs; - Array var_def_lhs; - Array var_def_rhs; + ffi::Map kwargs; + ffi::Array var_def_lhs; + ffi::Array var_def_rhs; // Step 0. Set up statistics std::unordered_map use_count; @@ -73,10 +74,10 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, }; // Step 1. Handle `buffer.shape` { - const Array& shape = buffer->shape; + const ffi::Array& shape = buffer->shape; AccessPath shape_p = buffer_p->Attr("shape"); int n = shape.size(); - Array results; + ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = shape[i]; @@ -108,10 +109,10 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, } // Step 4. Handle `buffer.strides` if (!buffer->strides.empty()) { - const Array& strides = buffer->strides; + const ffi::Array& strides = buffer->strides; AccessPath strides_p = buffer_p->Attr("strides"); int n = strides.size(); - Array results; + ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = strides[i]; @@ -148,7 +149,7 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, } // Step 6. Handle `buffer.scope` { - String scope = buffer.scope(); + ffi::String scope = buffer.scope(); if (scope != "global") { kwargs.Set( "scope", @@ -182,17 +183,18 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, return kwargs; } -ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Array args) { - Array kwargs_keys; - Array kwargs_values; - for (String s : {"shape", "dtype"}) { - if (Optional doc = attrs.Get(s)) { +ExprDoc BufferCall(const ExprDoc& prefix, const ffi::Map& attrs, + ffi::Array args) { + ffi::Array kwargs_keys; + ffi::Array kwargs_values; + for (ffi::String s : {"shape", "dtype"}) { + if (ffi::Optional doc = attrs.Get(s)) { args.push_back(doc.value()); } } - for (String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", - "buffer_type", "axis_separators"}) { - if (Optional doc = attrs.Get(s)) { + for (ffi::String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", + "buffer_type", "axis_separators"}) { + if (ffi::Optional doc = attrs.Get(s)) { kwargs_keys.push_back(s); kwargs_values.push_back(doc.value()); } @@ -200,9 +202,9 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Arr return prefix->Call(args, kwargs_keys, kwargs_values); } -ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, - const AccessPath& p, const Frame& frame, const IRDocsifier& d, - BufferVarDefinition var_definitions) { +ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, + const ffi::Array& args, const AccessPath& p, const Frame& frame, + const IRDocsifier& d, BufferVarDefinition var_definitions) { return BufferCall(/*prefix=*/TIR(d, method), /*attrs=*/BufferAttrs(buffer, p, frame, d, var_definitions), /*args=*/args); @@ -210,17 +212,18 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< ExprDoc BufferAttn(const tir::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d) { - Map attrs = BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); + ffi::Map attrs = + BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); ExprDoc shape = attrs.Get("shape").value(); ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); return TIR(d, "Buffer")->Call({shape, dtype}, {}, {}); } -Array BufferIndices(const Array& indices, const AccessPath& p, - const IRDocsifier& d) { +ffi::Array BufferIndices(const ffi::Array& indices, const AccessPath& p, + const IRDocsifier& d) { int n = indices.size(); - Array indices_doc; + ffi::Array indices_doc; indices_doc.reserve(n); for (int i = 0; i < n; ++i) { if (const auto* ramp = indices[i].as()) { @@ -231,7 +234,7 @@ Array BufferIndices(const Array& indices, const AccessPath& p, ramp_p->Attr("base")); ExprDoc stop = d->AsDoc(ramp->base + ramp->lanes * ramp->stride, // ramp_p->Attr("lanes")); - Optional step = std::nullopt; + ffi::Optional step = std::nullopt; if (stride->value != 1) { step = d->AsDoc(ramp->stride, ramp_p->Attr("stride")); } @@ -244,9 +247,10 @@ Array BufferIndices(const Array& indices, const AccessPath& p, return indices_doc; } -Array BufferSlices(const Array& region, const AccessPath& p, const IRDocsifier& d) { +ffi::Array BufferSlices(const ffi::Array& region, const AccessPath& p, + const IRDocsifier& d) { int n = region.size(); - Array indices; + ffi::Array indices; indices.reserve(n); for (int i = 0; i < n; ++i) { Range range = region[i]; @@ -306,14 +310,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // .set_dispatch("", [](tir::Buffer buffer, AccessPath p, IRDocsifier d) -> Doc { if (!d->IsVarDefined(buffer)) { - if (Optional opt_f = FindLowestVarDef(buffer, d)) { + if (ffi::Optional opt_f = FindLowestVarDef(buffer, d)) { ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); ExprDoc rhs = BufferDecl(buffer, "Buffer", {}, p, opt_f.value(), d, BufferVarDefinition::DataPointer); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } } - if (Optional doc = d->GetVarDoc(buffer)) { + if (ffi::Optional doc = d->GetVarDoc(buffer)) { return doc.value(); } LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer; diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 78b52edf859c..e05b30753bf3 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -27,9 +27,9 @@ namespace printer { ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { Type type = var->type_annotation; AccessPath type_p = var_p->Attr("type_annotation"); - ExprDoc rhs{nullptr}; - Array kwargs_keys; - Array kwargs_values; + ExprDoc rhs{ffi::UnsafeInit()}; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (var->IsInstance()) { kwargs_keys.push_back("is_size_var"); @@ -66,7 +66,7 @@ ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRD Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { - if (Optional opt_f = FindLowestVarDef(var, d)) { + if (ffi::Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); ExprDoc rhs = PrintVarCreation(var, var_p, d); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); @@ -74,7 +74,7 @@ Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint; } } - if (Optional doc = d->GetVarDoc(var)) { + if (ffi::Optional doc = d->GetVarDoc(var)) { return doc.value(); } LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var->name_hint; @@ -169,11 +169,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { ICHECK_EQ(r->lhs.size(), r->rhs.size()); - LambdaDoc lambda{nullptr}; + ffi::Optional lambda; { With f(d, r); int n_vars = r->lhs.size(); - Array vars; + ffi::Array vars; vars.reserve(n_vars + n_vars); for (int i = 0; i < n_vars; ++i) { vars.push_back(Downcast(DefineVar(r->lhs[i], *f, d))); @@ -182,7 +182,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) vars.push_back(Downcast(DefineVar(r->rhs[i], *f, d))); } int n_results = r->result.size(); - Array results; + ffi::Array results; results.reserve(n_results); for (int i = 0; i < n_results; ++i) { results.push_back(d->AsDoc(r->result[i], p->Attr("result")->ArrayItem(i))); @@ -194,17 +194,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } ExprDoc id = d->AsDoc(r->identity_element, p->Attr("identity_element")); - return TIR(d, "comm_reducer")->Call({lambda, id}); + return TIR(d, "comm_reducer")->Call({lambda.value(), id}); }); -LambdaDoc PrintIndexMap(const ObjectRef& map, const Array& vs, const AccessPath& vs_p, - const Array& es, const AccessPath& es_p, const IRDocsifier& d) { +LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array& vs, + const AccessPath& vs_p, const ffi::Array& es, + const AccessPath& es_p, const IRDocsifier& d) { With f(d, map); - Array vars; + ffi::Array vars; for (int i = 0, l = vs.size(); i < l; ++i) { vars.push_back(Downcast(DefineVar(vs[i], *f, d))); } - Array exprs; + ffi::Array exprs; for (int i = 0, l = es.size(); i < l; ++i) { exprs.push_back(d->AsDoc(es[i], es_p->ArrayItem(i))); } @@ -243,10 +244,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) static const OpAttrMap dtype_locations = Op::GetAttrMap("TScriptDtypePrintLocation"); tir::ScriptDtypePrintLocation dtype_print_location = tir::ScriptDtypePrintLocation::kNone; - ExprDoc prefix{nullptr}; + ffi::Optional prefix; if (auto optional_op = call->op.as()) { auto op = optional_op.value(); - String name = op_names.get(op, op->name); + ffi::String name = op_names.get(op, op->name); if (op_names.count(op) == 0) { LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } @@ -261,7 +262,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) auto f_llvm_lookup_intrinsic_name = tvm::ffi::Function::GetGlobal("target.llvm_get_intrinsic_name"); - Array args; + ffi::Array args; args.reserve(n_args + 1); if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); @@ -269,7 +270,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0; i < n_args; ++i) { if ((i == 0) && (f_llvm_lookup_intrinsic_name)) { - String name = (*f_llvm_lookup_intrinsic_name)(id).cast(); + ffi::String name = (*f_llvm_lookup_intrinsic_name)(id).cast(); args.push_back(LiteralDoc::Str(name.c_str(), call_p->Attr("args")->ArrayItem(i))); } else { args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayItem(i))); @@ -278,14 +279,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } - return prefix->Call(args); + ffi::Array kwargs_keys; + ffi::Array kwargs_values; + for (const auto& kv : call->annotations) { + kwargs_keys.push_back(kv.first); + kwargs_values.push_back( + d->AsDoc(kv.second, call_p->Attr("annotations")->Attr(kv.first))); + } + return prefix.value()->Call(args, kwargs_keys, kwargs_values); } } else if (call->op.as()) { prefix = d->AsDoc(call->op, call_p->Attr("op")); } else { LOG(FATAL) << "call: " << call; } - Array args; + ffi::Array args; int n_args = call->args.size(); args.reserve(n_args + 1); if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { @@ -298,7 +306,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } - return prefix->Call(args); + ffi::Array kwargs_keys; + ffi::Array kwargs_values; + for (const auto& kv : call->annotations) { + kwargs_keys.push_back(kv.first); + kwargs_values.push_back( + d->AsDoc(kv.second, call_p->Attr("annotations")->Attr(kv.first))); + } + return prefix.value()->Call(args, kwargs_keys, kwargs_values); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index bfdae3b14221..b2e091f38019 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -39,7 +39,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (l->kind != tir::ForKind::kSerial || // !tir::is_zero(l->min) || // !l->annotations.empty() || // - f_var_dep(l->extent)) { + !l->HasTrivialStep() || f_var_dep(l->extent)) { break; } grid.push_back(l); @@ -50,8 +50,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 2. Construct `T.grid` if (grid.size() > 1) { int n = grid.size(); - Array lhs; - Array rhs; + ffi::Array lhs; + ffi::Array rhs; lhs.reserve(n); rhs.reserve(n); for (int i = 0; i < n; ++i) { @@ -65,11 +65,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 3. If not `T.grid`, print loop kind accordingly ExprDoc lhs = DefineVar(loop->loop_var, *f, d); - Optional min = std::nullopt; - Optional max = std::nullopt; - Optional annotations = std::nullopt; - Optional thread = std::nullopt; - if (tir::is_zero(loop->min)) { + ffi::Optional min = std::nullopt; + ffi::Optional max = std::nullopt; + ffi::Optional annotations = std::nullopt; + ffi::Optional thread = std::nullopt; + if (tir::is_zero(loop->min) && loop->HasTrivialStep()) { max = d->AsDoc(loop->extent, loop_p->Attr("extent")); } else { min = d->AsDoc(loop->min, loop_p->Attr("min")); @@ -78,10 +78,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (!loop->annotations.empty()) { annotations = d->AsDoc(loop->annotations, loop_p->Attr("annotations")); } - ExprDoc prefix{nullptr}; + bool use_range_sugar = false; + ExprDoc prefix{ffi::UnsafeInit()}; if (loop->kind == tir::ForKind::kSerial) { if (loop->annotations.empty()) { prefix = IdDoc("range"); + use_range_sugar = true; } else { prefix = TIR(d, "serial"); } @@ -98,9 +100,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } else { LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind); } - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (min.defined()) { args.push_back(min.value()); } @@ -115,6 +117,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_keys.push_back("annotations"); kwargs_values.push_back(annotations.value()); } + if (!loop->HasTrivialStep()) { + ExprDoc step = d->AsDoc(*loop->step, loop_p->Attr("step")); + if (use_range_sugar) { + args.push_back(step); + } else { + kwargs_keys.push_back("step"); + kwargs_values.push_back(step); + } + } ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values); AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d); return ForDoc(lhs, rhs, (*f)->stmts); diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 688c58e6de09..c5083b57c2d0 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ++buffer_data_counter.at(data_var); } // Step 1. Handle `func->params` - Array args; + ffi::Array args; args.reserve(n_args); std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { @@ -107,8 +107,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (func->attrs.defined() && !func->attrs->dict.empty()) { // for global symbol, don't display it if it matches the func name if (func->attrs->dict.count(tvm::attr::kGlobalSymbol) && - Downcast(func->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { - Map new_attrs; + Downcast(func->attrs->dict.at(tvm::attr::kGlobalSymbol)) == + func_name->name) { + ffi::Map new_attrs; for (auto kv : func->attrs->dict) { if (kv.first != tvm::attr::kGlobalSymbol) { new_attrs.Set(kv.first, kv.second); @@ -142,7 +143,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // Step 4. Handle `func->body` - Optional implicit_root_block = [&]() -> Optional { + ffi::Optional implicit_root_block = [&]() -> ffi::Optional { const tir::BlockRealizeNode* root_block_realize = func->body.as(); if (root_block_realize && !root_block_realize->iter_values.size() && tir::is_one(root_block_realize->predicate)) { @@ -178,7 +179,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } else { AsDocBody(func->body, p->Attr("body"), f->get(), d); } - Optional ret_type = std::nullopt; + ffi::Optional ret_type = std::nullopt; if (func->ret_type.defined()) { const auto* as_tuple = func->ret_type.as(); if (!as_tuple || as_tuple->fields.size()) { @@ -189,9 +190,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc decorator = TIR(d, "prim_func"); // mark private if there is no global symbol if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { - Array pos_args; + ffi::Array pos_args; decorator = decorator->Call(pos_args, {"private"}, - {LiteralDoc::Boolean(true, Optional())}); + {LiteralDoc::Boolean(true, ffi::Optional())}); } return HeaderWrapper(d, FunctionDoc( @@ -207,7 +208,7 @@ TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "tir", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // - if (Optional doc = d->GetVarDoc(n)) { + if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { IdDoc ret(n->name_hint); @@ -219,7 +220,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "tir", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // - Optional doc = d->GetVarDoc(mod); + ffi::Optional doc = d->GetVarDoc(mod); ICHECK(doc) << "Unable to print IRModule before definition in TIR."; return doc.value(); }); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index a99d4236158f..431dc7dcc3e5 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -24,7 +24,7 @@ namespace tvm { namespace script { namespace printer { -TVM_FFI_STATIC_INIT_BLOCK({ TIRFrameNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TIRFrameNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](IntImm imm, AccessPath imm_p, IRDocsifier d) -> Doc { @@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](PointerType ty, AccessPath ty_p, IRDocsifier d) -> Doc { - ExprDoc element_type{nullptr}; + ExprDoc element_type{ffi::UnsafeInit()}; if (const auto* prim_type = ty->element_type.as()) { element_type = LiteralDoc::DataType(prim_type->dtype, // ty_p->Attr("element_type")->Attr("dtype")); @@ -91,7 +91,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Target target, AccessPath p, IRDocsifier d) -> Doc { - Map config = target->Export(); + ffi::Map config = target->Export(); return TIR(d, "target")->Call({d->AsDoc(config, p)}); }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 5a52de1849f1..1b0774be3686 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -23,8 +23,8 @@ namespace tvm { namespace script { namespace printer { -Doc DoConciseScoping(const Optional& lhs, const ExprDoc& rhs, Array* stmts, - bool concise_scoping) { +Doc DoConciseScoping(const ffi::Optional& lhs, const ExprDoc& rhs, + ffi::Array* stmts, bool concise_scoping) { if (concise_scoping) { if (lhs.defined()) { stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, std::nullopt)); @@ -64,7 +64,7 @@ bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IR return false; } -Optional FindReturnValue(const tir::Stmt& node) { +ffi::Optional FindReturnValue(const tir::Stmt& node) { auto eval = node.as(); if (!eval) return std::nullopt; @@ -99,8 +99,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::LetStmt stmt, AccessPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); // Step 1. Type annotation - Optional type_doc = d->AsDoc(stmt->var->type_annotation, // - p->Attr("var")->Attr("type_annotation")); + ffi::Optional type_doc = d->AsDoc(stmt->var->type_annotation, // + p->Attr("var")->Attr("type_annotation")); if (const auto* tuple_type = stmt->var->type_annotation.as()) { if (tuple_type->fields.empty()) { type_doc = std::nullopt; @@ -110,7 +110,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); // Step 3. LHS and body With f(d, stmt); - Array* stmts = &(*f)->stmts; + ffi::Array* stmts = &(*f)->stmts; bool var_defined = d->IsVarDefined(stmt->var); if (!var_defined) { DefineVar(stmt->var, *f, d); @@ -139,7 +139,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); if (concise) { - Array* stmts = &(*f)->stmts; + ffi::Array* stmts = &(*f)->stmts; stmts->insert(stmts->begin(), AssertDoc(cond, msg)); return StmtBlockDoc(*stmts); } @@ -177,8 +177,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); - Array then_branch; - Array else_branch; + ffi::Array then_branch; + ffi::Array else_branch; if (stmt->then_case.defined()) { With f(d, stmt->then_case); AsDocBody(stmt->then_case, p->Attr("then_case"), f->get(), d); @@ -226,9 +226,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DeclBufferDoc(Downcast(stmt->body), stmt_p->Attr("body"), d, BufferVarDefinition::DataPointer); } - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; args.push_back(d->AsDoc(stmt->extents, stmt_p->Attr("extents"))); args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype"))); args.push_back(LiteralDoc::Str(tir::GetPtrStorageScope(stmt->buffer_var), @@ -252,7 +252,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); template -ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { +ExprDoc PrintTensor(::tvm::runtime::Tensor arr) { // FIXME(@junrushao): this is a hack and can be wrong in most of the cases constexpr int NUM_PRINT = 200; int ndim = arr->ndim; @@ -260,7 +260,7 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { for (int i = 0; i < ndim; i++) { tot_dim *= arr->shape[i]; } - Array result; + ffi::Array result; T* data_ptr = reinterpret_cast(arr->data); runtime::DataType dtype = arr.DataType(); for (int i = 0; i < tot_dim; i++) { @@ -280,42 +280,42 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](tir::AllocateConst stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); - String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); - Array args; - Array kwargs_keys; - Array kwargs_values; - ExprDoc data_doc{nullptr}; + ffi::String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; + ExprDoc data_doc{ffi::UnsafeInit()}; if (stmt->dtype.is_int()) { if (stmt->dtype.bits() == 8) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 16) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 32) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 64) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else { LOG(FATAL) << "DataType not supported"; } } else if (stmt->dtype.is_uint()) { if (stmt->dtype.bits() == 8) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 16) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 32) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 64) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else { LOG(FATAL) << "DataType not supported"; } } else if (stmt->dtype.is_float()) { if (stmt->dtype.bits() == 16) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 32) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else if (stmt->dtype.bits() == 64) { - data_doc = PrintNDArray(stmt->data.value()); + data_doc = PrintTensor(stmt->data.value()); } else { LOG(FATAL) << "DataType not supported"; } @@ -332,11 +332,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); }); -ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional value, // +ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, ffi::Optional value, // AccessPath p, IRDocsifier d) { ExprDoc buffer = d->AsDoc(stmt->buffer, p->Attr("buffer")); { - Array bounds; + ffi::Array bounds; bounds.reserve(stmt->bounds.size()); for (int i = 0, n = stmt->bounds.size(); i < n; ++i) { Range range = stmt->bounds[i]; @@ -348,9 +348,9 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional args{buffer}; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args{buffer}; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (value.defined()) { args.push_back(value.value()); } @@ -373,11 +373,11 @@ void InsertEnvThread(const tir::IterVar& iter_var, const AccessPath& iter_var_p, } ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p, - Optional* define_var, const IRDocsifier& d) { + ffi::Optional* define_var, const IRDocsifier& d) { tir::IterVar iter_var = Downcast(attr_stmt->node); AccessPath iter_var_p = attr_stmt_p->Attr("node"); - ExprDoc var_doc{nullptr}; + ExprDoc var_doc{ffi::UnsafeInit()}; if (d->IsVarDefined(iter_var->var)) { var_doc = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); } else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) { @@ -408,9 +408,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); - Optional lhs = std::nullopt; - Optional rhs = std::nullopt; - Optional define_var = std::nullopt; + ffi::Optional lhs = std::nullopt; + ffi::Optional rhs = std::nullopt; + ffi::Optional define_var = std::nullopt; tir::Stmt body = stmt->body; AccessPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 4474a83ca8ff..8cb5636d1516 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -55,9 +55,7 @@ class TIRFrameNode : public FrameNode { .def_ro("tir", &TIRFrameNode::tir) .def_ro("allow_concise_scoping", &TIRFrameNode::allow_concise_scoping); } - - static constexpr const char* _type_key = "script.printer.TIRFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(TIRFrameNode, FrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.TIRFrame", TIRFrameNode, FrameNode); }; /*! \brief Managed reference to TIRFrameNode */ @@ -65,14 +63,14 @@ class TIRFrame : public Frame { public: /*! \brief Constructor */ explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tir) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmts.clear(); n->d = d.get(); n->tir = tir; data_ = std::move(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TIRFrame, Frame, TIRFrameNode); }; /*! @@ -84,7 +82,7 @@ class TIRFrame : public Frame { * \return The IdDoc corresponding to the variable */ inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { - if (Optional doc = d->GetVarDoc(var)) { + if (ffi::Optional doc = d->GetVarDoc(var)) { return doc.value(); } return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint); @@ -111,7 +109,7 @@ inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const I */ inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { - Array body = seq_stmt->seq; + ffi::Array body = seq_stmt->seq; for (int i = 0, n = body.size(); i < n; ++i) { f->allow_concise_scoping = (i == n - 1); Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); @@ -139,7 +137,7 @@ inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, cons * \param d The IRDocsifier * \return The frame that could place the var definition */ -inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) { +inline ffi::Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) { if (!d->common_prefix.count(var.get())) { return std::nullopt; } @@ -159,11 +157,11 @@ inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& const std::vector& path = d->common_prefix.at(var.get()); for (auto it = path.rbegin(); it != path.rend(); ++it) { if (tir_to_frame.count(*it)) { - return GetRef(tir_to_frame.at(*it)); + return ffi::GetRef(tir_to_frame.at(*it)); } } if (fallback_frame != nullptr) { - return GetRef(fallback_frame); + return ffi::GetRef(fallback_frame); } return std::nullopt; } @@ -214,9 +212,9 @@ enum class BufferVarDefinition { * the buffer. * \return The ExprDoc corresponding to the buffer declaration */ -ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, - const AccessPath& p, const Frame& frame, const IRDocsifier& d, - BufferVarDefinition var_definitions); +ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, + const ffi::Array& args, const AccessPath& p, const Frame& frame, + const IRDocsifier& d, BufferVarDefinition var_definitions); /*! * \brief Declare and define a buffer as annotation diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 1e3a258579a2..8e9b9cdf1049 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -56,9 +56,9 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra if (!cfg->verbose_expr) { f->stmts.clear(); } - f->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); + f->stmts.push_back(ExprStmtDoc(ffi::GetRef(expr_doc))); } else if (const auto* stmt_doc = doc.as()) { - f->stmts.push_back(GetRef(stmt_doc)); + f->stmts.push_back(ffi::GetRef(stmt_doc)); } else if (const auto* stmt_block = doc.as()) { for (const StmtDoc& d : stmt_block->stmts) { f->stmts.push_back(d); @@ -72,8 +72,8 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra if (d->cfg->show_meta) { os << "metadata = tvm.ir.load_json(\"\"\"" << support::StrEscape( - SaveJSON(Map(d->metadata.begin(), d->metadata.end())), false, - false) + SaveJSON(ffi::Map(d->metadata.begin(), d->metadata.end())), + false, false) << "\"\"\")\n"; } else { f->stmts.push_back( @@ -91,19 +91,19 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra } /*! \brief Creates the IR common prefix, which is by default `I` */ -inline ExprDoc IR(const IRDocsifier& d, const String& attr) { +inline ExprDoc IR(const IRDocsifier& d, const ffi::String& attr) { d->ir_usage.insert("ir"); return IdDoc(d->cfg->ir_prefix)->Attr(attr); } /*! \brief Creates the TIR common prefix, which is by default `T` */ -inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { +inline ExprDoc TIR(const IRDocsifier& d, const ffi::String& attr) { d->ir_usage.insert("tir"); return IdDoc(d->cfg->tir_prefix)->Attr(attr); } /*! \brief Creates the Relax common prefix, which is by default `R` */ -inline ExprDoc Relax(const IRDocsifier& d, const String& attr) { +inline ExprDoc Relax(const IRDocsifier& d, const ffi::String& attr) { d->ir_usage.insert("relax"); return IdDoc(d->cfg->relax_prefix)->Attr(attr); } @@ -115,7 +115,7 @@ inline std::string DType2Str(const runtime::DataType& dtype) { /*! \brief Add headers as comments to doc if needed */ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) { if (d->ir_usage.size()) { - Array stmts; + ffi::Array stmts; if (d->ir_usage.count("ir")) { stmts.push_back(CommentDoc("from tvm.script import ir as " + d->cfg->ir_prefix)); } @@ -137,23 +137,23 @@ inline bool HasMultipleLines(const std::string& str) { return str.find_first_of('\n') != std::string::npos; } -inline Optional GetBindingName(const IRDocsifier& d) { - return d->cfg->binding_names.empty() ? Optional(std::nullopt) +inline ffi::Optional GetBindingName(const IRDocsifier& d) { + return d->cfg->binding_names.empty() ? ffi::Optional(std::nullopt) : d->cfg->binding_names.back(); } -inline Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f) { - if (Optional name = GetBindingName(d)) { +inline ffi::Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f) { + if (ffi::Optional name = GetBindingName(d)) { return name.value(); } - if (Optional sym = f->GetAttr(tvm::attr::kGlobalSymbol)) { + if (ffi::Optional sym = f->GetAttr(tvm::attr::kGlobalSymbol)) { return sym.value(); } return std::nullopt; } -inline String GenerateUniqueName(std::string name_hint, - const std::unordered_set& defined_names) { +inline ffi::String GenerateUniqueName(std::string name_hint, + const std::unordered_set& defined_names) { for (char& c : name_hint) { if (c != '_' && !std::isalnum(c)) { c = '_'; diff --git a/src/support/array.h b/src/support/array.h index f49439aeb3ff..6e2aeca3e11f 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -35,7 +35,7 @@ namespace support { * \return A boolean indicating if they are the same */ template -inline bool ArrayWithSameContent(const Array& a, const Array& b) { +inline bool ArrayWithSameContent(const ffi::Array& a, const ffi::Array& b) { if (a.size() != b.size()) { return false; } @@ -76,7 +76,7 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector * \return The result vector */ template -inline std::vector AsVector(const Array& vec); +inline std::vector AsVector(const ffi::Array& vec); /*! * \brief Convert a std::vector to tvm::Array @@ -85,7 +85,7 @@ inline std::vector AsVector(const Array& vec); * \return The result Array */ template -inline Array AsArray(const std::vector& vec); +inline ffi::Array AsArray(const std::vector& vec); /*! * \brief Convert a tvm::Array to std::list @@ -93,7 +93,7 @@ inline Array AsArray(const std::vector& vec); * \return The result list */ template -inline std::list AsList(const Array& array) { +inline std::list AsList(const ffi::Array& array) { std::list list; for (const auto& v : array) list.push_back(v); return list; @@ -105,8 +105,8 @@ inline std::list AsList(const Array& array) { * \return The result list */ template -inline Array AsArray(const std::list& list) { - Array array; +inline ffi::Array AsArray(const std::list& list) { + ffi::Array array; for (const auto& v : list) array.push_back(v); return array; } @@ -116,8 +116,8 @@ inline Array AsArray(const std::list& list) { * \param shape The shape tuple * \return An array of the shape tuple */ -inline Array AsArray(const ffi::Shape& shape) { - Array result; +inline ffi::Array AsArray(const ffi::Shape& shape) { + ffi::Array result; result.reserve(shape->size); for (ffi::Shape::index_type i : shape) { result.push_back(Integer(i)); @@ -134,12 +134,12 @@ inline Array AsArray(const ffi::Shape& shape) { * \return The concatenated array */ template -inline Array ConcatArrayList(Iterator begin, Iterator end) { +inline ffi::Array ConcatArrayList(Iterator begin, Iterator end) { int size = 0; for (Iterator it = begin; it != end; ++it) { size += (*it).size(); } - Array result; + ffi::Array result; result.reserve(size); for (Iterator it = begin; it != end; ++it) { const auto& item = *it; @@ -157,17 +157,17 @@ struct AsVectorImpl {}; template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const ffi::Array& vec) const { return std::vector(vec.begin(), vec.end()); } }; template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { + inline std::vector operator()(const ffi::Array& array) const { ffi::Any ret_value; ret_value = array; - Array as_int_vec = ret_value.cast>(); + ffi::Array as_int_vec = ret_value.cast>(); std::vector results; for (const auto& value : as_int_vec) { @@ -179,10 +179,10 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { + inline std::vector operator()(const ffi::Array& array) const { ffi::Any ret_value; ret_value = array; - Array as_int_vec = ret_value.cast>(); + ffi::Array as_int_vec = ret_value.cast>(); std::vector results; for (const auto& value : as_int_vec) { @@ -194,10 +194,10 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { + inline std::vector operator()(const ffi::Array& array) const { ffi::Any ret_value; ret_value = array; - Array as_int_vec = ret_value.cast>(); + ffi::Array as_int_vec = ret_value.cast>(); std::vector results; for (const auto& value : as_int_vec) { @@ -217,15 +217,15 @@ struct AsArrayImpl {}; template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - return Array(vec.begin(), vec.end()); + inline ffi::Array operator()(const std::vector& vec) const { + return ffi::Array(vec.begin(), vec.end()); } }; template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - Array result; + inline ffi::Array operator()(const std::vector& vec) const { + ffi::Array result; result.reserve(vec.size()); for (auto x : vec) { ffi::Any ret_value; @@ -238,8 +238,8 @@ struct AsArrayImpl { template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - Array result; + inline ffi::Array operator()(const std::vector& vec) const { + ffi::Array result; result.reserve(vec.size()); for (auto x : vec) { ffi::Any ret_value; @@ -252,8 +252,8 @@ struct AsArrayImpl { template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - Array result; + inline ffi::Array operator()(const std::vector& vec) const { + ffi::Array result; result.reserve(vec.size()); for (auto x : vec) { ffi::Any ret_value; @@ -267,12 +267,12 @@ struct AsArrayImpl { } // namespace details template -inline std::vector AsVector(const Array& vec) { +inline std::vector AsVector(const ffi::Array& vec) { return details::AsVectorImpl()(vec); } template -inline Array AsArray(const std::vector& vec) { +inline ffi::Array AsArray(const std::vector& vec) { return details::AsArrayImpl()(vec); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 70c23c546bbb..8875046874e4 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -37,8 +37,8 @@ namespace tvm { // Attrs used to python API struct TestAttrs : public AttrsNodeReflAdapter { int axis; - String name; - Array padding; + ffi::String name; + ffi::Array padding; TypedEnvFunc func; static void RegisterReflection() { @@ -47,18 +47,16 @@ struct TestAttrs : public AttrsNodeReflAdapter { .def_ro("axis", &TestAttrs::axis, "axis field", refl::DefaultValue(10)) .def_ro("name", &TestAttrs::name, "name") .def_ro("padding", &TestAttrs::padding, "padding of input", - refl::DefaultValue(Array({0, 0}))) + refl::DefaultValue(ffi::Array({0, 0}))) .def_ro("func", &TestAttrs::func, "some random env function", refl::DefaultValue(TypedEnvFunc(nullptr))); } - - static constexpr const char* _type_key = "attrs.TestAttrs"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestAttrs, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("attrs.TestAttrs", TestAttrs, BaseAttrsNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ TestAttrs::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TestAttrs::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("testing.GetShapeSize", @@ -106,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ "if the python module is properly loaded"; *ret = (*identity_func)(args[0]); }); -}); +} // in src/api_test.cc void ErrorTest(int x, int y) { @@ -118,10 +116,10 @@ void ErrorTest(int x, int y) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("testing.ErrorTest", ErrorTest); -}); +} class FrontendTestModuleNode : public ffi::ModuleObj { public: @@ -129,7 +127,7 @@ class FrontendTestModuleNode : public ffi::ModuleObj { static constexpr const char* kAddFunctionName = "__add_function"; - virtual ffi::Optional GetFunction(const String& name); + virtual ffi::Optional GetFunction(const ffi::String& name); private: std::unordered_map functions_; @@ -137,8 +135,8 @@ class FrontendTestModuleNode : public ffi::ModuleObj { constexpr const char* FrontendTestModuleNode::kAddFunctionName; -ffi::Optional FrontendTestModuleNode::GetFunction(const String& name) { - ffi::Module self_strong_ref = GetRef(this); +ffi::Optional FrontendTestModuleNode::GetFunction(const ffi::String& name) { + ffi::Module self_strong_ref = ffi::GetRef(this); if (name == kAddFunctionName) { return ffi::Function::FromTyped( [this, self_strong_ref](std::string func_name, ffi::Function pf) { @@ -157,11 +155,11 @@ ffi::Optional FrontendTestModuleNode::GetFunction(const String& n } ffi::Module NewFrontendTestModule() { - auto n = make_object(); + auto n = ffi::make_object(); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("testing.FrontendTestModule", NewFrontendTestModule) @@ -172,16 +170,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ std::this_thread::sleep_for(duration); }) .def("testing.ReturnsVariant", - [](int x) -> Variant { + [](int x) -> ffi::Variant { if (x % 2 == 0) { return IntImm(DataType::Int(64), x / 2); } else { - return String("argument was odd"); + return ffi::String("argument was odd"); } }) .def("testing.AcceptsVariant", - [](Variant arg) -> String { - if (auto opt_str = arg.as()) { + [](ffi::Variant arg) -> ffi::String { + if (auto opt_str = arg.as()) { return ffi::StaticTypeKey::kTVMFFIStr; } else { return arg.get().GetTypeKey(); @@ -189,13 +187,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("testing.AcceptsBool", [](bool arg) -> bool { return arg; }) .def("testing.AcceptsInt", [](int arg) -> int { return arg; }) - .def("testing.AcceptsObjectRefArray", [](Array arg) -> Any { return arg[0]; }) + .def("testing.AcceptsObjectRefArray", [](ffi::Array arg) -> Any { return arg[0]; }) .def("testing.AcceptsMapReturnsValue", - [](Map map, Any key) -> Any { return map[key]; }) - .def("testing.AcceptsMapReturnsMap", [](Map map) -> ObjectRef { return map; }) + [](ffi::Map map, Any key) -> Any { return map[key]; }) + .def("testing.AcceptsMapReturnsMap", [](ffi::Map map) -> ObjectRef { return map; }) .def("testing.AcceptsPrimExpr", [](PrimExpr expr) -> ObjectRef { return expr; }) .def("testing.AcceptsArrayOfPrimExpr", - [](Array arr) -> ObjectRef { + [](ffi::Array arr) -> ObjectRef { for (ObjectRef item : arr) { CHECK(item->IsInstance()) << "Array contained " << item->GetTypeKey() << " when it should contain PrimExpr"; @@ -203,14 +201,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ return arr; }) .def("testing.AcceptsArrayOfVariant", - [](Array> arr) -> ObjectRef { + [](ffi::Array> arr) -> ObjectRef { for (auto item : arr) { CHECK(item.as() || item.as()) << "Array should contain either PrimExpr or ffi::Function"; } return arr; }) - .def("testing.AcceptsMapOfPrimExpr", [](Map map) -> ObjectRef { + .def("testing.AcceptsMapOfPrimExpr", [](ffi::Map map) -> ObjectRef { for (const auto& kv : map) { ObjectRef value = kv.second; CHECK(value->IsInstance()) @@ -218,7 +216,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } return map; }); -}); +} /** * Simple event logger that can be used for testing purposes @@ -226,7 +224,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ class TestingEventLogger { public: struct Entry { - String event; + ffi::String event; double time_us; }; @@ -235,7 +233,7 @@ class TestingEventLogger { start_ = std::chrono::high_resolution_clock::now(); } - void Record(String event) { + void Record(ffi::String event) { auto tend = std::chrono::high_resolution_clock::now(); double time_us = static_cast((tend - start_).count()) / 1e3; entries_.emplace_back(Entry{event, time_us}); @@ -259,13 +257,13 @@ class TestingEventLogger { std::vector entries_; }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("testing.record_event", [](ffi::PackedArgs args, ffi::Any* rv) { - if (args.size() != 0 && args[0].try_cast()) { - TestingEventLogger::ThreadLocal()->Record(args[0].cast()); + if (args.size() != 0 && args[0].try_cast()) { + TestingEventLogger::ThreadLocal()->Record(args[0].cast()); } else { TestingEventLogger::ThreadLocal()->Record("X"); } @@ -274,5 +272,5 @@ TVM_FFI_STATIC_INIT_BLOCK({ "testing.reset_events", [](ffi::PackedArgs args, ffi::Any* rv) { TestingEventLogger::ThreadLocal()->Reset(); }) .def("testing.dump_events", []() { TestingEventLogger::ThreadLocal()->Dump(); }); -}); +} } // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index c35ef140547a..d0646ee8b06f 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -274,7 +274,6 @@ TVM_DLL ffi::Map GetLibInfo() { {"BUILD_DUMMY_LIBTVM", TVM_INFO_BUILD_DUMMY_LIBTVM}, {"COMPILER_RT_PATH", TVM_INFO_COMPILER_RT_PATH}, {"CUDA_VERSION", TVM_INFO_CUDA_VERSION}, - {"DLPACK_PATH", TVM_INFO_DLPACK_PATH}, {"DMLC_PATH", TVM_INFO_DMLC_PATH}, {"GIT_COMMIT_HASH", TVM_INFO_GIT_COMMIT_HASH}, {"GIT_COMMIT_TIME", TVM_INFO_GIT_COMMIT_TIME}, @@ -340,6 +339,7 @@ TVM_DLL ffi::Map GetLibInfo() { {"USE_ROCM", TVM_INFO_USE_ROCM}, {"USE_RCCL", TVM_INFO_USE_RCCL}, {"USE_RPC", TVM_INFO_USE_RPC}, + {"TVM_BUILD_PYTHON_MODULE", TVM_INFO_TVM_BUILD_PYTHON_MODULE}, {"USE_RTTI", TVM_INFO_USE_RTTI}, {"USE_RUST_EXT", TVM_INFO_USE_RUST_EXT}, {"USE_SORT", TVM_INFO_USE_SORT}, @@ -366,9 +366,9 @@ TVM_DLL ffi::Map GetLibInfo() { return result; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("support.GetLibInfo", GetLibInfo); -}); +} } // namespace tvm diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index ae4a0386d404..f63aaf92faca 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -50,7 +50,7 @@ inline NDIntSet NDIntSetFromRegion(const tir::Region& region) { * \param shape The shape which is an array of the length of each dimension. * \return The constructed set. */ -inline NDIntSet NDIntSetFromShape(const Array& shape) { +inline NDIntSet NDIntSetFromShape(const ffi::Array& shape) { PrimExpr zero = Integer(0); NDIntSet result; result.reserve(shape.size()); @@ -65,7 +65,7 @@ inline NDIntSet NDIntSetFromShape(const Array& shape) { * \param indices The N-dimensional indices representing the point. * \return The constructed set. */ -inline NDIntSet NDIntSetFromPoint(const Array& indices) { +inline NDIntSet NDIntSetFromPoint(const ffi::Array& indices) { NDIntSet result; result.reserve(indices.size()); for (const PrimExpr& index : indices) { @@ -106,7 +106,7 @@ inline NDIntSet NDIntSetUnion(const std::vector& nd_int_sets) { } NDIntSet result; result.reserve(ndim); - Array int_sets(n, arith::IntSet{nullptr}); + ffi::Array int_sets(n, arith::IntSet{nullptr}); for (int dim = 0; dim < ndim; ++dim) { for (int i = 0; i < n; ++i) { int_sets.Set(i, nd_int_sets[i][dim]); diff --git a/src/support/scalars.cc b/src/support/scalars.cc index b2581ecb3c99..692746852694 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -19,7 +19,7 @@ /*! * \file src/support/scalars.cc - * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms. + * \brief Helpers for converting between scalars in native, text, TIR immediate and Tensor forms. */ #include "./scalars.h" @@ -38,9 +38,9 @@ static const DataType kFloat32 = DataType::Float(32); static const DataType kFloat64 = DataType::Float(64); static const DataType kBool = DataType::Bool(); -runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { +runtime::Tensor IntImmToTensor(const IntImm& int_imm) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev); + auto data = runtime::Tensor::Empty({}, int_imm->dtype, dev); if (int_imm.dtype() == kInt16) { auto* array = reinterpret_cast(data->data); array[0] = static_cast(int_imm->value); @@ -56,9 +56,9 @@ runtime::NDArray IntImmToNDArray(const IntImm& int_imm) { return data; } -runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) { +runtime::Tensor FloatImmToTensor(const FloatImm& float_imm) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev); + auto data = runtime::Tensor::Empty({}, float_imm->dtype, dev); if (float_imm.dtype() == kFloat16) { auto* array = reinterpret_cast(data->data); array[0] = __gnu_f2h_ieee(static_cast(float_imm->value)); @@ -74,15 +74,15 @@ runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) { return data; } -runtime::NDArray BoolToNDArray(bool value) { +runtime::Tensor BoolToTensor(bool value) { DLDevice dev = {DLDeviceType::kDLCPU, 0}; - auto data = runtime::NDArray::Empty({}, kBool, dev); + auto data = runtime::Tensor::Empty({}, kBool, dev); auto array = reinterpret_cast(data->data); array[0] = value; return data; } -std::string NDArrayScalarToString(const runtime::NDArray& data) { +std::string TensorScalarToString(const runtime::Tensor& data) { std::ostringstream os; DataType dtype(data->dtype); ICHECK_EQ(data->device.device_type, kDLCPU) << "Scalars must reside on the CPU to be printed"; @@ -108,7 +108,7 @@ std::string NDArrayScalarToString(const runtime::NDArray& data) { auto value = static_cast(data->data)[0]; os << (value ? "True" : "False"); } else { - LOG(FATAL) << "Unrecognized NDArray scalar dtype: " << DLDataTypeToString(dtype); + LOG(FATAL) << "Unrecognized Tensor scalar dtype: " << DLDataTypeToString(dtype); } return os.str(); } diff --git a/src/support/scalars.h b/src/support/scalars.h index d9f2d7c54316..fa5a3482f5f6 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -19,7 +19,7 @@ /*! * \file src/support/scalars.h - * \brief Helpers for converting between scalars in native, text, TIR immediate and NDArray forms. + * \brief Helpers for converting between scalars in native, text, TIR immediate and Tensor forms. */ #ifndef TVM_SUPPORT_SCALARS_H_ @@ -28,18 +28,18 @@ #include #include "tvm/ir/expr.h" -#include "tvm/runtime/ndarray.h" +#include "tvm/runtime/tensor.h" namespace tvm { namespace support { -/*! \brief Returns NDArray 'scalar' for given TIR immediate. */ -runtime::NDArray IntImmToNDArray(const IntImm& int_imm); -runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm); -runtime::NDArray BoolToNDArray(bool value); +/*! \brief Returns Tensor 'scalar' for given TIR immediate. */ +runtime::Tensor IntImmToTensor(const IntImm& int_imm); +runtime::Tensor FloatImmToTensor(const FloatImm& float_imm); +runtime::Tensor BoolToTensor(bool value); -/*! \brief Returns literal text for NDArray 'scalar'. */ -std::string NDArrayScalarToString(const runtime::NDArray& data); +/*! \brief Returns literal text for Tensor 'scalar'. */ +std::string TensorScalarToString(const runtime::Tensor& data); /*! \brief Returns literal text for given TIR immediate. */ std::string IntImmToString(const IntImm& int_imm); diff --git a/src/target/build_common.h b/src/target/build_common.h index 9e52f6f8ffa6..cf1e3344fc3c 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -60,12 +60,12 @@ inline std::unordered_map ExtractFuncInfo(co ? runtime::FunctionInfo::ArgExtraTags::kTensorMap : runtime::FunctionInfo::ArgExtraTags::kNone); } - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto& tag : opt.value()) { info.launch_param_tags.push_back(tag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); if (global_symbol) { fmap[static_cast(global_symbol.value())] = info; } diff --git a/src/target/codegen.cc b/src/target/codegen.cc index bd45ce32e053..30238318ffed 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -341,18 +341,18 @@ ffi::Module PackImportsToLLVM(const ffi::Module& mod, bool system_lib, .cast(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.Build", Build); -}); +} // Export a few auxiliary function to the runtime namespace. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.ModuleImportsBlobName", []() -> std::string { return ffi::symbol::tvm_ffi_library_bin; }) - .def("runtime.ModulePackImportsToNDArray", + .def("runtime.ModulePackImportsToTensor", [](const ffi::Module& mod) { std::string buffer = PackImportsToBytes(mod); ffi::Shape::index_type size = buffer.size(); @@ -363,13 +363,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ DLDevice dev; dev.device_type = kDLCPU; dev.device_id = 0; - auto array = runtime::NDArray::Empty({size}, uchar, dev); + auto array = runtime::Tensor::Empty({size}, uchar, dev); array.CopyFromBytes(buffer.data(), size); return array; }) .def("runtime.ModulePackImportsToC", PackImportsToC) .def("runtime.ModulePackImportsToLLVM", PackImportsToLLVM); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 4a0d5777252e..6b166d89db21 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -28,7 +28,7 @@ namespace datatype { using ffi::Any; using ffi::PackedArgs; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("dtype.register_custom_type", @@ -47,7 +47,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("runtime._datatype_get_type_registered", [](ffi::PackedArgs args, ffi::Any* ret) { *ret = Registry::Global()->GetTypeRegistered(args[0].cast()); }); -}); + // Register tfloat32 as a custom datatype with type code 130 + Registry::Global()->Register("tfloat32", 130); +} Registry* Registry::Global() { static Registry inst; diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 3103e6f5b9c3..3fc1a83d6fd5 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -34,6 +34,9 @@ using tir::FLowerIntrinsic; TVM_REGISTER_OP("tir.exp").set_attr("default.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tir.exp2") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); + TVM_REGISTER_OP("tir.erf").set_attr("default.FLowerIntrinsic", DispatchPureExtern); @@ -181,7 +184,8 @@ TVM_REGISTER_OP("tir.isfinite") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - return isfinite(call->args[0]); + PrimExpr x = call->args[0]; + return !isinf(x) && !isnan(x); }); TVM_REGISTER_OP("tir.isinf") diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index ac45476f7702..fbe03a6081e3 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -79,11 +79,11 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { name = T()(dtype, name.substr(4)); if (name.length() != 0) { - Array new_args = {StringImm(name)}; + ffi::Array new_args = {StringImm(name)}; for (auto arg : call->args) { new_args.push_back(arg); } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->dtype, builtin::call_pure_extern(), new_args, call->annotations); } else { return e; } diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 7937f72bea43..adac65914469 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -85,7 +85,7 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { } const auto* attr_value = op->value.as(); - ICHECK(attr_value) << "Expect " << attr_key << " to have a String value but was " + ICHECK(attr_value) << "Expect " << attr_key << " to have a ffi::String value but was " << op->value->GetTypeKey(); std::string aarch64_attr_key = attr_key.substr(7); @@ -107,13 +107,13 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { this->VisitStmt(op->body); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.codegen.llvm.target_aarch64", [](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAArch64()); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 9439af440b82..034b982f64b3 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -280,7 +280,7 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); auto fbitcode = tvm::ffi::Function::GetGlobalRequired("tvm_callback_rocm_bitcode_path"); - auto bitcode_files = fbitcode().cast>(); + auto bitcode_files = fbitcode().cast>(); for (auto& bitcode_path : bitcode_files) { std::unique_ptr mlib = llvm_instance.LoadIR(bitcode_path); @@ -361,14 +361,14 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.rocm", BuildAMDGPU) .def_packed("tvm.codegen.llvm.target_rocm", [](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAMDGPU()); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 3adcfc82bba8..180e1aea7345 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -75,10 +75,10 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { int total_size = call->dtype.bits() * call->dtype.lanes(); if (!call->dtype.is_fixed_length_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { - Array vcnt_args; + ffi::Array vcnt_args; vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(e); - return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); + return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args, call->annotations); } // Popcount lowering rule: @@ -98,43 +98,43 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { // Popcount 8bit->8bit const CallNode* c0 = input8.as(); ICHECK(c0 != nullptr); - Array vcnt8_args; + ffi::Array vcnt8_args; vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); + PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args, call->annotations); // Accumulation 8->16bit - Array vcnt16_args; + ffi::Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); + PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args, call->annotations); if (call->dtype.bits() == 16) { return vcnt16; } // Accumulation 16->32bit - Array vcnt32_args; + ffi::Array vcnt32_args; vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); + PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args, call->annotations); if (call->dtype.bits() == 32) { return vcnt32; } // Accumulation 32->64bit - Array vcnt64_args; + ffi::Array vcnt64_args; vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(vcnt32); - return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); + return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args, call->annotations); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.codegen.llvm.target_arm", [](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenARM()); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 5ce8b1ec6584..bc67cdad2fd3 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -71,7 +71,7 @@ CodeGenCPU::CodeGenCPU() = default; CodeGenCPU::~CodeGenCPU() = default; void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) { CodeGenLLVM::Init(module_name, llvm_target, system_lib_prefix, dynamic_lookup, target_c_runtime); system_lib_prefix_ = system_lib_prefix; @@ -175,7 +175,7 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, } llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(llvm::StringRef name, - const Array& param_types, + const ffi::Array& param_types, const Type& return_type) { #if TVM_LLVM_VERSION < 50 return nullptr; @@ -211,7 +211,7 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(llvm::StringRef name, } llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const GlobalVar& gvar, const PrimFunc& func) { - std::string name = func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); + std::string name = func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); return CreateDebugFunction(name, func->params.Map(GetType), func->ret_type); } @@ -220,7 +220,7 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { EmitDebugLocation(func->span); CodeGenLLVM::AddFunction(gvar, func); if (f_tvm_register_system_symbol_ != nullptr) { - if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { export_system_symbols_.emplace_back( std::make_pair(global_symbol.value().operator std::string(), function_)); } @@ -229,6 +229,11 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { + if (module_->getFunction(ffi::symbol::tvm_ffi_main) != nullptr) { + // main already exists, no need to create a wrapper function + // main takes precedence over other entry functions + return; + } // create a wrapper function with tvm_ffi_main name and redirects to the entry function llvm::Function* target_func = module_->getFunction(entry_func_name); ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module"; @@ -385,8 +390,8 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value } } -llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) { std::vector arg_values; for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { arg_values.push_back(MakeValue(args[i])); @@ -506,6 +511,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { std::swap(analyzer_, parent_->analyzer_); std::swap(var_map_, parent_->var_map_); std::swap(di_subprogram_, parent_->di_subprogram_); + std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_); } void ExitWithScope() { @@ -513,11 +519,13 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { std::swap(analyzer_, parent_->analyzer_); std::swap(var_map_, parent_->var_map_); std::swap(di_subprogram_, parent_->di_subprogram_); + std::swap(loop_frame_jump_tgts_, parent_->loop_frame_jump_tgts_); } llvm::Function* function_{nullptr}; llvm::DISubprogram* di_subprogram_{nullptr}; std::unordered_map var_map_; + std::vector> loop_frame_jump_tgts_; std::unique_ptr analyzer_{std::make_unique()}; CodeGenCPU* parent_; }; @@ -526,7 +534,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { // - Make sure the generated compute function is clearly separately(though it can get inlined) // - Set noalias on all the pointer arguments, some of them are loaded from ffi::PackedArgs. // This is easier than set the alias scope manually. - Array vargs = tir::UndefinedVars(op->body, {}); + ffi::Array vargs = tir::UndefinedVars(op->body, {}); std::vector arg_values; std::vector arg_types; for (Var v : vargs) { @@ -593,7 +601,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { AddDebugInformation(fcompute, vargs.Map(GetType)); } -CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array& vfields, +CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const ffi::Array& vfields, uint64_t* num_bytes, std::string struct_name) { if (vfields.size() == 0) { @@ -619,7 +627,7 @@ CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array& vfields, return TypedPointer(ctype, cvalue); } -void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array& vfields, +void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const ffi::Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { llvm::Type* field_type = cdata.type->getStructElementType(i); @@ -639,7 +647,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin SetTargetAttributes(f); // allocate and setup the closure, call the closure. - Array vfields = tir::UndefinedVars(body, {}); + ffi::Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name); #if TVM_LLVM_VERSION >= 90 @@ -715,7 +723,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod } // allocate and setup the closure, call the closure. uint64_t nbytes; - Array vfields = tir::UndefinedVars(body, {}); + ffi::Array vfields = tir::UndefinedVars(body, {}); TypedPointer cdata = PackClosureData(vfields, &nbytes); llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); @@ -825,7 +833,7 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { return phi; } -CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, +CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const ffi::Array& args, const DataType& r_type, const int64_t begin, const int64_t end, bool use_env_lookup) { @@ -857,8 +865,9 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& call_args.push_back(GetPackedFuncHandle(func_name)); call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result}); } else { + // directly call into symbol, needs to prefix with tvm_ffi_symbol_prefix callee_ftype = ftype_tvm_ffi_c_func_; - callee_value = module_->getFunction(func_name); + callee_value = module_->getFunction(ffi::symbol::tvm_ffi_symbol_prefix + func_name); if (callee_value == nullptr) { callee_value = llvm::Function::Create(ftype_tvm_ffi_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); @@ -1143,14 +1152,15 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { void CodeGenCPU::VisitStmt_(const ForNode* op) { EmitDebugLocation(op); - ICHECK(is_zero(op->min)); if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) { CodeGenLLVM::VisitStmt_(op); } else if (op->kind == ForKind::kParallel) { + ICHECK(is_zero(op->min)) << "Parallel launch require canonical loop with zero start index"; + ICHECK(op->HasTrivialStep()) << "Parallel launch require canonical loop with trivial loop step"; if (parallel_env_.penv == nullptr) { - CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, - op->thread_binding, op->annotations), - 0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str()); + auto copy_node = For(ffi::make_object(*op)); + CreateParallelLaunch(copy_node, 0, + std::string("loop_parallel_") + op->loop_var->name_hint.c_str()); } else { // already in parallel env. ICHECK(parallel_env_.task_id.defined()); @@ -1162,13 +1172,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ICHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; + PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); if (parallel_env_.stride_pattern) { - CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), - op->loop_var, op->body); + CreateSerialFor(MakeValue(task_id), MakeValue(end), MakeValue(num_task), op->loop_var, + op->body); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = min(task_id * step, op->extent); - PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); + end = min((task_id + make_const(t, 1)) * step, end); CreateSerialFor(MakeValue(begin), MakeValue(end), llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } @@ -1180,13 +1191,13 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.codegen.llvm.target_cpu", [](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenCPU()); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index f8c6b362badf..d5401b966220 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -65,7 +65,7 @@ class CodeGenCPU : public CodeGenLLVM { virtual ~CodeGenCPU(); void Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) override; void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; @@ -74,8 +74,8 @@ class CodeGenCPU : public CodeGenLLVM { void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const ForNode* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override; - llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg) override; + llvm::Value* CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) override; protected: void AddStartupFunction() final; @@ -122,10 +122,10 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - TypedPointer PackClosureData(const Array& fields, uint64_t* num_bytes, + TypedPointer PackClosureData(const ffi::Array& fields, uint64_t* num_bytes, std::string struct_name = ""); TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(TypedPointer cdata, const Array& fields, + void UnpackClosureData(TypedPointer cdata, const ffi::Array& fields, std::unordered_map* vmap); // Make packed call. struct PackedCall { @@ -133,7 +133,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* ret_type_index; llvm::BasicBlock* end_block; }; - PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, + PackedCall MakeCallPackedLowered(const ffi::Array& args, const DataType& r_type, const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); @@ -151,7 +151,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); llvm::DISubprogram* CreateDebugFunction(const GlobalVar& gvar, const PrimFunc& f); - llvm::DISubprogram* CreateDebugFunction(llvm::StringRef name, const Array& param_types, + llvm::DISubprogram* CreateDebugFunction(llvm::StringRef name, const ffi::Array& param_types, const Type& return_type); // Context for injection lookup @@ -161,7 +161,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::GlobalVariable* gv_tvm_ffi_set_last_error_c_str_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr}; - std::unordered_map gv_func_map_; + std::unordered_map gv_func_map_; // context for direct dynamic lookup llvm::Function* f_tvm_ffi_func_call_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr}; @@ -181,7 +181,7 @@ class CodeGenCPU : public CodeGenLLVM { bool target_c_runtime_; // The system lib prefix if it is not nullopt, then we should do // system lib registration with the given prefix. The prefix can be "" - Optional system_lib_prefix_; + ffi::Optional system_lib_prefix_; }; } // namespace codegen diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 67fccd8b073a..773e2a2e1d91 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -71,7 +71,7 @@ namespace codegen { class CodeGenHexagon final : public CodeGenCPU { public: void Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) override; void InitTarget() final; @@ -79,10 +79,10 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::Value* VisitExpr_(const BufferLoadNode* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override; - llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg) override; - llvm::Value* CreateCallExternQHL(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg); + llvm::Value* CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) override; + llvm::Value* CreateCallExternQHL(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg); llvm::Module* GetModulePtr() const { return module_.get(); } @@ -105,7 +105,7 @@ class CodeGenHexagon final : public CodeGenCPU { bool IsQHLFunction(const std::string& func); - llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, Array indices); + llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, ffi::Array indices); llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); std::vector fqhl_list_ = { "tvm_vect_qhmath_hvx_cos_ahf", "tvm_vect_qhmath_hvx_tanh_ahf", @@ -116,7 +116,7 @@ class CodeGenHexagon final : public CodeGenCPU { }; void CodeGenHexagon::Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) { CodeGenCPU::Init(module_name, llvm_target, system_lib_prefix, dynamic_lookup, target_c_runtime); } @@ -149,8 +149,9 @@ void CodeGenHexagon::InitTarget() { CodeGenCPU::InitTarget(); } -llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, + bool skip_first_arg) { int num_lanes = args[1].dtype().lanes(); int vector_length = native_vector_bits_ / args[1].dtype().bits(); num_lanes = ((num_lanes + vector_length - 1) / vector_length) * vector_length; @@ -184,8 +185,9 @@ bool CodeGenHexagon::IsQHLFunction(const std::string& func) { return std::find(fqhl_list_.begin(), fqhl_list_.end(), func) != fqhl_list_.end(); } -llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, + bool skip_first_arg) { int num_lanes = args[1].dtype().lanes(); int vector_length = native_vector_bits_ / args[1].dtype().bits(); if (IsQHLFunction(global_symbol) && (num_lanes > vector_length)) @@ -328,7 +330,7 @@ llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID, } llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_type, - Array indices) { + ffi::Array indices) { PrimExpr index = indices[0]; if (!index.dtype().is_fixed_length_vector()) { return nullptr; @@ -453,8 +455,8 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { return vec; }; std::string llvm_options_str = "llvm"; - if (const auto& llvm_options = target->GetAttr>("llvm-options")) { - for (const String& s : llvm_options.value()) llvm_options_str += "," + s; + if (const auto& llvm_options = target->GetAttr>("llvm-options")) { + for (const ffi::String& s : llvm_options.value()) llvm_options_str += "," + s; } // Postprocess the LLVM options string: replace '@' with '=', and ',' with ' '. for (int i = 0, e = llvm_options_str.size(); i != e; ++i) { @@ -494,7 +496,7 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { } auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()); entry_func = global_symbol.value(); } @@ -572,10 +574,10 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { ICHECK(f.has_value()) << "tvm.contrib.hexagon.link_shared does not to exist, " "do import tvm.contrib.hexagon"; - Array o_names = {StringImm(o_name)}; - Map extra_args; + ffi::Array o_names = {StringImm(o_name)}; + ffi::Map extra_args; if (target->attrs.count("mcpu")) { - std::string mcpu = Downcast(target->attrs.at("mcpu")); + std::string mcpu = Downcast(target->attrs.at("mcpu")); #if TVM_LLVM_VERSION >= 180 ICHECK(llvm::StringRef(mcpu).starts_with("hexagon")) #else @@ -590,7 +592,7 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, obj_str, ir_str, bc_str); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.hexagon", BuildHexagon) @@ -598,7 +600,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenHexagon()); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 5b2cb5cc95e3..131c8212c597 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -138,7 +138,7 @@ std::unique_ptr CodeGenLLVM::Create(LLVMTarget* llvm_target) { } void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) { llvm_target_ = llvm_target; llvm::LLVMContext* ctx = llvm_target_->GetContext(); @@ -148,6 +148,7 @@ void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, // types t_void_ = llvm::Type::getVoidTy(*ctx); t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace()); + t_int1_ = llvm::Type::getInt1Ty(*ctx); t_int_ = llvm::Type::getInt32Ty(*ctx); t_char_ = llvm::Type::getInt8Ty(*ctx); t_int8_ = llvm::Type::getInt8Ty(*ctx); @@ -240,7 +241,7 @@ void CodeGenLLVM::InitFuncState() { std::tuple CodeGenLLVM::GetLinkage( const GlobalVar& gvar, const PrimFunc& func) { - if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { return {global_symbol.value(), llvm::Function::ExternalLinkage}; } @@ -576,6 +577,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { llvm::LLVMContext* ctx = llvm_target_->GetContext(); if (dtype.is_int() || dtype.is_uint()) { etype = llvm::Type::getIntNTy(*ctx, dtype.bits()); + } else if (dtype.is_bool()) { + etype = t_int1_; } else if (dtype.is_float()) { switch (dtype.bits()) { case 16: @@ -717,8 +720,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; - *p_native_bits = - NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef(buf_var)))); + *p_native_bits = NativeVectorBits( + runtime::StorageScope::Create(GetPtrStorageScope(ffi::GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; @@ -775,6 +778,12 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul return debug_info; } +void CodeGenLLVM::PushLoopFrame(llvm::BasicBlock* backedge_tgt, llvm::BasicBlock* exit_tgt) { + loop_frame_jump_tgts_.emplace_back(backedge_tgt, exit_tgt); +} + +void CodeGenLLVM::PopLoopFrame() { loop_frame_jump_tgts_.pop_back(); } + llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; @@ -878,6 +887,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_); auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_); auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_); + auto* for_next = llvm::BasicBlock::Create(*ctx, "for_next_" + loop_var_name, function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); @@ -892,8 +902,13 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va builder_->SetInsertPoint(for_body); EmitDebugLocation(body->span); + PushLoopFrame(for_next, for_end); this->VisitStmt(body); + PopLoopFrame(); var_map_.erase(loop_var.get()); + + builder_->CreateBr(for_next); + builder_->SetInsertPoint(for_next); llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride); loop_value->addIncoming(loop_next, builder_->GetInsertBlock()); builder_->CreateBr(for_begin); @@ -910,7 +925,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va if (to.is_handle()) { return builder_->CreateBitCast(value, target); - } else if (to.is_uint() && to.bits() == 1) { + } else if (to.is_bool()) { if (from.is_float()) { llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.); return builder_->CreateFCmpONE(value, zero); @@ -931,7 +946,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } } else if (from.is_int() && to.is_float()) { return builder_->CreateSIToFP(value, target); - } else if (from.is_uint() && to.is_float()) { + } else if ((from.is_uint() || from.is_bool()) && to.is_float()) { return builder_->CreateUIToFP(value, target); } else { ICHECK(from.is_float() && to.is_float()); @@ -1060,8 +1075,8 @@ llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) { return call; } -llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) { std::vector arg_value; std::vector arg_type; for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { @@ -1367,7 +1382,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { arg_value.push_back(MakeValue(op->args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::Type* return_type = GetLLVMType(GetRef(op)); + llvm::Type* return_type = GetLLVMType(ffi::GetRef(op)); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " << llvmGetIntrinName(id); @@ -1406,7 +1421,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); - Array indices = load->indices; + ffi::Array indices = load->indices; if (const RampNode* r = indices[indices.size() - 1].as()) { indices.Set(indices.size() - 1, r->base); } @@ -1466,6 +1481,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_); builder_->SetInsertPoint(ret_dummy); return ret_dummy; + } else if (op->op.same_as(builtin::continue_loop())) { + ICHECK(!loop_frame_jump_tgts_.empty()) + << "the tir.continue_loop should be inserted under at least one For or While stmts."; + builder_->CreateBr(loop_frame_jump_tgts_.back().first); + // LLVM allows exactly one terminator in a single basic block + // append a new dummy basic block to avoid error. + llvm::BasicBlock* post_dummy = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_cont_dummy", function_); + builder_->SetInsertPoint(post_dummy); + return post_dummy; + } else if (op->op.same_as(builtin::break_loop())) { + ICHECK(!loop_frame_jump_tgts_.empty()) + << "the tir.break_loop should be inserted under at least one For or While stmts."; + builder_->CreateBr(loop_frame_jump_tgts_.back().second); + // LLVM allows exactly one terminator in a single basic block + // append a new dummy basic block to avoid error. + llvm::BasicBlock* post_dummy = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "post_break_dummy", function_); + builder_->SetInsertPoint(post_dummy); + return post_dummy; } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); @@ -1697,7 +1732,8 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, Optional predicate, DataType value_dtype, + Buffer buffer, ffi::Array indices, ffi::Optional predicate, + DataType value_dtype, std::function make_instruction) { @@ -1855,20 +1891,20 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); - return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, - true); + return this->CreateCallExtern(GetType(ffi::GetRef(op)), global_symbol->value, + op->args, true); } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. - return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], - op->args, false); + return this->CreateCallExtern(GetType(ffi::GetRef(op)), + op_attr_global_symbol_[call_op], op->args, false); } else { - VLOG(2) << "CreateIntrinsic: " << GetRef(op); + VLOG(2) << "CreateIntrinsic: " << ffi::GetRef(op); auto x = CreateIntrinsic(op); VLOG(2) << "CreateIntrinsic done"; return x; } } else if (auto* ptr_gvar = op->op.as()) { - auto gvar = GetRef(ptr_gvar); + auto gvar = ffi::GetRef(ptr_gvar); auto it = functions_.find(ptr_gvar); ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\""; llvm::Function* callee = it->second; @@ -1987,7 +2023,6 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { void CodeGenLLVM::VisitStmt_(const ForNode* op) { EmitDebugLocation(op); - ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); if (op->kind == ForKind::kUnrolled) { LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " @@ -1995,8 +2030,11 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } else { ICHECK(op->kind == ForKind::kSerial); } - CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); + PrimExpr step = op->step.value_or(make_const(op->extent->dtype, 1)); + PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); + llvm::Value* begin_value = MakeValue(op->min); + llvm::Value* end_value = MakeValue(end); + CreateSerialFor(begin_value, end_value, MakeValue(step), op->loop_var, op->body); } void CodeGenLLVM::VisitStmt_(const WhileNode* op) { @@ -2009,7 +2047,9 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) { builder_->SetInsertPoint(while_cond); builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); builder_->SetInsertPoint(while_body); + PushLoopFrame(while_cond, while_merge); this->VisitStmt(op->body); + PopLoopFrame(); builder_->CreateBr(while_cond); builder_->SetInsertPoint(while_merge); } @@ -2041,7 +2081,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { EmitDebugLocation(op); auto data = op->data.value(); - auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data); + auto array = TensorToLLVMArray(llvm_target_->GetContext(), data); std::string symbol_name = op->buffer_var->name_hint; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); @@ -2188,7 +2228,7 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } -void CodeGenLLVM::EmitDebugLocation(const Optional& span) { +void CodeGenLLVM::EmitDebugLocation(const ffi::Optional& span) { #if TVM_LLVM_VERSION >= 50 if (di_subprogram_ == nullptr) { // debug info is not always generated outside of CPU codegen @@ -2213,7 +2253,8 @@ void CodeGenLLVM::EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullpt void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); } // Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv -void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm, const Array& tvm_param_types) { +void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm, + const ffi::Array& tvm_param_types) { #if TVM_LLVM_VERSION >= 50 ICHECK(di_subprogram_); f_llvm->setSubprogram(di_subprogram_); @@ -2274,9 +2315,11 @@ void CodeGenLLVM::AddDebugInformation(llvm::Value* llvm_value, const Var& tir_va #if TVM_LLVM_VERSION >= 50 if (!di_subprogram_) return; + auto dbg_dtype = GetDebugType(GetType(tir_var)); + // no invalid dtypes + if (!dbg_dtype) return; auto local_var = dbg_info_->di_builder_->createAutoVariable( - di_subprogram_, std::string(tir_var->name_hint), dbg_info_->file_, 0, - GetDebugType(GetType(tir_var))); + di_subprogram_, std::string(tir_var->name_hint), dbg_info_->file_, 0, dbg_dtype); auto* di_loc = llvm::DILocation::get(*llvm_target_->GetContext(), 0, 0, di_subprogram_); @@ -2330,6 +2373,8 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) return nullptr; } + if (dtype.is_scalable_vector()) return nullptr; + return dbg_info_->di_builder_->createBasicType(DLDataTypeToString(dtype).operator std::string(), dtype.bits() * dtype.lanes(), dwarf_type); @@ -2351,9 +2396,9 @@ static void CodegenLLVMRegisterReflection() { []() -> std::string { return llvm::sys::getProcessTriple(); }) .def("tvm.codegen.llvm.GetHostCPUName", []() -> std::string { return llvm::sys::getHostCPUName().str(); }) - .def("tvm.codegen.llvm.GetHostCPUFeatures", []() -> Map { + .def("tvm.codegen.llvm.GetHostCPUFeatures", []() -> ffi::Map { #if TVM_LLVM_VERSION >= 190 - Map ret; + ffi::Map ret; auto features = llvm::sys::getHostCPUFeatures(); for (auto it = features.begin(); it != features.end(); ++it) { std::string name = it->getKey().str(); @@ -2364,7 +2409,7 @@ static void CodegenLLVMRegisterReflection() { #else llvm::StringMap features; if (llvm::sys::getHostCPUFeatures(features)) { - Map ret; + ffi::Map ret; for (auto it = features.begin(); it != features.end(); ++it) { std::string name = it->getKey().str(); bool value = it->getValue(); @@ -2378,7 +2423,7 @@ static void CodegenLLVMRegisterReflection() { }); } -TVM_FFI_STATIC_INIT_BLOCK({ CodegenLLVMRegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { CodegenLLVMRegisterReflection(); } } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e1667b637578..efec7ad6ada7 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -125,7 +125,8 @@ class CodeGenLLVM : public ExprFunctor, * this option influences whether global ctors are used. */ virtual void Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime); + ffi::Optional system_lib_prefix, bool dynamic_lookup, + bool target_c_runtime); /*! * \brief Turn on fast math flags for floating point operations. @@ -266,7 +267,7 @@ class CodeGenLLVM : public ExprFunctor, /*! * \brief Convert tvm::ffi::String into llvm::StringRef */ - static llvm::StringRef MakeStringRef(const String& string) { + static llvm::StringRef MakeStringRef(const ffi::String& string) { return llvm::StringRef(string.c_str(), string.size()); } /*! @@ -293,8 +294,8 @@ class CodeGenLLVM : public ExprFunctor, virtual llvm::Value* CreateIntrinsic(const CallNode* op); // create extern function call // skip first arg mode used for call extern intrinsic. - virtual llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg); + virtual llvm::Value* CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg); /*! \brief Insert a printf() call to the generated LLVM * @@ -359,7 +360,8 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, Optional predicate, DataType value_dtype, + Buffer buffer, ffi::Array indices, ffi::Optional predicate, + DataType value_dtype, std::function make_instruction); @@ -534,6 +536,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Type* t_void_{nullptr}; llvm::PointerType* t_void_p_{nullptr}; llvm::Type* t_int_{nullptr}; + llvm::Type* t_int1_{nullptr}; llvm::Type* t_char_{nullptr}; llvm::Type* t_int8_{nullptr}; llvm::Type* t_int16_{nullptr}; @@ -585,7 +588,7 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); void EmitDebugLocation(); - void EmitDebugLocation(const Optional& span); + void EmitDebugLocation(const ffi::Optional& span); void EmitDebugLocation(const StmtNode* op); // Get the DWARF type corresponding to the LLVM type |ty|. The current API in practice only @@ -594,7 +597,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::DIType* GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm); // Adds the DWARF debug information for |function| to |dbg_info_|. - void AddDebugInformation(llvm::Function* f_llvm, const Array& tvm_param_types); + void AddDebugInformation(llvm::Function* f_llvm, const ffi::Array& tvm_param_types); // Adds the DWARF debug information for |tir_var| to |dbg_info_|. void AddDebugInformation(llvm::Value* llvm_value, const Var& tir_var, llvm::Instruction* insert_before = nullptr); @@ -615,6 +618,13 @@ class CodeGenLLVM : public ExprFunctor, * initializes file and compilation_unit_ to TVM defaults. */ static std::unique_ptr CreateDebugInfo(llvm::Module* module); + + void PushLoopFrame(llvm::BasicBlock* backedge_tgt, llvm::BasicBlock* exit_tgt); + void PopLoopFrame(); + + // loop frame's jump target for continue and break generation + // store basic block pair (blk to backedge, blk to exit) for each frame. + std::vector> loop_frame_jump_tgts_; }; inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) { diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a1c967e644cb..17a90477d2fc 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -316,7 +316,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { } int GetCUDAComputeVersion(const Target& target) { - Optional mcpu = target->GetAttr("mcpu"); + ffi::Optional mcpu = target->GetAttr("mcpu"); ICHECK(mcpu.has_value()) << "InternalError: \"-mcpu\" is undefined in the NVPTX target"; std::string sm_version = mcpu.value(); return std::stoi(sm_version.substr(3)); @@ -377,14 +377,14 @@ ffi::Module BuildNVPTX(IRModule mod, Target target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.nvptx", BuildNVPTX) .def_packed("tvm.codegen.llvm.target_nvptx", [](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenNVPTX()); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_params.cc b/src/target/llvm/codegen_params.cc index 81ed4462318f..e2e5323445c8 100644 --- a/src/target/llvm/codegen_params.cc +++ b/src/target/llvm/codegen_params.cc @@ -70,7 +70,7 @@ void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_ele [&](T t) { return LLVMConstantGetter::getElement(element_type, t); }); } -llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr) { +llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::Tensor arr) { llvm::Type* element_type = nullptr; auto arr_type = arr.DataType(); diff --git a/src/target/llvm/codegen_params.h b/src/target/llvm/codegen_params.h index 9d05621469a7..b59630fb6150 100644 --- a/src/target/llvm/codegen_params.h +++ b/src/target/llvm/codegen_params.h @@ -24,7 +24,7 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ #define TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ -#include +#include namespace llvm { class ConstantArray; @@ -35,15 +35,15 @@ namespace tvm { namespace codegen { /*! - * \brief Convert an NDArray to an LLVM array of constants. + * \brief Convert an Tensor to an LLVM array of constants. * - * The supplied NDArray is flattened, and each element is converted to the appropriate LLVM type. + * The supplied Tensor is flattened, and each element is converted to the appropriate LLVM type. * * \param ctx LLVM context used to create the various primitive datatypes. - * \param arr NDArray to convert. + * \param arr Tensor to convert. * \return LLVM array containing the array data. */ -llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, tvm::runtime::NDArray arr); +llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, tvm::runtime::Tensor arr); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index b4c7cf190136..8a63149ebb0b 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -68,7 +68,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { DTypeToLLVMType(DataType::Float(32, from.lanes())), { MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), - {op->value})), + {op->value}, {})), MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), @@ -83,7 +83,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8, DTypeToLLVMType(DataType::Float(32, from.lanes())), {MakeValue(tir::Call(DataType::Int(16, from.lanes()), - tir::builtin::reinterpret(), {op->value}))}); + tir::builtin::reinterpret(), {op->value}, {}))}); } #endif } @@ -133,13 +133,13 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.codegen.llvm.target_x86-64", [](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenX86_64()); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index b38ff0674943..5415b8ed6f97 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -39,11 +39,11 @@ namespace llvm { using tir::FLowerIntrinsic; inline PrimExpr TVMExternCall(const tir::CallNode* call, const std::string& fname) { - Array new_args = {tir::StringImm(fname)}; + ffi::Array new_args = {tir::StringImm(fname)}; for (PrimExpr arg : call->args) { new_args.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_pure_extern(), new_args); + return tir::Call(call->dtype, tir::builtin::call_pure_extern(), new_args, call->annotations); } template @@ -51,7 +51,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { using namespace tir; const CallNode* call = e.as(); ICHECK(call != nullptr); - Array new_args; + ffi::Array new_args; #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); @@ -72,7 +72,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { new_args.push_back(IntImm(DataType::UInt(32), id)); new_args.push_back(IntImm(DataType::UInt(32), num_sign)); new_args.insert(new_args.end(), call->args.begin(), call->args.end()); - return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), new_args); + return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), new_args, call->annotations); } TVM_REGISTER_OP("tir.fma").set_attr( @@ -183,8 +183,8 @@ TVM_REGISTER_OP("tir.sigmoid") const PrimExpr v1 = tir::Max(x, MinBound); const PrimExpr v2 = tir::Min(v1, MaxBound); - Array new_args = {v2}; - const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); + ffi::Array new_args = {v2}; + const tir::Call new_call = tir::Call(call->dtype, call->op, new_args, call->annotations); // Enable QHL library for FP16 data type if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 17de699e00b4..cedf41aeb79f 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -264,12 +264,12 @@ TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimEx const tir::CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 1); - Array cargs; + ffi::Array cargs; cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); cargs.push_back(call->args[0]); cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef // LLVM requires that the return type must match the first argument type - auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); + auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs, call->annotations); return cast(call->dtype, clz); }); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index aa4f68d0b090..9b0826c10348 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -41,7 +41,7 @@ template inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); - Array cargs; + ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); ICHECK_EQ(call->args.size(), num_signature) @@ -51,14 +51,14 @@ inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs); + return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs, call->annotations); } template inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); - Array cargs; + ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); ICHECK_EQ(call->args.size(), num_signature) @@ -67,7 +67,7 @@ inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs); + return tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, call->annotations); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 48fc64172215..8a50e906969a 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -49,11 +49,11 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { intrinsic_name << "__nv_" << name.substr(4); if (call->dtype.bits() == 32) intrinsic_name << "f"; - Array new_args = {StringImm(intrinsic_name.str())}; + ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->dtype, builtin::call_pure_extern(), new_args, call->annotations); } namespace llvm { diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 30afcee92acc..9fc0a0da82d2 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -52,12 +52,12 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { std::ostringstream intrinsic_name; intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); - Array new_args = {StringImm(intrinsic_name.str())}; + ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->dtype, builtin::call_pure_extern(), new_args, call->annotations); } inline PrimExpr DispatchShuffle(const PrimExpr& e) { @@ -72,9 +72,9 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); PrimExpr zero = tir::make_zero(DataType::Int(32)); PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}); + {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, call->annotations); PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}); + {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, call->annotations); // compute lane to get from PrimExpr width = call->args[3]; @@ -96,7 +96,7 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32; PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var); PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}); + {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}, call->annotations); if (!is_int32) { res = reinterpret(var.dtype(), res); } diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index e494a2bbf9e9..32bada242ceb 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -203,19 +203,19 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) : LLVMTargetInfo(instance, target->Export()) {} LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) { - triple_ = Downcast(target.Get("mtriple").value_or(String("default"))); + triple_ = Downcast(target.Get("mtriple").value_or(ffi::String("default"))); if (triple_.empty() || triple_ == "default") { triple_ = llvm::sys::getDefaultTargetTriple(); } - cpu_ = Downcast(target.Get("mcpu").value_or(String(defaults::cpu))); + cpu_ = Downcast(target.Get("mcpu").value_or(ffi::String(defaults::cpu))); - if (const auto& v = Downcast>>(target.Get("mattr"))) { - for (const String& s : v.value()) { + if (const auto& v = Downcast>>(target.Get("mattr"))) { + for (const ffi::String& s : v.value()) { attrs_.push_back(s); } } // llvm module target - if (Downcast(target.Get("kind").value()) == "llvm") { + if (Downcast(target.Get("kind").value()) == "llvm") { // legalize -mcpu with the target -mtriple auto arches = GetAllLLVMTargetArches(); bool has_arch = @@ -225,16 +225,16 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // give the code a chance to run with a less-specific target. LOG(ERROR) << "Using LLVM " << LLVM_VERSION_STRING << " with `-mcpu=" << cpu_ << "` is not valid in `-mtriple=" << triple_ << "`" - << ", using default `-mcpu=" << String(defaults::cpu) << "`"; + << ", using default `-mcpu=" << ffi::String(defaults::cpu) << "`"; // LLVM default cpu fallback - cpu_ = String(defaults::cpu); + cpu_ = ffi::String(defaults::cpu); } } - if (const auto& v = Downcast>>(target.Get("cl-opt"))) { + if (const auto& v = Downcast>>(target.Get("cl-opt"))) { llvm::StringMap& options = llvm::cl::getRegisteredOptions(); bool parse_error = false; - for (const String& s : v.value()) { + for (const ffi::String& s : v.value()) { Option opt = ParseOptionString(s); if (opt.type == Option::OptType::Invalid) { parse_error = true; @@ -252,8 +252,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; - if (const auto& v = Downcast>(target.Get("mfloat-abi"))) { - String value = v.value(); + if (const auto& v = Downcast>(target.Get("mfloat-abi"))) { + ffi::String value = v.value(); if (value == "hard") { float_abi = llvm::FloatABI::Hard; } else if (value == "soft") { @@ -264,8 +264,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } // LLVM JIT engine options - if (const auto& v = Downcast>(target.Get("jit").value_or(nullptr))) { - String value = v.value(); + if (const auto& v = Downcast>(target.Get("jit").value_or(nullptr))) { + ffi::String value = v.value(); if ((value == "mcjit") || (value == "orcjit")) { jit_engine_ = value; } else { @@ -274,7 +274,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } // TVM & LLVM vector width options - if (const auto& w = Downcast>(target.Get("vector-width").value_or(nullptr))) { + if (const auto& w = + Downcast>(target.Get("vector-width").value_or(nullptr))) { vector_width_ = w.value(); if ((vector_width_ <= 0) || (vector_width_ > 65536)) { LOG(FATAL) << "Invalid -vector-width value: " << vector_width_; @@ -288,7 +289,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) code_model_ = llvm::CodeModel::Medium; #if TVM_LLVM_VERSION >= 140 // get VLEN from the LLVM backend (zvlXXXb) - Map features = GetAllLLVMCpuFeatures(); + ffi::Map features = GetAllLLVMCpuFeatures(); // check vector ISA if (features.count("v") > 0) { vector_width_ = 0; @@ -320,7 +321,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) target_options_.NoNaNsFPMath = true; target_options_.FloatABIType = float_abi; if (target.find("mabi") != target.end()) { - target_options_.MCOptions.ABIName = Downcast(target.Get("mabi").value()); + target_options_.MCOptions.ABIName = Downcast(target.Get("mabi").value()); } auto maybe_level = target.Get("opt-level"); @@ -833,8 +834,8 @@ void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const { } } -const Array LLVMTargetInfo::GetAllLLVMTargets() const { - Array llvm_targets; +const ffi::Array LLVMTargetInfo::GetAllLLVMTargets() const { + ffi::Array llvm_targets; // iterate all archtypes for (auto a = llvm::Triple::ArchType(llvm::Triple::ArchType::UnknownArch + 1); a < llvm::Triple::ArchType::LastArchType; a = llvm::Triple::ArchType(a + 1)) { @@ -848,8 +849,8 @@ const Array LLVMTargetInfo::GetAllLLVMTargets() const { return llvm_targets; } -const Array LLVMTargetInfo::GetAllLLVMTargetArches() const { - Array cpu_arches; +const ffi::Array LLVMTargetInfo::GetAllLLVMTargetArches() const { + ffi::Array cpu_arches; // get the subtarget info module auto llvm_instance = CreateLLVMTargetInstance(triple_, true); std::unique_ptr target_machine = @@ -873,7 +874,7 @@ const Array LLVMTargetInfo::GetAllLLVMTargetArches() const { return cpu_arches; } -const Map LLVMTargetInfo::GetAllLLVMCpuFeatures() const { +const ffi::Map LLVMTargetInfo::GetAllLLVMCpuFeatures() const { std::string feats = ""; for (const auto& attr : attrs_) { feats += feats.empty() ? attr : ("," + attr); @@ -892,7 +893,7 @@ const Map LLVMTargetInfo::GetAllLLVMCpuFeatures() const { MCInfo->getAllProcessorFeatures(); #endif // TVM doesn't have an FFI friendly Set, so use a Map instead for now - Map cpu_features; + ffi::Map cpu_features; for (const auto& feat : llvm_features) { if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { cpu_features.Set(feat.Key, ""); diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index a68637cc844e..a41c57d6fae6 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -324,14 +324,14 @@ class LLVMTargetInfo { * \brief Get all supported targets from the LLVM backend * \return list with all valid targets */ - const Array GetAllLLVMTargets() const; + const ffi::Array GetAllLLVMTargets() const; /*! * \brief Get all CPU arches from target * \return list with all valid cpu architectures * \note The arches are fetched from the LLVM backend using the target `-mtriple`. */ - const Array GetAllLLVMTargetArches() const; + const ffi::Array GetAllLLVMTargetArches() const; /*! * \brief Get all CPU features from target @@ -340,7 +340,7 @@ class LLVMTargetInfo { * \note The features are fetched from the LLVM backend using the target `-mtriple` * and the `-mcpu` architecture, but also consider the `-mattr` attributes. */ - const Map GetAllLLVMCpuFeatures() const; + const ffi::Map GetAllLLVMCpuFeatures() const; /*! * \brief Check the target if has a specific cpu feature diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index f90729a45f06..5f7494558eaa 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -95,7 +95,7 @@ class LLVMModuleNode final : public ffi::ModuleObj { const char* kind() const final { return "llvm"; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; /*! \brief Get the property of the runtime module .*/ // TODO(tvm-team): Make it serializable @@ -103,15 +103,15 @@ class LLVMModuleNode final : public ffi::ModuleObj { return ffi::Module::kRunnable | ffi::Module::kCompilationExportable; } - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; - String InspectSource(const String& format) const final; + ffi::String InspectSource(const ffi::String& format) const final; void Init(const IRModule& mod, const Target& target); void Init(std::unique_ptr module, std::unique_ptr llvm_instance); void LoadIR(const std::string& file_name); - bool ImplementsFunction(const String& name) final; + bool ImplementsFunction(const ffi::String& name) final; void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; } @@ -135,7 +135,7 @@ class LLVMModuleNode final : public ffi::ModuleObj { // (EngineBuilder takes ownership of the module). std::unique_ptr module_owning_ptr_; /* \brief names of the external functions declared in this module */ - Array function_names_; + ffi::Array function_names_; std::string jit_engine_; }; @@ -155,7 +155,7 @@ LLVMModuleNode::~LLVMModuleNode() { module_owning_ptr_.reset(); } -Optional LLVMModuleNode::GetFunction(const String& name) { +ffi::Optional LLVMModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "__tvm_is_system_module") { bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); @@ -189,9 +189,10 @@ Optional LLVMModuleNode::GetFunction(const String& name) { TVMFFISafeCallType faddr; With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); - faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); + ffi::String name_with_prefix = ffi::symbol::tvm_ffi_symbol_prefix + name; + faddr = reinterpret_cast(GetFunctionAddr(name_with_prefix, *llvm_target)); if (faddr == nullptr) return std::nullopt; - ffi::Module self_strong_ref = GetRef(this); + ffi::Module self_strong_ref = ffi::GetRef(this); return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) { TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), @@ -235,7 +236,8 @@ bool LLVMAddPassesToEmitFile(llvm::TargetMachine* tm, llvm::legacy::PassManager* } // namespace -void LLVMModuleNode::WriteToFile(const String& file_name_str, const String& format) const { +void LLVMModuleNode::WriteToFile(const ffi::String& file_name_str, + const ffi::String& format) const { // CHECK(imports_.empty()) << "SaveToFile does not handle imported modules"; std::string file_name = file_name_str; std::string fmt = runtime::GetFileFormat(file_name, format); @@ -274,7 +276,7 @@ ffi::Bytes LLVMModuleNode::SaveToBytes() const { LOG(FATAL) << "LLVMModule: SaveToBytes not supported"; } -String LLVMModuleNode::InspectSource(const String& format) const { +ffi::String LLVMModuleNode::InspectSource(const ffi::String& format) const { std::string fmt = runtime::GetFileFormat("", format); std::string type_str; llvm::SmallString<256> str; @@ -324,7 +326,8 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { std::string entry_func; - Optional system_lib_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix); + ffi::Optional system_lib_prefix = + mod->GetAttr(tvm::attr::kSystemLibPrefix); for (auto kv : mod->functions) { if (!kv.second->IsInstance()) { @@ -332,7 +335,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { continue; } auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc); ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally."; @@ -385,8 +388,9 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) { Init(std::move(module), std::move(llvm_instance)); } -bool LLVMModuleNode::ImplementsFunction(const String& name) { - return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); +bool LLVMModuleNode::ImplementsFunction(const ffi::String& name) { + return std::find(function_names_.begin(), function_names_.end(), + ffi::symbol::tvm_ffi_symbol_prefix + name) != function_names_.end(); } void LLVMModuleNode::InitMCJIT() { @@ -443,7 +447,7 @@ void LLVMModuleNode::InitMCJIT() { *ctx_addr = this; } - ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + ffi::Module::VisitContextSymbols([this, &llvm_target](const ffi::String& name, void* symbol) { if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { *ctx_addr = symbol; } @@ -491,7 +495,7 @@ void LLVMModuleNode::InitORCJIT() { } // data layout - String module_name = module_->getModuleIdentifier(); + ffi::String module_name = module_->getModuleIdentifier(); llvm::DataLayout layout(tm->createDataLayout()); ICHECK(layout == module_->getDataLayout()) << "Data layout mismatch between module(" @@ -515,7 +519,13 @@ void LLVMModuleNode::InitORCJIT() { const llvm::Triple& triple) -> std::unique_ptr { #endif #if _WIN32 +#if TVM_LLVM_VERSION >= 210 + auto GetMemMgr = [](const llvm::MemoryBuffer&) { + return std::make_unique(); + }; +#else auto GetMemMgr = []() { return std::make_unique(); }; +#endif auto ObjLinkingLayer = std::make_unique(session, std::move(GetMemMgr)); #else @@ -587,7 +597,7 @@ void LLVMModuleNode::InitORCJIT() { reinterpret_cast(GetGlobalAddr(ffi::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } - ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + ffi::Module::VisitContextSymbols([this, &llvm_target](const ffi::String& name, void* symbol) { if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { *ctx_addr = symbol; } @@ -650,7 +660,7 @@ static void LLVMReflectionRegister() { refl::GlobalDef() .def("target.build.llvm", [](IRModule mod, Target target) -> ffi::Module { - auto n = make_object(); + auto n = ffi::make_object(); n->Init(mod, target); return ffi::Module(n); }) @@ -658,7 +668,7 @@ static void LLVMReflectionRegister() { [](std::string target_str, std::string module_name) -> ffi::Module { auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, target_str); - auto n = make_object(); + auto n = ffi::make_object(); // Generate a LLVM module from an input target string auto module = std::make_unique(module_name, *llvm_target->GetContext()); llvm_target->SetTargetMetadata(module.get()); @@ -681,9 +691,9 @@ static void LLVMReflectionRegister() { #endif }) .def("target.llvm_get_intrinsic_name", - [](int64_t id) -> String { return llvmGetIntrinName(id); }) + [](int64_t id) -> ffi::String { return llvmGetIntrinName(id); }) .def("target.llvm_get_system_x86_vendor", - []() -> String { + []() -> ffi::String { #if TVM_LLVM_VERSION >= 120 #if defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64) using namespace llvm::sys::detail::x86; @@ -712,22 +722,22 @@ static void LLVMReflectionRegister() { return llvm_backend.GetVectorWidth(); }) .def("target.llvm_get_system_triple", - []() -> String { return llvm::sys::getDefaultTargetTriple(); }) + []() -> ffi::String { return llvm::sys::getDefaultTargetTriple(); }) .def("target.llvm_get_system_cpu", - []() -> String { return llvm::sys::getHostCPUName().str(); }) + []() -> ffi::String { return llvm::sys::getHostCPUName().str(); }) .def("target.llvm_get_targets", - []() -> Array { + []() -> ffi::Array { auto llvm_instance = std::make_unique(); LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); return llvm_backend.GetAllLLVMTargets(); }) .def("target.llvm_get_cpu_archlist", - [](const Target& target) -> Array { + [](const Target& target) -> ffi::Array { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { if (target->kind->name != "llvm") { - return Array{}; + return ffi::Array{}; } } auto llvm_instance = std::make_unique(); @@ -735,7 +745,7 @@ static void LLVMReflectionRegister() { return llvm_backend.GetAllLLVMTargetArches(); }) .def("target.llvm_get_cpu_features", - [](const Target& target) -> Map { + [](const Target& target) -> ffi::Map { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { @@ -748,7 +758,7 @@ static void LLVMReflectionRegister() { return llvm_backend.GetAllLLVMCpuFeatures(); }) .def("target.llvm_cpu_has_feature", - [](const String feature, const Target& target) -> bool { + [](const ffi::String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { @@ -763,7 +773,7 @@ static void LLVMReflectionRegister() { return has_feature; }) .def("target.target_has_feature", - [](const String feature, const Target& target) -> bool { + [](const ffi::String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { @@ -778,7 +788,7 @@ static void LLVMReflectionRegister() { .def("target.llvm_version_major", []() -> int { return TVM_LLVM_VERSION / 10; }) .def("ffi.Module.load_from_file.ll", [](std::string filename, std::string fmt) -> ffi::Module { - auto n = make_object(); + auto n = ffi::make_object(); n->SetJITEngine("orcjit"); n->LoadIR(filename); return ffi::Module(n); @@ -793,7 +803,7 @@ static void LLVMReflectionRegister() { .def("codegen.codegen_blob", [](std::string data, bool system_lib, std::string llvm_target_string, std::string c_symbol_prefix) -> ffi::Module { - auto n = make_object(); + auto n = ffi::make_object(); auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, llvm_target_string); std::unique_ptr blob = @@ -804,7 +814,7 @@ static void LLVMReflectionRegister() { }); } -TVM_FFI_STATIC_INIT_BLOCK({ LLVMReflectionRegister(); }); +TVM_FFI_STATIC_INIT_BLOCK() { LLVMReflectionRegister(); } } // namespace codegen } // namespace tvm diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 6072a483877c..8d2589aaec13 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -131,7 +131,7 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { CodeGenCUDA cg; cg.Init(output_ssa); - Map functions; + ffi::Map functions; for (auto [gvar, base_func] : mod->functions) { ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto prim_func = Downcast(base_func); @@ -173,10 +173,10 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.cuda", BuildCUDA); -}); -TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", String); +} +TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", ffi::String); } // namespace codegen } // namespace tvm diff --git a/src/target/parsers/aprofile.cc b/src/target/parsers/aprofile.cc index 65bd6a66aedb..4edff94baeda 100644 --- a/src/target/parsers/aprofile.cc +++ b/src/target/parsers/aprofile.cc @@ -35,8 +35,8 @@ namespace target { namespace parsers { namespace aprofile { -double GetArchVersion(Array mattr) { - for (const String& attr : mattr) { +double GetArchVersion(ffi::Array mattr) { + for (const ffi::String& attr : mattr) { std::string attr_string = attr; size_t attr_len = attr_string.size(); if (attr_len >= 4 && attr_string.substr(0, 2) == "+v" && attr_string.back() == 'a') { @@ -47,14 +47,14 @@ double GetArchVersion(Array mattr) { return 0.0; } -double GetArchVersion(Optional> attr) { +double GetArchVersion(ffi::Optional> attr) { if (!attr) { return false; } return GetArchVersion(attr.value()); } -bool IsAArch32(Optional mtriple, Optional mcpu) { +bool IsAArch32(ffi::Optional mtriple, ffi::Optional mcpu) { if (mtriple) { bool is_mprofile = mcpu && support::StartsWith(mcpu.value(), "cortex-m"); return support::StartsWith(mtriple.value(), "arm") && !is_mprofile; @@ -62,7 +62,7 @@ bool IsAArch32(Optional mtriple, Optional mcpu) { return false; } -bool IsAArch64(Optional mtriple) { +bool IsAArch64(ffi::Optional mtriple) { if (mtriple) { return support::StartsWith(mtriple.value(), "aarch64"); } @@ -70,28 +70,32 @@ bool IsAArch64(Optional mtriple) { } bool IsArch(TargetJSON attrs) { - Optional mtriple = Downcast>(attrs.Get("mtriple").value_or(nullptr)); - Optional mcpu = Downcast>(attrs.Get("mcpu").value_or(nullptr)); + ffi::Optional mtriple = + Downcast>(attrs.Get("mtriple").value_or(nullptr)); + ffi::Optional mcpu = + Downcast>(attrs.Get("mcpu").value_or(nullptr)); return IsAArch32(mtriple, mcpu) || IsAArch64(mtriple); } -bool CheckContains(Array array, String predicate) { - return std::any_of(array.begin(), array.end(), [&](String var) { return var == predicate; }); +bool CheckContains(ffi::Array array, ffi::String predicate) { + return std::any_of(array.begin(), array.end(), [&](ffi::String var) { return var == predicate; }); } static TargetFeatures GetFeatures(TargetJSON target) { #ifdef TVM_LLVM_VERSION - String kind = Downcast(target.Get("kind").value()); + ffi::String kind = Downcast(target.Get("kind").value()); ICHECK_EQ(kind, "llvm") << "Expected target kind 'llvm', but got '" << kind << "'"; - Optional mtriple = Downcast>(target.Get("mtriple").value_or(nullptr)); - Optional mcpu = Downcast>(target.Get("mcpu").value_or(nullptr)); + ffi::Optional mtriple = + Downcast>(target.Get("mtriple").value_or(nullptr)); + ffi::Optional mcpu = + Downcast>(target.Get("mcpu").value_or(nullptr)); // Check that LLVM has been compiled with the correct target support auto llvm_instance = std::make_unique(); - codegen::LLVMTargetInfo llvm_backend(*llvm_instance, {{"kind", String("llvm")}}); - Array targets = llvm_backend.GetAllLLVMTargets(); + codegen::LLVMTargetInfo llvm_backend(*llvm_instance, {{"kind", ffi::String("llvm")}}); + ffi::Array targets = llvm_backend.GetAllLLVMTargets(); if ((IsAArch64(mtriple) && !CheckContains(targets, "aarch64")) || (IsAArch32(mtriple, mcpu) && !CheckContains(targets, "arm"))) { LOG(WARNING) << "Cannot parse target features for target: " << target @@ -100,9 +104,9 @@ static TargetFeatures GetFeatures(TargetJSON target) { } codegen::LLVMTargetInfo llvm_target(*llvm_instance, target); - Map features = llvm_target.GetAllLLVMCpuFeatures(); + ffi::Map features = llvm_target.GetAllLLVMCpuFeatures(); - auto has_feature = [features](const String& feature) { + auto has_feature = [features](const ffi::String& feature) { return features.find(feature) != features.end(); }; @@ -120,15 +124,15 @@ static TargetFeatures GetFeatures(TargetJSON target) { return {}; } -static Array MergeKeys(Optional> existing_keys) { - const Array kExtraKeys = {"arm_cpu", "cpu"}; +static ffi::Array MergeKeys(ffi::Optional> existing_keys) { + const ffi::Array kExtraKeys = {"arm_cpu", "cpu"}; if (!existing_keys) { return kExtraKeys; } - Array keys = existing_keys.value(); - for (String key : kExtraKeys) { + ffi::Array keys = existing_keys.value(); + for (ffi::String key : kExtraKeys) { if (std::find(keys.begin(), keys.end(), key) == keys.end()) { keys.push_back(key); } @@ -138,7 +142,8 @@ static Array MergeKeys(Optional> existing_keys) { TargetJSON ParseTarget(TargetJSON target) { target.Set("features", GetFeatures(target)); - target.Set("keys", MergeKeys(Downcast>>(target.Get("keys")))); + target.Set("keys", + MergeKeys(Downcast>>(target.Get("keys")))); return target; } diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc index ee9bf814d323..ac187a03bbdc 100644 --- a/src/target/parsers/cpu.cc +++ b/src/target/parsers/cpu.cc @@ -28,24 +28,24 @@ namespace target { namespace parsers { namespace cpu { -Optional DetectSystemTriple() { +ffi::Optional DetectSystemTriple() { #ifdef TVM_LLVM_VERSION auto pf = tvm::ffi::Function::GetGlobal("target.llvm_get_system_triple"); ICHECK(pf.has_value()) << "The target llvm_get_system_triple was not found, " "please compile with USE_LLVM = ON"; - return (*pf)().cast(); + return (*pf)().cast(); #endif return {}; } TargetJSON ParseTarget(TargetJSON target) { - String kind = Downcast(target.Get("kind").value()); - Optional mtriple = Downcast>(target.Get("mtriple")); - Optional mcpu = Downcast>(target.Get("mcpu")); + ffi::String kind = Downcast(target.Get("kind").value()); + ffi::Optional mtriple = Downcast>(target.Get("mtriple")); + ffi::Optional mcpu = Downcast>(target.Get("mcpu")); // Try to fill in the blanks by detecting target information from the system if (kind == "llvm" && !mtriple.has_value() && !mcpu.has_value()) { - String system_triple = DetectSystemTriple().value_or(""); + ffi::String system_triple = DetectSystemTriple().value_or(""); target.Set("mtriple", system_triple); } diff --git a/src/target/parsers/mprofile.cc b/src/target/parsers/mprofile.cc index acd878c667c0..bd3bf5848a68 100644 --- a/src/target/parsers/mprofile.cc +++ b/src/target/parsers/mprofile.cc @@ -41,7 +41,7 @@ static const char* dspCPUs[] = {"cortex-m55", "cortex-m4", "cortex-m7", static const char* mveCPUs[] = {"cortex-m55", "cortex-m85"}; template -static inline bool MatchesCpu(Optional mcpu, const Container& cpus) { +static inline bool MatchesCpu(ffi::Optional mcpu, const Container& cpus) { if (!mcpu) { return false; } @@ -50,31 +50,32 @@ static inline bool MatchesCpu(Optional mcpu, const Container& cpus) { return std::find_if(std::begin(cpus), std::end(cpus), matches_cpu) != std::end(cpus); } -static inline bool HasFlag(String attr, std::string flag) { +static inline bool HasFlag(ffi::String attr, std::string flag) { std::string attr_str = attr; return attr_str.find(flag) != std::string::npos; } -static inline bool HasFlag(Optional attr, std::string flag) { +static inline bool HasFlag(ffi::Optional attr, std::string flag) { if (!attr) { return false; } return HasFlag(attr.value(), flag); } -static inline bool HasFlag(Optional> attr, std::string flag) { +static inline bool HasFlag(ffi::Optional> attr, std::string flag) { if (!attr) { return false; } - Array attr_array = attr.value(); + ffi::Array attr_array = attr.value(); - auto matching_attr = std::find_if(attr_array.begin(), attr_array.end(), - [flag](String attr_str) { return HasFlag(attr_str, flag); }); + auto matching_attr = + std::find_if(attr_array.begin(), attr_array.end(), + [flag](ffi::String attr_str) { return HasFlag(attr_str, flag); }); return matching_attr != attr_array.end(); } bool IsArch(TargetJSON attrs) { - Optional mcpu = Downcast>(attrs.Get("mcpu")); + ffi::Optional mcpu = Downcast>(attrs.Get("mcpu")); if (mcpu) { bool matches_base = MatchesCpu(mcpu, baseCPUs); bool matches_dsp = MatchesCpu(mcpu, dspCPUs); @@ -85,8 +86,9 @@ bool IsArch(TargetJSON attrs) { } static TargetFeatures GetFeatures(TargetJSON target) { - Optional mcpu = Downcast>(target.Get("mcpu")); - Optional> mattr = Downcast>>(target.Get("mattr")); + ffi::Optional mcpu = Downcast>(target.Get("mcpu")); + ffi::Optional> mattr = + Downcast>>(target.Get("mattr")); bool nomve = HasFlag(mcpu, "+nomve") || HasFlag(mattr, "+nomve"); bool nodsp = HasFlag(mcpu, "+nodsp") || HasFlag(mattr, "+nodsp"); @@ -104,15 +106,15 @@ static TargetFeatures GetFeatures(TargetJSON target) { return kNoExt; } -static Array MergeKeys(Optional> existing_keys) { - const Array kExtraKeys = {"arm_cpu", "cpu"}; +static ffi::Array MergeKeys(ffi::Optional> existing_keys) { + const ffi::Array kExtraKeys = {"arm_cpu", "cpu"}; if (!existing_keys) { return kExtraKeys; } - Array keys = existing_keys.value(); - for (String key : kExtraKeys) { + ffi::Array keys = existing_keys.value(); + for (ffi::String key : kExtraKeys) { if (std::find(keys.begin(), keys.end(), key) == keys.end()) { keys.push_back(key); } @@ -122,7 +124,8 @@ static Array MergeKeys(Optional> existing_keys) { TargetJSON ParseTarget(TargetJSON target) { target.Set("features", GetFeatures(target)); - target.Set("keys", MergeKeys(Downcast>>(target.Get("keys")))); + target.Set("keys", + MergeKeys(Downcast>>(target.Get("keys")))); return target; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index acc05cf96c08..2fe8e44dac57 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -76,7 +76,7 @@ void CodeGenC::ReserveKeywordsAsUnique() { name_supply_->ReserveName("return"); } -void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func, +void CodeGenC::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) { PrintFuncPrefix(os); PrintType(func->ret_type, os); @@ -136,8 +136,8 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { return; } - auto function_name = [&]() -> String { - if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto function_name = [&]() -> ffi::String { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { auto name = global_symbol.value(); ICHECK(!func_name_supply_->ContainsName(name)) << "Function " << gvar << " must use global symbol " << name @@ -149,7 +149,9 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { return gvar->name_hint; } }(); - + if (function_name == ffi::symbol::tvm_ffi_main) { + has_tvm_ffi_main_func_ = true; + } internal_functions_.insert({gvar, function_name}); InitFuncState(func); @@ -157,7 +159,7 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { fwd_decl_stream << ";\n"; } -String CodeGenC::GetFunctionName(const GlobalVar& gvar) { +ffi::String CodeGenC::GetFunctionName(const GlobalVar& gvar) { auto it = internal_functions_.find(gvar); ICHECK(it != internal_functions_.end()) << "Attempted to find name of " << gvar @@ -209,6 +211,7 @@ void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) static bool CheckOutermostBracketMatch(const std::string& s); void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) { + PrintIndent(); PrintType(t, stream); stream << ' ' << target << " = "; if (CheckOutermostBracketMatch(src)) { @@ -358,7 +361,22 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << ")"; return os.str(); } else { - TVM_FFI_THROW(RuntimeError) << "Unsupported type index: " << kind; + ICHECK_LT(kind, builtin::kTVMValueKindBound_); + std::ostringstream os; + os << "(((TVMFFIAny*)"; + this->PrintExpr(buffer, os); + os << ")[" << index << "]."; + if (t.is_handle()) { + os << "v_ptr"; + } else if (t.is_float()) { + os << "v_float64"; + } else if (t.is_int()) { + os << "v_int64"; + } else { + LOG(FATAL) << "Do not know how to handle type" << t; + } + os << ")"; + return os.str(); } } @@ -478,7 +496,7 @@ void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT PrintConst(op, os, this); } void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) - os << "\"" << op->value << "\""; + os << EscapeString(op->value); } template @@ -590,8 +608,9 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->a, os); } -void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg, + std::ostream& os) { // NOLINT(*) os << global_symbol << "("; for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { this->PrintExpr(args[i], os); @@ -609,15 +628,19 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::ret())) { os << "return "; PrintExpr(op->args[0], os); + } else if (op->op.same_as(builtin::continue_loop())) { + os << "continue;"; + } else if (op->op.same_as(builtin::break_loop())) { + os << "break;"; } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); - this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); + this->PrintCallExtern(GetType(ffi::GetRef(op)), func->value, op->args, true, os); // If the call_extern refers to an function within the IRModule, then // the forward declaration is already provided from DeclareFunction. if (!func_name_supply_->ContainsName(func->value)) { - Array arg_types; + ffi::Array arg_types; for (size_t i = 1; i < op->args.size(); i++) { arg_types.push_back(GetType(op->args[i])); } @@ -626,7 +649,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. - this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], + this->PrintCallExtern(GetType(ffi::GetRef(op)), op_attr_global_symbol_[call_op], op->args, false, os); } else if (op->op.same_as(builtin::bitwise_and())) { PrintBinaryIntrinsic(op, " & ", os, this); @@ -650,32 +673,42 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin::shift_right())) { PrintBinaryIntrinsic(op, " >> ", os, this); } else if (op->op.same_as(builtin::if_then_else())) { - // conditional that skips eval if cond evals to false + // Conditional that skips eval if cond evals to false. + // When inside a select, combine conditions to prevent OOB access. std::string result = name_supply_->FreshName("condval"); std::string cond = PrintExpr(op->args[0]); + std::string outer_cond = select_condition_stack_.empty() ? "" : select_condition_stack_.back(); + this->PrintIndent(); PrintType(op->dtype, this->stream); this->stream << " " << result << ";\n"; + + // Generate if condition (combine with outer select condition if present) this->PrintIndent(); - this->stream << "if (" << cond << ") {\n"; - { - int then_scope = this->BeginScope(); - std::string true_val = PrintExpr(op->args[1]); - this->PrintIndent(); - this->stream << result << " = " << true_val << ";\n"; - this->EndScope(then_scope); - this->PrintIndent(); - this->stream << "} else {\n"; - } - { - int else_scope = this->BeginScope(); - std::string false_val = PrintExpr(op->args[2]); - this->PrintIndent(); - this->stream << result << " = " << false_val << ";\n"; - this->EndScope(else_scope); - this->PrintIndent(); - this->stream << "}\n"; + if (outer_cond.empty()) { + this->stream << "if (" << cond << ") {\n"; + } else { + this->stream << "if ((" << outer_cond << ") && (" << cond << ")) {\n"; } + + // True branch + int then_scope = this->BeginScope(); + std::string true_val = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << result << " = " << true_val << ";\n"; + this->EndScope(then_scope); + + // False branch + this->PrintIndent(); + this->stream << (outer_cond.empty() ? "} else {\n" : "} else if (" + outer_cond + ") {\n"); + int else_scope = this->BeginScope(); + std::string false_val = PrintExpr(op->args[2]); + this->PrintIndent(); + this->stream << result << " = " << false_val << ";\n"; + this->EndScope(else_scope); + this->PrintIndent(); + this->stream << "}\n"; + os << result; } else if (op->op.same_as(builtin::address_of())) { const BufferLoadNode* load = op->args[0].as(); @@ -703,12 +736,10 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) CHECK_EQ(target_dtype.lanes() * target_dtype.bits(), source_dtype.lanes() * source_dtype.bits()) << "reinterpret expects source and target to have the same number of bits"; - int ssa_scope = BeginScope(); std::string rhs = SSAGetID(PrintExpr(op->args[0]), source_dtype); os << "(*("; this->PrintType(target_dtype, os); os << " *)(&(" << rhs << ")))"; - EndScope(ssa_scope); } else if (op->op.same_as(builtin::isnan())) { os << "("; this->PrintExpr(op->args[0], os); @@ -730,7 +761,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (auto opt = op->op.as()) { auto gvar = opt.value(); auto callee_name = GetFunctionName(gvar); - PrintCallExtern(GetType(GetRef(op)), callee_name, op->args, false, os); + PrintCallExtern(GetType(ffi::GetRef(op)), callee_name, op->args, false, os); } else { LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, " << "nor a GlobalVar reference to another function in the IRModule"; @@ -776,7 +807,7 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) { decl_stream << " __attribute__((section(\".rodata.tvm\"), " << "aligned(" << constants_byte_alignment_->value << "))) " << symbol_name << "[" << num_elements << "] = {\n"; - NDArrayDataToC(data, 4, decl_stream); + TensorDataToC(data, 4, decl_stream); decl_stream << "};\n" << "#ifdef __cplusplus\n" @@ -911,13 +942,19 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { } void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) - auto it = let_binding_.find(op->var); - if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) - << "Let cannot bind the same var to two different values"; - } else { - let_binding_[op->var] = op; - } + // auto it = let_binding_.find(op->var); + // if (it != let_binding_.end()) { + // std::cerr << "CHECK: " << op->var << "(" << op->var.get() << "): " << op->var << " = " << op->value << " : " << std::hex << (unsigned long long)(it->second) << "\n"; + // std::cerr << " var=" << op->var.get() << "\n"; + // std::cerr << " val=" << op->value.get() << "\n"; + // ICHECK(deep_equal_(it->second->value, op->value)) + // << "Let cannot bind the same var to two different values: " << op->var << " " << op->value; + // } else { + // std::cerr << "BIND: " << op->var << "(" << op->var.get() << "): " << op->var << " = " << op->value << " : " << std::hex << (unsigned long long)(op) << "\n"; + // std::cerr << " var=" << op->var.get() << "\n"; + // std::cerr << " val=" << op->value.get() << "\n"; + // let_binding_[op->var] = op; + // } std::string value = PrintExpr(op->value); if (print_ssa_form_) { ICHECK(!var_idmap_.count(op->var.get())); @@ -1032,12 +1069,20 @@ void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLIN } void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) + std::string cond = PrintExpr(op->condition); os << "("; - PrintExpr(op->condition, os); + os << cond; os << " ? "; + // Push condition before processing true_value so that nested if_then_else + // can guard their branches with this condition + select_condition_stack_.push_back(cond); PrintExpr(op->true_value, os); + select_condition_stack_.pop_back(); os << " : "; + // Push negated condition for false_value + select_condition_stack_.push_back("!(" + cond + ")"); PrintExpr(op->false_value, os); + select_condition_stack_.pop_back(); os << ")"; } @@ -1113,13 +1158,21 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) { } void CodeGenC::VisitStmt_(const ForNode* op) { - std::string extent = PrintExpr(op->extent); + std::string begin_str = PrintExpr(op->min); + PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer().Simplify(op->min + op->extent); + std::string end_str = PrintExpr(end); + std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : ""; PrintIndent(); std::string vid = AllocVarID(op->loop_var.get()); - ICHECK(is_zero(op->min)); stream << "for ("; PrintType(op->loop_var.dtype(), stream); - stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; + stream << ' ' << vid << " = " << begin_str << "; " << vid << " < " << end_str << "; "; + if (step_str.empty()) { + stream << "++" << vid; + } else { + stream << vid << " += " << step_str; + } + stream << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); @@ -1200,6 +1253,21 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { // cast int to enum cast = "(DLDeviceType)"; } + // Special-case: Assigning a string literal to the Any union's v_ptr + // triggers const correctness issues when compiling as C++. + // If the destination is the Any union value (kTVMFFIAnyUnionValue), + // the store dtype is a handle (thus maps to v_ptr), and the source value + // is a StringImm, cast the string literal to (void*) to avoid + // discarding const qualifier errors under C++. + if (kind == builtin::kTVMFFIAnyUnionValue && store_dtype.is_handle()) { + if (const auto* str_imm = call->args[3].as()) { + (void)str_imm; // silence unused warning + // prepend cast if not already added + if (cast.empty()) { + cast = "(void*)"; + } + } + } this->PrintIndent(); this->stream << ref << " = " << cast << value << ";\n"; return; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 8c5e1ffd897b..50bd98afccc5 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -90,7 +90,7 @@ class CodeGenC : public ExprFunctor, * \param gvar The GlobalVar of the function * \returns The string name of the function */ - String GetFunctionName(const GlobalVar& gvar); + ffi::String GetFunctionName(const GlobalVar& gvar); /*! * \brief Finalize the compilation and return the code. @@ -131,7 +131,7 @@ class CodeGenC : public ExprFunctor, * * \param os The output stream */ - virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + virtual void PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os); /*! @@ -271,8 +271,8 @@ class CodeGenC : public ExprFunctor, * \param ret_type The return type of the function * \param os The output stream. */ - virtual void GenerateForwardFunctionDeclarations(String global_symbol, - const Array& arg_types, + virtual void GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array& arg_types, const Type& ret_type) {} /*! @@ -283,8 +283,9 @@ class CodeGenC : public ExprFunctor, * \param skip_first_arg Whether to skip the first arguments. * \param os The output stream. */ - virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os); // NOLINT(*) + virtual void PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg, + std::ostream& os); // NOLINT(*) /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. @@ -319,6 +320,16 @@ class CodeGenC : public ExprFunctor, Integer constants_byte_alignment_ = 16; /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; + /*! \brief whether the module has a main function declared */ + bool has_tvm_ffi_main_func_{false}; + /*! \brief Stack of select conditions for if_then_else codegen. + * + * When processing select(cond, true_value, false_value), we push the condition + * before processing true_value. This allows nested if_then_else to guard their + * branches with the outer select condition, preventing potential out-of-bounds + * access when the outer condition is false. + */ + std::vector select_condition_stack_; private: /*! \brief set of volatile buf access */ @@ -337,7 +348,7 @@ class CodeGenC : public ExprFunctor, * functions, this is the name of the function's GlobalVar, possibly * altered to prevent duplicate names. */ - std::unordered_map internal_functions_; + std::unordered_map internal_functions_; /* \brief Name supply to generate unique function names */ NameSupply func_name_supply_; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index e18ba0128d6b..15bee36e31d9 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -32,10 +32,15 @@ #include #include +// For escaping strings embedded into generated C sources +#include "../../support/str_escape.h" + namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); } +CodeGenCHost::CodeGenCHost() { + module_name_ = name_supply_->FreshName(ffi::symbol::tvm_ffi_library_ctx); +} void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, const std::unordered_set& devices) { @@ -48,6 +53,8 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \"tvm/ffi/c_api.h\"\n"; decl_stream << "#include \n"; + // snprintf for richer assert messages with actual values + decl_stream << "#include \n"; decl_stream << "#include \n"; CodeGenCHost::InitGlobalContext(); CodeGenC::Init(output_ssa); @@ -65,14 +72,14 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, bool emit_fwd_func_decl) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (global_symbol) { function_names_.push_back(global_symbol.value()); } emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(gvar, func); - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) { ICHECK(global_symbol.has_value()) << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; @@ -88,8 +95,8 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, } } -void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, - const Array& arg_types, +void CodeGenCHost::GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array& arg_types, const Type& ret_type) { if (!emit_fwd_func_decl_) { return; @@ -235,7 +242,7 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) { } else { // directly use the original symbol ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); - packed_func_name = func_name->value; + packed_func_name = ffi::symbol::tvm_ffi_symbol_prefix + func_name->value; } std::string args_stack = PrintExpr(op->args[1]); @@ -321,9 +328,33 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*) PrintIndent(); stream << "if (!(" << cond << ")) {\n"; int assert_if_scope = this->BeginScope(); - PrintIndent(); - stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" - << op->message.as()->value << "\", NULL);\n"; + { + // Prepare the base error message + const auto* msg_node = op->message.as(); + ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm"; + const std::string& raw_msg = msg_node->value; + const std::string esc_msg = + tvm::support::StrEscape(raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true, + /*escape_whitespace_special_chars=*/true); + + // If the assertion is an equality check, append the actual LHS/RHS values + if (const auto* eq = op->condition.as()) { + std::string lhs = PrintExpr(eq->a); + std::string rhs = PrintExpr(eq->b); + PrintIndent(); + stream << "char __tvm_assert_msg_buf[512];\n"; + PrintIndent(); + stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; got: %lld, expected: %lld\", \"" + << esc_msg << "\", (long long)(" << lhs << "), (long long)(" << rhs + << "));\n"; + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", __tvm_assert_msg_buf);\n"; + } else { + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg + << "\");\n"; + } + } PrintIndent(); stream << "return -1;\n"; this->EndScope(assert_if_scope); @@ -357,13 +388,14 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, ffi::Module BuildCHost(IRModule mod, Target target) { bool output_ssa = false; - bool emit_asserts = false; + // Enable emission of runtime asserts in generated C host code + bool emit_asserts = true; bool emit_fwd_func_decl = true; std::unordered_set devices; - if (mod->GetAttr>("device_contexts") != nullptr) { - Map device_contexts = - mod->GetAttr>("device_contexts").value(); + if (mod->GetAttr>("device_contexts") != nullptr) { + ffi::Map device_contexts = + mod->GetAttr>("device_contexts").value(); for (auto const& context : device_contexts) { devices.insert(context.second.data()); } @@ -407,9 +439,9 @@ ffi::Module BuildCHost(IRModule mod, Target target) { return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.c", BuildCHost); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 4a2f530e2f98..feb0f715d847 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -44,6 +44,7 @@ class CodeGenCHost : public CodeGenC { const std::unordered_set& devices); void InitGlobalContext(); + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override; void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl); /*! @@ -69,20 +70,23 @@ class CodeGenCHost : public CodeGenC { void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) - void GenerateForwardFunctionDeclarations(String global_symbol, const Array& arg_types, + void GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array& arg_types, const Type& ret_type) override; - Array GetFunctionNames() { return function_names_; } + ffi::Array GetFunctionNames() { return function_names_; } private: std::string module_name_; /* \brief mapping global packed func to the unique name */ std::unordered_map declared_globals_; /* \brief names of the functions declared in this module */ - Array function_names_; + ffi::Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; /*! \brief whether to emit forwared function declarations in the resulting C code */ bool emit_fwd_func_decl_; + /*! \brief whether to generate the entry function if encountered */ + bool has_main_func_ = false; std::string GetPackedName(const CallNode* op); void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 951415c3b353..bac0af79ca46 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -140,7 +140,7 @@ void CodeGenCUDA::Init(bool output_ssa) { ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } -void CodeGenCUDA::PrintFunctionSignature(const String& function_name, const PrimFunc& func, +void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) { auto calling_conv = func->GetAttr(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault)); @@ -319,7 +319,6 @@ std::string CodeGenCUDA::Finish() { } void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { - ICHECK(is_const_int(op->min, 0)); if (op->kind == tir::ForKind::kUnrolled) { PrintIndent(); stream << "#pragma unroll\n"; @@ -640,12 +639,12 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, static const char access[] = {'x', 'y', 'z', 'w'}; ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - std::string type_name = t.is_int() ? "char" : "unsigned char"; + std::string type_name = t.is_int() ? "signed char" : "unsigned char"; if (t.lanes() == 2 || t.lanes() == 3) { os << vec << "." << access[i % t.lanes()]; } else { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); - os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; + os << "(reinterpret_cast(&(" << ac << "))[" << (i % 4) << "])"; } } else if (t.is_float16()) { if (t.lanes() <= 4) { @@ -697,12 +696,9 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, << "(" << value << ");\n"; } else { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); - stream << ac << "="; - // Do not read the first undef lane. - if (i != 0) { - stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |"; - } - stream << "(" << value << " << " << i % 4 * 8 << ");\n"; + std::string type_name = t.is_int() ? "signed char" : "unsigned char"; + stream << "reinterpret_cast<" << type_name << "*>(&(" << ac << "))[" << (i % 4) << "] = (" + << type_name << ")(" << value << ");\n"; } } else if (t.is_float16()) { if (t.lanes() <= 4) { @@ -866,8 +862,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } -void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg, + std::ostream& os) { // NOLINT(*) DataType ret_dtype = GetRuntimeDataType(ret_type); if (ret_dtype.is_fixed_length_vector()) { // @@ -1292,7 +1289,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (tgt_dtype.is_float4_e2m1fn()) { // We view the source as an uint16, and then extract bits of two fp4 numbers, // and finally reinterpret the result as fp4x2. - value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}); + value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}, op->annotations); tir::Var temp_var("temp_var", DataType::UInt(16)); value = tir::Let( temp_var, value, @@ -1300,7 +1297,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0)))); } else { value = tir::Cast(DataType::UInt(16), - tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value})); + tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value}, op->annotations)); tir::Var temp_var("temp_var", DataType::UInt(16)); value = tir::Let(temp_var, value, (temp_var & IntImm(DataType::UInt(16), 0xF)) | @@ -1311,7 +1308,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (tgt_dtype.is_float4_e2m1fn()) { // We view the source as an uint32, and then extract bits of four fp4 numbers, // and finally reinterpret the result as fp4x4. - value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value}); + value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value}, op->annotations); tir::Var temp_var("temp_var", DataType::UInt(32)); value = tir::Let(temp_var, value, tir::Cast(DataType::UInt(16), @@ -1321,7 +1318,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); } else { value = tir::Cast(DataType::UInt(32), - tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value})); + tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}, op->annotations)); tir::Var temp_var("temp_var", DataType::UInt(32)); value = tir::Let(temp_var, value, (temp_var & IntImm(DataType::UInt(32), 0xF)) | @@ -1329,7 +1326,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); } - os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); + os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}, op->annotations)); } else { LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; } @@ -1614,13 +1611,17 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) // Type code is kBFloat if (op->dtype.is_bfloat16()) { os << "__float2bfloat16_rn"; - os << '(' << std::scientific << op->value << 'f' << ')'; + os << '(' << std::hexfloat << op->value << 'f'; + os << "/*" << std::scientific << op->value << "*/"; + os << ')'; return; } // Type code is kFloat8_e5m2 or kE4M4Float if (op->dtype.is_float8() || op->dtype.is_float4()) { p->PrintType(op->dtype, os); - os << '(' << std::scientific << op->value << 'f' << ')'; + os << '(' << std::hexfloat << op->value << 'f'; + os << "/*" << std::scientific << op->value << "*/"; + os << ')'; return; } // Type code is kFloat @@ -1655,7 +1656,8 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) temp << "CUDART_NAN_F"; p->need_math_constants_h_ = true; } else { - temp << std::scientific << op->value << 'f'; + temp << std::hexfloat << op->value << 'f'; + temp << "/*" << std::scientific << op->value << "*/"; } p->MarkConst(temp.str()); os << temp.str(); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 6441f87909db..02fc0603a52f 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -46,7 +46,7 @@ class CodeGenCUDA final : public CodeGenC { enable_fp4_ || need_math_constants_h_ || need_mma_h_); } // override behavior - void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + void PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const ForNode* op) final; @@ -74,7 +74,7 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const AttrStmtNode* op) final; protected: - void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + void PrintCallExtern(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg, std::ostream& os) final; // NOLINT(*) private: diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index dc019c28a7a0..2645c8e49693 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -77,7 +77,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { name_supply_->FreshName("v_"); // add to alloc buffer type. - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; @@ -146,10 +146,18 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { decl_stream << "};\n\n"; } // Setup the thread group info. + // Reserve the CUDA-style alias names so user code or downstream passes + // cannot accidentally collide with them, even though the kernel itself + // emits Metal builtin names directly (no `blockIdx`/`threadIdx` aliases). ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + ICHECK_EQ(name_supply_->FreshName("threadgroup_position_in_grid"), + "threadgroup_position_in_grid"); + ICHECK_EQ(name_supply_->FreshName("thread_position_in_threadgroup"), + "thread_position_in_threadgroup"); int work_dim = 0; - auto launch_params = func->GetAttr>(tir::attr::kKernelLaunchParams).value(); + auto launch_params = + func->GetAttr>(tir::attr::kKernelLaunchParams).value(); for (const auto& tag : launch_params) { if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); @@ -158,13 +166,16 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } if (work_dim != 0) { - // use ushort by default for now + // Emit Metal builtin names directly as the kernel parameter identifiers + // rather than using CUDA-style `blockIdx`/`threadIdx` aliases. This keeps + // body references aligned with Apple's MSL convention and avoids forcing + // downstream passes to canonicalize the alias back to the Metal builtin. stream << " "; PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); - stream << " blockIdx [[threadgroup_position_in_grid]],\n"; + stream << " threadgroup_position_in_grid [[threadgroup_position_in_grid]],\n"; stream << " "; PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); - stream << " threadIdx [[thread_position_in_threadgroup]]\n"; + stream << " thread_position_in_threadgroup [[thread_position_in_threadgroup]]\n"; } thread_work_dim_ = work_dim; @@ -179,11 +190,24 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { void CodeGenMetal::BindThreadIndex(const IterVar& iv) { ICHECK(!var_idmap_.count(iv->var.get())); - // if we only have threadIdx.x - // metal will directly print as threadIdx + // The thread_tag is the CUDA-style name (e.g. "threadIdx.x", "blockIdx.y"). + // Translate to the Metal builtin reference so emitted body references + // resolve directly against the kernel parameters declared in AddFunction + // (which now use the Metal builtin names verbatim instead of the + // blockIdx/threadIdx aliases). The .x/.y/.z suffix is preserved. std::string vname = iv->thread_tag; - if (thread_work_dim_ <= 1) { - vname = vname.substr(0, iv->thread_tag.length() - 2); + std::string axis; + if (vname.length() >= 2 && vname[vname.length() - 2] == '.') { + axis = vname.substr(vname.length() - 2); // ".x" / ".y" / ".z" + vname = vname.substr(0, vname.length() - 2); + } + if (vname == "threadIdx") { + vname = "thread_position_in_threadgroup"; + } else if (vname == "blockIdx") { + vname = "threadgroup_position_in_grid"; + } + if (thread_work_dim_ > 1) { + vname += axis; } var_idmap_[iv->var.get()] = CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); @@ -359,7 +383,7 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) CHECK(!op->op.as()) << "CodegenMetal does not support inter-function calls, " - << "but expression " << GetRef(op) << " calls PrimFunc " << op->op; + << "but expression " << ffi::GetRef(op) << " calls PrimFunc " << op->op; auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) { ICHECK(col->IsInstance() && row->IsInstance()) << "Only constant shape is supported for simdgroup matrix, but got " << col << "x" << row; @@ -442,7 +466,7 @@ ffi::Module BuildMetal(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()); std::string func_name = global_symbol.value(); @@ -467,9 +491,9 @@ ffi::Module BuildMetal(IRModule mod, Target target) { return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.metal", BuildMetal); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 1342464665f3..8ea55b8ff5d8 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -230,6 +230,12 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } + } else if (t.is_bool()) { + os << "uint"; + if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) { + os << lanes; + return; + } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { os << 'u'; @@ -475,10 +481,10 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { // Enable atomics extension if used. if (func->value == "atomic_add" && op->dtype.is_float()) { enable_atomics_ = true; - this->PrintCallExtern(GetType(GetRef(op)), "atomic_add_float_emu", op->args, true, - os); + this->PrintCallExtern(GetType(ffi::GetRef(op)), "atomic_add_float_emu", op->args, + true, os); } else if (func->value == "nearbyint") { - this->PrintCallExtern(GetType(GetRef(op)), "round", op->args, true, os); + this->PrintCallExtern(GetType(ffi::GetRef(op)), "round", op->args, true, os); } else { if (func->value == "atomic_add") { enable_atomics_ = true; @@ -635,7 +641,7 @@ void CodeGenOpenCL::SetTextureScope( ffi::Module BuildOpenCL(IRModule mod, Target target) { #if TVM_ENABLE_SPIRV - Optional device = target->GetAttr("device"); + ffi::Optional device = target->GetAttr("device"); if (device && device.value() == "spirv") { auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::OpenCLModuleCreate(smap, spirv_text, ExtractFuncInfo(mod)); @@ -644,7 +650,7 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) { bool output_ssa = false; - Map functions; + ffi::Map functions; for (auto [gvar, base_func] : mod->functions) { ICHECK(base_func->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto prim_func = Downcast(base_func); @@ -674,25 +680,25 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) { return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.opencl", BuildOpenCL); -}); +} -String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { +ffi::String DeviceScopeCompatibilityFromTarget(Target target, ffi::String memory_scope) { auto prototype_keys = target->GetKeys(); bool is_adreno = std::find(prototype_keys.begin(), prototype_keys.end(), "adreno") != prototype_keys.end(); if (is_adreno) { - return String("global"); + return ffi::String("global"); } return memory_scope; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("DeviceScopeCompatibility.opencl", DeviceScopeCompatibilityFromTarget); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_params.cc b/src/target/source/codegen_params.cc index cd2bcd769c04..d840ebec7df3 100644 --- a/src/target/source/codegen_params.cc +++ b/src/target/source/codegen_params.cc @@ -160,8 +160,8 @@ void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, } } -void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os, - const std::string& eol) { +void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& os, + const std::string& eol) { auto arr_type = arr.DataType(); CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw " << arr_type.lanes(); diff --git a/src/target/source/codegen_params.h b/src/target/source/codegen_params.h index 6df800ed1721..5c8c129006b3 100644 --- a/src/target/source/codegen_params.h +++ b/src/target/source/codegen_params.h @@ -24,7 +24,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_PARAMS_H_ #define TVM_TARGET_SOURCE_CODEGEN_PARAMS_H_ -#include +#include #include #include @@ -36,8 +36,8 @@ namespace codegen { * \brief Write a C representation of arr to os. * * This function generates a comma-separated, indented list of C integer listeals suitable for use - * in an initializer. The NDArray is flattened and then the list is produced element by element. - * For the int16_t NDArray [-3, -2, -1, 0, 1, 2, 3, ...], and indent_chars = 4, the following output + * in an initializer. The Tensor is flattened and then the list is produced element by element. + * For the int16_t Tensor [-3, -2, -1, 0, 1, 2, 3, ...], and indent_chars = 4, the following output * is produced: * -0x0003, -0x0002, -0x0001, +0x0000, +0x0001, +0x0002, +0x0003 * @@ -45,8 +45,8 @@ namespace codegen { * \param indent_chars Number of chars to indent * \param os Output stream where the array data should be written. */ -void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os, - const std::string& eol = "\n"); +void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& os, + const std::string& eol = "\n"); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 60fa786d5287..c986d0f72f72 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -47,7 +47,6 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { e.vid = name_supply_->FreshName("v_"); e.scope_id = static_cast(scope_mark_.size() - 1); ssa_assign_map_[src] = e; - this->PrintIndent(); PrintSSAAssign(e.vid, src, t); return e.vid; } @@ -109,6 +108,11 @@ void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT( os << "void"; return; } + // default c may be have bool type, can be handled in subclass + if (type.is_bool()) { + os << "int"; + return; + } if (type.is_float()) { if (type.bits() == 32) { os << "float"; diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index f077f8c3a83b..104bf2cbdc34 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -150,9 +150,9 @@ ffi::Module SourceModuleCreate(std::string code, std::string fmt); * \param const_vars. The constant variables that the c source module needs. * \return The created module. */ -ffi::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, - const Array& const_vars = {}); +ffi::Module CSourceModuleCreate(const ffi::String& code, const ffi::String& fmt, + const ffi::Array& func_names, + const ffi::Array& const_vars = {}); /*! * \brief Wrap the submodules in a metadata module. @@ -163,9 +163,9 @@ ffi::Module CSourceModuleCreate(const String& code, const String& fmt, * \param target The target that all the modules are compiled for * \return The wrapped module. */ -ffi::Module CreateMetadataModule(const std::unordered_map& params, - ffi::Module target_module, const Array& ext_modules, - Target target); +ffi::Module CreateMetadataModule(const std::unordered_map& params, + ffi::Module target_module, + const ffi::Array& ext_modules, Target target); /*! * \brief Create a source module for viewing and limited saving for device. diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 28d158c3c21e..cf8176001a8a 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -63,7 +63,7 @@ class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { private: void VisitExpr_(const VarNode* op) final { StmtExprVisitor::VisitExpr_(op); - Var buffer_var = GetRef(op); + Var buffer_var = ffi::GetRef(op); if (buffer_var.dtype().is_handle()) { info_.write_access_set.insert(buffer_var); } @@ -137,7 +137,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim"); // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; @@ -233,7 +233,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re << "var " << val_pod_args << " : " << type_pod_args << ";\n\n"; // setup thread tags and param access in launch param tags; - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto& thread_tag : opt.value()) { func_info.launch_param_tags.push_back(thread_tag); } @@ -667,13 +667,21 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { } void CodeGenWebGPU::VisitStmt_(const ForNode* op) { - std::string extent = PrintExpr(op->extent); + std::string begin_str = PrintExpr(op->min); + PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer().Simplify(op->min + op->extent); + std::string end_str = PrintExpr(end); + std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : ""; std::string vid = AllocVarID(op->loop_var.get()); - ICHECK(is_zero(op->min)); PrintIndent(); stream << "for (var " << vid << " : "; PrintType(op->loop_var.dtype(), stream); - stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n"; + stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << vid; + if (step_str.empty()) { + stream << "++"; + } else { + stream << " += " << step_str; + } + stream << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); @@ -716,7 +724,7 @@ class WebGPUSourceModuleNode final : public ffi::ModuleObj { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; } @@ -729,7 +737,7 @@ class WebGPUSourceModuleNode final : public ffi::ModuleObj { return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (format == "func_info") { std::ostringstream stream; dmlc::JSONWriter(&stream).Write(fmap_); @@ -770,7 +778,7 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); @@ -780,15 +788,15 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { smap[f_name] = code; } - auto n = make_object(smap, fmap); + auto n = ffi::make_object(smap, fmap); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.webgpu", [](IRModule mod, Target target) { return BuildWebGPU(mod, target); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index e762bde69f4d..b2533079bc10 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -52,6 +52,10 @@ struct CUDAMath { default: return ""; } + } else if (t.is_tfloat32()) { + if (name == "fabs") { + return "abs"; + } } else if (t.is_bfloat16()) { if (name == "fabs") { return "__habs"; @@ -136,7 +140,7 @@ struct CUDAWarpIntrinsic { static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { const CallNode* call = e.as(); - return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); + return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, call->annotations); } template @@ -144,8 +148,8 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); + ffi::Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args, call->annotations); } TVM_REGISTER_OP("tir.clz").set_attr( @@ -170,37 +174,37 @@ TVM_REGISTER_OP("tir.nearbyint") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.exp").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.exp2") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.exp10") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.erf").set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.log").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.log2") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.log10") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.tan").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.cos").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.cosh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.sin").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.sinh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index b7561e86715e..489c39237d00 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -48,8 +48,8 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - Array metal_args{{call->args[1], call->args[2]}}; - return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); + ffi::Array metal_args{{call->args[1], call->args[2]}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args, call->annotations); } TVM_REGISTER_OP("tir.clz").set_attr("metal.FLowerIntrinsic", diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index bd9e148b187d..81d69cf99f6d 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -109,8 +109,9 @@ static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { arith::Analyzer analyzer; ICHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; - Array opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - return Call(call->dtype, builtin::call_pure_extern(), opencl_args); + ffi::Array opencl_args{ + {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; + return Call(call->dtype, builtin::call_pure_extern(), opencl_args, call->annotations); } TVM_REGISTER_OP("tir.tvm_warp_shuffle") diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 1350357d866c..0112ad961de4 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include @@ -56,14 +56,14 @@ class SourceModuleNode : public ffi::ModuleObj { SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} const char* kind() const final { return "source"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; } - String InspectSource(const String& format) const final { return code_; } + ffi::String InspectSource(const ffi::String& format) const final { return code_; } - Array GetWriteFormats() const override { return {fmt_}; } + ffi::Array GetWriteFormats() const override { return {fmt_}; } protected: std::string code_; @@ -71,7 +71,7 @@ class SourceModuleNode : public ffi::ModuleObj { }; ffi::Module SourceModuleCreate(std::string code, std::string fmt) { - auto n = make_object(code, fmt); + auto n = ffi::make_object(code, fmt); return ffi::Module(n); } @@ -79,14 +79,15 @@ ffi::Module SourceModuleCreate(std::string code, std::string fmt) { class CSourceModuleNode : public ffi::ModuleObj { public: CSourceModuleNode(const std::string& code, const std::string& fmt, - const Array& func_names, const Array& const_vars) + const ffi::Array& func_names, + const ffi::Array& const_vars) : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) { if (fmt_.empty()) fmt_ = "c"; } const char* kind() const final { return "c"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Currently c-source module is used as demonstration purposes with binary metadata module // that expects get_symbol interface. When c-source module is used as external module, it @@ -106,9 +107,9 @@ class CSourceModuleNode : public ffi::ModuleObj { } } - String InspectSource(const String& format) const final { return code_; } + ffi::String InspectSource(const ffi::String& format) const final { return code_; } - Array GetWriteFormats() const override { return {fmt_}; } + ffi::Array GetWriteFormats() const override { return {fmt_}; } ffi::Bytes SaveToBytes() const final { std::string buffer; @@ -138,17 +139,17 @@ class CSourceModuleNode : public ffi::ModuleObj { CHECK(stream->Read(&tmp_func_names)) << "Loading func names failed"; CHECK(stream->Read(&tmp_const_vars)) << "Loading const vars failed"; - Array func_names; - for (auto func_name : tmp_func_names) func_names.push_back(String(func_name)); + ffi::Array func_names; + for (auto func_name : tmp_func_names) func_names.push_back(ffi::String(func_name)); - Array const_vars; - for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var)); + ffi::Array const_vars; + for (auto const_var : tmp_const_vars) const_vars.push_back(ffi::String(const_var)); - auto n = make_object(code, fmt, func_names, const_vars); + auto n = ffi::make_object(code, fmt, func_names, const_vars); return ffi::Module(n); } - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") { @@ -163,28 +164,29 @@ class CSourceModuleNode : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name) final { + bool ImplementsFunction(const ffi::String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } protected: std::string code_; std::string fmt_; - Array const_vars_; - Array func_names_; + ffi::Array const_vars_; + ffi::Array func_names_; }; -ffi::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, const Array& const_vars) { - auto n = make_object(code.operator std::string(), fmt.operator std::string(), - func_names, const_vars); +ffi::Module CSourceModuleCreate(const ffi::String& code, const ffi::String& fmt, + const ffi::Array& func_names, + const ffi::Array& const_vars) { + auto n = ffi::make_object(code.operator std::string(), + fmt.operator std::string(), func_names, const_vars); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ffi.Module.load_from_bytes.c", CSourceModuleNode::LoadFromBytes); -}); +} /*! * \brief A concrete class to get access to base methods of CodegenSourceBase. @@ -210,12 +212,12 @@ class DeviceSourceModuleNode final : public ffi::ModuleObj { std::function fget_source) : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (fget_source_ != nullptr) { return fget_source_(format); } else { @@ -227,7 +229,7 @@ class DeviceSourceModuleNode final : public ffi::ModuleObj { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -257,20 +259,20 @@ ffi::Module DeviceSourceModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source) { - auto n = make_object(data, fmt, fmap, type_key, fget_source); + auto n = ffi::make_object(data, fmt, fmap, type_key, fget_source); return ffi::Module(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.SourceModuleCreate", SourceModuleCreate) - .def("runtime.CSourceModuleCreate", [](String code, String fmt, - Optional> func_names, - Optional> const_vars) { + .def("runtime.CSourceModuleCreate", [](ffi::String code, ffi::String fmt, + ffi::Optional> func_names, + ffi::Optional> const_vars) { return CSourceModuleCreate(code, fmt, func_names.value_or({}), const_vars.value_or({})); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index bd44607a98eb..f71b7ef8d6fa 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -37,11 +37,11 @@ ffi::Module BuildSPIRV(IRModule mod, Target target) { return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.vulkan", [](IRModule mod, Target target) { return BuildSPIRV(mod, target); }); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index ddbc22d88a04..136f969896f5 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -430,7 +430,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value dst_ptr = builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index)); spirv::Value src_ptr = VisitExpr(op->args[5]); - spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); + spirv::SType type_bool = builder_->GetSType(DataType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); spirv::Value loaded = @@ -492,7 +492,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); uint32_t mask = spv::MemoryAccessMaskNone; spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask); - spirv::SType type_bool = builder_->GetSType(DataType::UInt(1)); + spirv::SType type_bool = builder_->GetSType(DataType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val, @@ -672,10 +672,21 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { } void CodeGenSPIRV::VisitStmt_(const ForNode* op) { - ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); spirv::Value init_value = MakeValue(op->min); - spirv::Value extent_value = MakeValue(op->extent); + PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); + spirv::Value end_value = MakeValue(end); + spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); + + // loop step + spirv::Value step; + if (op->HasTrivialStep()) { + step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) + : builder_->UIntImm(loop_var.stype, 1); + } else { + step = MakeValue(tvm::cast(end->dtype, *op->step)); + } + // Must get init label after making value(to make sure they are correct) spirv::Label init_label = builder_->CurrentLabel(); spirv::Label head_label = builder_->NewLabel(); @@ -690,9 +701,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Loop head builder_->StartLabel(head_label); - spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); - spirv::Value loop_cond = builder_->LT(loop_var, extent_value); + spirv::Value loop_cond = builder_->LT(loop_var, end_value); uint32_t control = (op->kind == ForKind::kUnrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone); builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); @@ -707,9 +717,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // loop continue builder_->StartLabel(continue_label); - spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) - : builder_->UIntImm(loop_var.stype, 1); - spirv::Value next_value = builder_->Add(loop_var, one); + + spirv::Value next_value = builder_->Add(loop_var, step); loop_var.SetIncoming(1, next_value, builder_->CurrentLabel()); builder_->MakeInst(spv::OpBranch, head_label); // loop merge diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 3010b74dd976..a457d95209a9 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -34,17 +34,17 @@ namespace codegen { namespace spirv { // num_signature means number of arguments used to query signature template -PrimExpr CallGLSLIntrin(PrimExpr e, const Array& args) { +PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); - Array cargs; + ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); for (PrimExpr arg : args) { cargs.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs); + return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs, call->annotations); } template diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 545e677af9f2..bac66a3aacf7 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -76,7 +76,7 @@ void IRBuilder::InitPreDefs() { ext_glsl450_ = ExtInstImport("GLSL.std.450"); t_int32_ = DeclareType(DataType::Int(32)); t_uint32_ = DeclareType(DataType::UInt(32)); - t_bool_ = DeclareType(DataType::UInt(1)); + t_bool_ = DeclareType(DataType::Bool()); t_fp32_ = DeclareType(DataType::Float(32)); const_i32_zero_ = IntImm(t_int32_, 0); @@ -115,7 +115,7 @@ std::vector IRBuilder::Finalize() { SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { if (dtype == DataType::Int(32)) { return t_int32_; - } else if (dtype == DataType::UInt(1)) { + } else if (dtype == DataType::Bool()) { return t_bool_; } else if (dtype == DataType::Float(32)) { return t_fp32_; @@ -467,7 +467,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { } ICHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); - if (dtype.type == DataType::UInt(1)) { + if (dtype.type == DataType::Bool()) { // bool types. if (*pvalue) { ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); @@ -501,8 +501,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) SType t; t.id = id_counter_++; t.type = dtype; - if (dtype.bits() == 1) { - ICHECK(dtype.is_uint()); + if (dtype.is_bool()) { ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_); } else if (dtype.is_int()) { ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_); @@ -584,7 +583,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // future. Requiring StorageBuffer8BitAccess in order to declare an // Int8 prevents use of an 8-bit loop iterator on a device that // supports Int8 but doesn't support 8-bit buffer access. - if (dtype.bits() == 8) { + if (dtype.bits() == 8 && !dtype.is_bool()) { ICHECK(spirv_support_.supports_storage_buffer_8bit_access) << "Vulkan target does not support StorageBuffer8BitAccess. " << "If your device supports 8-bit buffer access, " @@ -822,19 +821,19 @@ Value IRBuilder::Mod(Value a, Value b) { } } -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, bool_type, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, bool_type, a, b); \ - } else { \ - ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + ICHECK_EQ(a.stype.id, b.stype.id); \ + ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ + if (a.stype.type.is_int()) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.is_uint()) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + ICHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -842,17 +841,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, bool_type, a, b); \ - } else { \ - ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + ICHECK_EQ(a.stype.id, b.stype.id); \ + ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + ICHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); @@ -860,7 +859,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { ICHECK_EQ(a.stype.id, b.stype.id); - ICHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1)); + ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); return MakeValue(spv::OpSelect, a.stype, cond, a, b); } diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index a17a694da4dd..91b45b85bbd0 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -94,8 +94,9 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { supports_integer_dot_product = target->GetAttr("supports_integer_dot_product").value(); } // Check whether integer dot product is enabled in mattr. - if (const Optional>& v = target->GetAttr>("mattr")) { - for (const String& s : v.value()) { + if (const ffi::Optional>& v = + target->GetAttr>("mattr")) { + for (const ffi::String& s : v.value()) { if (s.compare("+dotprod") == 0) { supports_integer_dot_product = true; break; diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index f0226466f625..a4cec2c0fd65 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -129,7 +129,7 @@ std::pair, std::string> Lo auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/tag.cc b/src/target/tag.cc index f305c84e09a4..dfe179f7ac16 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -32,24 +32,24 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ TargetTagNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TargetTagNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.TargetTagListTags", TargetTag::ListTags) .def("target.TargetTagAddTag", TargetTag::AddTag); -}); +} /********** Registry-related code **********/ using TargetTagRegistry = AttrRegistry; -TargetTagRegEntry& TargetTagRegEntry::RegisterOrGet(const String& target_tag_name) { +TargetTagRegEntry& TargetTagRegEntry::RegisterOrGet(const ffi::String& target_tag_name) { return TargetTagRegistry::Global()->RegisterOrGet(target_tag_name); } -Optional TargetTag::Get(const String& target_tag_name) { +ffi::Optional TargetTag::Get(const ffi::String& target_tag_name) { const TargetTagRegEntry* reg = TargetTagRegistry::Global()->Get(target_tag_name); if (reg == nullptr) { return std::nullopt; @@ -57,15 +57,15 @@ Optional TargetTag::Get(const String& target_tag_name) { return Target(reg->tag_->config); } -Map TargetTag::ListTags() { - Map result; - for (const String& tag : TargetTagRegistry::Global()->ListAllNames()) { +ffi::Map TargetTag::ListTags() { + ffi::Map result; + for (const ffi::String& tag : TargetTagRegistry::Global()->ListAllNames()) { result.Set(tag, TargetTag::Get(tag).value()); } return result; } -Target TargetTag::AddTag(String name, Map config, bool override) { +Target TargetTag::AddTag(ffi::String name, ffi::Map config, bool override) { TargetTagRegEntry& tag = TargetTagRegEntry::RegisterOrGet(name).set_name(); ICHECK(override || tag.tag_->config.empty()) << "Tag \"" << name << "\" has been previously defined as: " << tag.tag_->config; @@ -77,73 +77,78 @@ Target TargetTag::AddTag(String name, Map config, bool overrid #if TVM_LLVM_HAS_AARCH64_TARGET TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") - .set_config({{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a72")}, - {"mattr", Array{"+neon"}}, + .set_config({{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a72")}, + {"mattr", ffi::Array{"+neon"}}, {"num-cores", 4}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a72")}, - {"mattr", Array{"+neon"}}, - {"num-cores", 4}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a72")}, + {"mattr", ffi::Array{"+neon"}}, + {"num-cores", 4}}}}); #if TVM_LLVM_VERSION >= 110 TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_72")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_72")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("carmel")}, - {"num-cores", 8}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("carmel")}, + {"num-cores", 8}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_87")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_87")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("carmel")}, - {"num-cores", 6}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("carmel")}, + {"num-cores", 6}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_87")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_87")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a78")}, - {"num-cores", 8}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a78")}, + {"num-cores", 8}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_87")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_87")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a78")}, - {"num-cores", 12}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a78")}, + {"num-cores", 12}}}}); #endif // TVM_LLVM_VERSION >= 110 #endif // TVM_LLVM_HAS_AARCH64_TARGET #define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ TVM_REGISTER_TARGET_TAG(Name).set_config({ \ - {"kind", String("cuda")}, \ - {"keys", Array{"cuda", "gpu"}}, \ - {"arch", String(Arch)}, \ + {"kind", ffi::String("cuda")}, \ + {"keys", ffi::Array{"cuda", "gpu"}}, \ + {"arch", ffi::String(Arch)}, \ {"max_shared_memory_per_block", SharedMem}, \ {"max_threads_per_block", 1024}, \ {"thread_warp_size", 32}, \ @@ -421,10 +426,10 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); #undef TVM_REGISTER_CUDA_TAG -#define TVM_REGISTER_TAG_AWS_C5(Name, Cores, Arch) \ - TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ - {"keys", Array{"x86", "cpu"}}, \ - {"mcpu", String(Arch)}, \ +#define TVM_REGISTER_TAG_AWS_C5(Name, Cores, Arch) \ + TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", ffi::String("llvm")}, \ + {"keys", ffi::Array{"x86", "cpu"}}, \ + {"mcpu", ffi::String(Arch)}, \ {"num-cores", Cores}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); @@ -439,25 +444,25 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #undef TVM_REGISTER_TAG_AWS_C5 #if TVM_LLVM_VERSION >= 190 -#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ - TVM_REGISTER_TARGET_TAG(Name).set_config( \ - {{"kind", String("metal")}, \ - {"max_threads_per_block", ThreadsPerBlock}, \ - {"max_shared_memory_per_block", SharedMem}, \ - {"thread_warp_size", WarpSize}, \ - {"host", Map{{"kind", String("llvm")}, \ - {"mtriple", String("arm64-apple-macos")}, \ - {"mcpu", String("apple-m4")}}}}); +#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ + TVM_REGISTER_TARGET_TAG(Name).set_config( \ + {{"kind", ffi::String("metal")}, \ + {"max_threads_per_block", ThreadsPerBlock}, \ + {"max_shared_memory_per_block", SharedMem}, \ + {"thread_warp_size", WarpSize}, \ + {"host", ffi::Map{{"kind", ffi::String("llvm")}, \ + {"mtriple", ffi::String("arm64-apple-macos")}, \ + {"mcpu", ffi::String("apple-m4")}}}}); #else -#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ - TVM_REGISTER_TARGET_TAG(Name).set_config( \ - {{"kind", String("metal")}, \ - {"max_threads_per_block", ThreadsPerBlock}, \ - {"max_shared_memory_per_block", SharedMem}, \ - {"thread_warp_size", WarpSize}, \ - {"host", Map{{"kind", String("llvm")}, \ - {"mtriple", String("arm64-apple-macos")}, \ - {"mcpu", String("apple-latest")}}}}); +#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ + TVM_REGISTER_TARGET_TAG(Name).set_config( \ + {{"kind", ffi::String("metal")}, \ + {"max_threads_per_block", ThreadsPerBlock}, \ + {"max_shared_memory_per_block", SharedMem}, \ + {"thread_warp_size", WarpSize}, \ + {"host", ffi::Map{{"kind", ffi::String("llvm")}, \ + {"mtriple", ffi::String("arm64-apple-macos")}, \ + {"mcpu", ffi::String("apple-latest")}}}}); #endif #if TVM_LLVM_HAS_AARCH64_TARGET diff --git a/src/target/target.cc b/src/target/target.cc index 1c56fa5bd210..23ee76fc898d 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -43,31 +43,33 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ TargetNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TargetNode::RegisterReflection(); } class TargetInternal { public: static void EnterScope(Target target) { target.EnterWithScope(); } static void ExitScope(Target target) { target.ExitWithScope(); } - static Map Export(Target target) { return target->Export(); } + static ffi::Map Export(Target target) { return target->Export(); } static const TargetKindNode::ValueTypeInfo& FindTypeInfo(const TargetKind& kind, const std::string& key); - static Optional StringifyAttrsToRaw(const Map& attrs); + static ffi::Optional StringifyAttrsToRaw( + const ffi::Map& attrs); static Any ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info); static Any ParseType(const Any& obj, const TargetKindNode::ValueTypeInfo& info); - static ObjectPtr FromString(const String& tag_or_config_or_target_str); - static ObjectPtr FromConfigString(const String& config_str); - static ObjectPtr FromRawString(const String& target_str); - static ObjectPtr FromConfig(Map config); + static ObjectPtr FromString(const ffi::String& tag_or_config_or_target_str); + static ObjectPtr FromConfigString(const ffi::String& config_str); + static ObjectPtr FromRawString(const ffi::String& target_str); + static ObjectPtr FromConfig(ffi::Map config); static void ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv); static Target WithHost(const Target& target, const Target& target_host) { - ObjectPtr n = make_object(*target.get()); + ObjectPtr n = ffi::make_object(*target.get()); n->host = target_host; return (Target)n; } private: - static std::unordered_map QueryDevice(int device_id, const TargetNode* target); + static std::unordered_map QueryDevice(int device_id, + const TargetNode* target); static bool IsQuoted(const std::string& str); static std::string Quote(const std::string& str); static std::string JoinString(const std::vector& array, char separator); @@ -91,8 +93,8 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -static std::vector DeduplicateKeys(const std::vector& keys) { - std::vector new_keys; +static std::vector DeduplicateKeys(const std::vector& keys) { + std::vector new_keys; for (size_t i = 0; i < keys.size(); ++i) { bool found = false; for (size_t j = 0; j < i; ++j) { @@ -118,8 +120,8 @@ static T ObjTypeCheck(const Any& obj, const std::string& expected_type) { return opt.value(); } -static TargetKind GetTargetKind(const String& name) { - Optional kind = TargetKind::Get(name); +static TargetKind GetTargetKind(const ffi::String& name) { + ffi::Optional kind = TargetKind::Get(name); if (!kind.defined()) { TVM_FFI_THROW(TypeError) << "Target kind \"" + name + "\" is not defined"; } @@ -228,7 +230,7 @@ std::vector TargetInternal::SplitString(const std::string& str, cha } std::string TargetInternal::Interpret(const std::string& str) { - // String interpretation deals with quotes (') and escapes(\). + // ffi::String interpretation deals with quotes (') and escapes(\). // - An escape character must be followed by another character forming an // "escape sequence". (Trailing escape is not allowed.) An escape prevents // interpretation of the character that follows. This happens regardless of @@ -386,9 +388,9 @@ Any TargetInternal::ParseType(const std::string& str, const TargetKindNode::Valu auto end = interp_str.find_last_not_of(' '); if (start == std::string::npos || end == std::string::npos) { // The whole string is made of spaces. - return String(); + return ffi::String(); } - return String(interp_str.substr(start, (end - start + 1))); + return ffi::String(interp_str.substr(start, (end - start + 1))); } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target @@ -402,10 +404,10 @@ Any TargetInternal::ParseType(const std::string& str, const TargetKindNode::Valu result.push_back(parsed); } catch (const Error& e) { std::string index = "[" + std::to_string(result.size()) + "]"; - throw Error(e.kind(), e.message() + index, e.traceback()); + throw Error(e.kind(), e.message() + index, e.backtrace()); } } - return Array(result); + return ffi::Array(result); } TVM_FFI_THROW(TypeError) << "Unsupported type \"" + info.type_key << "\" for parsing from string: " + interp_str; @@ -420,12 +422,12 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf return ObjTypeCheck(obj, "bool"); } else if (info.type_index == ffi::TypeIndex::kTVMFFIStr) { // Parsing string - return ObjTypeCheck(obj, "String"); + return ObjTypeCheck(obj, "String"); } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); - } else if (auto str = obj.try_cast()) { + } else if (auto str = obj.try_cast()) { return Target(TargetInternal::FromString(str.value())); } else if (const auto* ptr = obj.as()) { for (const auto& kv : *ptr) { @@ -434,7 +436,7 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf << "Target object requires key of dict to be str, but get: " << kv.first.GetTypeKey(); } } - Map config = GetRef>(ptr); + ffi::Map config = ffi::GetRef>(ptr); return Target(TargetInternal::FromConfig({config.begin(), config.end()})); } TVM_FFI_THROW(TypeError) << "Expect type 'dict' or 'str' to construct Target, but get: " + @@ -448,10 +450,10 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf result.push_back(TargetInternal::ParseType(e, *info.key)); } catch (const Error& e) { std::string index = '[' + std::to_string(result.size()) + ']'; - throw Error(e.kind(), index + e.message(), e.traceback()); + throw Error(e.kind(), index + e.message(), e.backtrace()); } } - return Array(result); + return ffi::Array(result); } else if (info.type_index == ffi::MapObj::RuntimeTypeIndex()) { // Parsing map const auto* map = ObjTypeCheck(obj, "Map"); @@ -461,18 +463,18 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf try { key = TargetInternal::ParseType(kv.first, *info.key); } catch (const Error& e) { - throw Error(e.kind(), e.message() + ", during parse key of map", e.traceback()); + throw Error(e.kind(), e.message() + ", during parse key of map", e.backtrace()); } try { val = TargetInternal::ParseType(kv.second, *info.val); } catch (const Error& e) { std::ostringstream os; os << ", during parseing value of map[\"" << key << "\"]"; - throw Error(e.kind(), e.message() + os.str(), e.traceback()); + throw Error(e.kind(), e.message() + os.str(), e.backtrace()); } result[key] = val; } - return Map(result); + return ffi::Map(result); } if (info.type_index != obj.type_index()) { TVM_FFI_THROW(TypeError) << "Parsing type \"" << info.type_key @@ -489,7 +491,7 @@ std::string TargetInternal::StringifyAtomicType(const Any& obj) { return std::to_string(obj.cast()); } else if (obj.type_index() == ffi::TypeIndex::kTVMFFIInt) { return std::to_string(obj.cast()); - } else if (auto opt_str = obj.as()) { + } else if (auto opt_str = obj.as()) { std::string s = opt_str.value(); auto u = Uninterpret(s); if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) { @@ -516,9 +518,10 @@ std::string TargetInternal::StringifyArray(const ffi::ArrayObj& array) { return JoinString(elements, ','); } -Optional TargetInternal::StringifyAttrsToRaw(const Map& attrs) { +ffi::Optional TargetInternal::StringifyAttrsToRaw( + const ffi::Map& attrs) { std::ostringstream os; - std::vector keys; + std::vector keys; for (const auto& kv : attrs) { keys.push_back(kv.first); } @@ -531,7 +534,7 @@ Optional TargetInternal::StringifyAttrsToRaw(const Map // skip undefined attrs if (obj == nullptr) continue; if (const auto* array = obj.as()) { - value = String(StringifyArray(*array)); + value = ffi::String(StringifyArray(*array)); } else { value = StringifyAtomicType(obj); } @@ -539,7 +542,7 @@ Optional TargetInternal::StringifyAttrsToRaw(const Map result.push_back("-" + key + "=" + value); } } - return String(JoinString(result, ' ')); + return ffi::String(JoinString(result, ' ')); } const std::string& TargetNode::str() const { @@ -549,7 +552,7 @@ const std::string& TargetNode::str() const { if (!this->keys.empty()) { os << " -keys="; bool is_first = true; - for (const String& s : keys) { + for (const ffi::String& s : keys) { if (is_first) { is_first = false; } else { @@ -558,7 +561,7 @@ const std::string& TargetNode::str() const { os << s; } } - if (Optional attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { + if (ffi::Optional attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { os << ' ' << attrs_str.value(); } @@ -569,38 +572,38 @@ const std::string& TargetNode::str() const { /********** Small member methods **********/ -Target::Target(const String& tag_or_config_or_target_str) { +Target::Target(const ffi::String& tag_or_config_or_target_str) { ObjectPtr target; try { target = TargetInternal::FromString(tag_or_config_or_target_str); } catch (const Error& e) { std::ostringstream os; os << ". Target creation from string failed: " << tag_or_config_or_target_str; - throw Error("ValueError", e.message() + os.str(), e.traceback()); + throw Error("ValueError", e.message() + os.str(), e.backtrace()); } data_ = std::move(target); } -Target::Target(const Map& config) { +Target::Target(const ffi::Map& config) { ObjectPtr target; try { target = TargetInternal::FromConfig({config.begin(), config.end()}); } catch (const Error& e) { std::ostringstream os; os << ". Target creation from config dict failed: " << config; - throw Error("ValueError", std::string(e.message()) + os.str(), e.traceback()); + throw Error("ValueError", std::string(e.message()) + os.str(), e.backtrace()); } data_ = std::move(target); } Target::Target(Target target, Target host) { - ObjectPtr n = make_object(*target.get()); + ObjectPtr n = ffi::make_object(*target.get()); n->host = std::move(host); data_ = std::move(n); } -Target::Target(TargetKind kind, Optional host, String tag, Array keys, - Map attrs) { +Target::Target(TargetKind kind, ffi::Optional host, ffi::String tag, + ffi::Array keys, ffi::Map attrs) { auto data = ffi::make_object(); data->kind = std::move(kind); data->host = std::move(host); @@ -619,7 +622,7 @@ std::vector TargetNode::GetKeys() const { } std::unordered_set TargetNode::GetLibs() const { - Optional> libs = this->GetAttr>("libs"); + ffi::Optional> libs = this->GetAttr>("libs"); if (!libs.defined()) { return {}; } @@ -630,8 +633,8 @@ std::unordered_set TargetNode::GetLibs() const { return result; } -Map TargetNode::Export() const { - Map result = { +ffi::Map TargetNode::Export() const { + ffi::Map result = { {"kind", this->kind->name}, {"tag", this->tag}, {"keys", this->keys}, @@ -645,11 +648,11 @@ Map TargetNode::Export() const { return result; } -Optional TargetNode::GetHost() const { return this->host.as(); } +ffi::Optional TargetNode::GetHost() const { return this->host.as(); } Target Target::WithoutHost() const { if ((*this)->GetHost()) { - auto output = make_object(*get()); + auto output = ffi::make_object(*get()); output->host = std::nullopt; return Target(output); } else { @@ -658,7 +661,7 @@ Target Target::WithoutHost() const { } int TargetNode::GetTargetDeviceType() const { - if (Optional device_type = GetAttr("target_device_type")) { + if (ffi::Optional device_type = GetAttr("target_device_type")) { return Downcast(device_type)->value; } return kind->default_device_type; @@ -669,7 +672,7 @@ bool TargetNode::HasKey(const std::string& query_key) const { [&query_key](const auto& key) { return key == query_key; }); } -String TargetNode::ToDebugString() const { +ffi::String TargetNode::ToDebugString() const { std::ostringstream os; os << "Target("; os << "id=" << std::hex << reinterpret_cast(this); @@ -747,9 +750,9 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { const auto& arg = args[0]; if (auto opt_target = arg.as()) { *rv = Target(opt_target.value()); - } else if (auto opt_str = arg.try_cast()) { + } else if (auto opt_str = arg.try_cast()) { *rv = Target(opt_str.value()); - } else if (auto opt_map = arg.try_cast>()) { + } else if (auto opt_map = arg.try_cast>()) { *rv = Target(opt_map.value()); } else { LOG(FATAL) << "TypeError: Cannot create target with type: " << args[0].GetTypeKey(); @@ -768,10 +771,10 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: " << args.size(); } -ObjectPtr TargetInternal::FromString(const String& tag_or_config_or_target_str) { - if (Optional target = TargetTag::Get(tag_or_config_or_target_str)) { +ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) { + if (ffi::Optional target = TargetTag::Get(tag_or_config_or_target_str)) { Target value = target.value(); - return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(value); + return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(value); } if (!tag_or_config_or_target_str.empty() && tag_or_config_or_target_str.data()[0] == '{') { return TargetInternal::FromConfigString(tag_or_config_or_target_str); @@ -779,25 +782,25 @@ ObjectPtr TargetInternal::FromString(const String& tag_or_config_or_targ return TargetInternal::FromRawString(tag_or_config_or_target_str); } -ObjectPtr TargetInternal::FromConfigString(const String& config_str) { +ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str) { const auto loader = tvm::ffi::Function::GetGlobal("target._load_config_dict"); ICHECK(loader.has_value()) << "AttributeError: \"target._load_config_dict\" is not registered. Please check " "if the python module is properly loaded"; - auto config = (*loader)(config_str).cast>>(); + auto config = (*loader)(config_str).cast>>(); if (!config.defined()) { TVM_FFI_THROW(ValueError) << "Cannot load config dict with python JSON loader"; } return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); } -ObjectPtr TargetInternal::FromRawString(const String& target_str) { +ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string"; // Split the string by empty spaces std::vector options = SplitString(std::string(target_str), ' '); std::string name = options[0]; // Create the target config - std::unordered_map config = {{"kind", String(name)}}; + std::unordered_map config = {{"kind", ffi::String(name)}}; TargetKind kind = GetTargetKind(name); for (size_t iter = 1, end = options.size(); iter < end;) { std::string key, value; @@ -807,7 +810,7 @@ ObjectPtr TargetInternal::FromRawString(const String& target_str) { iter += ParseKVPair(RemovePrefixDashes(options[iter]), s_next, &key, &value); } catch (const Error& e) { throw Error(e.kind(), e.message() + ", during parsing target `" + target_str + "`", - e.traceback()); + e.backtrace()); } try { // check if `key` has been used @@ -817,26 +820,26 @@ ObjectPtr TargetInternal::FromRawString(const String& target_str) { config[key] = TargetInternal::ParseType(value, TargetInternal::FindTypeInfo(kind, key)); } catch (const Error& e) { throw Error(e.kind(), std::string(e.message()) + ", during parsing target[\"" + key + "\"]", - e.traceback()); + e.backtrace()); } } return TargetInternal::FromConfig(config); } -ObjectPtr TargetInternal::FromConfig(Map config) { - const String kKind = "kind"; - const String kTag = "tag"; - const String kKeys = "keys"; - const String kDeviceName = "device"; - const String kHost = "host"; - const String kFeatures = "features"; - ObjectPtr target = make_object(); +ObjectPtr TargetInternal::FromConfig(ffi::Map config) { + const ffi::String kKind = "kind"; + const ffi::String kTag = "tag"; + const ffi::String kKeys = "keys"; + const ffi::String kDeviceName = "device"; + const ffi::String kHost = "host"; + const ffi::String kFeatures = "features"; + ObjectPtr target = ffi::make_object(); ICHECK(!config.count(kFeatures)) << "Target Features should be generated by Target parser"; // parse 'kind' if (config.count(kKind)) { - if (auto kind = config[kKind].try_cast()) { + if (auto kind = config[kKind].try_cast()) { target->kind = GetTargetKind(kind.value()); ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr)) << "Cannot use both set_attrs_preprocessor and set_target_parser"; @@ -846,7 +849,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { VLOG(9) << "TargetInternal::FromConfig - Running target_parser"; config = target->kind->target_parser(config); if (config.count(kFeatures)) { - target->features = Downcast>(config[kFeatures]); + target->features = Downcast>(config[kFeatures]); config.erase(kFeatures); } } @@ -861,7 +864,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // parse "tag" if (config.count(kTag)) { - if (auto tag = config[kTag].try_cast()) { + if (auto tag = config[kTag].try_cast()) { target->tag = tag.value(); config.erase(kTag); } else { @@ -873,13 +876,13 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // parse "keys" { - std::vector keys; + std::vector keys; bool has_user_keys = config.count(kKeys); if (has_user_keys) { // user provided keys if (const auto* cfg_keys = config[kKeys].as()) { for (const Any& e : *cfg_keys) { - if (auto key = e.try_cast()) { + if (auto key = e.try_cast()) { keys.push_back(key.value()); } else { TVM_FFI_THROW(TypeError) << "Expect 'keys' to be an array of strings, but it " @@ -893,7 +896,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // add device name if (config.count(kDeviceName)) { - if (auto device = config.at(kDeviceName).try_cast()) { + if (auto device = config.at(kDeviceName).try_cast()) { keys.push_back(device.value()); } } @@ -915,16 +918,16 @@ ObjectPtr TargetInternal::FromConfig(Map config) { target->host = std::nullopt; } // parse attrs - std::unordered_map attrs; + std::unordered_map attrs; for (const auto& cfg_kv : config) { - const String& key = cfg_kv.first; + const ffi::String& key = cfg_kv.first; const ffi::Any& value = cfg_kv.second; try { const TargetKindNode::ValueTypeInfo& info = TargetInternal::FindTypeInfo(target->kind, key); attrs[key] = TargetInternal::ParseType(value, info); } catch (const Error& e) { throw Error(e.kind(), std::string(e.message()) + ", during parsing target[\"" + key + "\"]", - e.traceback()); + e.backtrace()); } } @@ -950,8 +953,8 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // do extra pre-processing if (target->kind->preprocessor != nullptr) { - target->attrs = - target->kind->preprocessor(Map(attrs)).cast>(); + target->attrs = target->kind->preprocessor(ffi::Map(attrs)) + .cast>(); } else { target->attrs = attrs; } @@ -959,9 +962,9 @@ ObjectPtr TargetInternal::FromConfig(Map config) { return target; } // namespace tvm -std::unordered_map TargetInternal::QueryDevice(int device_id, - const TargetNode* target) { - std::unordered_map output; +std::unordered_map TargetInternal::QueryDevice(int device_id, + const TargetNode* target) { + std::unordered_map output; Device device{static_cast(target->GetTargetDeviceType()), device_id}; @@ -984,7 +987,7 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, } for (const auto& kv : target->kind->key2vtype_) { - const String& key = kv.first; + const ffi::String& key = kv.first; ffi::Any ret; api->GetTargetProperty(device, key, &ret); @@ -996,7 +999,7 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, /********** Registry **********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("target.Target", TargetInternal::ConstructorDispatcher) @@ -1007,14 +1010,15 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("target.WithHost", TargetInternal::WithHost) .def("target.TargetGetDeviceType", [](const Target& target) { return target->GetTargetDeviceType(); }) - .def("target.TargetGetFeature", [](const Target& target, const String& feature_key) -> Any { - if (auto opt_any = target->GetFeature(feature_key)) { - return opt_any.value(); - } else { - return Any(); - } - }); -}); + .def("target.TargetGetFeature", + [](const Target& target, const ffi::String& feature_key) -> Any { + if (auto opt_any = target->GetFeature(feature_key)) { + return opt_any.value(); + } else { + return Any(); + } + }); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { diff --git a/src/target/target_info.cc b/src/target/target_info.cc index 578276162678..1966024dd4b7 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -26,7 +26,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ MemoryInfoNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { MemoryInfoNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 4d61c035fbe5..99a5684af521 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -36,7 +36,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; TargetKindNode::RegisterReflection(); refl::TypeAttrDef() @@ -45,12 +45,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ // simply save as the string return node->name; }) - .def("__data_from_json__", [](const String& name) { + .def("__data_from_json__", [](const ffi::String& name) { auto kind = TargetKind::Get(name); ICHECK(kind.has_value()) << "Cannot find target kind \'" << name << '\''; return kind.value(); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { @@ -62,32 +62,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) using TargetKindRegistry = AttrRegistry; -Array TargetKindRegEntry::ListTargetKinds() { +ffi::Array TargetKindRegEntry::ListTargetKinds() { return TargetKindRegistry::Global()->ListAllNames(); } -Map TargetKindRegEntry::ListTargetKindOptions(const TargetKind& target_kind) { - Map options; +ffi::Map TargetKindRegEntry::ListTargetKindOptions( + const TargetKind& target_kind) { + ffi::Map options; for (const auto& kv : target_kind->key2vtype_) { options.Set(kv.first, kv.second.type_key); } return options; } -TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) { +TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const ffi::String& target_kind_name) { return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name); } -void TargetKindRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { +void TargetKindRegEntry::UpdateAttr(const ffi::String& key, ffi::Any value, int plevel) { TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel); } const AttrRegistryMapContainerMap& TargetKind::GetAttrMapContainer( - const String& attr_name) { + const ffi::String& attr_name) { return TargetKindRegistry::Global()->GetAttrMap(attr_name); } -Optional TargetKind::Get(const String& target_kind_name) { +ffi::Optional TargetKind::Get(const ffi::String& target_kind_name) { const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name); if (reg == nullptr) { return std::nullopt; @@ -140,12 +141,13 @@ static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, ffi::A return true; } -void CheckOrSetAttr(Map* attrs, const String& name, const String& value) { +void CheckOrSetAttr(ffi::Map* attrs, const ffi::String& name, + const ffi::String& value) { auto iter = attrs->find(name); if (iter == attrs->end()) { attrs->Set(name, value); } else { - auto str = (*iter).second.try_cast(); + auto str = (*iter).second.try_cast(); ICHECK(str && str.value() == value) << "ValueError: Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second; } @@ -162,7 +164,7 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) { // Update -arch=sm_xx if (target.count("arch")) { // If -arch has been specified, validate the correctness - String archStr = Downcast(target.at("arch")); + ffi::String archStr = Downcast(target.at("arch")); ICHECK(support::StartsWith(archStr, "sm_")) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr; } else { @@ -175,7 +177,7 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) { } else { archInt = std::stod(version.cast()) * 10 + 0.1; } - target.Set("arch", String("sm_") + std::to_string(archInt)); + target.Set("arch", ffi::String("sm_") + std::to_string(archInt)); } return target; } @@ -190,7 +192,7 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) { // Update -mcpu=sm_xx if (target.count("mcpu")) { // If -mcpu has been specified, validate the correctness - String mcpu = Downcast(target.at("mcpu")); + ffi::String mcpu = Downcast(target.at("mcpu")); ICHECK(support::StartsWith(mcpu, "sm_")) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; } else { @@ -203,7 +205,7 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) { } else { arch = std::stod(version.cast()) * 10 + 0.1; } - target.Set("mcpu", String("sm_") + std::to_string(arch)); + target.Set("mcpu", ffi::String("sm_") + std::to_string(arch)); } return target; } @@ -218,7 +220,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { // Update -mcpu=gfx std::string arch = "gfx900"; if (target.count("mcpu")) { - String mcpu = Downcast(target.at("mcpu")); + ffi::String mcpu = Downcast(target.at("mcpu")); arch = ExtractStringWithPrefix(mcpu, "gfx"); ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu; } else { @@ -226,7 +228,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { if (const auto f_get_rocm_arch = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_get_arch")) { arch = (*f_get_rocm_arch)().cast(); } - target.Set("mcpu", String(arch)); + target.Set("mcpu", ffi::String(arch)); } // Update -mattr before ROCm 3.5: // Before ROCm 3.5 we needed code object v2, starting @@ -241,9 +243,9 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { version = val.cast(); } if (version < 305) { - Array mattr; + ffi::Array mattr; if (target.count("mattr")) { - mattr = Downcast>(target.at("mattr")); + mattr = Downcast>(target.at("mattr")); } mattr.push_back("-code-object-v3"); target.Set("mattr", mattr); @@ -257,7 +259,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", true}}; + ffi::Map features = {{"is_test", true}}; target.Set("features", features); return target; } @@ -265,11 +267,11 @@ TargetJSON TestTargetParser(TargetJSON target) { /********** Register Target kinds and attributes **********/ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) - .add_attr_option>("mattr") - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option("mfloat-abi") - .add_attr_option("mabi") + .add_attr_option>("mattr") + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option("mfloat-abi") + .add_attr_option("mabi") .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags .add_attr_option("fast-math") // implies all the below @@ -281,9 +283,9 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("fast-math-reassoc") .add_attr_option("opt-level") // LLVM command line flags, see below - .add_attr_option>("cl-opt") + .add_attr_option>("cl-opt") // LLVM JIT engine mcjit/orcjit - .add_attr_option("jit") + .add_attr_option("jit") // TVM & LLVM custom vector bit width .add_attr_option("vector-width") .set_default_keys({"cpu"}) @@ -314,16 +316,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) // Hence the type is "uint". TVM_REGISTER_TARGET_KIND("c", kDLCPU) - .add_attr_option("mcpu") - .add_attr_option("march") + .add_attr_option("mcpu") + .add_attr_option("march") .add_attr_option("workspace-byte-alignment") .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) - .add_attr_option("mcpu") - .add_attr_option("arch") + .add_attr_option("mcpu") + .add_attr_option("arch") .add_attr_option("max_shared_memory_per_block") .add_attr_option("max_threads_per_block") .add_attr_option("thread_warp_size", 32) @@ -334,17 +336,17 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) - .add_attr_option("mcpu") - .add_attr_option("mtriple") + .add_attr_option("mcpu") + .add_attr_option("mtriple") .add_attr_option("max_num_threads", 1024) .add_attr_option("thread_warp_size", 32) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option>("mattr") + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 .add_attr_option("max_num_threads", 256) @@ -354,6 +356,19 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); +TVM_REGISTER_TARGET_KIND("hip", kDLROCM) + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("mattr") + // TODO(masahi): Support querying from a target device + // On RDNA cards, thread_warp_size should be 32 + .add_attr_option("max_num_threads", 256) + .add_attr_option("max_threads_per_block", 256) + .add_attr_option("max_shared_memory_per_block", 65536) + .add_attr_option("thread_warp_size", 64) + .set_default_keys({"hip", "gpu"}) + .set_target_parser(UpdateROCmAttrs); + TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) .add_attr_option("max_threads_per_block", 256) .add_attr_option("max_shared_memory_per_block", 16384) @@ -382,7 +397,7 @@ TVM_REGISTER_TARGET_KIND("metal", kDLMetal) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) - .add_attr_option>("mattr") + .add_attr_option>("mattr") // Feature support .add_attr_option("supports_float16") .add_attr_option("supports_float32", true) @@ -412,9 +427,9 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("max_per_stage_descriptor_storage_buffer") .add_attr_option("max_shared_memory_per_block") // Other device properties - .add_attr_option("device_type") - .add_attr_option("device_name") - .add_attr_option("driver_name") + .add_attr_option("device_type") + .add_attr_option("device_name") + .add_attr_option("driver_name") .add_attr_option("driver_version") .add_attr_option("vulkan_api_version") .add_attr_option("max_spirv_version") @@ -426,31 +441,29 @@ TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) - .add_attr_option>("mattr") - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option>("llvm-options") + .add_attr_option>("mattr") + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("llvm-options") .add_attr_option("num-cores") .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev); -TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU); - TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break - .add_attr_option>("devices"); + .add_attr_option>("devices"); TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break .set_target_parser(TestTargetParser); /********** Registry **********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.TargetKindGetAttr", - [](TargetKind kind, String attr_name) -> ffi::Any { + [](TargetKind kind, ffi::String attr_name) -> ffi::Any { auto target_attr_map = TargetKind::GetAttrMap(attr_name); ffi::Any rv; if (target_attr_map.count(kind)) { @@ -460,10 +473,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("target.ListTargetKinds", TargetKindRegEntry::ListTargetKinds) .def("target.ListTargetKindOptions", TargetKindRegEntry::ListTargetKindOptions) - .def("target.ListTargetKindOptionsFromName", [](String target_kind_name) { + .def("target.ListTargetKindOptionsFromName", [](ffi::String target_kind_name) { TargetKind kind = TargetKind::Get(target_kind_name).value(); return TargetKindRegEntry::ListTargetKindOptions(kind); }); -}); +} } // namespace tvm diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index ac67afcfafe5..54529acb409c 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -28,7 +28,7 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ VirtualDeviceNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { VirtualDeviceNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -71,7 +71,7 @@ VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target ICHECK(!target.defined() || device_type_int == target->GetTargetDeviceType()) << "target " << target->ToDebugString() << " has device type " << target->GetTargetDeviceType() << " but virtual device has device type " << device_type_int; - auto node = make_object(); + auto node = ffi::make_object(); node->device_type_int = device_type_int; node->virtual_device_id = virtual_device_id; node->target = std::move(target); @@ -85,7 +85,8 @@ VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target } /* static */ -Optional VirtualDevice::Join(const VirtualDevice& lhs, const VirtualDevice& rhs) { +ffi::Optional VirtualDevice::Join(const VirtualDevice& lhs, + const VirtualDevice& rhs) { if (lhs == rhs) { return lhs; } @@ -191,10 +192,10 @@ VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) { virtual_device->target, virtual_device->memory_scope); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.VirtualDevice_ForDeviceTargetAndMemoryScope", VirtualDevice::ForDeviceTargetAndMemoryScope); -}); +} } // namespace tvm diff --git a/src/target/z3/z3_prover_off.cc b/src/target/z3/z3_prover_off.cc new file mode 100644 index 000000000000..5ed04d792be7 --- /dev/null +++ b/src/target/z3/z3_prover_off.cc @@ -0,0 +1,40 @@ +#include +#include +#include + +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" +#include "tvm/tir/analysis.h" +#include "tvm/arith/analyzer.h" + +namespace tvm::arith { + +using namespace tir; +using namespace ffi; + +class Z3Prover::Impl {}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr & expr) { return false; } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) {} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { return [](){}; } +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + return "; Z3 Prover is disabled."; +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {} +void Z3Prover::SetRLimit(unsigned rlimit) {} +ffi::String Z3Prover::GetModel(const PrimExpr & expr) { + return "; Z3 Prover is disabled."; +} +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive) { + return -1; // Z3 disabled, return error +} + +void Z3Prover::CopyFrom(const Z3Prover & other) {} +ffi::String Z3Prover::GetStats() { + return "; Z3 Prover is disabled."; +} +Z3Prover::Z3Prover(Analyzer*): impl_(nullptr) {} +TVM_DLL Z3Prover::~Z3Prover() {} + +} // namespace tvm::arith diff --git a/src/target/z3/z3_prover_on.cc b/src/target/z3/z3_prover_on.cc new file mode 100644 index 000000000000..f4832334ce95 --- /dev/null +++ b/src/target/z3/z3_prover_on.cc @@ -0,0 +1,767 @@ +#include +#include +#include +#include +#include "z3++.h" + +#include +#include +#include + +#include "tvm/ffi/cast.h" +#include "tvm/ffi/object.h" +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" +#include "tvm/node/structural_equal.h" +#include "tvm/node/structural_hash.h" +#include "tvm/runtime/data_type.h" +#include "tvm/tir/analysis.h" +#include "tvm/tir/expr_functor.h" +#include "tvm/arith/analyzer.h" +#include "tvm/tir/op_attr_types.h" + +namespace tvm::arith { + +using namespace tir; +using namespace ffi; + +namespace { + +struct Namespace { + std::unordered_set used_names; + /// @brief Get a new name that is not used before + /// This function is used to generate z3 variable names + /// + /// Z3 may deduplicate variables with the same name, which + /// causes issues when different TVM variables are mapped to + /// the same z3 variable. + /// + /// This function generates unique names by appending + /// suffixes to the original expression string representation. + /// + /// such as : "x", "x$1", "x$2", ... + std::string GetNewName(const PrimExpr & expr) { + std::stringstream ss; + ss << expr; + auto name = ss.str(); + if(used_names.count(name) == 0) { + used_names.insert(name); + return name; + } + int idx = 1; + std::string check_name = name + "$" + std::to_string(idx); + while(used_names.count(check_name)) { + idx ++; + check_name = name + "$" + std::to_string(idx); + } + used_names.insert(check_name); + return check_name; + } +}; + +} // namespace + +class Z3Prover::Impl : ExprFunctor { +public: + using Base = ExprFunctor; + using Self = Z3Prover::Impl; + + Analyzer* analyzer; + /// @brief Z3 context, a shared ptr, because tilelang want to copy the Analyzer + // We use a thread_local static Z3 context so all analyzers within the same thread + // can share a common context, because Z3 initialization is slow on some CPUs + // (e.g., AMD EPYC 7502 32-Core). Using thread_local ensures thread safety. + inline static thread_local std::shared_ptr ctx { new z3::context() }; + + /// @brief Z3 solver instance + z3::solver solver {*ctx}; + + /// @brief Memorize pure expressions + std::unordered_map memo_; + + bool is_assume = false; + + /// @brief Namespace for variable naming + Namespace ns; + + /// @brief Timeout in milliseconds + unsigned timeout_ms {UINT_MAX}; + + /// @brief Max steps + unsigned rlimit {UINT_MAX}; + + /// @brief Create a z3 solver with custom options + static z3::solver CreateSolver(z3::context & ctx) { + z3::solver solver(ctx); + // here we disable model generation to speed up the solving process + solver.set("model", false); + // ensure determinstic behavior + solver.set("random_seed", (unsigned)42); + return solver; + } + + Impl(Analyzer * parent): analyzer(parent) { + scope_stack_.push_back({}); + solver = CreateSolver(*ctx); + // default timeout 5ms + // Z3's implementation of timeout, when setting timeout T ms, it will stop at T - 1 ms + // SetTimeoutMs(5); + // use rlimit, not timeout to ensure determinstic behavior + SetRLimit(1e4); + } + + /// @brief Create a Free z3 expression from PrimExprNode + z3::expr Create(const PrimExprNode *op) { + auto ref = ffi::GetRef(op); + auto dtype = op->dtype; + std::string name = ns.GetNewName(ref); + /// TVM max_val can't handle uint64 max correctly, so we special case it here + if(dtype.is_bool()) { + return ctx->bool_const(name.c_str()); + } + else { + z3::expr e = ctx->int_const(name.c_str()); + if(dtype.is_uint() && dtype.bits() == 64) { + solver.add(ctx->int_val(0) <= e && e <= ctx->int_val((uint64_t)UINT64_MAX)); + } else { + auto min_val = Downcast(min_value(dtype))->value; + auto max_val = Downcast(max_value(dtype))->value; + solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val)); + } + return e; + } + } + + struct Scope { + enum Kind { + BindValue, + BindRange, + Constraint, + } kind; + Var var; + PrimExpr value; + PrimExpr min; + PrimExpr extent; + PrimExpr constraint; + }; + + /// @brief scope_stack memorizes existing constraint and bindings + /// to generate SMTLIB2 representation with comments + std::vector> scope_stack_; + + /// @brief Enter a constraint scope + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false) { + scope_stack_.push_back({}); + scope_stack_.back().push_back(Scope{Scope::Constraint, Var(), PrimExpr(), PrimExpr(), PrimExpr(), constraint}); + solver.push(); + this->is_assume = is_assume; + solver.add(VisitBool(constraint)); + this->is_assume = false; + auto side_effect_exprs = std::move(side_effect_exprs_); + side_effect_exprs_.clear(); + if(is_assume) { + return [this, side_effect_exprs]() { + solver.pop(); + for (const auto& expr : side_effect_exprs) { + memo_.erase(expr); + } + scope_stack_.pop_back(); + }; + } else { + for(const auto & expr: side_effect_exprs) { + memo_.erase(expr); + } + return [this]() { + solver.pop(); + scope_stack_.pop_back(); + }; + } + } + + /// @brief Check trivil bad cases, return true if the expr is a bad case + /// Z3 prover may take a long time to initialize (at least 200us), + /// This optimization can speedup 30% of the test cases in our unit tests + bool CheckTrivilBadCases(const PrimExpr & expr) { + if(IsFreeNode(expr)) { + return true; + } + auto checkTrivilCmp = [this](const PrimExpr & lhs, const PrimExpr & rhs) { + if(IsFreeNode(lhs) && rhs->IsInstance()) { + return true; + } + if(IsFreeNode(rhs) && lhs->IsInstance()) { + return true; + } + if(IsFreeNode(lhs) && IsFreeNode(rhs)) { + return true; + } + // cast('xxx', free_var) == constant + if(auto cast = lhs.as()) { + if(IsFreeNode(cast->value) && rhs->IsInstance()) { + return true; + } + } + // constant == cast('xxx', free_var) + if(auto cast = rhs.as()) { + if(IsFreeNode(cast->value) && lhs->IsInstance()) { + return true; + } + } + return false; + }; + if(auto eq = expr.as()) { + auto lhs = eq->a; + auto rhs = eq->b; + return checkTrivilCmp(lhs, rhs); + } else if(auto ne = expr.as()) { + auto lhs = ne->a; + auto rhs = ne->b; + return checkTrivilCmp(lhs, rhs); + } + return false; + } + + /// @brief Check if the expression can be proved + bool CanProve(const PrimExpr &expr) { + if (CheckTrivilBadCases(expr)) return false; + if (!IsValidDType(expr->dtype)) return false; + z3::expr_vector constr(*ctx); + constr.push_back(!ConvertBool(expr)); + auto result = solver.check(constr); + constr.pop_back(); + return result == z3::unsat; + } + + /// @brief Binded + /// @brief Bind a variable to a value or a range + void Bind(const Var & var, const PrimExpr & value, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back(Scope{ + Scope::BindValue, + var, + value + }); + // we add the binding whenever the value is pure, + // because non-pure parts are handling by creating free variables in VisitExpr + memo_.emplace(var, ConvertInt(value)); + } + + /// @brief Bind a variable to a range + void Bind(const Var & var, const Range & range, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back(Scope{ + Scope::BindRange, + var, + PrimExpr(), + range->min, + range->extent + }); + // 1. Create a placeholder for the var, and save it in the memo + // if the var is overrided later, we can just update the memo, and the old placeholder will be ignored + auto var_expr = Create(var.as()); + memo_.emplace(var, var_expr); + + // 2. Add constraint on the placeholder + // when min_expr >= max_expr, the range is empty, which is under undefined behavior + // instead of adding an unsat constraint, we just skip the range constraint to leave it a free var + if(tir::is_const_int(range->min) && tir::is_const_int(range->min + range->extent)) { + int64_t min_value = *tir::as_const_int(range->min); + int64_t max_value = *tir::as_const_int(range->min + range->extent); + if(min_value < max_value) { + solver.add(ctx->int_val(min_value) <= var_expr); + solver.add(var_expr < ctx->int_val(max_value)); + } + } else { + solver.add(ConvertBool(range->extent <= 0 || (range->min <= var && var < range->min + range->extent))); + } + } + + void CopyFrom(const Self & other_) { + // 1. create a new solver + // because this->solver depends on this->ctx + // we need to deconstruct the old solver, and create a new one depending on other_.ctx + solver = CreateSolver(*other_.ctx); + // 2. copy the context + // the context is a shared_ptr, we can just copy the pointer + ctx = other_.ctx; + // 3. copy other objects + ns = other_.ns; + for(auto & item: other_.memo_) { + memo_.emplace(item.first, item.second); + } + for(auto a: other_.solver.assertions()) { + solver.add(a); + } + // 4. copy timeout options + // but other solver options are not copied + SetTimeoutMs(other_.timeout_ms); + SetRLimit(other_.rlimit); + // 5. copy the scope stack, which containing comments for SMTLIB2 generation + scope_stack_ = other_.scope_stack_; + } + + /// @brief Set timeout in milliseconds + void SetTimeoutMs(unsigned timeout_ms) { + this->timeout_ms = timeout_ms; + solver.set("timeout", timeout_ms); + } + + /// @brief Set max steps + void SetRLimit(unsigned rlimit) { + this->rlimit = rlimit; + solver.set("rlimit", rlimit); + } + + /// @brief Get the SMTLIB2 representation of the current solver state + ffi::String GetSMTLIB2() { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << solver.to_smt2(); + return ss.str(); + } + + void AddScopeDebugMsg(std::ostream & ss) { + for(const auto &scope: scope_stack_) { + ss << "; Entering Scope\n"; + for(const auto & s: scope) { + switch(s.kind) { + case Scope::Constraint: + ss << "; constraint: " << s.constraint << "\n"; + break; + case Scope::BindValue: + ss << "; bind value: " << s.var << " = " << s.value << "\n"; + break; + case Scope::BindRange: + ss << "; bind range: " << s.var << " in [" << s.min << ", " << s.min + s.extent << ")\n"; + break; + } + } + } + } + + /// @brief Get the SMTLIB2 representation of the current solver state with additional expr trying to prove + ffi::String GetSMTLIB2(const PrimExpr & expr) { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << "; Trying to prove: " << expr << "\n"; + solver.push(); + solver.add(!ConvertBool(expr)); + ss << solver.to_smt2(); + solver.pop(); + return ss.str(); + } + + /// @brief Get the statistics of the solver + ffi::String GetStats() { + std::stringstream ss; + ss << solver.statistics(); + return ss.str(); + } + + ffi::String GetModel(const PrimExpr & expr) { + solver.set("model", true); + solver.push(); + solver.add(!ConvertBool(expr)); + auto result = solver.check(); + ffi::String model_str; + if (result == z3::sat) { + z3::model m = solver.get_model(); + std::map model_map; + for(unsigned i = 0; i < m.size(); i++) { + z3::func_decl d = m[i]; + model_map.emplace(d.name().str(), m.get_const_interp(d)); + } + std::stringstream ss; + for(const auto & [k, v]: model_map) { + ss << " " << k << " = " << v << "\n"; + } + model_str = ss.str(); + } + solver.pop(); + solver.set("model", false); + return model_str; + } + + /*! + * \brief Count the number of distinct integer values satisfying current constraints. + * + * Uses Z3's model enumeration (AllSAT pattern) to count solutions: + * 1. Find a satisfying assignment + * 2. Add a blocking clause to exclude it + * 3. Repeat until UNSAT + * + * \param var The variable to count values for + * \param max_count Safety limit on enumeration + * \param min_consecutive Minimum consecutive count requirement (0 to disable) + * \return Number of satisfying values, -1 on error, -2 if min_consecutive constraint not met + */ + int64_t CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive = 1) { + if (!IsValidDType(var->dtype)) { + return -1; + } + + solver.set("model", true); + solver.push(); + + // Convert the TVM variable to Z3 expression + z3::expr z3_var = VisitInt(var); + + int64_t count = 0; + std::vector found_values; + + while (count < max_count) { + auto result = solver.check(); + if (result != z3::sat) { + break; // No more solutions + } + + z3::model m = solver.get_model(); + z3::expr val_expr = m.eval(z3_var, true); + + // Extract the integer value from Z3 expression + int64_t val; + if (val_expr.is_numeral()) { + val = val_expr.get_numeral_int64(); + } else { + // If we can't get a concrete value, stop enumeration + break; + } + + found_values.push_back(val); + count++; + + // Add blocking clause: var != val (exclude this solution) + solver.add(z3_var != ctx->int_val(val)); + } + + solver.pop(); + solver.set("model", false); + + // Clear any side effects from visiting the variable + for (const auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + + // Check minimum consecutive constraint if enabled + if (min_consecutive > 0 && count > 0) { + // Sort the values to check consecutive groups + std::sort(found_values.begin(), found_values.end()); + + // Check that all values form groups of at least min_consecutive consecutive numbers + int64_t consecutive_count = 1; + for (size_t i = 1; i < found_values.size(); i++) { + if (found_values[i] == found_values[i - 1] + 1) { + // Consecutive value + consecutive_count++; + } else { + // Gap found, check if the previous group meets the minimum + if (consecutive_count < min_consecutive) { + return -2; // Previous group too small + } + consecutive_count = 1; // Start new group + } + } + // Check the last group + if (consecutive_count < min_consecutive) { + return -2; // Last group too small + } + } + + return count; + } + +private: + + using Z3BinOp = z3::expr(*)(const z3::expr &, const z3::expr &); + + std::vector side_effect_exprs_; + + z3::expr ConvertBool(const PrimExpr & e, bool is_assume=false) { + this->is_assume = is_assume; + auto res = VisitBool(e); + for(auto & expr: side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + this->is_assume = false; + return res; + } + + z3::expr ConvertInt(const PrimExpr & e, bool is_assume=false) { + this->is_assume = is_assume; + auto res = VisitInt(e); + for(auto & expr: side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + this->is_assume = false; + return res; + } + + /// @brief Visit expression with memoization + z3::expr VisitExpr(const PrimExpr & e) override { + if(memo_.count(e)) { + return memo_.at(e); + } + auto res = Base::VisitExpr(e); + auto side_effect = SideEffect(e); + if(side_effect <= CallEffectKind::kPure) { + memo_.emplace(e, res); + } else if(side_effect <= CallEffectKind::kReadState) { + memo_.emplace(e, res); + side_effect_exprs_.emplace_back(e); + } else { + if(is_assume) { + memo_.emplace(e, res); + } + side_effect_exprs_.emplace_back(e); + } + return res; + } + + /// @brief Check if the expression is a free node having no constraints + bool IsFreeNode(const PrimExpr & e) { + if(memo_.count(e)) { + return false; + } + return e->IsInstance() + || e->IsInstance() + || e->IsInstance() + || e->IsInstance() + || (e->IsInstance() && !IsValidDType(Downcast(e)->value->dtype)); + } + + /// @brief Check if the dtype is valid for z3 integer operations + static bool IsValidDType(const DataType & dtype) { + return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) && dtype.lanes() == 1; + } + + /// @brief Visit the expression and convert it into z3 integer expression + z3::expr VisitInt(const PrimExpr &expr) { + auto e = VisitExpr(expr); + if (e.is_bool()) { + return z3::ite(e, ctx->int_val(1), ctx->int_val(0)); + } else { + return e; + } + } + + /// @brief Visit the expression and convert it into z3 boolean expression + z3::expr VisitBool(const PrimExpr &e) { + auto expr = VisitExpr(e); + if (expr.is_bool()) { + return expr; + } else { + return expr != ctx->int_val(0); + } + } + + /// @brief Helper function to visit binary arithmetic operations + z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode *op, const PrimExpr &a, const PrimExpr &b) { + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return signed_op(VisitInt(a), VisitInt(b)); + } else { + return Create(op); + } + } + + z3::expr VisitExpr_(const LetNode *op) override { + if (IsValidDType(op->var->dtype)) { + memo_.emplace(op->var, VisitInt(op->value)); + } + return VisitExpr(op->body); + } + z3::expr VisitExpr_(const CastNode * op) override { + // if the inner dtype is valid, we just visit it + if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) { + return VisitInt(op->value); + } else { + // otherwise, we create a new free z3 variable + return Create(op); + } + } + z3::expr VisitExpr_(const VarNode *op) override { + return Create(op); + } + z3::expr VisitExpr_(const BufferLoadNode *op) override { return Create(op); } + z3::expr VisitExpr_(const ProducerLoadNode *op) override { return Create(op); } + z3::expr VisitExpr_(const ReduceNode *op) override { return Create(op); } + z3::expr VisitExpr_(const MinNode *op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a < b, a, b); + } + z3::expr VisitExpr_(const MaxNode *op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a > b, a, b); + } + static z3::expr floordiv(const z3::expr & a, const z3::expr & b) { return z3::ite(b > 0, a / b, -((-a) / b)); } + static z3::expr floormod(const z3::expr & a, const z3::expr & b) { return z3::ite(b > 0, a % b, -((-a) % b)); } + z3::expr VisitExpr_(const AddNode *op) override { return VisitArith(z3::operator +, op, op->a, op->b); } + z3::expr VisitExpr_(const SubNode *op) override { return VisitArith(z3::operator -, op, op->a, op->b); } + z3::expr VisitExpr_(const MulNode *op) override { return VisitArith(z3::operator *, op, op->a, op->b); } + z3::expr VisitExpr_(const DivNode *op) override { return VisitArith(z3::operator /, op, op->a, op->b); } + z3::expr VisitExpr_(const ModNode *op) override { return VisitArith(z3::operator %, op, op->a, op->b); } + z3::expr VisitExpr_(const FloorDivNode *op) override { return VisitArith(floordiv, op, op->a, op->b); } + z3::expr VisitExpr_(const FloorModNode *op) override { return VisitArith(floormod, op, op->a, op->b); } + z3::expr VisitExpr_(const EQNode *op) override { return VisitArith(z3::operator==, op, op->a, op->b); } + z3::expr VisitExpr_(const NENode *op) override { return VisitArith(z3::operator!=, op, op->a, op->b); } + z3::expr VisitExpr_(const LTNode *op) override { return VisitArith(z3::operator<, op, op->a, op->b); } + z3::expr VisitExpr_(const LENode *op) override { return VisitArith(z3::operator<=, op, op->a, op->b); } + z3::expr VisitExpr_(const GTNode *op) override { return VisitArith(z3::operator>, op, op->a, op->b); } + z3::expr VisitExpr_(const GENode *op) override { return VisitArith(z3::operator>=, op, op->a, op->b); } + z3::expr VisitExpr_(const AndNode *op) override { return VisitBool(op->a) && VisitBool(op->b); } + z3::expr VisitExpr_(const OrNode *op) override { return VisitBool(op->a) || VisitBool(op->b); } + z3::expr VisitExpr_(const NotNode *op) override { return !VisitBool(op->a); } + z3::expr VisitExpr_(const SelectNode *op) override { return z3::ite(VisitBool(op->condition), VisitInt(op->true_value), VisitInt(op->false_value)); } + z3::expr VisitExpr_(const IntImmNode *op) override { return ctx->int_val(op->value); } + + // Bitwise operations + z3::expr VisitExpr_(const CallNode *op) override { + // Check if this is a bitwise operation + if (op->op.same_as(tir::builtin::bitwise_and())) { + return VisitBitwiseOp(z3::operator&, op); + } else if (op->op.same_as(tir::builtin::bitwise_or())) { + return VisitBitwiseOp(z3::operator|, op); + } else if (op->op.same_as(tir::builtin::bitwise_xor())) { + return VisitBitwiseOp(z3::operator^, op); + } else if (op->op.same_as(tir::builtin::bitwise_not())) { + return VisitBitwiseNotOp(op); + } else if (op->op.same_as(tir::builtin::shift_left())) { + return VisitShiftOp(z3::shl, op); + } else if (op->op.same_as(tir::builtin::shift_right())) { + return VisitShiftOp(z3::ashr, op); + } else { + // For other call nodes, create a free variable + return Create(op); + } + } + + /// @brief Helper function to visit binary bitwise operations + z3::expr VisitBitwiseOp(z3::expr(*op_func)(const z3::expr &, const z3::expr &), const CallNode *op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Binary bitwise operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr &a = op->args[0]; + const PrimExpr &b = op->args[1]; + unsigned bit_width = std::max(op->args[0].dtype().bits(), op->args[1].dtype().bits()); + + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return z3::bv2int(op_func(z3::int2bv(bit_width, VisitInt(a)), z3::int2bv(bit_width, VisitInt(b))), true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit unary bitwise not operation + z3::expr VisitBitwiseNotOp(const CallNode *op) { + if (op->args.size() != 1) { + LOG(FATAL) << "Bitwise not operation expects 1 argument, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr &a = op->args[0]; + + if (IsValidDType(a->dtype)) { + // Cast integer to bit-vector, apply bitwise not, then cast back. + unsigned bit_width = a.dtype().bits(); + z3::expr a_int = VisitInt(a); + z3::expr a_bv = z3::int2bv(bit_width, a_int); + return z3::bv2int(~a_bv, true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit shift operations + z3::expr VisitShiftOp(z3::expr(*op_func)(const z3::expr &, const z3::expr &), const CallNode *op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Shift operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr &a = op->args[0]; + const PrimExpr &b = op->args[1]; + + // Shift operations require integer types for both operands + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + // For shift operations, we need to ensure the shift amount is non-negative + // and within reasonable bounds + z3::expr a_expr = VisitInt(a); + z3::expr b_expr = VisitInt(b); + + // Add constraint that shift amount should be non-negative + // This is a common assumption in many programming languages + solver.add(b_expr >= 0); + + // Also limit shift amount to avoid unrealistic large shifts + // We'll limit to 64 bits (reasonable for most use cases) + solver.add(b_expr < 64); + + unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); + z3::expr a_bv = z3::int2bv(bit_width, a_expr); + z3::expr b_bv = z3::int2bv(bit_width, b_expr); + + // Perform the shift in bit-vector domain, then cast back to int. + z3::expr result_bv = op_func(a_bv, b_bv); + return z3::bv2int(result_bv, true); + } else { + return Create(op); + } + } + + z3::expr VisitExprDefault_(const Object* op) override { + LOG(FATAL) << "Z3Prover only support integers, but got " << op->GetTypeKey() << "."; + TVM_FFI_UNREACHABLE(); + } +}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr & expr) { + return impl_->CanProve(expr); +} +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) { + return impl_->Bind(var, new_range, allow_override); +} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + return impl_->Bind(var, expr, allow_override); +} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + if(expr.has_value()) { + return impl_->GetSMTLIB2(expr.value()); + } else { + return impl_->GetSMTLIB2(); + } +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) { + impl_->SetTimeoutMs(timeout_ms); +} +void Z3Prover::SetRLimit(unsigned max_step) { + impl_->SetRLimit(max_step); +} +void Z3Prover::CopyFrom(const Z3Prover & other) { + impl_->CopyFrom(*other.impl_); +} +ffi::String Z3Prover::GetStats() { + return impl_->GetStats(); +} +ffi::String Z3Prover::GetModel(const PrimExpr & expr) { + return impl_->GetModel(expr); +} +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive) { + return impl_->CountSatisfyingValues(var, max_count, min_consecutive); +} +Z3Prover::Z3Prover(Analyzer* parent): impl_(new Impl{parent}) {} +TVM_DLL Z3Prover::~Z3Prover() { + delete impl_; +} + +} // namespace tvm::arith diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 01b80386e2c0..fa7424a7cda0 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -39,11 +39,11 @@ namespace tvm { namespace te { using namespace tir; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { OperationNode::RegisterReflection(); BaseComputeOpNode::RegisterReflection(); ComputeOpNode::RegisterReflection(); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -84,10 +84,10 @@ DataType ComputeOpNode::output_dtype(size_t idx) const { return body[idx].dtype(); } -Array BaseComputeOpNode::output_shape(size_t idx) const { +ffi::Array BaseComputeOpNode::output_shape(size_t idx) const { ICHECK_LT(idx, num_outputs()); // for now, all outputs of a BaseComputeOp have the same shape - Array shape; + ffi::Array shape; for (const auto& ivar : this->axis) { const Range& r = ivar->dom; shape.push_back(r->extent); @@ -95,8 +95,8 @@ Array BaseComputeOpNode::output_shape(size_t idx) const { return shape; } -Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, - Map attrs) { +Tensor compute(ffi::Array shape, FCompute fcompute, std::string name, std::string tag, + ffi::Map attrs) { // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -112,8 +112,8 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std:: return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0); } -Array compute(Array shape, FBatchCompute fcompute, std::string name, - std::string tag, Map attrs) { +ffi::Array compute(ffi::Array shape, FBatchCompute fcompute, std::string name, + std::string tag, ffi::Map attrs) { // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -127,19 +127,19 @@ Array compute(Array shape, FBatchCompute fcompute, std::string } Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args)); - Array outputs; + ffi::Array outputs; for (int idx = 0; idx < op->num_outputs(); ++idx) { outputs.push_back(op.output(idx)); } return outputs; } -ComputeOp::ComputeOp(std::string name, std::string tag, Map attrs, - Array axis, Array body) { +ComputeOp::ComputeOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array axis, ffi::Array body) { if (!attrs.defined()) { - attrs = Map(); + attrs = ffi::Map(); } - auto n = make_object(); + auto n = ffi::make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -153,18 +153,18 @@ ComputeOp::ComputeOp(std::string name, std::string tag, Map at data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.ComputeOp", - [](std::string name, std::string tag, Optional> attrs, - Array axis, Array body) { - return ComputeOp(name, tag, attrs.value_or({}), axis, body); - }); -}); + refl::GlobalDef().def("te.ComputeOp", [](std::string name, std::string tag, + ffi::Optional> attrs, + ffi::Array axis, ffi::Array body) { + return ComputeOp(name, tag, attrs.value_or({}), axis, body); + }); +} // The schedule related logics -Array ComputeOpNode::InputTensors() const { - Array ret; +ffi::Array ComputeOpNode::InputTensors() const { + ffi::Array ret; std::unordered_set visited; for (auto& e : body) { tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 7408eb46eb51..fa84ab3863fb 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -105,19 +105,19 @@ class BufferSubstituter : public StmtExprMutator { /*! \brief Helper data structure to store information. */ struct CreateFuncInfo { /*! \brief The Tensor arg_list. */ - Array arg_list; + ffi::Array arg_list; /*! \brief The map from each Tensor to its corresponding buffer. */ std::unordered_map tensor2buffers; /*! \brief The transformer from ProducerLoad to BufferLoad. */ ProducerToBufferTransformer transformer; /*! \brief The buffers should be allocated at function root. */ - Array root_alloc; + ffi::Array root_alloc; /*! \brief The NameSupply to make block name unique. */ NameSupply name_supply; - String FreshName(String base_name) { return name_supply->FreshName(base_name); } + ffi::String FreshName(ffi::String base_name) { return name_supply->FreshName(base_name); } - explicit CreateFuncInfo(Array arg_list) + explicit CreateFuncInfo(ffi::Array arg_list) : arg_list(std::move(arg_list)), transformer(tensor2buffers) {} bool IsArg(const te::Tensor& tensor) const { @@ -131,7 +131,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { PrimFunc Process(PrimFunc func) { for (int i = 0, n = func->params.size(); i < n; ++i) { if (auto v = func->params[i].as()) { - if (Optional buffer = func->buffer_map.Get(v.value())) { + if (ffi::Optional buffer = func->buffer_map.Get(v.value())) { buffer2index_[buffer.value()] = i; } } @@ -141,7 +141,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { if (this->layout_free_buffer_indices_.empty()) { return func; } - Array indices; + ffi::Array indices; indices.reserve(this->layout_free_buffer_indices_.size()); for (int i : this->layout_free_buffer_indices_) { indices.push_back(i); @@ -153,8 +153,8 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { Block block = Downcast(StmtMutator::VisitStmt_(_block)); BlockNode* n = block.CopyOnWrite(); if (auto opt_ann = n->annotations.Get(topi_attr)) { - Array new_buffers; - for (Buffer buffer : Downcast>(opt_ann.value())) { + ffi::Array new_buffers; + for (Buffer buffer : Downcast>(opt_ann.value())) { auto it = buffer2index_.find(buffer); if (it != buffer2index_.end()) { layout_free_buffer_indices_.insert(it->second); @@ -168,7 +168,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { n->annotations.Set(topi_attr, new_buffers); } } - for (const String& attr : this->blocklist) { + for (const ffi::String& attr : this->blocklist) { auto it = n->annotations.find(attr); if (it != n->annotations.end()) { n->annotations.erase(attr); @@ -179,9 +179,9 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { std::unordered_map buffer2index_; std::set layout_free_buffer_indices_; - String topi_attr = "layout_free_placeholders"; - std::vector blocklist = {"const_matrix", "auto_scheduler_simplify_const_tensor_indices", - "workload"}; + ffi::String topi_attr = "layout_free_placeholders"; + std::vector blocklist = {"const_matrix", + "auto_scheduler_simplify_const_tensor_indices", "workload"}; }; /**! @@ -191,7 +191,8 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { **/ using NestedIterLevels = std::vector>; -NestedIterLevels GenerateNestedIterLevels(const Array& axes, arith::Analyzer* analyzer) { +NestedIterLevels GenerateNestedIterLevels(const ffi::Array& axes, + arith::Analyzer* analyzer) { int global_max_depth = 0; std::unordered_map depth; std::unordered_map var2iter; @@ -244,9 +245,9 @@ NestedIterLevels GenerateNestedIterLevels(const Array& axes, arith::Ana * \param info Generation context info. * \returns The output buffer objects, ordered by compute op's outputs. **/ -Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncInfo* info) { +ffi::Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncInfo* info) { // Step 1. Collect output tensors in TE operation. - Array tensors; + ffi::Array tensors; if (compute_op->body[0]->IsInstance()) { auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { StructuralEqual eq; @@ -265,8 +266,8 @@ Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncI ICHECK(reduce_); ICHECK(f_reducer_equal(reduce_, reduce)) << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " - << "but the first argument has body " << GetRef(reduce_) << ", while the " << k - << "-th argument has body " << GetRef(reduce); + << "but the first argument has body " << ffi::GetRef(reduce_) << ", while the " + << k << "-th argument has body " << ffi::GetRef(reduce); tensors.push_back(compute_op.output(k)); } } else { @@ -278,7 +279,7 @@ Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncI // - Declare buffers // - Update `op2buffers` // - Add the non-argument tensors to `alloc_buffer` of the root block - Array buffers; + ffi::Array buffers; for (const te::Tensor& tensor : tensors) { Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; @@ -296,9 +297,9 @@ Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncI * \param info Generation context info. * \returns The block annotation dict. **/ -Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, - CreateFuncInfo* info) { - Map annotations; +ffi::Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, + CreateFuncInfo* info) { + ffi::Map annotations; auto mutate_attr = [&info](const ffi::Any& value) -> ffi::Any { if (auto tensor_value = value.try_cast()) { return info->tensor2buffers.at(tensor_value.value()); @@ -307,11 +308,11 @@ Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, } }; for (const auto& pair : compute_op->attrs) { - const String& key = pair.first; + const ffi::String& key = pair.first; const Any& value = pair.second; // TensorIR will not allow Tensor data structure if (value.as()) { - const auto array_value = Downcast>(value); + const auto array_value = Downcast>(value); annotations.Set(key, array_value.Map(mutate_attr)); } else { annotations.Set(key, mutate_attr(value)); @@ -331,17 +332,17 @@ Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, * \param info Generation context info. * \returns Init stmt. **/ -Stmt GenerateInitStmt(const Array& indices, const Array& buffers, - const ReduceNode* reduce, const Map& var_map, +Stmt GenerateInitStmt(const ffi::Array& indices, const ffi::Array& buffers, + const ReduceNode* reduce, const ffi::Map& var_map, CreateFuncInfo* info) { // helper to transform the expr and remap iters to the block domain auto f_transform_and_remap = [&](const PrimExpr& e) { return Substitute(info->transformer(e), var_map); }; - Optional init = std::nullopt; + ffi::Optional init = std::nullopt; Stmt body; int n_buffers = buffers.size(); - Array init_stmts; + ffi::Array init_stmts; init_stmts.reserve(n_buffers); for (int i = 0; i < n_buffers; ++i) { const Buffer& buffer = buffers[i]; @@ -361,9 +362,9 @@ Stmt GenerateInitStmt(const Array& indices, const Array& buffe * \param analyzer Arithmetic analyzer in context. * \returns Init stmt. **/ -Stmt GenerateBodyStmt(const Array& indices, const Array& buffers, - const Map& var_map, PrimExpr expr_body, CreateFuncInfo* info, - arith::Analyzer* analyzer) { +Stmt GenerateBodyStmt(const ffi::Array& indices, const ffi::Array& buffers, + const ffi::Map& var_map, PrimExpr expr_body, + CreateFuncInfo* info, arith::Analyzer* analyzer) { // helper to transform the expr and remap iters to the block domain auto f_transform_and_remap = [&](const PrimExpr& e) { return Substitute(info->transformer(e), var_map); @@ -373,8 +374,8 @@ Stmt GenerateBodyStmt(const Array& indices, const Array& buffe // Case 1. Reduce compute int n_buffers = buffers.size(); - Array lhs; - Array rhs; + ffi::Array lhs; + ffi::Array rhs; lhs.reserve(n_buffers); rhs.reserve(n_buffers); @@ -389,8 +390,8 @@ Stmt GenerateBodyStmt(const Array& indices, const Array& buffe ICHECK_EQ(left->dtype, right->dtype); } - Array temp_vars; - Array body_stmts; + ffi::Array temp_vars; + ffi::Array body_stmts; temp_vars.reserve(n_buffers); body_stmts.reserve(n_buffers); @@ -433,16 +434,16 @@ struct NestedScopeInfo { // loop var and range in the scope. std::vector> loop_vars; // block iters for current level's block. - Array block_iters; + ffi::Array block_iters; // block bindings for current level's block. - Array bindings; + ffi::Array bindings; // store indices for current level's block. - Array store_indices; + ffi::Array store_indices; // mapping from original TE compute axes to new block vars. - Map axes_remap; + ffi::Map axes_remap; // helper to add new block var - void AddBlockIter(const Optional& origin_axis, const IterVar& iter, + void AddBlockIter(const ffi::Optional& origin_axis, const IterVar& iter, const PrimExpr& value) { block_iters.push_back(iter); bindings.push_back(value); @@ -455,9 +456,9 @@ struct NestedScopeInfo { } // helper to renew leaf block var defs to ensure SSA. - void Renew(const Array& origin_axes) { + void Renew(const ffi::Array& origin_axes) { block_iters.MutateByApply([](const IterVar& itervar) { - auto n = make_object(*itervar.get()); + auto n = ffi::make_object(*itervar.get()); n->var = n->var.copy_with_suffix(""); return IterVar(n); }); @@ -474,7 +475,7 @@ struct NestedScopeInfo { Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info, arith::Analyzer* analyzer) { // Step 1. Collect all iter axes in original TE compute op - Array axes = compute_op->axis; + ffi::Array axes = compute_op->axis; axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); // Step 2. Prepare nested iteration scopes. @@ -528,12 +529,12 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in } // Step 3. Generate output buffers for each output tensor - Array buffers = GenerateOutputBuffers(compute_op, info); + ffi::Array buffers = GenerateOutputBuffers(compute_op, info); // Step 4. Generate leaf block stmts. - Array seq_stmt; + ffi::Array seq_stmt; auto leaf = scopes.back(); - Map annotations = GenerateBlockAnnotations(compute_op, info); + ffi::Map annotations = GenerateBlockAnnotations(compute_op, info); const ReduceNode* reduce = compute_op->body[0].as(); if (reduce) { PrimExpr expr_body = compute_op->body[0]; @@ -585,7 +586,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in auto block_name = info->FreshName(compute_op->name + "_l" + std::to_string(i)); const auto& block_iters = cur.block_iters; - Optional init{std::nullopt}; + ffi::Optional init{std::nullopt}; if (reduce && std::any_of(block_iters.begin(), block_iters.end(), [](const IterVar& iter) { return iter->iter_type == IterVarType::kCommReduce; })) { @@ -649,7 +650,10 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf // reads/writes filled in. BufferSubstituter substituter(var_map, input_buffer_map); - Stmt body = substituter(extern_op->body); + Stmt substituted_body = substituter(extern_op->body); + + ProducerToBufferTransformer transformer(info->tensor2buffers); + Stmt body = transformer(substituted_body); // Step 4. Generate opaque block as body. return BlockRealize(/*iter_values=*/{}, @@ -666,13 +670,13 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf /*annotations=*/extern_op->attrs)); } -Array CollectOrderedOps(const Array& arg_list) { - Array arg_ops; +ffi::Array CollectOrderedOps(const ffi::Array& arg_list) { + ffi::Array arg_ops; for (const te::Tensor& arg : arg_list) { arg_ops.push_back(arg->op); } te::ReadGraph g = te::CreateReadGraph(arg_ops); - Array order = te::PostDFSOrder(arg_ops, g); + ffi::Array order = te::PostDFSOrder(arg_ops, g); for (const te::Operation& op : order) { if (!(op->IsInstance() || op->IsInstance() || @@ -683,7 +687,7 @@ Array CollectOrderedOps(const Array& arg_list) { return order; } -void InitializeBufferBinds(const Array& ordered_ops, CreateFuncInfo* info) { +void InitializeBufferBinds(const ffi::Array& ordered_ops, CreateFuncInfo* info) { // Process any TE operations which contain user defined buffers for (const auto& op : ordered_ops) { // Initialize the tensor2buffer binds map with buffers defined by the te.extern @@ -698,8 +702,8 @@ void InitializeBufferBinds(const Array& ordered_ops, CreateFuncIn } } -void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array* root_stmts, - arith::Analyzer* analyzer) { +void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, + ffi::Array* root_stmts, arith::Analyzer* analyzer) { if (const auto* placeholder = op.as()) { // Case 1. PlaceholderOp (te.placeholder) ICHECK_EQ(op->num_outputs(), 1); @@ -727,10 +731,10 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array& arg_list, - const Array& root_stmts, CreateFuncInfo* info) { - Array parameters; - Map buffer_map; +PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_list, + const ffi::Array& root_stmts, CreateFuncInfo* info) { + ffi::Array parameters; + ffi::Map buffer_map; for (const te::Tensor& tensor : arg_list) { Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); parameters.push_back(arg); @@ -742,25 +746,25 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, /*body=*/SeqStmt::Flatten(root_stmts), /*ret_type=*/VoidType(), /*buffer_map=*/std::move(buffer_map)), - {{"global_symbol", String("main")}, {"tir.noalias", true}}); + {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); return func; } -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override) { // Information used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. - Array root_stmts; + ffi::Array root_stmts; // Analyzer arith::Analyzer analyzer; // Step 1. Create ordered array of operations and validate they are supported. - Array order = CollectOrderedOps(arg_list); + ffi::Array order = CollectOrderedOps(arg_list); // Step 2. Initialize buffer binds map InitializeBufferBinds(order, &info); @@ -780,15 +784,15 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, return result; } -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override) { return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("te.CreatePrimFunc", [](ffi::PackedArgs args, ffi::Any* ret) { - Array arg_list = args[0].cast>(); + ffi::Array arg_list = args[0].cast>(); std::optional index_dtype_override{std::nullopt}; // Add conversion to make std::optional compatible with FFI. if (args[1] != nullptr) { @@ -796,13 +800,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ } *ret = CreatePrimFunc(arg_list, index_dtype_override); }); -}); +} // Relax version impl -PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, - const Array& root_stmts, CreateFuncInfo* info) { - Array parameters; - Map buffer_map; +PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_tir_var_list, + const ffi::Array& root_stmts, CreateFuncInfo* info) { + ffi::Array parameters; + ffi::Map buffer_map; for (const ObjectRef& arg : arg_tir_var_list) { if (auto opt_tensor = arg.as()) { te::Tensor tensor = opt_tensor.value(); @@ -819,32 +823,32 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, /*body=*/SeqStmt::Flatten(root_stmts), /*ret_type=*/VoidType(), /*buffer_map=*/std::move(buffer_map)), - {{"global_symbol", String("main")}, {"tir.noalias", true}}); + {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); return func; } -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override) { - Array tensor_arg_list; + ffi::Array tensor_arg_list; for (const ObjectRef& x : arg_list) { if (auto tensor_node = x.as()) { - te::Tensor tensor = GetRef(tensor_node); + te::Tensor tensor = ffi::GetRef(tensor_node); tensor_arg_list.push_back(tensor); } } // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(tensor_arg_list); // Root body stmts. - Array root_stmts; + ffi::Array root_stmts; // Analyzer arith::Analyzer analyzer; // Step 1. Create ordered array of operations and validate they are supported. - Array order = CollectOrderedOps(tensor_arg_list); + ffi::Array order = CollectOrderedOps(tensor_arg_list); // Step 2. Initialize buffer binds map InitializeBufferBinds(order, &info); @@ -862,7 +866,7 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, return result; } -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override) { return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index eb4a6183dd5c..f7ad7e0e1e0e 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,7 +30,7 @@ namespace tvm { namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override = std::nullopt); /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the @@ -38,12 +38,12 @@ PrimFunc CreatePrimFunc(const Array& arg_list, * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants * will be embedded in the body as AllocateConstNode. */ -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override = std::nullopt); /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override); /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the @@ -51,8 +51,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list, * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants * will be embedded in the body as AllocateConstNode. */ -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override); } // namespace tir diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 23f43a99d8e6..def64595412d 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -31,7 +31,7 @@ namespace tvm { namespace te { using namespace tir; -TVM_FFI_STATIC_INIT_BLOCK({ ExternOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ExternOpNode::RegisterReflection(); } // ExternOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -44,15 +44,17 @@ int ExternOpNode::num_outputs() const { return static_cast(output_placehold DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; } -Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } +ffi::Array ExternOpNode::output_shape(size_t i) const { + return output_placeholders[i]->shape; +} -ExternOp::ExternOp(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body) { +ExternOp::ExternOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array inputs, ffi::Array input_placeholders, + ffi::Array output_placeholders, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = ffi::Map(); } - auto n = make_object(); + auto n = ffi::make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -72,18 +74,19 @@ ExternOp::ExternOp(std::string name, std::string tag, Map attr data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.ExternOp", - [](std::string name, std::string tag, Optional> attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body) { - return ExternOp(name, tag, attrs.value_or({}), inputs, input_placeholders, - output_placeholders, body); - }); -}); + refl::GlobalDef().def( + "te.ExternOp", + [](std::string name, std::string tag, ffi::Optional> attrs, + ffi::Array inputs, ffi::Array input_placeholders, + ffi::Array output_placeholders, Stmt body) { + return ExternOp(name, tag, attrs.value_or({}), inputs, input_placeholders, + output_placeholders, body); + }); +} -Array ExternOpNode::InputTensors() const { return inputs; } +ffi::Array ExternOpNode::InputTensors() const { return inputs; } } // namespace te } // namespace tvm diff --git a/src/te/operation/graph.cc b/src/te/operation/graph.cc index f477f9129b2a..bddea5f7f2d4 100644 --- a/src/te/operation/graph.cc +++ b/src/te/operation/graph.cc @@ -37,7 +37,7 @@ namespace te { // construct a read graph that gives readers of each operation // that the root depend on -ReadGraph CreateReadGraph(const Array& roots) { +ReadGraph CreateReadGraph(const ffi::Array& roots) { ReadGraph rmap; std::vector stack; std::unordered_set visited; @@ -50,7 +50,7 @@ ReadGraph CreateReadGraph(const Array& roots) { while (!stack.empty()) { Operation op = stack.back(); stack.pop_back(); - Array deps = op->InputTensors(); + ffi::Array deps = op->InputTensors(); rmap.Set(op, deps); for (Tensor t : deps) { if (t->op.defined() && visited.count(t->op.get()) == 0) { @@ -63,7 +63,7 @@ ReadGraph CreateReadGraph(const Array& roots) { } void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, - Array* post_order) { + ffi::Array* post_order) { if (visited->count(op)) return; visited->insert(op); for (const auto& t : g.at(op)) { @@ -72,23 +72,23 @@ void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_setpush_back(op); } -Array PostDFSOrder(const Array& roots, const ReadGraph& g) { +ffi::Array PostDFSOrder(const ffi::Array& roots, const ReadGraph& g) { std::unordered_set visited; - Array post_order; + ffi::Array post_order; for (Operation op : roots) { PostDFSOrder(op, g, &visited, &post_order); } return post_order; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("schedule.CreateReadGraph", CreateReadGraph) - .def("schedule.PostDFSOrder", [](const Array& roots, const ReadGraph& g) { + .def("schedule.PostDFSOrder", [](const ffi::Array& roots, const ReadGraph& g) { return PostDFSOrder(roots, g); }); -}); +} } // namespace te } // namespace tvm diff --git a/src/te/operation/graph.h b/src/te/operation/graph.h index 51ab8e1aa7bb..dc2b211cf3cb 100644 --- a/src/te/operation/graph.h +++ b/src/te/operation/graph.h @@ -33,7 +33,7 @@ namespace te { /*! * \brief data structure of Operation->Tensors it reads */ -using ReadGraph = Map>; +using ReadGraph = ffi::Map>; /*! * \brief Get read graph of each operation to all the @@ -43,7 +43,7 @@ using ReadGraph = Map>; * \param roots The root operation. * \return The result map. */ -ReadGraph CreateReadGraph(const Array& roots); +ReadGraph CreateReadGraph(const ffi::Array& roots); /*! * \brief Get a post DFS ordered of operations in the graph. @@ -54,7 +54,7 @@ ReadGraph CreateReadGraph(const Array& roots); * \note PostDFSOrder is a special case of Topoligical order, * and can be used when topoligical order is needed. */ -Array PostDFSOrder(const Array& roots, const ReadGraph& g); +ffi::Array PostDFSOrder(const ffi::Array& roots, const ReadGraph& g); } // namespace te } // namespace tvm diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 160f89f1eb84..6c7d60841c0f 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -29,7 +29,7 @@ namespace tvm { namespace te { -TVM_FFI_STATIC_INIT_BLOCK({ PlaceholderOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PlaceholderOpNode::RegisterReflection(); } // PlaceholderOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -45,31 +45,31 @@ DataType PlaceholderOpNode::output_dtype(size_t i) const { return dtype; } -Array PlaceholderOpNode::output_shape(size_t i) const { +ffi::Array PlaceholderOpNode::output_shape(size_t i) const { ICHECK_EQ(i, 0U); return shape; } -PlaceholderOp::PlaceholderOp(std::string name, Array shape, DataType dtype) { - auto n = make_object(); +PlaceholderOp::PlaceholderOp(std::string name, ffi::Array shape, DataType dtype) { + auto n = ffi::make_object(); n->name = name; n->shape = shape; n->dtype = dtype; data_ = std::move(n); } -Tensor placeholder(Array shape, DataType dtype, std::string name) { +Tensor placeholder(ffi::Array shape, DataType dtype, std::string name) { return PlaceholderOp(name, shape, dtype).output(0); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.Placeholder", [](Variant> shape_arg, + refl::GlobalDef().def("te.Placeholder", [](ffi::Variant> shape_arg, DataType dtype, std::string name) { - auto shape = [&]() -> Array { + auto shape = [&]() -> ffi::Array { if (auto arg_expr = shape_arg.as()) { return {arg_expr.value()}; - } else if (auto arg_array = shape_arg.as>()) { + } else if (auto arg_array = shape_arg.as>()) { return arg_array.value(); } else { LOG(FATAL) << "Variant did not contain either allowed type"; @@ -77,9 +77,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ }(); return placeholder(shape, dtype, name); }); -}); +} -Array PlaceholderOpNode::InputTensors() const { return {}; } +ffi::Array PlaceholderOpNode::InputTensors() const { return {}; } } // namespace te } // namespace tvm diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index cd621c11dfc7..fbc65e8a61fb 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -30,7 +30,7 @@ namespace tvm { namespace te { using namespace tir; -TVM_FFI_STATIC_INIT_BLOCK({ ScanOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ScanOpNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -42,18 +42,19 @@ int ScanOpNode::num_outputs() const { return static_cast(update.size()); } DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } -Array ScanOpNode::output_shape(size_t i) const { +ffi::Array ScanOpNode::output_shape(size_t i) const { ICHECK_LT(i, state_placeholder.size()); return state_placeholder[i]->shape; } -ScanOp::ScanOp(std::string name, std::string tag, Optional> attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array inputs) { +ScanOp::ScanOp(std::string name, std::string tag, + ffi::Optional> attrs, IterVar axis, + ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array inputs) { if (!attrs.defined()) { - attrs = Map(); + attrs = ffi::Map(); } - auto n = make_object(); + auto n = ffi::make_object(); ICHECK_EQ(init.size(), update.size()); ICHECK_EQ(init.size(), state_placeholder.size()); arith::Analyzer analyzer; @@ -99,32 +100,34 @@ ScanOp::ScanOp(std::string name, std::string tag, Optional data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "te.ScanOp", [](std::string name, std::string tag, Optional> attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array inputs) { + "te.ScanOp", + [](std::string name, std::string tag, ffi::Optional> attrs, + IterVar axis, ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array inputs) { return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); }); -}); +} -Array scan(Array init, Array update, Array state_placeholder, - Array inputs, std::string name, std::string tag, - Optional> attrs) { +ffi::Array scan(ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array inputs, + std::string name, std::string tag, + ffi::Optional> attrs) { IterVar scan_axis = IterVar(Range::FromMinExtent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), Var(name + ".idx"), kOrdered); Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); - Array res; + ffi::Array res; for (int i = 0; i < op->num_outputs(); ++i) { res.push_back(op.output(i)); } return res; } -Array ScanOpNode::InputTensors() const { - Array ret; +ffi::Array ScanOpNode::InputTensors() const { + ffi::Array ret; for (Tensor t : init) { ret.push_back(t); } diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 06dc0ccbc92c..8035564b27f4 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -37,7 +37,7 @@ void TensorNode::RegisterReflection() { .def_ro("value_index", &TensorNode::value_index); } -TVM_FFI_STATIC_INIT_BLOCK({ TensorNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TensorNode::RegisterReflection(); } IterVar thread_axis(Range dom, std::string tag) { return IterVar(dom, Var(tag, dom.defined() ? dom->extent.dtype() : DataType::Int(32)), @@ -51,8 +51,9 @@ IterVar reduce_axis(Range dom, std::string name) { Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor -inline PrimExpr Tensor::IndexTensor(Array indices, bool support_negative_indices) const { - Array shape = (*this)->shape; +inline PrimExpr Tensor::IndexTensor(ffi::Array indices, + bool support_negative_indices) const { + ffi::Array shape = (*this)->shape; if (shape.size() != 0) { ICHECK_EQ(shape.size(), indices.size()) @@ -70,30 +71,32 @@ inline PrimExpr Tensor::IndexTensor(Array indices, bool support_negati return ProducerLoad((*this), indices); } -PrimExpr Tensor::operator()(Array indices) const { - Array arr(indices.begin(), indices.end()); +PrimExpr Tensor::operator()(ffi::Array indices) const { + ffi::Array arr(indices.begin(), indices.end()); return operator()(arr); } -PrimExpr Tensor::operator()(Array indices) const { return IndexTensor(indices, false); } +PrimExpr Tensor::operator()(ffi::Array indices) const { + return IndexTensor(indices, false); +} -PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { - Array arr(indices.begin(), indices.end()); +PrimExpr Tensor::IndexWithNegativeIndices(ffi::Array indices) const { + ffi::Array arr(indices.begin(), indices.end()); return IndexWithNegativeIndices(arr); } -PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { +PrimExpr Tensor::IndexWithNegativeIndices(ffi::Array indices) const { return IndexTensor(indices, true); } -String TensorNode::GetNameHint() const { +ffi::String TensorNode::GetNameHint() const { return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } -PrimExpr TensorNode::ToPrimExpr() const { return GetRef(this)(); } +PrimExpr TensorNode::ToPrimExpr() const { return ffi::GetRef(this)(); } Tensor Operation::output(size_t i) const { - auto node = make_object(); + auto node = ffi::make_object(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); @@ -101,8 +104,8 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { - auto n = make_object(); +Tensor::Tensor(ffi::Array shape, DataType dtype, Operation op, int value_index) { + auto n = ffi::make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; @@ -110,13 +113,13 @@ Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_in data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.Tensor", - [](Array shape, DataType dtype, Operation op, int value_index) { - return Tensor(shape, dtype, op, value_index); - }); -}); + refl::GlobalDef().def( + "te.Tensor", [](ffi::Array shape, DataType dtype, Operation op, int value_index) { + return Tensor(shape, dtype, op, value_index); + }); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -125,7 +128,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Other tensor ops. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("te.TensorEqual", &Tensor::operator==) @@ -137,7 +140,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](Operation op, int64_t output) { return op.output(static_cast(output)); }) .def_method("te.OpNumOutputs", &OperationNode::num_outputs) .def_method("te.OpInputTensors", &OperationNode::InputTensors); -}); +} } // namespace te } // namespace tvm diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 2503d12df195..531db7d5c7b9 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -40,21 +40,21 @@ namespace tir { */ class BlockReadWriteDetector : public StmtExprVisitor { public: - explicit BlockReadWriteDetector(const Map& buffer_var_map) + explicit BlockReadWriteDetector(const ffi::Map& buffer_var_map) : buffer_var_map_(buffer_var_map) {} /*! \brief Return read regions of the block */ - Array CollectReads( + ffi::Array CollectReads( const std::unordered_set* excluded_buffers = nullptr); /*! \brief Return write regions of the block */ - Array CollectWrites( + ffi::Array CollectWrites( const std::unordered_set* excluded_buffers = nullptr); /*! * \brief Return opaque buffer regions of the block * \note The buffer accessed by load/store or call with buffer.data will * be marked as opaque. */ - Array CollectOpaques(); + ffi::Array CollectOpaques(); /*! \brief overload operator() to make sure it accepts a block node */ void operator()(const Stmt& stmt); @@ -78,7 +78,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { /*! \brief The opaque regions of the current block */ std::vector> opaque_regions_; /*! \brief The outside buffer data mapping to its buffer */ - Map buffer_var_map_; + ffi::Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; /*! \brief let bindings inside the block */ @@ -97,7 +97,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { Buffer buffer, std::vector region); /*! \brief Helper function to collect access regions. */ - Array CollectRegions( + ffi::Array CollectRegions( const std::vector& buffers, const std::vector>& regions, const std::unordered_set* excluded_buffers = nullptr); @@ -136,21 +136,21 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) { StmtExprVisitor::operator()(stmt); } -Array BlockReadWriteDetector::CollectReads( +ffi::Array BlockReadWriteDetector::CollectReads( const std::unordered_set* excluded_buffers) { return CollectRegions(read_buffers_, read_regions_, excluded_buffers); } -Array BlockReadWriteDetector::CollectWrites( +ffi::Array BlockReadWriteDetector::CollectWrites( const std::unordered_set* excluded_buffers) { return CollectRegions(writes_buffers_, write_regions_, excluded_buffers); } -Array BlockReadWriteDetector::CollectOpaques() { +ffi::Array BlockReadWriteDetector::CollectOpaques() { return CollectRegions(opaque_buffers_, opaque_regions_); } -void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } +void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(ffi::GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; @@ -198,7 +198,7 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { const VarNode* buffer_var = op->args[1].as(); const IntImmNode* access_mask = op->args[4].as(); if (buffer_var && access_mask) { - auto it = buffer_var_map_.find(GetRef(buffer_var)); + auto it = buffer_var_map_.find(ffi::GetRef(buffer_var)); if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); @@ -208,12 +208,13 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { for (const Range& range : region) { int_set.push_back(arith::EvalSet(range, dom_map_)); } - // read access, write access or opaque access - if ((access_mask->value & 1) && (access_mask->value & 2)) { - Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); - } else if (access_mask->value & 1) { + // Conservatively treat rw_mask as the union of reads and writes. + // This avoids forcing TVM Script users to manually annotate access + // regions for common patterns (e.g., atomic read-modify-write). + if (access_mask->value & 1) { Update(&read_buffers_, &read_regions_, buffer, int_set); - } else if (access_mask->value & 2) { + } + if (access_mask->value & 2) { Update(&writes_buffers_, &write_regions_, buffer, int_set); } } @@ -279,6 +280,7 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { } Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region); } + StmtVisitor::VisitStmt_(op); } std::vector BlockReadWriteDetector::ConvertMatchedRegion( @@ -320,7 +322,12 @@ void BlockReadWriteDetector::Update(std::vector* buffers, if ((*buffers)[i].same_as(buffer)) { ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer dimension"; for (size_t j = 0; j < region.size(); ++j) { - (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]}); + try { + (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]}); + } catch (const std::exception& e) { + // Fallback to full region for this dimension if Union fails + (*regions)[i][j] = arith::IntSet::FromRange(Range::FromMinExtent(0, buffer->shape[j])); + } } return; } @@ -329,26 +336,33 @@ void BlockReadWriteDetector::Update(std::vector* buffers, regions->push_back(std::move(region)); } -Array BlockReadWriteDetector::CollectRegions( +ffi::Array BlockReadWriteDetector::CollectRegions( const std::vector& buffers, const std::vector>& regions, const std::unordered_set* excluded_buffers) { ICHECK_EQ(buffers.size(), regions.size()); - Array res; + ffi::Array res; res.reserve(buffers.size()); for (size_t i = 0; i < regions.size(); ++i) { if (excluded_buffers != nullptr && excluded_buffers->count(buffers[i].get())) { continue; } - Array region; + ffi::Array region; region.reserve(regions[i].size()); ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - if (range.CanProveSinglePoint(&ana_)) { - PrimExpr min = range.min(); - region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); - } else { - region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + // Try to prove single point access, fallback to cover range if analysis fails + // (e.g., due to divide-by-zero in symbolic simplification) + try { + if (range.CanProveSinglePoint(&ana_)) { + PrimExpr min = range.min(); + region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); + } else { + region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + } + } catch (const std::exception& e) { + // Fallback to full buffer range if symbolic analysis fails + region.push_back(Range::FromMinExtent(0, buffers[i]->shape[j])); } } res.push_back(BufferRegion(buffers[i], region)); @@ -371,11 +385,11 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { } } -Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map) { +ffi::Array> GetBlockAccessRegion( + const Block& block, const ffi::Map& buffer_var_map) { BlockReadWriteDetector detector(buffer_var_map); detector(block); - Array writes = detector.CollectWrites(); + ffi::Array writes = detector.CollectWrites(); std::unordered_set excluded_buffers; // exclude write buffers from read regions for reductions if init block is defined. if (block->init.defined()) { @@ -383,27 +397,27 @@ Array> GetBlockAccessRegion(const Block& block, excluded_buffers.insert(write_access->buffer.get()); } } - Array reads = detector.CollectReads(&excluded_buffers); - Array opaques = detector.CollectOpaques(); + ffi::Array reads = detector.CollectReads(&excluded_buffers); + ffi::Array opaques = detector.CollectOpaques(); return {reads, writes, opaques}; } -Array> GetBlockReadWriteRegion(const Block& block, - const Map& buffer_var_map) { +ffi::Array> GetBlockReadWriteRegion( + const Block& block, const ffi::Map& buffer_var_map) { BlockReadWriteDetector detector(buffer_var_map); detector(block); - Array opaques = detector.CollectOpaques(); + ffi::Array opaques = detector.CollectOpaques(); std::unordered_set excluded_buffers; for (const BufferRegion& opaque_access : opaques) { excluded_buffers.insert(opaque_access->buffer.get()); } - Array writes = detector.CollectWrites(&excluded_buffers); + ffi::Array writes = detector.CollectWrites(&excluded_buffers); if (block->init.defined()) { for (const BufferRegion& write_access : writes) { excluded_buffers.insert(write_access->buffer.get()); } } - Array reads = detector.CollectReads(&excluded_buffers); + ffi::Array reads = detector.CollectReads(&excluded_buffers); for (const BufferRegion& opaque_access : opaques) { reads.push_back(opaque_access); writes.push_back(opaque_access); @@ -411,12 +425,12 @@ Array> GetBlockReadWriteRegion(const Block& block, return {reads, writes}; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.analysis.GetBlockAccessRegion", GetBlockAccessRegion) .def("tir.analysis.GetBlockReadWriteRegion", GetBlockReadWriteRegion); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 2ecd32b65a2e..f8665362fa5e 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -42,7 +42,7 @@ namespace tir { */ class LCADetector : public StmtExprVisitor { public: - static Map> Detect(const PrimFunc& func) { + static ffi::Map> Detect(const PrimFunc& func) { LCADetector detector; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -60,16 +60,39 @@ class LCADetector : public StmtExprVisitor { detector.UpdateWithBlockidx(); // Prepare the return - Map> buffer_lca; + ffi::Map> buffer_lca; for (const auto& kv : detector.buffer_lca_) { - const Buffer& buffer = GetRef(kv.first); - const Optional stmt = - kv.second ? GetRef>(kv.second->stmt) : std::nullopt; + const Buffer& buffer = ffi::GetRef(kv.first); + const ffi::Optional stmt = + kv.second ? ffi::GetRef>(kv.second->stmt) : std::nullopt; buffer_lca.Set(buffer, stmt); } return buffer_lca; } + static ffi::Map> DetectVar(const PrimFunc& func) { + LCADetector detector; + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); + } + + ScopeInfo root(nullptr, nullptr, 0); + detector.ancestor_scopes_.push_back(&root); + + detector(func->body); + + // Prepare the return + ffi::Map> var_lca; + for (const auto& kv : detector.buffer_var_lca_) { + const Var& var = ffi::GetRef(kv.first); + const ffi::Optional stmt = + kv.second ? ffi::GetRef>(kv.second->stmt) : std::nullopt; + var_lca.Set(var, stmt); + } + return var_lca; + } + private: /*! * \brief The AST node information for querying LCA. @@ -271,6 +294,7 @@ class LCADetector : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); } void VisitBufferVar(const VarNode* op) { + UpdateVarLCA(op, ancestor_scopes_.back()); auto it = buffer_var_map_.find(op); if (it != buffer_var_map_.end()) { UpdateBufferLCA(it->second, ancestor_scopes_.back()); @@ -279,6 +303,8 @@ class LCADetector : public StmtExprVisitor { void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) { buffer_var_map_.emplace(buffer->data.get(), buffer); + // Also record LCA for the underlying data var to capture BufferLoad/Store cases. + UpdateVarLCA(buffer->data.get(), scope); if (match_buffers_.find(buffer) == match_buffers_.end()) { // Ingore buffer created by block match_buffer const ScopeInfo*& lca = buffer_lca_[buffer]; @@ -286,10 +312,15 @@ class LCADetector : public StmtExprVisitor { } } + void UpdateVarLCA(const VarNode* var, const ScopeInfo* scope) { + const ScopeInfo*& lca = buffer_var_lca_[var]; + lca = LowestCommonAncestor(lca, scope); + } + void UpdateWithBlockidx() { for (const auto& it : buffer_lca_) { const runtime::StorageScope& scope = - runtime::StorageScope::Create(GetRef(it.first).scope()); + runtime::StorageScope::Create(ffi::GetRef(it.first).scope()); if (scope.rank == runtime::StorageRank::kGlobal) { const ScopeInfo*& lca = buffer_lca_[it.first]; for (const ScopeInfo* blockidx_scope : blockidx_scopes_) { @@ -333,6 +364,8 @@ class LCADetector : public StmtExprVisitor { std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; + /*! \brief The map from Buffer data var to its LCA ForNode/BlockNode. */ + std::unordered_map buffer_var_lca_ = {}; /*! \brief The match buffers inside blocks. */ std::unordered_set match_buffers_ = {}; /*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */ @@ -343,13 +376,18 @@ class LCADetector : public StmtExprVisitor { support::Arena arena_; }; -Map> DetectBufferAccessLCA(const PrimFunc& func) { +ffi::Map> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } -TVM_FFI_STATIC_INIT_BLOCK({ +ffi::Map> DetectBufferVarAccessLCA(const PrimFunc& func) { + return LCADetector::DetectVar(func); +} + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.detect_buffer_access_lca", DetectBufferAccessLCA); -}); + refl::GlobalDef().def("tir.analysis.detect_buffer_var_access_lca", DetectBufferVarAccessLCA); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index feaa491cc8a2..557f42c5ba10 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -41,7 +41,7 @@ template class AllocationCalculator : public StmtExprVisitor { public: AllocationCalculator() = default; - tvm::Map operator()(const PrimFunc& func); + tvm::ffi::Map operator()(const PrimFunc& func); private: void VisitStmt_(const T* op) override; @@ -50,11 +50,11 @@ class AllocationCalculator : public StmtExprVisitor { }; template -tvm::Map AllocationCalculator::operator()(const PrimFunc& func) { +tvm::ffi::Map AllocationCalculator::operator()(const PrimFunc& func) { this->VisitStmt(func->body); - tvm::Map res; + tvm::ffi::Map res; for (auto [k, v] : _max_size) { - res.Set(String(k), Integer(v)); + res.Set(ffi::String(k), Integer(v)); } return res; } @@ -80,28 +80,30 @@ void AllocationCalculator::VisitStmt_(const T* op) { _current_size[storage_scope] -= size; } -tvm::Map > CalculateAllocatedBytes(const PrimFunc& func) { - tvm::Map > results; +tvm::ffi::Map > CalculateAllocatedBytes( + const PrimFunc& func) { + tvm::ffi::Map > results; results.Set("main", AllocationCalculator()(func)); return results; } -tvm::Map > CalculateAllocatedBytes(const IRModule& mod) { - tvm::Map > results; +tvm::ffi::Map > CalculateAllocatedBytes( + const IRModule& mod) { + tvm::ffi::Map > results; for (const auto& kv : mod->functions) { if (auto prim_func = kv.second.as()) { - String func_name = kv.first->name_hint; + ffi::String func_name = kv.first->name_hint; results.Set(func_name, AllocationCalculator()(prim_func.value())); } } return results; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.analysis.calculate_allocated_bytes", - [](ObjectRef obj) -> tvm::Map > { + [](ObjectRef obj) -> tvm::ffi::Map > { if (auto func = obj.as()) { return CalculateAllocatedBytes(func.value()); } else if (auto mod = obj.as()) { @@ -112,7 +114,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ throw; } }); -}); +} bool VerifyVTCMLimit(const IRModule& mod, Integer limit) { auto all_sizes = CalculateAllocatedBytes(mod); @@ -144,8 +146,8 @@ int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; } -Array GetVTCMCompactionPasses() { - auto pass_list = Array(); +ffi::Array GetVTCMCompactionPasses() { + auto pass_list = ffi::Array(); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); @@ -160,15 +162,15 @@ Array GetVTCMCompactionPasses() { return pass_list; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.get_vtcm_compaction_passes", []() { return GetVTCMCompactionPasses(); }); -}); +} namespace transform { -Pass VerifyVTCMLimit(Optional default_target) { +Pass VerifyVTCMLimit(ffi::Optional default_target) { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { if (auto opt = kv.second.as()) { @@ -198,10 +200,10 @@ Pass VerifyVTCMLimit(Optional default_target) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.VerifyVTCMLimit", VerifyVTCMLimit); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index a9c2b9ecc609..8d001dd1e459 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -63,14 +63,14 @@ bool HasBufferLoad(PrimExpr expr) { return visitor.found_buffer_load; } -Optional SubstituteParamValues(const Array& param_vars, - const Array& param_values, - const PrimExpr& expr) { +ffi::Optional SubstituteParamValues(const ffi::Array& param_vars, + const ffi::Array& param_values, + const PrimExpr& expr) { ICHECK_EQ(param_vars.size(), param_values.size()) << "Expression was defined as having " << param_vars.size() << " parameters, but received " << param_values.size() << " arguments."; - Map var_map; + ffi::Map var_map; for (size_t i = 0; i < param_values.size(); i++) { var_map.Set(param_vars[i], param_values[i]); } @@ -151,7 +151,7 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { public: using Parent = IRMutatorWithAnalyzer; - BufferConstraintApply(const Map>& axis_var_lookup, + BufferConstraintApply(const ffi::Map>& axis_var_lookup, const std::vector& knowns, Analyzer* analyzer) : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {} @@ -163,10 +163,10 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { continue; } - Optional lane_var = std::nullopt; + ffi::Optional lane_var = std::nullopt; IntImm num_lanes; - Array indices = op->indices.Map([&](const auto& index) { + ffi::Array indices = op->indices.Map([&](const auto& index) { if (index.dtype().lanes() == 1) { return index; } else { @@ -192,11 +192,11 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { } } - return GetRef(op); + return ffi::GetRef(op); } private: - const Map>& axis_var_lookup_; + const ffi::Map>& axis_var_lookup_; const std::vector& knowns_; }; @@ -339,13 +339,13 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { void VisitExpr_(const BufferLoadNode* op) override { Parent::VisitExpr_(op); - BufferLoad load = GetRef(op); + BufferLoad load = ffi::GetRef(op); VisitAccess(load, BufferTouch::AccessType::Read, load); } void VisitStmt_(const BufferStoreNode* op) override { Parent::VisitStmt_(op); - VisitAccess(GetRef(op), BufferTouch::AccessType::Write, op->value); + VisitAccess(ffi::GetRef(op), BufferTouch::AccessType::Write, op->value); // Appending a control block ensures that all control blocks have // at most one statement that changes the buffer contents. auto prev_block = CurrentControlBlock(); @@ -554,7 +554,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { With analyzer_context; size_t old_num_constraints{0}; size_t new_num_constraints{0}; - Optional assume{std::nullopt}; + ffi::Optional assume{std::nullopt}; // Disable default-generated copy/move assignment and constructors InternalConstraintContext(const InternalConstraintContext&) = delete; @@ -623,7 +623,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { // binding. When making a predicate in terms of the buffer indices, // these need to be substituted out. // std::unordered_map let_bindings_using_loop_; - Map let_bindings_using_loop_; + ffi::Map let_bindings_using_loop_; // Track in order to know what conditions limit the buffer access std::vector conditions_; @@ -635,17 +635,17 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { ControlFlowGraph* out_; }; -std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( - const tir::Buffer& buf, Array index_variables, Array indices, +std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( + const tir::Buffer& buf, ffi::Array index_variables, ffi::Array indices, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { const auto& current_block = *this; Analyzer local_analyzer; - Optional lane_var = std::nullopt; + ffi::Optional lane_var = std::nullopt; IntImm num_lanes; - Array index_expressions = indices.Map([&](const auto& index) { + ffi::Array index_expressions = indices.Map([&](const auto& index) { if (index.dtype().lanes() == 1) { return index; } else { @@ -656,9 +656,9 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make } }); - Array loop_vars; + ffi::Array loop_vars; - Map loop_ranges; + ffi::Map loop_ranges; for (const auto& loop_entry : current_block.active_loop_iterators) { loop_vars.push_back(loop_entry.loop_var); loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range); @@ -675,7 +675,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make IntConstraintsTransform transform = [&]() { ICHECK_EQ(index_variables.size(), index_expressions.size()); - Array relations; + ffi::Array relations; for (size_t i = 0; i < index_expressions.size(); i++) { PrimExpr expr = index_expressions[i]; @@ -689,16 +689,16 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make return arith::SolveLinearEquations(system); }(); - Map loop_var_to_axis_var = transform->src_to_dst; - Map free_params = transform->dst->ranges; + ffi::Map loop_var_to_axis_var = transform->src_to_dst; + ffi::Map free_params = transform->dst->ranges; PrimExpr transform_predicate = std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(), PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; }); transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer); - auto find_removable_params = [&]() -> Map { - Map removable_params; + auto find_removable_params = [&]() -> ffi::Map { + ffi::Map removable_params; // The arith::SolveLinearEquations is more general than the // utilities in iter_affine_map.h, but can introduce free @@ -712,13 +712,13 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make return; } - Var var = GetRef(var_ptr); + Var var = ffi::GetRef(var_ptr); if (free_params.count(var) == 0) { return; } - bool uses_free_param = - UsesVar(b, [&](const VarNode* v) { return free_params.count(GetRef(v)) > 0; }); + bool uses_free_param = UsesVar( + b, [&](const VarNode* v) { return free_params.count(ffi::GetRef(v)) > 0; }); if (uses_free_param) { return; } @@ -746,7 +746,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make return local_analyzer.Simplify(Substitute(expr, removable_params)); }; - Map new_map; + ffi::Map new_map; for (const auto [loop_var, expr] : loop_var_to_axis_var) { static_cast(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 new_map.Set(loop_var, update(expr)); @@ -808,7 +808,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph, const tir::Buffer& buf, - const Array& indices, + const ffi::Array& indices, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { ICHECK(graph); @@ -949,7 +949,7 @@ std::ostream& operator<<(std::ostream& os, const BufferState& state) { } PrimExpr BufferState::SubstituteKnownBufferValues( - PrimExpr expr, const Map>& axis_var_lookup, + PrimExpr expr, const ffi::Map>& axis_var_lookup, Analyzer* analyzer) const { BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer); return mutator(std::move(expr)); @@ -961,7 +961,7 @@ void BufferState::AddCondition(const PrimExpr& condition) { } } -void BufferState::Substitute(const Map& var_remap, Analyzer* analyzer) { +void BufferState::Substitute(const ffi::Map& var_remap, Analyzer* analyzer) { if (var_remap.size()) { for (auto& prior : constraints_) { PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap); @@ -1026,12 +1026,12 @@ class BufferRegionCollector : public ExprVisitor { public: struct Region { PrimExpr region_predicate; - std::unordered_map> known_values; + std::unordered_map> known_values; }; - static std::vector Collect(const Map>& axis_var_lookup, + static std::vector Collect(const ffi::Map>& axis_var_lookup, const std::vector& knowns, - const std::vector>& exprs, + const std::vector>& exprs, Analyzer* analyzer) { BufferRegionCollector collector(axis_var_lookup, knowns, analyzer); for (const auto& expr : exprs) { @@ -1046,7 +1046,7 @@ class BufferRegionCollector : public ExprVisitor { private: using Parent = ExprVisitor; - BufferRegionCollector(const Map>& axis_var_lookup, + BufferRegionCollector(const ffi::Map>& axis_var_lookup, const std::vector& knowns, Analyzer* analyzer) : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) { regions_.push_back(Region{Bool(true), {}}); @@ -1058,7 +1058,7 @@ class BufferRegionCollector : public ExprVisitor { // Helper struct for the known values of this BufferLoad struct Known { PrimExpr predicate; - Optional value; + ffi::Optional value; }; std::vector new_regions; @@ -1077,7 +1077,7 @@ class BufferRegionCollector : public ExprVisitor { touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_); if (!is_zero(touch_predicate)) { - Optional known_value = + ffi::Optional known_value = SubstituteParamValues(axis_vars, op->indices, constraint.value); new_regions.push_back(Known{touch_predicate, known_value}); @@ -1112,14 +1112,14 @@ class BufferRegionCollector : public ExprVisitor { Analyzer* analyzer_; std::vector regions_; - const Map>& axis_var_lookup_; + const ffi::Map>& axis_var_lookup_; const std::vector& knowns_; }; class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { public: static PrimExpr Apply( - const std::unordered_map>& known_values, + const std::unordered_map>& known_values, PrimExpr expr, Analyzer* analyzer) { BufferRegionValueReplacer mutator(known_values, analyzer); PrimExpr result = mutator(expr); @@ -1134,7 +1134,7 @@ class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { using Parent = IRMutatorWithAnalyzer; BufferRegionValueReplacer( - const std::unordered_map>& known_values, + const std::unordered_map>& known_values, Analyzer* analyzer) : Parent(analyzer), known_values_(known_values) {} @@ -1145,17 +1145,17 @@ class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { if (it != known_values_.end() && it->second) { return it->second.value(); } else { - return GetRef(op); + return ffi::GetRef(op); } } - const std::unordered_map>& known_values_; + const std::unordered_map>& known_values_; }; -void BufferState::ApplyTouches(const Map>& axis_var_lookup, +void BufferState::ApplyTouches(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, Analyzer* analyzer) { std::vector new_knowns; - Map keep_prior_known_at; + ffi::Map keep_prior_known_at; for (auto& touch : touch_points) { if (touch.touch_type == BufferTouch::AccessType::Read) { @@ -1209,7 +1209,7 @@ void BufferState::ApplyTouches(const Map>& axis_var_lookup, for (size_t i = 0; i < new_knowns.size(); i++) { if (new_knowns[i].buffer.same_as(constraint.buffer)) { - Optional overwritten_with = new_knowns[i].value; + ffi::Optional overwritten_with = new_knowns[i].value; if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) { expand_known_at = SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer); @@ -1237,18 +1237,18 @@ void BufferState::ApplyTouches(const Map>& axis_var_lookup, constraints_.end()); } -void BufferState::BackpropUnusedIndices(const Map>& axis_var_lookup, +void BufferState::BackpropUnusedIndices(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, Analyzer* analyzer) { std::vector new_knowns; - Map keep_prior_known_at; + ffi::Map keep_prior_known_at; - Map regions_written; - Map regions_read; + ffi::Map regions_written; + ffi::Map regions_read; for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) { const auto& touch = *it; - Map* to_update{nullptr}; + ffi::Map* to_update{nullptr}; if (touch.touch_type == BufferTouch::AccessType::Write) { to_update = ®ions_written; @@ -1264,7 +1264,7 @@ void BufferState::BackpropUnusedIndices(const Map>& axis_var_ } auto update_map = [&](auto& map) { - Map new_map; + ffi::Map new_map; for (auto [buffer, predicate] : map) { new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer)); } @@ -1303,7 +1303,7 @@ void BufferState::BackpropUnusedIndices(const Map>& axis_var_ constraints_.end()); } -void BufferState::RemoveFreeParameters(const Map& free_predicate_parameters, +void BufferState::RemoveFreeParameters(const ffi::Map& free_predicate_parameters, Analyzer* analyzer) { for (auto& known : constraints_) { known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters); @@ -1325,7 +1325,7 @@ bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) c return true; } -Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { +ffi::Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { return (*it).second; } else { @@ -1333,12 +1333,13 @@ Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) cons } } -Array ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array& indices) { +ffi::Array ControlFlowGraph::GetIndexVariables(const Buffer& buf, + const ffi::Array& indices) { if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { return (*it).second; } - Array vars; + ffi::Array vars; for (size_t i = 0; i < indices.size(); i++) { std::stringstream ss; ss << buf->name << "_axis_" << i; @@ -1620,7 +1621,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, const Stmt& context) const { - Optional> index_variables = GetIndexVariables(store->buffer); + ffi::Optional> index_variables = GetIndexVariables(store->buffer); if (!index_variables) { return false; } diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index f4babffbb74c..7bde341c38fa 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -186,7 +186,7 @@ class BufferState { * the original expression is returned. */ PrimExpr SubstituteKnownBufferValues(PrimExpr expr, - const Map>& axis_var_lookup, + const ffi::Map>& axis_var_lookup, arith::Analyzer* analyzer) const; /*! \brief Apply a condition to all known constraints @@ -205,7 +205,7 @@ class BufferState { * * \param var_remap The variable remapping to apply. */ - void Substitute(const Map& var_remap, arith::Analyzer* analyzer); + void Substitute(const ffi::Map& var_remap, arith::Analyzer* analyzer); /*! \brief Simplify the predicate of all constraints * @@ -226,7 +226,7 @@ class BufferState { * * \param analyzer The analyzer to use for simplifications */ - void ApplyTouches(const Map>& axis_var_lookup, + void ApplyTouches(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, arith::Analyzer* analyzer); /*! \brief Update unused buffer locations based on buffer touches @@ -245,7 +245,7 @@ class BufferState { * * \param analyzer The analyzer to use for simplifications */ - void BackpropUnusedIndices(const Map>& axis_var_lookup, + void BackpropUnusedIndices(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, arith::Analyzer* analyzer); @@ -255,7 +255,7 @@ class BufferState { * * \param analyzer The analyzer with which to simplify after removal */ - void RemoveFreeParameters(const Map& free_predicate_parameters, + void RemoveFreeParameters(const ffi::Map& free_predicate_parameters, arith::Analyzer* analyzer); /*! \brief Check if two buffer states are equivalent @@ -462,7 +462,7 @@ class ControlFlowGraph { * * \returns Variables representing a position along the buffer's axis. */ - Array GetIndexVariables(const Buffer& buf, const Array& indices); + ffi::Array GetIndexVariables(const Buffer& buf, const ffi::Array& indices); /*! \brief Return index variables representing locations within a * buffer, if they have been generated before. @@ -473,7 +473,7 @@ class ControlFlowGraph { * * \returns Variables representing a position along the buffer's axis. */ - Optional> GetIndexVariables(const Buffer& buf) const; + ffi::Optional> GetIndexVariables(const Buffer& buf) const; /*! \brief Propagate known values from known BufferStore/assume * subsequent control flow blocks @@ -501,7 +501,7 @@ class ControlFlowGraph { * e.g. Replacing loop iterator `i` with `i-1` when following an * edge from the end of a loop to the beginning of the loop. */ - Map var_remap; + ffi::Map var_remap; /*! \brief Condition that must to true after following this edge * @@ -509,7 +509,7 @@ class ControlFlowGraph { * loop_min` when following the an edge from the end of a loop to * the beginning of the loop. */ - Optional post_condition; + ffi::Optional post_condition; }; friend std::ostream& operator<<(std::ostream& os, const ControlFlowEdge& edge); @@ -525,7 +525,7 @@ class ControlFlowGraph { std::vector active_loop_iterators; /*! \brief Loop-dependent Let bindings that may appear within the block */ - Map let_bindings_using_loop; + ffi::Map let_bindings_using_loop; /*! \brief Predicate that must be true to have reached this block */ PrimExpr scope_predicate{Bool(true)}; @@ -577,7 +577,8 @@ class ControlFlowGraph { * \returns The newly generated BufferTouch */ BufferTouch MakeBufferTouch(ControlFlowGraph* graph, const Buffer& buf, - const Array& indices, BufferTouch::AccessType touch_type, + const ffi::Array& indices, + BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const; /* \brief Construct a BufferTouch instance as if it occurred in @@ -602,11 +603,11 @@ class ControlFlowGraph { * all free parameters that may occur in the BufferTouch's * predicate. */ - std::pair> MakeBufferTouch(const Buffer& buf, - Array index_variables, - Array indices, - BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const; + std::pair> MakeBufferTouch(const Buffer& buf, + ffi::Array index_variables, + ffi::Array indices, + BufferTouch::AccessType touch_type, + PrimExpr known_value_expr) const; }; friend std::ostream& operator<<(std::ostream& os, const ControlFlowBlock& pattern); @@ -629,10 +630,10 @@ class ControlFlowGraph { * the free parameters allows them to be removed later, by requiring * a predicate to be true for all values of the free parameters. */ - Map free_predicate_parameters_; + ffi::Map free_predicate_parameters_; /*! \brief Ranges of iterators found in the analyzed statement */ - Map iterator_ranges_; + ffi::Map iterator_ranges_; /* \brief A map from buffer to the variables representing positions * along the buffer's axes. @@ -642,7 +643,7 @@ class ControlFlowGraph { * variables to represent the buffer's axes, reducing the amount of * variable substitution required. */ - Map> axis_var_lookup_; + ffi::Map> axis_var_lookup_; /* \brief Assumptions that do not depend on buffer values * diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 5d85ef31e88e..60a3e0d448d2 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -66,7 +66,7 @@ class ExprDeepEqualChecker : private ExprFunctor& lhs, const Array& rhs) { + bool ArrayDeepEqual(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); i++) { if (!VisitExpr(lhs[i], rhs[i])) return false; @@ -74,7 +74,7 @@ class ExprDeepEqualChecker : private ExprFunctor& lhs, const Array& rhs) { + bool ArrayDeepEqual(const ffi::Array& lhs, const ffi::Array& rhs) { // for iter var, we require pointer equality if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); i++) { @@ -83,7 +83,7 @@ class ExprDeepEqualChecker : private ExprFunctor& lhs, const Optional& rhs) { + bool OptionalDeepEqual(const ffi::Optional& lhs, const ffi::Optional& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() && rhs.defined()) return false; if (lhs.defined() && !rhs.defined()) return false; @@ -196,12 +196,12 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { return ExprDeepEqualChecker::Check(lhs, rhs); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.analysis.expr_deep_equal", [](const PrimExpr& lhs, const PrimExpr& rhs) { return ExprDeepEqual()(lhs, rhs); }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 3f012d5f15af..3fe33cdf2af2 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -37,7 +37,7 @@ int32_t DataType2Int(const tvm::DataType& dtype) { return converter.dst; } -String Int2DataTypeStr(int32_t dtype) { +ffi::String Int2DataTypeStr(int32_t dtype) { union { DLDataType dst; int32_t src; @@ -193,10 +193,20 @@ class FlopEstimator : private ExprFunctor, return cond; } + TResult VisitStmt_(const AssertStmtNode* op) override { + TResult result = VisitExpr(op->condition); + if (op->message.defined()) { + result += VisitExpr(op->message); + } + result += VisitStmt(op->body); + return result; + } + TResult VisitExpr_(const VarNode* op) override { return TResult(); } TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); } TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } + TResult VisitExpr_(const StringImmNode* op) override { return TResult(); } TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } @@ -247,7 +257,7 @@ double EstimateTIRFlops(const IRModule& mod) { return PostprocessResults(result) + cached_result; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.EstimateTIRFlops", [](ObjectRef obj) -> double { if (auto mod = obj.as()) { @@ -260,7 +270,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ throw; } }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index c23eed2da997..71f92900d892 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -42,8 +42,8 @@ namespace tir { std::variant IdentifyMemCpyImpl(const For& loop, arith::Analyzer* analyzer) { - Map loop_intervals; - Map loop_ranges; + ffi::Map loop_intervals; + ffi::Map loop_ranges; PrimExpr total_loop_iterations = 1; // Walk through the loop nest, stopping at the first loop whose body @@ -82,8 +82,8 @@ std::variant IdentifyMemCpyImpl(const For& loop, // Now, we have a BufferStore whose value is a BufferLoad. Because // non-flat physical indices are target-dependent, only handle cases // where the buffer will be flattened to a 1-d physical buffer. - Array flattened_dst = store->buffer.OffsetOf(store->indices); - Array flattened_src = load->buffer.OffsetOf(load->indices); + ffi::Array flattened_dst = store->buffer.OffsetOf(store->indices); + ffi::Array flattened_src = load->buffer.OffsetOf(load->indices); if (flattened_dst.size() != 1 || flattened_src.size() != 1) { return static_cast( @@ -283,22 +283,22 @@ std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* an } // Expose the IdentifyMemCpy functionality to Python API for purpose of unit testing. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis._identify_memcpy", [](const Stmt& stmt) { - Array output; + ffi::Array output; struct Visitor : arith::IRVisitorWithAnalyzer { - explicit Visitor(Array* output) : output(output) {} - Array* output; + explicit Visitor(ffi::Array* output) : output(output) {} + ffi::Array* output; private: using IRVisitorWithAnalyzer::VisitStmt_; void VisitStmt_(const ForNode* op) override { - For loop = GetRef(op); + For loop = ffi::GetRef(op); auto result = IdentifyMemCpyImpl(loop, &(Visitor::analyzer_)); if (auto* ptr = std::get_if(&result)) { - output->push_back(Array{ptr->source, ptr->dest}); + output->push_back(ffi::Array{ptr->source, ptr->dest}); } else if (auto* ptr = std::get_if(&result)) { output->push_back(StringImm(*ptr)); } else { @@ -314,7 +314,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return output; }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc index 9e85e4cc86c7..a6a3fc4bc7f3 100644 --- a/src/tir/analysis/is_pure_function.cc +++ b/src/tir/analysis/is_pure_function.cc @@ -79,7 +79,7 @@ class PurityChecker : TIRVisitorWithPath { LOG_IF(FATAL, assert_on_error_) << "AssertionError: " << "Pure functions must not contain calls to impure operators, " - << "but " << GetRef(call) << " calls operator " << call->op + << "but " << ffi::GetRef(call) << " calls operator " << call->op << ", which has side effect " << effect; } } @@ -94,10 +94,10 @@ bool IsPureFunction(const PrimFunc& func, bool assert_on_error) { return PurityChecker::Check(func, assert_on_error); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.is_pure_function", IsPureFunction); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index 72626d27188d..06deb7934ad0 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -41,9 +41,9 @@ struct OOBLocation { class OOBError : public ScheduleError { public: OOBError(IRModule mod, std::vector locations) : mod_(mod), locations_(locations) {} - String FastErrorString() const final { return "Out of bound memory access"; } + ffi::String FastErrorString() const final { return "Out of bound memory access"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::stringstream s; for (const auto& oob : locations_) { s << "Out of bounds memory access on buffer " << oob.buf->name << " dimension " @@ -56,7 +56,7 @@ class OOBError : public ScheduleError { return s.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { std::vector locs; for (auto loc : locations_) { locs.push_back(loc.index); @@ -124,10 +124,10 @@ transform::Pass OOBChecker() { return transform::CreatePrimFuncPass(pass_func, 0, "tir.analysis.OOBChecker", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.OOBChecker", OOBChecker); -}); +} } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index 2fe2ce5235a7..9f6f4da7eaf3 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -98,7 +98,7 @@ Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) { } } - LOG(FATAL) << "Enclosing loop not found for a block " << GetRef(block); + LOG(FATAL) << "Enclosing loop not found for a block " << ffi::GetRef(block); TVM_FFI_UNREACHABLE(); } @@ -140,16 +140,16 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) { return nullptr; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.find_anchor_block", [](const IRModule& mod) { auto ret = FindAnchorBlock(mod); if (ret) { - return Optional(GetRef(ret)); + return ffi::Optional(ffi::GetRef(ret)); } - return Optional(std::nullopt); + return ffi::Optional(std::nullopt); }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 95da50204b97..becae607fb39 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -27,7 +27,7 @@ namespace tvm { namespace tir { -VarUseDefAnalyzer::VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent) +VarUseDefAnalyzer::VarUseDefAnalyzer(const ffi::Array& defined_vars, bool visit_thread_extent) : visit_thread_extent_(visit_thread_extent) { for (const Var v : defined_vars) { use_count_[v.get()] = 0; @@ -104,7 +104,7 @@ void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { } void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { - this->HandleUse(GetRef(op)); + this->HandleUse(ffi::GetRef(op)); StmtExprVisitor::VisitExpr_(op); } @@ -123,7 +123,7 @@ void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { void VarUseDefAnalyzer::VisitBuffer(const Buffer& buffer) { this->HandleUse(buffer->data); - auto visit_arr = [&](Array arr) { + auto visit_arr = [&](ffi::Array arr) { for (const auto& element : arr) { this->VisitExpr(element); } @@ -151,7 +151,7 @@ void VarUseDefAnalyzer::HandleUse(const Var& var) { ++it->second; } } else { - undefined_.push_back(GetRef(v)); + undefined_.push_back(ffi::GetRef(v)); use_count_[v] = -1; } } @@ -176,43 +176,43 @@ void VarUseDefAnalyzer::HandleUse(const Buffer& buf) { ++it->second; } } else { - undefined_buffers_.push_back(GetRef(ptr)); + undefined_buffers_.push_back(ffi::GetRef(ptr)); buffer_use_count_[ptr] = -1; } VisitBuffer(buf); } -Array UndefinedVars(const Stmt& stmt, const Array& args) { +ffi::Array UndefinedVars(const Stmt& stmt, const ffi::Array& args) { VarUseDefAnalyzer m(args); m(stmt); return m.undefined_; } -Array UndefinedVars(const PrimExpr& expr) { +ffi::Array UndefinedVars(const PrimExpr& expr) { VarUseDefAnalyzer m({}); m(expr); return m.undefined_; } -Array UndefinedVars(const PrimExpr& expr, const Array& args) { +ffi::Array UndefinedVars(const PrimExpr& expr, const ffi::Array& args) { VarUseDefAnalyzer m(args); m(expr); return m.undefined_; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tir.analysis.UndefinedVars", [](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_stmt = args[0].as()) { - *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); + *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); } else if (auto opt_expr = args[0].as()) { - *rv = UndefinedVars(opt_expr.value(), args[1].cast>()); + *rv = UndefinedVars(opt_expr.value(), args[1].cast>()); } else { LOG(FATAL) << "either UndefinedVars(stmt, args) or UndefinedVars(expr, args) is expected"; } }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 64985b11a9fa..51323d65d5b2 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -40,12 +40,12 @@ namespace tir { */ class VarUseDefAnalyzer : public StmtExprVisitor { public: - explicit VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent = true); + explicit VarUseDefAnalyzer(const ffi::Array& defined_vars, bool visit_thread_extent = true); // The fields are publically readible to // be accessible to the users. bool visit_thread_extent_{true}; - Array undefined_; - Array undefined_buffers_; + ffi::Array undefined_; + ffi::Array undefined_buffers_; std::unordered_map use_count_; std::unordered_map def_count_; diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index c1f8b327ecea..e0273069cc46 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -39,10 +39,11 @@ namespace tir { class GPUCodeVerifier : public StmtExprVisitor { public: - std::vector Verify(Stmt stmt, int64_t max_local_memory_per_block, - int64_t max_shared_memory_per_block, int64_t max_threads_per_block, - int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z, - int64_t max_vthread, int64_t max_vector_bytes, int64_t max_kernels) { + std::vector Verify(Stmt stmt, int64_t max_local_memory_per_block, + int64_t max_shared_memory_per_block, + int64_t max_threads_per_block, int64_t max_thread_x, + int64_t max_thread_y, int64_t max_thread_z, int64_t max_vthread, + int64_t max_vector_bytes, int64_t max_kernels) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); @@ -187,7 +188,7 @@ class GPUCodeVerifier : public StmtExprVisitor { StmtVisitor::VisitStmt_(op); } - void CheckBufferIndicesVectorizable(const Array indices) { + void CheckBufferIndicesVectorizable(const ffi::Array indices) { for (const auto index : indices) { if (const auto* ramp = index.as()) { if (!is_one(ramp->stride) && @@ -263,7 +264,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t max_vector_bytes_; size_t max_kernels_; - std::vector errors_; + std::vector errors_; void Reset_() { local_memory_per_block_ = 0; @@ -274,7 +275,8 @@ class GPUCodeVerifier : public StmtExprVisitor { } }; -std::vector VerifyGPUCode_(const PrimFunc& func, Map constraints) { +std::vector VerifyGPUCode_(const PrimFunc& func, + ffi::Map constraints) { GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; @@ -317,19 +319,19 @@ std::vector VerifyGPUCode_(const PrimFunc& func, Map c max_vthread, max_vector_bytes, max_kernels); } -bool VerifyGPUCode(const PrimFunc& func, Map constraints) { +bool VerifyGPUCode(const PrimFunc& func, ffi::Map constraints) { auto errs = VerifyGPUCode_(func, constraints); return errs.size() == 0; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.verify_gpu_code", VerifyGPUCode); -}); +} namespace transform { -Pass VerifyGPUCode(Map constraints) { +Pass VerifyGPUCode(ffi::Map constraints) { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { if (auto func = kv.second.as()) { @@ -350,10 +352,10 @@ Pass VerifyGPUCode(Map constraints) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.VerifyGPUCode", VerifyGPUCode); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 6a93fa0206d4..a82de34716c8 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -63,7 +63,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } /// Verification result - std::vector Errors() const { return errs_; } + std::vector Errors() const { return errs_; } protected: /// Visitor implementation @@ -158,7 +158,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Status of visitor //@{ bool in_thread_env_{false}; - std::vector errs_; + std::vector errs_; //@} tir::PrimFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type @@ -167,7 +167,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } // namespace /// Interface of VerifyMemory pass -std::vector VerifyMemory_(const PrimFunc& func) { +std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; @@ -187,10 +187,10 @@ std::vector VerifyMemory_(const PrimFunc& func) { bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.verify_memory", VerifyMemory); -}); +} namespace transform { @@ -215,10 +215,10 @@ Pass VerifyMemory() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.VerifyMemory", VerifyMemory); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index 0d5f3f6cb491..eafe28bd63a9 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -81,7 +81,7 @@ class SSAVerifier final : public StmtExprVisitor { } void VisitExpr_(const VarNode* node) final { - auto var = GetRef(node); + auto var = ffi::GetRef(node); if (match_scope_) { MarkDef(var, var, true); } @@ -140,10 +140,10 @@ bool VerifySSA(const PrimFunc& func) { return visitor.is_ssa_; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.verify_ssa", VerifySSA); -}); +} namespace transform { @@ -159,10 +159,10 @@ Pass VerifySSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.VerifySSA", VerifySSA); -}); +} } // namespace transform diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 2efd3648a5bb..c10931d1bd10 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -248,20 +248,25 @@ class UndefinedVarVerifier : public Verifier { bool redefine_is_allowed = redefine_allowed_within_function_.count(var); { auto it = currently_defined_.find(var); - Verify(it == currently_defined_.end() || redefine_is_allowed) - << "ValueError: " - << "TIR is ill-formed, " - << "due to multiple nested definitions of variable " << var - << ". It was first defined at " << it->second << ", and was re-defined at " << path; + auto verify = Verify(it == currently_defined_.end() || redefine_is_allowed); + verify << "ValueError: " + << "TIR is ill-formed, " + << "due to multiple nested definitions of variable " << var << "."; + if (it != currently_defined_.end()) { + verify << " It was first defined at " << it->second << ", and was re-defined at " << path; + } } { auto it = previously_defined_.find(var); - Verify(it == previously_defined_.end() || redefine_is_allowed) - << "ValueError: " - << "TIR is ill-formed, " - << "due to multiple definitions of variable " << var << ". It was first defined at " - << it->second << ", and was later re-defined at " << path; + auto verify = Verify(it == previously_defined_.end() || redefine_is_allowed); + verify << "ValueError: " + << "TIR is ill-formed, " + << "due to multiple definitions of variable " << var << "."; + if (it != previously_defined_.end()) { + verify << " It was first defined at " << it->second << ", and was later re-defined at " + << path; + } } currently_defined_.insert({var, path}); @@ -275,7 +280,7 @@ class UndefinedVarVerifier : public Verifier { } void VisitExpr_(const VarNode* op, AccessPath path) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); auto active_def = currently_defined_.find(var); auto verify = Verify(active_def != currently_defined_.end()); @@ -342,7 +347,7 @@ class SingleEnvThreadVerifier : public Verifier { } } - std::unordered_map> env_thread_vars_; + std::unordered_map> env_thread_vars_; }; bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { @@ -371,7 +376,7 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { return true; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.VerifyWellFormed", [](const ObjectRef& obj, bool assert_mode) { @@ -384,7 +389,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ << obj->GetTypeKey(); } }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/block_dependence_info.cc b/src/tir/ir/block_dependence_info.cc index 87847aed2d88..3cda278d0a71 100644 --- a/src/tir/ir/block_dependence_info.cc +++ b/src/tir/ir/block_dependence_info.cc @@ -24,7 +24,7 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ BlockDependenceInfoNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { BlockDependenceInfoNode::RegisterReflection(); } /** * @brief A helper class to collect and build Block Dependences using BlockScope class @@ -42,7 +42,7 @@ class BlockDependenceInfoCollector : private StmtVisitor { } void MakeBlockScope(StmtSRef scope) { - Array child_block_srefs = std::move(block_frames_.back()); + ffi::Array child_block_srefs = std::move(block_frames_.back()); self_->sref2scope[scope] = BlockScope(child_block_srefs); } @@ -67,13 +67,13 @@ class BlockDependenceInfoCollector : private StmtVisitor { BlockDependenceInfoNode* self_; /*! \brief The stack frames of blocks in the DFS visit. */ - std::vector> block_frames_; + std::vector> block_frames_; }; -BlockDependenceInfo::BlockDependenceInfo() { data_ = make_object(); } +BlockDependenceInfo::BlockDependenceInfo() { data_ = ffi::make_object(); } BlockDependenceInfo::BlockDependenceInfo(IRModule mod) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); BlockDependenceInfoNode* self = n.get(); n->stmt2ref = SRefTreeCreator::Create(mod, /* include_loops */ false); @@ -87,18 +87,18 @@ BlockDependenceInfo::BlockDependenceInfo(IRModule mod) { data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.BlockDependenceInfo", [](IRModule mod) -> BlockDependenceInfo { return BlockDependenceInfo(mod); }) .def_method("tir.BlockDependenceInfoGetBlockScope", &BlockDependenceInfoNode::GetBlockScope) .def("tir.BlockDependenceInfoGetSRef", - [](BlockDependenceInfo self, Stmt stmt) -> Optional { + [](BlockDependenceInfo self, Stmt stmt) -> ffi::Optional { auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); + return it != self->stmt2ref.end() ? it->second : ffi::Optional(std::nullopt); }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc index ba651b953acc..676f162076ce 100644 --- a/src/tir/ir/block_scope.cc +++ b/src/tir/ir/block_scope.cc @@ -23,11 +23,11 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { StmtSRefNode::RegisterReflection(); DependencyNode::RegisterReflection(); BlockScopeNode::RegisterReflection(); -}); +} /******** Utility functions ********/ @@ -52,7 +52,7 @@ void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& ds /******** Constructors ********/ StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmt = stmt; n->parent = parent; n->seq_index = seq_index; @@ -70,19 +70,19 @@ StmtSRef StmtSRef::RootMark() { } Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) { - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->src = std::move(src); node->dst = std::move(dst); node->kind = kind; data_ = std::move(node); } -BlockScope::BlockScope() { data_ = make_object(); } +BlockScope::BlockScope() { data_ = ffi::make_object(); } -BlockScope::BlockScope(const Array& child_block_srefs) { - ObjectPtr n = make_object(); - SMap> buffer_readers; - SMap>& buffer_writers = n->buffer_writers; +BlockScope::BlockScope(const ffi::Array& child_block_srefs) { + ObjectPtr n = ffi::make_object(); + SMap> buffer_readers; + SMap>& buffer_writers = n->buffer_writers; for (const StmtSRef& child_block_sref : child_block_srefs) { const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer @@ -125,7 +125,7 @@ BlockScope::BlockScope(const Array& child_block_srefs) { /******** Dependency ********/ -Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { +ffi::Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { auto iter = this->src2deps.find(block_sref); if (iter != this->src2deps.end()) { return iter->second; @@ -134,7 +134,7 @@ Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const } } -Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { +ffi::Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { auto iter = this->dst2deps.find(block_sref); if (iter != this->dst2deps.end()) { return iter->second; @@ -193,20 +193,22 @@ void SRefTreeCreator::VisitStmt_(const SeqStmtNode* seq_stmt) { /******** FFI ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.StmtSRefStmt", - [](StmtSRef sref) -> Optional { return GetRef>(sref->stmt); }) + [](StmtSRef sref) -> ffi::Optional { + return ffi::GetRef>(sref->stmt); + }) .def("tir.StmtSRefParent", - [](StmtSRef sref) -> Optional { - return GetRef>(sref->parent); + [](StmtSRef sref) -> ffi::Optional { + return ffi::GetRef>(sref->parent); }) .def("tir.StmtSRefRootMark", StmtSRef::RootMark) .def("tir.StmtSRefInlineMark", StmtSRef::InlineMark) .def_method("tir.BlockScopeGetDepsBySrc", &BlockScopeNode::GetDepsBySrc) .def_method("tir.BlockScopeGetDepsByDst", &BlockScopeNode::GetDepsByDst); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1cac41ff3ce5..87b9a2628cc7 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -38,24 +38,25 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ BufferNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { BufferNode::RegisterReflection(); } using IndexMod = tir::FloorModNode; using IndexDiv = tir::FloorDivNode; -Array SimplifyArray(arith::Analyzer* ana, Array array) { +ffi::Array SimplifyArray(arith::Analyzer* ana, ffi::Array array) { for (size_t i = 0; i < array.size(); ++i) { array.Set(i, ana->Simplify(array[i])); } return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, - Optional> axis_separators, Span span) { +Buffer decl_buffer(ffi::Array shape, DataType dtype, ffi::String name, + ffi::String storage_scope, ffi::Optional> axis_separators, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, - Array(), PrimExpr(), name, 0, 0, kDefault, - axis_separators.value_or(Array()), span); + ffi::Array(), PrimExpr(), name, 0, 0, kDefault, + axis_separators.value_or(ffi::Array()), span); } // Split the given expression w.r.t the add operator @@ -250,14 +251,14 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { return no_opt_sum; } -Array Buffer::OffsetOf(Array input_indices) const { +ffi::Array Buffer::OffsetOf(ffi::Array input_indices) const { return (*this)->ElemOffset(std::move(input_indices)); } // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. -Array BufferNode::ElemOffset(Array input_indices) const { +ffi::Array BufferNode::ElemOffset(ffi::Array input_indices) const { ICHECK_EQ(shape.size(), input_indices.size()) << "Buffer " << this->name << " is " << shape.size() << "-dimensional, cannot be indexed with the " << input_indices.size() @@ -272,7 +273,7 @@ Array BufferNode::ElemOffset(Array input_indices) const { // TODO(Lunderberg): Better handling for cases where there is more // than one output index. Currently, this only allows elem_offset // to be non-zero for flat memory allocations. - Array elem_offsets = {}; + ffi::Array elem_offsets = {}; if (elem_offset.defined() && !is_zero(elem_offset)) { elem_offsets = {elem_offset}; } @@ -283,7 +284,7 @@ Array BufferNode::ElemOffset(Array input_indices) const { << "there must be one element offset for each output index."; } - Array output_indices(axis_separators.size() + 1, 0); + ffi::Array output_indices(axis_separators.size() + 1, 0); size_t current_output_axis = 0; @@ -318,8 +319,9 @@ Array BufferNode::ElemOffset(Array input_indices) const { return SimplifyArray(&ana, output_indices); } -inline Array BufferOffset(const BufferNode* n, Array index, DataType dtype) { - Array offsets = n->ElemOffset(index); +inline ffi::Array BufferOffset(const BufferNode* n, ffi::Array index, + DataType dtype) { + ffi::Array offsets = n->ElemOffset(index); // If the Buffer has element type with more than one lane, scale to // get the offset in number of scalars. if (n->dtype.lanes() != 1) { @@ -338,7 +340,7 @@ inline Array BufferOffset(const BufferNode* n, Array index, return offsets; } -static void ValidateAxisSeparators(const Array& axis_separators, size_t buffer_dim) { +static void ValidateAxisSeparators(const ffi::Array& axis_separators, size_t buffer_dim) { // These checks ensure that all output axes contain at least one // input axis. for (size_t i = 0; (i + 1) < axis_separators.size(); i++) { @@ -370,7 +372,7 @@ Buffer Buffer::GetFlattenedBuffer() const { ValidateAxisSeparators(self->axis_separators, self->shape.size()); - Array output_shape; + ffi::Array output_shape; if (self->strides.size()) { // If strides are defined, then the extent of each flattened // buffer is the stride*size for the first input axis used for @@ -386,7 +388,7 @@ Buffer Buffer::GetFlattenedBuffer() const { // of the extents of each input axis used to generate that output // axis. This also "flattens" rank-0 tensors to a rank-1 buffer // of shape [1]. - output_shape = Array(self->axis_separators.size() + 1, 1); + output_shape = ffi::Array(self->axis_separators.size() + 1, 1); size_t current_output_index = 0; for (size_t i = 0; i < self->shape.size(); i++) { if ((current_output_index < self->axis_separators.size()) && @@ -398,7 +400,7 @@ Buffer Buffer::GetFlattenedBuffer() const { } // The axis_separators for the output buffer. - Array output_axis_separators; + ffi::Array output_axis_separators; for (size_t i = 0; i < self->axis_separators.size(); i++) { auto dtype = self->axis_separators[i]->dtype; output_axis_separators.push_back(IntImm(dtype, i + 1)); @@ -416,8 +418,8 @@ Buffer Buffer::GetFlattenedBuffer() const { } } -PrimExpr Buffer::vload(Array begin, DataType value_dtype, - Optional predicate) const { +PrimExpr Buffer::vload(ffi::Array begin, DataType value_dtype, + ffi::Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); @@ -425,7 +427,7 @@ PrimExpr Buffer::vload(Array begin, DataType value_dtype, value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot load " << value_dtype << " from buffer of " << n->dtype; - Array indices = begin; + ffi::Array indices = begin; PrimExpr base = indices[indices.size() - 1]; if (value_dtype.is_fixed_length_vector()) { int factor = value_dtype.lanes() / n->dtype.lanes(); @@ -436,7 +438,8 @@ PrimExpr Buffer::vload(Array begin, DataType value_dtype, return BufferLoad(*this, indices, predicate); } -Stmt Buffer::vstore(Array begin, PrimExpr value, Optional predicate) const { +Stmt Buffer::vstore(ffi::Array begin, PrimExpr value, + ffi::Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); @@ -445,7 +448,7 @@ Stmt Buffer::vstore(Array begin, PrimExpr value, Optional pr value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot store " << value_dtype << " to buffer of " << n->dtype; - Array indices = begin; + ffi::Array indices = begin; PrimExpr base = indices[indices.size() - 1]; if (value_dtype.is_fixed_length_vector()) { int factor = value_dtype.lanes() / n->dtype.lanes(); @@ -456,7 +459,7 @@ Stmt Buffer::vstore(Array begin, PrimExpr value, Optional pr return BufferStore(*this, value, indices, predicate); } -String Buffer::scope() const { +ffi::String Buffer::scope() const { const auto* ptr_type = (*this)->data->type_annotation.as(); ICHECK(ptr_type) << "Buffer variable is not of pointer type"; if (ptr_type->storage_scope.empty()) { @@ -471,7 +474,7 @@ Buffer Buffer::MakeStrideView() const { std::vector temp; const BufferNode* self = operator->(); ICHECK(self != nullptr); - auto n = make_object(*self); + auto n = ffi::make_object(*self); PrimExpr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0; --i) { temp.push_back(acc); @@ -483,15 +486,15 @@ Buffer Buffer::MakeStrideView() const { return Buffer(n); } -Buffer Buffer::MakeSlice(Array begins, Array extents) const { +Buffer Buffer::MakeSlice(ffi::Array begins, ffi::Array extents) const { const BufferNode* n = operator->(); ICHECK(n != nullptr); arith::Analyzer ana; begins = SimplifyArray(&ana, begins); - Array elem_offset = + ffi::Array elem_offset = n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); }); - Array strides = n->strides; + ffi::Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; bool need_stride = false; @@ -526,7 +529,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const } PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset, - Optional input_extent) const { + ffi::Optional input_extent) const { const BufferNode* self = operator->(); ICHECK(self != nullptr); PrimExpr e_dtype; @@ -553,14 +556,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane if (input_extent.defined()) { extent = input_extent.value(); } - Array acc_args{e_dtype, self->data, elem_offset, extent, - make_const(DataType::Int(32), access_mask)}; + ffi::Array acc_args{e_dtype, self->data, elem_offset, extent, + make_const(DataType::Int(32), access_mask)}; return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args); } -Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Array axis_separators, Span span) { +Buffer::Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array strides, + PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, + BufferType buffer_type, ffi::Array axis_separators, Span span) { DataType storage_dtype = dtype; // specially handle bool if (storage_dtype == DataType::Bool()) { @@ -584,7 +587,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array ValidateAxisSeparators(axis_separators, shape.size()); - auto n = make_object(); + auto n = ffi::make_object(); n->data = std::move(data); n->dtype = dtype; @@ -614,7 +617,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array data_ = std::move(n); } -tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, +tir::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); @@ -637,27 +640,27 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std elem_offset = PrimExpr(); } - return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, + return tir::Buffer(data, dtype, shape, ffi::Array(), elem_offset, name, data_alignment, offset_factor, buffer_type); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tir.Buffer", [](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_EQ(args.size(), 11); - auto buffer_type = args[8].cast(); + auto buffer_type = args[8].cast(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; auto data = args[0].cast(); auto dtype = args[1].cast(); - auto shape = args[2].cast>(); - auto strides = args[3].cast>(); + auto shape = args[2].cast>(); + auto strides = args[3].cast>(); auto elem_offset = args[4].cast(); - auto name = args[5].cast(); + auto name = args[5].cast(); auto data_alignment = args[6].cast(); auto offset_factor = args[7].cast(); - auto axis_separators = args[9].cast>(); + auto axis_separators = args[9].cast>(); auto span = args[10].cast(); *ret = Buffer(data, dtype, shape, strides, elem_offset, name, data_alignment, offset_factor, type, axis_separators, span); @@ -668,7 +671,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.BufferVLoad", &Buffer::vload) .def_method("tir.BufferVStore", &Buffer::vstore) .def_method("tir.BufferStorageScope", &Buffer::scope); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index c1fd75d44efd..75f9bb50d15e 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -35,10 +35,10 @@ using tir::IterVar; using tir::IterVarNode; using tir::Var; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { LayoutNode::RegisterReflection(); BijectiveLayoutNode::RegisterReflection(); -}); +} const LayoutAxis LayoutAxis::UPPER_CASE[] = { LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), @@ -74,8 +74,8 @@ const LayoutAxis& LayoutAxis::Get(const std::string& name) { return LayoutAxis::Get(name[0]); } -Layout::Layout(const Array& axes) { - auto node = make_object(); +Layout::Layout(const ffi::Array& axes) { + auto node = ffi::make_object(); node->axes = axes; std::ostringstream repr; for (const IterVar& axis : axes) { @@ -97,7 +97,7 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type"; if (name == "__undef__") return; - auto node = make_object(); + auto node = ffi::make_object(); node->name = name; if (name.empty()) return; // scalar @@ -149,9 +149,9 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); - if (len == 0) return Layout(Array()); + if (len == 0) return Layout(ffi::Array()); if (pos + len > ndim()) len = ndim() - pos; - Array new_layout; + ffi::Array new_layout; const auto axes = operator->()->axes; for (size_t i = pos; i < pos + len; ++i) { new_layout.push_back(axes[i]); @@ -170,7 +170,7 @@ Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) ICHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis << " has already been split in " << name; ICHECK(factor > 0) << "Invalid split size " << factor; - Array new_layout; + ffi::Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { if (i == target_pos) { new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)), @@ -207,7 +207,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Layout(" << l->name << ")"; }); -inline bool GetStoreRule(Array* index_rule, Array* shape_rule, +inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* shape_rule, const Layout& src_layout, const Layout& dst_layout) { if (!src_layout.defined() || src_layout.name().empty()) { LOG(WARNING) << "src layout '" << src_layout.name() << "' is invalid."; @@ -294,11 +294,11 @@ inline bool GetStoreRule(Array* index_rule, Array* shape_rul return true; } -inline Array TransformIndex(const Array& src_index, - const Array& src_axis, - const Array& transform_rule) { +inline ffi::Array TransformIndex(const ffi::Array& src_index, + const ffi::Array& src_axis, + const ffi::Array& transform_rule) { arith::Analyzer ana; - Array result; + ffi::Array result; std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; @@ -309,7 +309,7 @@ inline Array TransformIndex(const Array& src_index, return result; } -Array BijectiveLayout::ForwardIndex(const Array& src_index) const { +ffi::Array BijectiveLayout::ForwardIndex(const ffi::Array& src_index) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); ICHECK_EQ(src_index.size(), self->src_layout->axes.size()) @@ -317,7 +317,7 @@ Array BijectiveLayout::ForwardIndex(const Array& src_index) return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule); } -Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { +ffi::Array BijectiveLayout::BackwardIndex(const ffi::Array& dst_index) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) @@ -325,10 +325,10 @@ Array BijectiveLayout::BackwardIndex(const Array& dst_index) return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule); } -inline Array TransformShape(const Array& src_shape, - const Array& src_axis, - const Array& target_axis, - const Array& transform_rule) { +inline ffi::Array TransformShape(const ffi::Array& src_shape, + const ffi::Array& src_axis, + const ffi::Array& target_axis, + const ffi::Array& transform_rule) { arith::Analyzer ana; ICHECK_EQ(src_shape.size(), src_axis.size()) << "Input shape size " << src_shape.size() << " mismatch with the expected shape size " @@ -361,7 +361,7 @@ inline Array TransformShape(const Array& src_shape, // infer the target shape, // for major-axis, use the forward/backward_rule directly, // for minor-axis, simply use the extent. - Array result; + ffi::Array result; ICHECK_EQ(transform_rule.size(), target_axis.size()); for (size_t i = 0; i < transform_rule.size(); ++i) { PrimExpr rule = transform_rule[i]; @@ -395,14 +395,14 @@ inline Array TransformShape(const Array& src_shape, return result; } -Array BijectiveLayout::ForwardShape(const Array& shape) const { +ffi::Array BijectiveLayout::ForwardShape(const ffi::Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->shape_forward_rule); } -Array BijectiveLayout::BackwardShape(const Array& shape) const { +ffi::Array BijectiveLayout::BackwardShape(const ffi::Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, @@ -410,7 +410,7 @@ Array BijectiveLayout::BackwardShape(const Array& shape) con } BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { - auto n = make_object(); + auto n = ffi::make_object(); n->src_layout = std::move(src_layout); n->dst_layout = std::move(dst_layout); @@ -430,7 +430,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.Layout", [](std::string name, DataType dtype) { return Layout(name, dtype); }) @@ -456,6 +456,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.BijectiveLayoutBackwardIndex", &BijectiveLayout::BackwardIndex) .def_method("tir.BijectiveLayoutForwardShape", &BijectiveLayout::ForwardShape) .def_method("tir.BijectiveLayoutBackwardShape", &BijectiveLayout::BackwardShape); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 346f1ab63250..393ac7ee57d0 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -41,13 +41,18 @@ Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); - return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, - op->thread_binding, op->annotations); + auto n = CopyOnWrite(op); + n->min = cast(var.dtype(), op->min); + n->extent = cast(var.dtype(), op->extent); + if (op->step.has_value()) { + n->step = cast(var.dtype(), *op->step); + } + return For(n); } Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { BlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); - Array new_iter_values; + ffi::Array new_iter_values; bool changed = false; for (int i = 0; i < static_cast(op->iter_values.size()); ++i) { auto dtype = realize->block->iter_vars[i]->var->dtype; @@ -66,17 +71,18 @@ Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) { Block new_block = Downcast(StmtExprMutator::VisitStmt_(op)); - Array new_iter_vars = MutateArray(new_block->iter_vars, [/*this*/](const IterVar& iter) { - auto dtype = iter->var.dtype(); - if (iter->dom->min->dtype != dtype || iter->dom->extent->dtype != dtype) { - IterVar new_iter = iter; - new_iter.CopyOnWrite()->dom = - Range(cast(dtype, iter->dom->min), cast(dtype, iter->dom->extent)); - return new_iter; - } else { - return iter; - } - }); + ffi::Array new_iter_vars = + MutateArray(new_block->iter_vars, [/*this*/](const IterVar& iter) { + auto dtype = iter->var.dtype(); + if (iter->dom->min->dtype != dtype || iter->dom->extent->dtype != dtype) { + IterVar new_iter = iter; + new_iter.CopyOnWrite()->dom = + Range(cast(dtype, iter->dom->min), cast(dtype, iter->dom->extent)); + return new_iter; + } else { + return iter; + } + }); if (!op->iter_vars.same_as(new_iter_vars)) { new_block.CopyOnWrite()->iter_vars = std::move(new_iter_vars); } @@ -123,7 +129,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const LetNode* op) { PrimExpr new_body = this->VisitExpr(op->body); if (value.same_as(op->value) && new_body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(var, value, new_body, op->span); } @@ -141,7 +147,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) { Stmt new_body = this->VisitStmt(op->body); if (value.same_as(op->value) && new_body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(var, value, new_body, op->span); } @@ -151,7 +157,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) { if (auto it = var_remap_.find(op); it != var_remap_.end()) { return it->second; } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) { @@ -160,7 +166,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) { PrimExpr false_value = this->VisitExpr(op->false_value); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) { - return GetRef(op); + return ffi::GetRef(op); } else { int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); DataType dtype = true_value.dtype().with_bits(bits); @@ -174,7 +180,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { PrimExpr base = VisitExpr(op->base); PrimExpr stride = VisitExpr(op->stride); if (base.same_as(op->base) && stride.same_as(op->stride) && base.dtype() == stride.dtype()) { - return GetRef(op); + return ffi::GetRef(op); } else { ICHECK(base.dtype().is_int() && stride.dtype().is_int()); int bits = std::max(base.dtype().bits(), stride.dtype().bits()); @@ -194,7 +200,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CastNode* op) { PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return FUNC(a, b); \ } \ @@ -219,7 +225,7 @@ TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); #undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { - Call before = GetRef(op); + Call before = ffi::GetRef(op); PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); static const Op& builtin_pow_ = Op::Get("tir.pow"); @@ -264,7 +270,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { auto new_body = this->VisitStmt(op->body); if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) || !new_body.same_as(op->body)) { - Allocate new_allocate = GetRef(op); + Allocate new_allocate = ffi::GetRef(op); auto* n = new_allocate.CopyOnWrite(); n->extents = std::move(new_extents); n->condition = std::move(new_cond); @@ -272,7 +278,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { return new_allocate; } else { - return GetRef(op); + return ffi::GetRef(op); } } @@ -310,7 +316,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { Block new_body = Downcast(this->VisitStmt(op->block)); if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) || !new_body.same_as(op->block)) { - BlockRealize new_block_realize = GetRef(op); + BlockRealize new_block_realize = ffi::GetRef(op); auto* n = new_block_realize.CopyOnWrite(); n->predicate = std::move(new_predicate); n->iter_values = std::move(new_iter_values); @@ -318,14 +324,14 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { return new_block_realize; } else { - return GetRef(op); + return ffi::GetRef(op); } } Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { - Array new_alloc_buffers = + ffi::Array new_alloc_buffers = op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); }); - Array new_match_buffers = + ffi::Array new_match_buffers = op->match_buffers.Map([this](const MatchBufferRegion& match_buffer_region) { Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer); BufferRegion new_buffer_region = this->VisitBufferRegion(match_buffer_region->source); @@ -336,17 +342,17 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { return match_buffer_region; } }); - Array new_reads = op->reads.Map( + ffi::Array new_reads = op->reads.Map( [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); - Array new_writes = op->writes.Map( + ffi::Array new_writes = op->writes.Map( [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); - Array new_iter_vars = + ffi::Array new_iter_vars = op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); }); - Optional new_init = std::nullopt; + ffi::Optional new_init = std::nullopt; if (op->init.defined()) { new_init = this->VisitStmt(op->init.value()); } - Map new_annotations = VisitBlockAnnotations(op->annotations); + ffi::Map new_annotations = VisitBlockAnnotations(op->annotations); Stmt new_body = this->VisitStmt(op->body); if (!new_init.same_as(op->init) || !new_body.same_as(op->body) || @@ -354,7 +360,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) || !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars) || !new_annotations.same_as(op->annotations)) { - Block new_block = GetRef(op); + Block new_block = ffi::GetRef(op); BlockNode* n = new_block.CopyOnWrite(); n->alloc_buffers = std::move(new_alloc_buffers); n->match_buffers = std::move(new_match_buffers); @@ -366,11 +372,11 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { n->body = std::move(new_body); return new_block; } - return GetRef(op); + return ffi::GetRef(op); } -Map IndexDataTypeRewriter::VisitBlockAnnotations( - const Map& annotations) { +ffi::Map IndexDataTypeRewriter::VisitBlockAnnotations( + const ffi::Map& annotations) { auto new_annotations = annotations; std::function f_mutate_obj = [this, &f_mutate_obj](const Any& obj) -> Any { @@ -383,7 +389,7 @@ Map IndexDataTypeRewriter::VisitBlockAnnotations( return new_buffer; } } else if (obj.as()) { - return Downcast>(obj).Map(f_mutate_obj); + return Downcast>(obj).Map(f_mutate_obj); } return obj; }; @@ -427,9 +433,9 @@ Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) { bool is_enabled = is_enabled_; is_enabled_ = true; - Array new_shape = + ffi::Array new_shape = buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); - Array new_strides = + ffi::Array new_strides = buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); auto new_elem_offset = VisitExpr(buffer->elem_offset); is_enabled_ = is_enabled; @@ -467,7 +473,7 @@ BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer } Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { - BufferStore store = GetRef(op); + BufferStore store = ffi::GetRef(op); Buffer new_buffer = GetRemappedBuffer(op->buffer); auto value = this->VisitExpr(op->value); @@ -488,7 +494,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { } PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { - BufferLoad load = GetRef(op); + BufferLoad load = ffi::GetRef(op); Buffer new_buffer = GetRemappedBuffer(op->buffer); auto indices = VisitIndices(op->indices); @@ -502,7 +508,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { return load; } -Array IndexDataTypeRewriter::VisitIndices(Array indices) { +ffi::Array IndexDataTypeRewriter::VisitIndices(ffi::Array indices) { bool is_enabled = is_enabled_; is_enabled_ = true; @@ -521,18 +527,19 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) { is_condition_ = is_condition; Stmt then_case = VisitStmt(op->then_case); - Optional else_case = - op->else_case.defined() ? Optional{VisitStmt(op->else_case.value())} : std::nullopt; + ffi::Optional else_case = op->else_case.defined() + ? ffi::Optional{VisitStmt(op->else_case.value())} + : std::nullopt; if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) || !else_case.same_as(op->else_case)) { - IfThenElse new_stmt = GetRef(op); + IfThenElse new_stmt = ffi::GetRef(op); auto* n = new_stmt.CopyOnWrite(); n->condition = std::move(cond); n->then_case = std::move(then_case); n->else_case = std::move(else_case); return new_stmt; } - return GetRef(op); + return ffi::GetRef(op); } Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { @@ -547,7 +554,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) || !new_body.same_as(op->body)) { - For new_for = GetRef(op); + For new_for = ffi::GetRef(op); auto* n = new_for.CopyOnWrite(); n->loop_var = new_loop_var; n->min = cast(new_loop_var.dtype(), min); @@ -556,13 +563,13 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { auto old_thread_binding = op->thread_binding.value(); auto* ptr = old_thread_binding.CopyOnWrite(); ptr->var = old_thread_binding->var.copy_with_dtype(new_loop_var.dtype()); - n->thread_binding = Optional(std::move(old_thread_binding)); + n->thread_binding = ffi::Optional(std::move(old_thread_binding)); } n->body = new_body; return new_for; } else { - return GetRef(op); + return ffi::GetRef(op); } } @@ -619,7 +626,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const SelectNode* op) { if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) { - return GetRef(op); + return ffi::GetRef(op); } else { int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); DataType dtype = true_value.dtype().with_bits(bits); @@ -640,14 +647,14 @@ PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) { buffer_remap_.clear(); ivmap_.clear(); // start rewrite - Map new_buffer_map = func->buffer_map; + ffi::Map new_buffer_map = func->buffer_map; for (const auto& [var, buffer] : func->buffer_map) { new_buffer_map.Set(var, VisitBuffer(buffer)); } // remap params bool is_enabled = true; std::swap(is_enabled_, is_enabled); - Array params = func->params.Map([this](Var param) { + ffi::Array params = func->params.Map([this](Var param) { if (param.dtype().is_int()) { return Downcast(this->VisitExpr(param)); } else { @@ -670,15 +677,15 @@ bool IndexDataTypeNormalizer::CanRewriteDType(DataType dtype) const { PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) { if (is_enabled_ && CanRewriteDType(op->dtype)) { ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); - return cast(target_data_type_, GetRef(op)); + return cast(target_data_type_, ffi::GetRef(op)); } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype != target_data_type_ && !var_remap_.count(op)) { - var_remap_[op] = GetRef(op).copy_with_dtype(target_data_type_); + var_remap_[op] = ffi::GetRef(op).copy_with_dtype(target_data_type_); } return DataTypeLegalizer::VisitExpr_(op); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 4d787015cb19..e6ffd2f09b57 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -36,7 +36,7 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { VarNode::RegisterReflection(); SizeVarNode::RegisterReflection(); IterVarNode::RegisterReflection(); @@ -70,7 +70,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ShuffleNode::RegisterReflection(); CommReducerNode::RegisterReflection(); ReduceNode::RegisterReflection(); -}); +} /* \brief Convert an object to a PrimExpr * @@ -80,11 +80,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ * `expr.dtype` field), this function allows the FFI conversions to be * explicitly invoked. */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.convert", - [](Variant> expr) { return expr; }); -}); + [](ffi::Variant> expr) { return expr; }); +} #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ Name::Name(PrimExpr a, PrimExpr b, Span span) { \ @@ -93,7 +93,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ << b.dtype() << "\n"; \ - ObjectPtr node = make_object(); \ + ObjectPtr node = ffi::make_object(); \ node->dtype = a.dtype(); \ node->a = std::move(a); \ node->b = std::move(b); \ @@ -108,7 +108,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ << b.dtype() << "\n"; \ - ObjectPtr node = make_object(); \ + ObjectPtr node = ffi::make_object(); \ DataType a_dtype = a.dtype(); \ node->dtype = \ DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \ @@ -119,8 +119,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ } // Var -Var::Var(String name_hint, DataType dtype, Span span) { - auto n = make_object(); +Var::Var(ffi::String name_hint, DataType dtype, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->type_annotation = GetTypeFromRuntimeDataType(dtype); n->dtype = std::move(dtype); @@ -128,8 +128,8 @@ Var::Var(String name_hint, DataType dtype, Span span) { data_ = std::move(n); } -Var::Var(String name_hint, Type type_annotation, Span span) { - auto n = make_object(); +Var::Var(ffi::String name_hint, Type type_annotation, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); n->type_annotation = std::move(type_annotation); @@ -137,19 +137,19 @@ Var::Var(String name_hint, Type type_annotation, Span span) { data_ = std::move(n); } -Var Var::copy_with_name(const String& name) const { +Var Var::copy_with_name(const ffi::String& name) const { const VarNode* node = get(); ObjectPtr new_ptr; if (auto* ptr = this->as()) { - new_ptr = make_object(*ptr); + new_ptr = ffi::make_object(*ptr); } else { - new_ptr = make_object(*node); + new_ptr = ffi::make_object(*node); } new_ptr->name_hint = name; return Var(new_ptr); } -Var Var::copy_with_suffix(const String& suffix) const { +Var Var::copy_with_suffix(const ffi::String& suffix) const { return this->copy_with_name(get()->name_hint + suffix); } @@ -157,29 +157,29 @@ Var Var::copy_with_dtype(DataType dtype) const { const VarNode* node = get(); ObjectPtr new_ptr; if (auto* ptr = this->as()) { - new_ptr = make_object(*ptr); + new_ptr = ffi::make_object(*ptr); } else { - new_ptr = make_object(*node); + new_ptr = ffi::make_object(*node); } new_ptr->type_annotation = GetTypeFromRuntimeDataType(dtype); new_ptr->dtype = std::move(dtype); return Var(new_ptr); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Var", [](String name_hint, ffi::AnyView type, Span span) { + refl::GlobalDef().def("tir.Var", [](ffi::String name_hint, ffi::AnyView type, Span span) { if (type.as()) { return Var(name_hint, type.cast(), span); } else { return Var(name_hint, type.cast(), span); } }); -}); +} // SizeVar -SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { - auto n = make_object(); +SizeVar::SizeVar(ffi::String name_hint, DataType dtype, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->type_annotation = GetTypeFromRuntimeDataType(dtype); n->dtype = std::move(dtype); @@ -187,8 +187,8 @@ SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { data_ = std::move(n); } -SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { - auto n = make_object(); +SizeVar::SizeVar(ffi::String name_hint, Type type_annotation, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); n->type_annotation = std::move(type_annotation); @@ -196,15 +196,15 @@ SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.SizeVar", - [](String s, DataType t, Span span) { return SizeVar(s, t, span); }); -}); + [](ffi::String s, DataType t, Span span) { return SizeVar(s, t, span); }); +} // IterVar -IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) { - ObjectPtr n = make_object(); +IterVar::IterVar(Range dom, Var var, IterVarType t, ffi::String thread_tag, Span span) { + ObjectPtr n = ffi::make_object(); if (dom.defined() && dom->extent.defined()) { CHECK(dom->extent.dtype().is_int()) << "The dtype of the domain of an IterVar must be an integer type. However, the domain's " @@ -222,176 +222,176 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.IterVar", [](Range dom, Var var, int iter_type, String thread_tag, Span span) { + "tir.IterVar", [](Range dom, Var var, int iter_type, ffi::String thread_tag, Span span) { return IterVar(dom, var, static_cast(iter_type), thread_tag, span); }); -}); +} // StringImm -StringImm::StringImm(String value, Span span) { - ObjectPtr node = make_object(); +StringImm::StringImm(ffi::String value, Span span) { + ObjectPtr node = ffi::make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.StringImm", - [](String value, Span span) { return StringImm(value, span); }); -}); + [](ffi::String value, Span span) { return StringImm(value, span); }); +} // Cast Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor()); ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = t; node->value = std::move(value); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Cast", [](DataType dtype, PrimExpr value, Span span) { return Cast(dtype, value, span); }); -}); +} // Add TVM_DEFINE_BINOP_CONSTRUCTOR(Add); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Add", [](PrimExpr a, PrimExpr b, Span span) { return Add(a, b, span); }); -}); +} // Sub TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Sub", [](PrimExpr a, PrimExpr b, Span span) { return Sub(a, b, span); }); -}); +} // Mul TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Mul", [](PrimExpr a, PrimExpr b, Span span) { return Mul(a, b, span); }); -}); +} // Div TVM_DEFINE_BINOP_CONSTRUCTOR(Div); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Div", [](PrimExpr a, PrimExpr b, Span span) { return Div(a, b, span); }); -}); +} // Mod TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Mod", [](PrimExpr a, PrimExpr b, Span span) { return Mod(a, b, span); }); -}); +} // FloorDiv TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.FloorDiv", [](PrimExpr a, PrimExpr b, Span span) { return FloorDiv(a, b, span); }); -}); +} // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.FloorMod", [](PrimExpr a, PrimExpr b, Span span) { return FloorMod(a, b, span); }); -}); +} // Min TVM_DEFINE_BINOP_CONSTRUCTOR(Min); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Min", [](PrimExpr a, PrimExpr b, Span span) { return Min(a, b, span); }); -}); +} // Max TVM_DEFINE_BINOP_CONSTRUCTOR(Max); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Max", [](PrimExpr a, PrimExpr b, Span span) { return Max(a, b, span); }); -}); +} // EQ TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.EQ", [](PrimExpr a, PrimExpr b, Span span) { return EQ(a, b, span); }); -}); +} // NE TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.NE", [](PrimExpr a, PrimExpr b, Span span) { return NE(a, b, span); }); -}); +} // LT TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.LT", [](PrimExpr a, PrimExpr b, Span span) { return LT(a, b, span); }); -}); +} // LE TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.LE", [](PrimExpr a, PrimExpr b, Span span) { return LE(a, b, span); }); -}); +} // GT TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.GT", [](PrimExpr a, PrimExpr b, Span span) { return GT(a, b, span); }); -}); +} // GE TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.GE", [](PrimExpr a, PrimExpr b, Span span) { return GE(a, b, span); }); -}); +} // And And::And(PrimExpr a, PrimExpr b, Span span) { @@ -401,7 +401,7 @@ And::And(PrimExpr a, PrimExpr b, Span span) { ICHECK(b.dtype().is_bool()); ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); @@ -410,11 +410,11 @@ And::And(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.And", [](PrimExpr a, PrimExpr b, Span span) { return And(a, b, span); }); -}); +} // Or Or::Or(PrimExpr a, PrimExpr b, Span span) { @@ -424,7 +424,7 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { ICHECK(b.dtype().is_bool()); ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); @@ -433,17 +433,17 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Or", [](PrimExpr a, PrimExpr b, Span span) { return Or(a, b, span); }); -}); +} // Not Not::Not(PrimExpr a, Span span) { ICHECK(a.defined()) << "ValueError: a is undefined"; ICHECK(a.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); DataType a_dtype = a.dtype(); node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); node->a = std::move(a); @@ -451,10 +451,10 @@ Not::Not(PrimExpr a, Span span) { data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Not", [](PrimExpr a, Span span) { return Not(a, span); }); -}); +} // Select Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { @@ -469,7 +469,7 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp << "TypeError: mismatched types. " << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = true_value.dtype(); node->condition = std::move(condition); node->true_value = std::move(true_value); @@ -478,13 +478,13 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.Select", [](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { return Select(condition, true_value, false_value, span); }); -}); +} // Ramp Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { @@ -496,7 +496,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { stride = cast(base.dtype(), stride); } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); auto* lanes_as_int = lanes.as(); if (lanes_as_int) { int lanes = static_cast(lanes_as_int->value); @@ -518,19 +518,19 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Ramp", [](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { return Ramp(base, stride, lanes, span); }); -}); +} // Broadcast Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { ICHECK(value.defined()); ICHECK(value.dtype().is_scalar()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); auto* lanes_int = lanes.as(); if (lanes_int) { int lanes = static_cast(lanes_int->value); @@ -551,12 +551,12 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { data_ = node; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Broadcast", [](PrimExpr value, PrimExpr lanes, Span span) { return Broadcast(value, lanes, span); }); -}); +} // Let Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { @@ -564,7 +564,7 @@ Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { ICHECK(body.defined()); ICHECK_EQ(value.dtype(), var.dtype()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = body.dtype(); node->var = std::move(var); node->value = std::move(value); @@ -573,43 +573,47 @@ Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Let", [](Var var, PrimExpr value, PrimExpr body, Span span) { return Let(var, value, body, span); }); -}); +} // Call -Call::Call(DataType dtype, RelaxExpr op, Array args, Span span) { +Call::Call(DataType dtype, RelaxExpr op, ffi::Array args, + ffi::Map annotations, Span span) { for (size_t i = 0; i < args.size(); ++i) { ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = dtype; node->op = std::move(op); node->args = std::move(args); + node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.Call", - [](Optional dtype, RelaxExpr op, - Array> args, Span span) { - Array prim_expr_args; + [](ffi::Optional dtype, RelaxExpr op, + ffi::Array> args, + ffi::Optional> annotations, + Span span) { + ffi::Array prim_expr_args; for (const auto& it : args) { - if (auto opt_str = it.as()) { + if (auto opt_str = it.as()) { prim_expr_args.push_back(StringImm(opt_str.value())); } else if (auto opt_dtype = it.as()) { prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); } else if (const auto* iter_var = it.as()) { prim_expr_args.push_back(iter_var->var); } else if (const auto* br = it.as()) { - Array indices; + ffi::Array indices; for (Range r : br->region) { if (is_one(r->extent)) { indices.push_back(r->min); @@ -617,7 +621,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); } else { LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " - << GetRef(br); + << ffi::GetRef(br); } } prim_expr_args.push_back(BufferLoad(br->buffer, indices)); @@ -625,12 +629,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ prim_expr_args.push_back(Downcast(it)); } } - return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); + return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, + annotations.value_or(ffi::Map()), span); }); -}); +} // Shuffle -Shuffle::Shuffle(Array vectors, Array indices, Span span) { +Shuffle::Shuffle(ffi::Array vectors, ffi::Array indices, Span span) { ICHECK_NE(vectors.size(), 0U); ICHECK_NE(indices.size(), 0U); @@ -643,7 +648,7 @@ Shuffle::Shuffle(Array vectors, Array indices, Span span) { } ICHECK_LE(indices.size(), static_cast(total_lanes)); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); @@ -651,12 +656,12 @@ Shuffle::Shuffle(Array vectors, Array indices, Span span) { data_ = node; } -PrimExpr Shuffle::Concat(Array vectors, Span span) { +PrimExpr Shuffle::Concat(ffi::Array vectors, Span span) { ICHECK_NE(vectors.size(), 0); if (vectors.size() == 1) { return vectors[0]; } - Array indices; + ffi::Array indices; int index = 0; for (const PrimExpr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { @@ -670,15 +675,17 @@ PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { return Shuffle({vector}, {Integer(index)}, span); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Shuffle", [](Array vectors, Array indices, - Span span) { return Shuffle(vectors, indices, span); }); -}); + refl::GlobalDef().def("tir.Shuffle", + [](ffi::Array vectors, ffi::Array indices, Span span) { + return Shuffle(vectors, indices, span); + }); +} // CommReducer -CommReducer::CommReducer(Array lhs, Array rhs, Array result, - Array identity_element, Span span) { +CommReducer::CommReducer(ffi::Array lhs, ffi::Array rhs, ffi::Array result, + ffi::Array identity_element, Span span) { size_t n_group = result.size(); CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the " "number of elements in `results`"; @@ -708,7 +715,7 @@ CommReducer::CommReducer(Array lhs, Array rhs, Array result, p_result->SetItem(i, Substitute(result[i], var_map)); } - auto node = make_object(); + auto node = ffi::make_object(); node->lhs = lhs; node->rhs = rhs; node->result = result; @@ -717,11 +724,12 @@ CommReducer::CommReducer(Array lhs, Array rhs, Array result, data_ = std::move(node); } -Array CommReducerNode::operator()(Array a, Array b) const { +ffi::Array CommReducerNode::operator()(ffi::Array a, + ffi::Array b) const { ICHECK_EQ(a.size(), b.size()); ICHECK_EQ(lhs.size(), a.size()); ICHECK_EQ(rhs.size(), b.size()); - Map value_map; + ffi::Map value_map; for (size_t i = 0; i < a.size(); ++i) { value_map.Set(lhs[i], a[i]); value_map.Set(rhs[i], b[i]); @@ -729,26 +737,26 @@ Array CommReducerNode::operator()(Array a, Array b return Substitute(this->result, value_map); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.CommReducer", - [](Array lhs, Array rhs, Array result, - Array identity_element, + [](ffi::Array lhs, ffi::Array rhs, ffi::Array result, + ffi::Array identity_element, Span span) { return CommReducer(lhs, rhs, result, identity_element, span); }) .def_method("tir.CommReducerCombine", &tir::CommReducerNode::operator()); -}); +} // Reduce -Reduce::Reduce(CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index, Array init, Span span) { +Reduce::Reduce(CommReducer combiner, ffi::Array source, ffi::Array axis, + PrimExpr condition, int value_index, ffi::Array init, Span span) { for (size_t i = 0; i < axis.size(); ++i) { ICHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; } if (!condition.defined()) { condition = const_true(); } - auto n = make_object(); + auto n = ffi::make_object(); ICHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { ICHECK(axis[i].defined()); @@ -774,21 +782,21 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Reduce", - [](CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index, Array init, Span span) { - return Reduce(combiner, source, axis, condition, value_index, init, span); - }); -}); + refl::GlobalDef().def( + "tir.Reduce", [](CommReducer combiner, ffi::Array source, ffi::Array axis, + PrimExpr condition, int value_index, ffi::Array init, Span span) { + return Reduce(combiner, source, axis, condition, value_index, init, span); + }); +} // BufferLoad void BufferLoadNode::LegalizeDType() { - for (int i = 0; i < static_cast(indices.size()) - 1; i++) { - ICHECK(indices[i].dtype().is_scalar()) - << "Only the last index of a buffer access may be a vector type."; - } + // for (int i = 0; i < static_cast(indices.size()) - 1; i++) { + // ICHECK(indices[i].dtype().is_scalar()) + // << "Only the last index of a buffer access may be a vector type."; + // } if (indices.empty()) { this->dtype = buffer->dtype; @@ -812,8 +820,8 @@ void BufferLoadNode::LegalizeDType() { } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional predicate, - Span span) { +BufferLoad::BufferLoad(Buffer buffer, ffi::Array indices, + ffi::Optional predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -836,12 +844,12 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); node->predicate = std::move(predicate); @@ -850,16 +858,17 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional indices, Optional predicate, - Span span) { return BufferLoad(buffer, indices, predicate, span); }); -}); + refl::GlobalDef().def("tir.BufferLoad", [](Buffer buffer, ffi::Array indices, + ffi::Optional predicate, Span span) { + return BufferLoad(buffer, indices, predicate, span); + }); +} // ProducerLoad -ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span span) { - ObjectPtr node = make_object(); +ProducerLoad::ProducerLoad(DataProducer producer, ffi::Array indices, Span span) { + ObjectPtr node = ffi::make_object(); node->dtype = producer->GetDataType(); node->producer = std::move(producer); node->indices = std::move(indices); @@ -867,13 +876,13 @@ ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.ProducerLoad", - [](DataProducer producer, Array indices, Span span) { + [](DataProducer producer, ffi::Array indices, Span span) { return ProducerLoad(producer, indices, span); }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 05e333b78ac6..f9e54de52f58 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -47,6 +47,13 @@ void ExprVisitor::VisitExpr_(const LetNode* op) { void ExprVisitor::VisitExpr_(const CallNode* op) { VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); + // Also visit PrimExpr values inside annotations (e.g. barrier arguments + // stored as CallNode annotations by tile operators like tma_copy). + for (const auto& kv : op->annotations) { + if (auto opt = kv.second.as()) { + this->VisitExpr(opt.value()); + } + } } #define DEFINE_BINOP_VISIT_(OP) \ @@ -111,7 +118,7 @@ void ExprVisitor::VisitExpr_(const ShuffleNode* op) { void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } -PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } +PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return ffi::GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { return this->VisitExpr_(static_cast(op)); @@ -119,9 +126,9 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); if (indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { return BufferLoad(op->buffer, indices, op->predicate); } @@ -129,9 +136,9 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); if (indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ProducerLoad(op->producer, indices); } @@ -141,7 +148,7 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -149,17 +156,33 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array args = op->args.Map(fmutate); + ffi::Array args = op->args.Map(fmutate); + + // Also mutate PrimExpr values inside annotations (e.g. barrier arguments + // stored as CallNode annotations by tile operators like tma_copy). + ffi::Map new_annotations; + bool annotations_changed = false; + for (const auto& kv : op->annotations) { + if (auto opt = kv.second.as()) { + PrimExpr new_val = this->VisitExpr(opt.value()); + new_annotations.Set(kv.first, new_val); + if (!new_val.same_as(opt.value())) { + annotations_changed = true; + } + } else { + new_annotations.Set(kv.first, kv.second); + } + } - if (args.same_as(op->args)) { - return GetRef(op); + if (args.same_as(op->args) && !annotations_changed) { + return ffi::GetRef(op); } else { - return Call(op->dtype, op->op, args); + return Call(op->dtype, op->op, args, annotations_changed ? new_annotations : op->annotations); } } #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef(op); } + PrimExpr ExprMutator::VisitExpr_(const OP* op) { return ffi::GetRef(op); } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) @@ -170,7 +193,7 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return OP(a, b); \ } \ @@ -205,17 +228,17 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag); } }; - Array axis = op->axis.Map(fitervar); + ffi::Array axis = op->axis.Map(fitervar); auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array source = op->source.Map(fexpr); - Array init = op->init.Map(fexpr); + ffi::Array source = op->source.Map(fexpr); + ffi::Array init = op->init.Map(fexpr); PrimExpr condition = this->VisitExpr(op->condition); if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) && init.same_as(op->init)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Reduce(op->combiner, source, axis, condition, op->value_index, init); } @@ -224,7 +247,7 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Cast(op->dtype, value); } @@ -233,7 +256,7 @@ PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { PrimExpr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Not(a); } @@ -245,7 +268,7 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr false_value = this->VisitExpr(op->false_value); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Select(condition, true_value, false_value); } @@ -256,7 +279,7 @@ PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr stride = this->VisitExpr(op->stride); PrimExpr lanes = this->VisitExpr(op->lanes); if (base.same_as(op->base) && stride.same_as(op->stride) && lanes.same_as(op->lanes)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Ramp(base, stride, lanes); } @@ -266,7 +289,7 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr lanes = this->VisitExpr(op->lanes); if (value.same_as(op->value) && lanes.same_as(op->lanes)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Broadcast(value, lanes); } @@ -277,7 +300,7 @@ PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { auto vectors = op->vectors.Map(fexpr); auto indices = op->indices.Map(fexpr); if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Shuffle(vectors, indices); } diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index c8769222e02d..9daf09695086 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -31,14 +31,14 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PrimFuncNode::RegisterReflection(); TensorIntrinNode::RegisterReflection(); -}); +} namespace { relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { - Array params; + ffi::Array params; for (const auto& param : prim_func->params) { relax::StructInfo param_sinfo = [&]() -> relax::StructInfo { if (auto opt_buf = prim_func->buffer_map.Get(param)) { @@ -62,7 +62,7 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { if (const auto* prim = prim_func->ret_type.as()) { return relax::PrimStructInfo(prim->dtype); } else if (IsVoidType(prim_func->ret_type)) { - return relax::TupleStructInfo(Array{}); + return relax::TupleStructInfo(ffi::Array{}); } else { return relax::ObjectStructInfo(); } @@ -75,8 +75,8 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { } // namespace // Get the function type of a PrimFunc -PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { +PrimFunc::PrimFunc(ffi::Array params, Stmt body, Type ret_type, + ffi::Map buffer_map, DictAttrs attrs, Span span) { if (!attrs.defined()) { attrs = DictAttrs(); } @@ -85,7 +85,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, ret_type = VoidType(); } - auto n = make_object(); + auto n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); n->ret_type = std::move(ret_type); @@ -99,7 +99,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, } FuncType PrimFuncNode::func_type_annotation() const { - Array param_types; + ffi::Array param_types; for (auto param : this->params) { param_types.push_back(GetType(param)); } @@ -108,7 +108,7 @@ FuncType PrimFuncNode::func_type_annotation() const { class TensorIntrinManager { public: - Map reg; + ffi::Map reg; static TensorIntrinManager* Global() { static TensorIntrinManager* inst = new TensorIntrinManager(); @@ -129,13 +129,13 @@ TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { } ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->desc = std::move(desc); n->impl = std::move(impl); data_ = std::move(n); } -void TensorIntrin::Register(String name, TensorIntrin intrin, bool override) { +void TensorIntrin::Register(ffi::String name, TensorIntrin intrin, bool override) { TensorIntrinManager* manager = TensorIntrinManager::Global(); if (!override) { CHECK_EQ(manager->reg.count(name), 0) @@ -144,7 +144,7 @@ void TensorIntrin::Register(String name, TensorIntrin intrin, bool override) { manager->reg.Set(name, intrin); } -Optional TensorIntrin::Get(String name, bool allow_missing) { +ffi::Optional TensorIntrin::Get(ffi::String name, bool allow_missing) { const TensorIntrinManager* manager = TensorIntrinManager::Global(); auto it = manager->reg.find(name); if (it == manager->reg.end()) { @@ -157,12 +157,12 @@ Optional TensorIntrin::Get(String name, bool allow_missing) { return (*it).second; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.PrimFunc", - [](Array params, Stmt body, Type ret_type, Map buffer_map, - DictAttrs attrs, + [](ffi::Array params, Stmt body, Type ret_type, + ffi::Map buffer_map, DictAttrs attrs, Span span) { return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }) .def("tir.TensorIntrin", [](PrimFunc desc_func, PrimFunc intrin_func) { @@ -170,7 +170,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("tir.TensorIntrinRegister", TensorIntrin::Register) .def("tir.TensorIntrinGet", TensorIntrin::Get); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index 901a5d5234ca..c9f21b1b38ec 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -30,14 +30,14 @@ namespace tir { // Implementation of Visitors template -inline void VisitArray(const Array& arr, F fvisit) { +inline void VisitArray(const ffi::Array& arr, F fvisit) { for (size_t i = 0; i < arr.size(); i++) { fvisit(arr[i]); } } template -inline Array MutateArray(Array arr, F fmutate) { +inline ffi::Array MutateArray(ffi::Array arr, F fmutate) { return arr.Map(fmutate); } diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 34e7e9c56f9f..84e701210247 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -35,20 +35,21 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ IndexMapNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { IndexMapNode::RegisterReflection(); } -IndexMap::IndexMap(Array initial_indices, Array final_indices, - Optional inverse_index_map) { - auto n = make_object(); +IndexMap::IndexMap(ffi::Array initial_indices, ffi::Array final_indices, + ffi::Optional inverse_index_map) { + auto n = ffi::make_object(); n->initial_indices = std::move(initial_indices); n->final_indices = std::move(final_indices); n->inverse_index_map = std::move(inverse_index_map); data_ = std::move(n); } -IndexMap IndexMap::FromFunc(int ndim, ffi::TypedFunction(Array)> func, - Optional inverse_index_map) { - Array initial_indices; +IndexMap IndexMap::FromFunc(int ndim, + ffi::TypedFunction(ffi::Array)> func, + ffi::Optional inverse_index_map) { + ffi::Array initial_indices; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32))); @@ -57,7 +58,7 @@ IndexMap IndexMap::FromFunc(int ndim, ffi::TypedFunction(Array IndexMapInverseImpl(const IndexMap& self, - const Array& initial_ranges, + const ffi::Array& initial_ranges, arith::IterMapLevel check_level, arith::Analyzer* analyzer) { ICHECK(analyzer != nullptr); @@ -70,7 +71,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, } // Dummy variables to represent the inverse's inputs. - Array output_vars; + ffi::Array output_vars; for (size_t i = 0; i < self->final_indices.size(); i++) { PrimExpr index = self->final_indices[i]; // TODO(Lunderberg): Better names for these variables. A variable @@ -85,7 +86,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, } // Dummy ranges for the extent of each input. - Map input_iters; + ffi::Map input_iters; ICHECK_EQ(self->initial_indices.size(), initial_ranges.size()); for (size_t i = 0; i < initial_ranges.size(); i++) { input_iters.Set(self->initial_indices[i], initial_ranges[i]); @@ -97,15 +98,16 @@ std::pair IndexMapInverseImpl(const IndexMap& self, /*check_level=*/check_level, analyzer, /*simplify_trivial_iterators=*/false); CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " - << "Error: " << padded_iter_map->errors[0]; + << "\nIndex map: " << self->initial_indices << " -> " << self->final_indices + << "\nError: " << padded_iter_map->errors[0]; // Determine expressions for the input variables, in terms of the // output variables. - Map inverse_exprs_map = InverseAffineIterMap( - padded_iter_map->indices, Array(output_vars.begin(), output_vars.end())); + ffi::Map inverse_exprs_map = InverseAffineIterMap( + padded_iter_map->indices, ffi::Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. - Array inverse_exprs; + ffi::Array inverse_exprs; for (int i = 0, n = self->initial_indices.size(); i < n; ++i) { Var index = self->initial_indices[i]; PrimExpr expr; @@ -137,13 +139,13 @@ std::pair IndexMapInverseImpl(const IndexMap& self, return {IndexMap(output_vars, inverse_exprs), padding_predicate}; } -std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges, +std::pair IndexMap::NonSurjectiveInverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck, analyzer); } -IndexMap IndexMap::Inverse(Array initial_ranges, arith::Analyzer* analyzer) const { +IndexMap IndexMap::Inverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); auto [inverse, padding_predicate] = IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective, analyzer); @@ -153,18 +155,18 @@ IndexMap IndexMap::Inverse(Array initial_ranges, arith::Analyzer* analyze return inverse; } -Array IndexMapNode::MapIndices(const Array& indices, - arith::Analyzer* analyzer) const { +ffi::Array IndexMapNode::MapIndices(const ffi::Array& indices, + arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); ICHECK_EQ(indices.size(), initial_indices.size()); - Map vmap; + ffi::Map vmap; for (size_t i = 0; i < initial_indices.size(); i++) { vmap.Set(initial_indices[i], indices[i]); } - Array output = final_indices.Map([&](PrimExpr index) { + ffi::Array output = final_indices.Map([&](PrimExpr index) { PrimExpr result = SubstituteWithDataTypeLegalization( std::move(index), [&](const Var& var) { return vmap.Get(var); }); return analyzer->Simplify(result); @@ -172,24 +174,25 @@ Array IndexMapNode::MapIndices(const Array& indices, return output; } -Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer* analyzer) const { +ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, + arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); ICHECK_EQ(ranges.size(), initial_indices.size()); - Map input_iters; + ffi::Map input_iters; for (size_t i = 0; i < initial_indices.size(); i++) { input_iters.Set(initial_indices[i], ranges[i]); } auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */ 1, /*check_level=*/arith::IterMapLevel::NoCheck, analyzer, /*simplify_trivial_iterators=*/false); - Array output; + ffi::Array output; if (iter_map->indices.size()) { // Preferred route, requires the map to be expressible as an // affine sum. Since the terms are orthogonal, the extent of the // sum is the extent of the largest term. for (const auto& index : iter_map->indices) { - Optional extent = std::nullopt; + ffi::Optional extent = std::nullopt; for (const auto& term : index->args) { PrimExpr term_extent = term->extent * term->scale; if (extent.defined()) { @@ -235,18 +238,18 @@ Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer return output; } -Array IndexMapNode::MapShape(const Array& shape, - arith::Analyzer* analyzer) const { +ffi::Array IndexMapNode::MapShape(const ffi::Array& shape, + arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); ICHECK_EQ(shape.size(), initial_indices.size()); - Array ranges; + ffi::Array ranges; for (auto& dim : shape) { ranges.push_back(Range(make_zero(dim.dtype()), dim)); } - Array mapped = MapRanges(std::move(ranges), analyzer); + ffi::Array mapped = MapRanges(std::move(ranges), analyzer); - Array output; + ffi::Array output; for (auto& range : mapped) { ICHECK(is_zero(range->min)); output.push_back(range->extent); @@ -255,14 +258,14 @@ Array IndexMapNode::MapShape(const Array& shape, return output; } -runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { +runtime::Tensor IndexMapNode::MapTensor(runtime::Tensor arr_src) const { arith::Analyzer analyzer; auto shape = arr_src.Shape(); ICHECK(shape.size() == initial_indices.size()) << "The rank of the input array should be " << initial_indices.size() << " but got " << shape.size(); size_t size_1d = 1; - Array orig_shape; + ffi::Array orig_shape; for (size_t i = 0; i < shape.size(); ++i) { size_1d *= shape[i]; orig_shape.push_back(PrimExpr(static_cast((shape[i])))); @@ -283,7 +286,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { for (size_t i = 0; i < size_1d; ++i) { // Convert a linear coordinate to an N-d coordinate tuple // z * height * width + y * width + x -> (z, y, x) - Array src_indices; + ffi::Array src_indices; auto div_factor = size_1d; auto src_linear_index = i; for (auto s : shape) { @@ -305,15 +308,15 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { bytes_dst.begin() + dst_linear_index * elem_bytes); } - auto arr_dst = runtime::NDArray::Empty(dst_shape_int, arr_src->dtype, arr_src->device); + auto arr_dst = runtime::Tensor::Empty(dst_shape_int, arr_src->dtype, arr_src->device); arr_dst.CopyFromBytes(bytes_dst.data(), bytes_dst.size()); return arr_dst; } IndexMap IndexMap::RenameVariables( - const std::function(const Var& var)>& f_name_map) const { + const std::function(const Var& var)>& f_name_map) const { std::unordered_set used_names; - Map var_remap; + ffi::Map var_remap; NameSupply name_supply; const IndexMapNode* n = this->get(); if (f_name_map != nullptr) { @@ -329,8 +332,8 @@ IndexMap IndexMap::RenameVariables( } visited.emplace(obj.get()); Var var = Downcast(obj); - if (Optional opt_name = f_name_map(var); opt_name.has_value()) { - String name = opt_name.value(); + if (ffi::Optional opt_name = f_name_map(var); opt_name.has_value()) { + ffi::String name = opt_name.value(); ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false)); name_supply->ReserveName(name, /*add_prefix=*/false); var_remap.Set(var, Var(name, var->dtype)); @@ -344,7 +347,8 @@ IndexMap IndexMap::RenameVariables( // The name of the variable is pre-defined. continue; } - String unique_name = name_supply->FreshName(initial_index->name_hint, /*add_prefix=*/false); + ffi::String unique_name = + name_supply->FreshName(initial_index->name_hint, /*add_prefix=*/false); if (unique_name != initial_index->name_hint) { var_remap.Set(initial_index, Var(unique_name)); } @@ -354,7 +358,7 @@ IndexMap IndexMap::RenameVariables( [&](const Var& var) { return Downcast(Substitute(var, var_remap)); }); auto new_final_indices = n->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, var_remap); }); - Optional new_inverse_index_map = std::nullopt; + ffi::Optional new_inverse_index_map = std::nullopt; if (n->inverse_index_map.defined()) { new_inverse_index_map = Downcast(n->inverse_index_map).RenameVariables(f_name_map); } @@ -367,10 +371,10 @@ IndexMap IndexMap::RenameVariables( * \param final_indices The final indices in the index map. * \return The lambda expression string. */ -std::string IndexMap2PythonLambdaExpr(const Array& initial_indices, - const Array& final_indices) { +std::string IndexMap2PythonLambdaExpr(const ffi::Array& initial_indices, + const ffi::Array& final_indices) { std::unordered_set used_names; - Map var_remap; + ffi::Map var_remap; std::ostringstream oss; oss << "lambda "; for (size_t i = 0; i < initial_indices.size(); ++i) { @@ -391,13 +395,13 @@ std::string IndexMap2PythonLambdaExpr(const Array& initial_indices, return oss.str(); } -String IndexMapNode::ToPythonString( - const std::function(const Var& var)>& f_name_map) const { - auto index_map = GetRef(this).RenameVariables(f_name_map); +ffi::String IndexMapNode::ToPythonString( + const std::function(const Var& var)>& f_name_map) const { + auto index_map = ffi::GetRef(this).RenameVariables(f_name_map); std::string lambda_expr = IndexMap2PythonLambdaExpr(index_map->initial_indices, index_map->final_indices); if (!index_map->inverse_index_map.defined()) { - return String(lambda_expr); + return ffi::String(lambda_expr); } // Also convert the inverse index map. IndexMap inverse = Downcast(index_map->inverse_index_map.value()); @@ -406,51 +410,52 @@ String IndexMapNode::ToPythonString( std::ostringstream oss; oss << "tvm.tir.IndexMap.from_func(" << lambda_expr << ", inverse_index_map=" << inverse_lambda_expr << ")"; - return String(oss.str()); + return ffi::String(oss.str()); } IndexMap Substitute(const IndexMap& index_map, - std::function(const Var& var)> f_subst) { - Array new_output = + std::function(const Var& var)> f_subst) { + ffi::Array new_output = index_map->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, f_subst); }); - Optional new_inverse_map = std::nullopt; + ffi::Optional new_inverse_map = std::nullopt; if (index_map->inverse_index_map.defined()) { new_inverse_map = Substitute(Downcast(index_map->inverse_index_map.value()), f_subst); } return IndexMap{index_map->initial_indices, new_output, new_inverse_map}; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.IndexMap", - [](Array initial_indices, Array final_indices, - Optional inverse_index_map) { + [](ffi::Array initial_indices, ffi::Array final_indices, + ffi::Optional inverse_index_map) { return IndexMap(initial_indices, final_indices, inverse_index_map); }) .def("tir.IndexMapMapIndices", - [](IndexMap map, Array indices) { + [](IndexMap map, ffi::Array indices) { arith::Analyzer analyzer; return map->MapIndices(indices, &analyzer); }) .def("tir.IndexMapMapShape", - [](IndexMap map, Array shape) { + [](IndexMap map, ffi::Array shape) { arith::Analyzer analyzer; return map->MapShape(shape, &analyzer); }) .def("tir.IndexMapInverse", - [](IndexMap map, Array initial_ranges) { + [](IndexMap map, ffi::Array initial_ranges) { arith::Analyzer analyzer; return map.Inverse(initial_ranges, &analyzer); }) - .def("tir.IndexMapMapNDArray", - [](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }) - .def("tir.IndexMapNonSurjectiveInverse", [](IndexMap forward, Array initial_ranges) { - arith::Analyzer analyzer; - auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); - return Array{result.first, result.second}; - }); -}); + .def("tir.IndexMapMapTensor", + [](IndexMap map, runtime::Tensor arr) { return map->MapTensor(arr); }) + .def("tir.IndexMapNonSurjectiveInverse", + [](IndexMap forward, ffi::Array initial_ranges) { + arith::Analyzer analyzer; + auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); + return ffi::Array{result.first, result.second}; + }); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index cf5e7e80a893..61bdfb15e70e 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -215,11 +215,12 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { } static void RegisterReflection() { - // No fields to register as they are not visited + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - static constexpr const char* _type_key = "tir.PyStmtExprVisitor"; - TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprVisitorNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("tir.PyStmtExprVisitor", PyStmtExprVisitorNode, Object); private: // Statement functions @@ -342,6 +343,9 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { */ class PyStmtExprVisitor : public ObjectRef { public: + explicit PyStmtExprVisitor(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function f_visit_stmt, // ffi::Function f_visit_expr, // ffi::Function f_visit_let_stmt, // @@ -392,7 +396,7 @@ class PyStmtExprVisitor : public ObjectRef { ffi::Function f_visit_int_imm, // ffi::Function f_visit_float_imm, // ffi::Function f_visit_string_imm) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Set statement functions @@ -448,8 +452,8 @@ class PyStmtExprVisitor : public ObjectRef { return PyStmtExprVisitor(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprVisitor, ObjectRef, - PyStmtExprVisitorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprVisitor, ObjectRef, + PyStmtExprVisitorNode); }; /*! \brief The python interface of StmtExprMutator. */ @@ -578,11 +582,12 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { } static void RegisterReflection() { - // No fields to register as they are not visited + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - static constexpr const char* _type_key = "tir.PyStmtExprMutator"; - TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprMutatorNode, Object); + static constexpr const bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("tir.PyStmtExprMutator", PyStmtExprMutatorNode, Object); private: // Statement functions @@ -702,6 +707,9 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { /*! \brief Managed reference to PyStmtExprMutatorNode. */ class PyStmtExprMutator : public ObjectRef { public: + explicit PyStmtExprMutator(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } /*! * \brief Create a PyStmtExprMutator with customized methods on the python-side. * \return The PyStmtExprMutator created. @@ -756,7 +764,7 @@ class PyStmtExprMutator : public ObjectRef { ffi::Function f_visit_int_imm, // ffi::Function f_visit_float_imm, // ffi::Function f_visit_string_imm) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Statement functions @@ -812,28 +820,28 @@ class PyStmtExprMutator : public ObjectRef { return PyStmtExprMutator(n); } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyStmtExprMutator, ObjectRef, - PyStmtExprMutatorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprMutator, ObjectRef, + PyStmtExprMutatorNode); }; // ================================================ // TVM Register // ================================================ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { PyStmtExprVisitorNode::RegisterReflection(); PyStmtExprMutatorNode::RegisterReflection(); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.MakePyStmtExprVisitor", PyStmtExprVisitor::MakePyStmtExprVisitor) .def("tir.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator); -}); +} // StmtExprVisitor -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.PyStmtExprVisitorDefaultVisitExpr", @@ -844,10 +852,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }) .def("tir.PyStmtExprVisitorVisitExpr", [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->VisitExpr(expr); }); -}); +} // StmtExprMutator -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.PyStmtExprMutatorDefaultVisitExpr", @@ -862,7 +870,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](PyStmtExprMutator mutator, const PrimExpr& expr) { return mutator->VisitExpr(expr); }) .def("tir.PyStmtExprMutatorVisitStmt", [](PyStmtExprMutator mutator, const Stmt& stmt) { return mutator->VisitStmt(stmt); }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index d18bda77fab6..bf2b333f2501 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,10 +36,11 @@ namespace tir { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { public: - explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} + explicit ScriptCompleter(ffi::Map* buffer_var_map) + : buffer_var_map_(buffer_var_map) {} private: - Map* buffer_var_map_; + ffi::Map* buffer_var_map_; Stmt VisitStmt_(const BlockRealizeNode* op) final { for (const PrimExpr& value : op->iter_values) { CHECK(value.dtype().is_int()) @@ -81,9 +82,9 @@ class ScriptCompleter : public StmtMutator { // ignore root block or blocks which already has reads/writes regions if (mask != 0) { auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); - const Array& reads = access_region[0]; - const Array& writes = access_region[1]; - const Array& opaque = access_region[2]; + const ffi::Array& reads = access_region[0]; + const ffi::Array& writes = access_region[1]; + const ffi::Array& opaque = access_region[2]; CHECK(opaque.empty()) << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or " "direct access by buffer data. Please annotation the access region manually"; @@ -114,8 +115,8 @@ class ScriptCompleter : public StmtMutator { bool is_root_block_ = true; }; -PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { - Map buffer_var_map; +PrimFunc ScriptComplete(PrimFunc func, const ffi::Array& root_allocates) { + ffi::Map buffer_var_map; for (const auto& pair : func->buffer_map) { const Buffer& buffer = pair.second; buffer_var_map.Set(buffer->data, buffer); @@ -161,10 +162,10 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.Complete", ScriptComplete); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h index 273ca946a7ff..1facab664346 100644 --- a/src/tir/ir/script/script_complete.h +++ b/src/tir/ir/script/script_complete.h @@ -30,7 +30,7 @@ namespace tvm { namespace tir { -PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates); +PrimFunc ScriptComplete(PrimFunc func, const ffi::Array& root_allocates); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 69a7c293b19f..083dd8dedf31 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -54,7 +54,7 @@ inline bool IsParam(const PrimFunc& func, const Var& param) { PrimExpr a = VisitExpr(op->a); \ PrimExpr b = VisitExpr(op->b); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return BinaryFunc(a, b); \ } \ @@ -63,7 +63,7 @@ inline bool IsParam(const PrimFunc& func, const Var& param) { PrimExpr VisitExpr_(const UnaryNode* op) final { \ PrimExpr a = VisitExpr(op->a); \ if (a.same_as(op->a)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return UnaryFunc(a); \ } \ @@ -77,7 +77,7 @@ class PrimFuncSpecializer : public StmtExprMutator { static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) { PrimFuncSpecializer specializer(var_map); // Updating Buffer map - Map buffer_map; + ffi::Map buffer_map; bool buffer_map_updated = false; for (const auto& it : f->buffer_map) { const Var& var = it.first; @@ -91,7 +91,7 @@ class PrimFuncSpecializer : public StmtExprMutator { } // Updating parmeters - Array params; + ffi::Array params; bool param_updated = false; for (const auto& var : f->params) { // Remove parmeters which has been specialized. @@ -115,7 +115,7 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Stmt VisitStmt_(const BlockNode* op) final { // Step.0. Define buffer mappings which is allocated inside the block - Array alloc_buffers = + ffi::Array alloc_buffers = op->alloc_buffers.Map([this](const auto& buf) { return MutateAllocBuffer(buf); }); // Step.1. Recursively visit block body @@ -123,14 +123,14 @@ class PrimFuncSpecializer : public StmtExprMutator { op = stmt.as(); ICHECK(op != nullptr); - Array reads = + ffi::Array reads = op->reads.Map([this](const auto& region) { return MutateBufferRegion(region); }); - Array writes = + ffi::Array writes = op->writes.Map([this](const auto& region) { return MutateBufferRegion(region); }); if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && writes.same_as(op->writes)) { - return GetRef(op); + return ffi::GetRef(op); } else { ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); @@ -184,7 +184,7 @@ class PrimFuncSpecializer : public StmtExprMutator { auto new_buf = GetNewBuffer(op->buffer); if (new_buf.same_as(op->buffer)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->buffer = new_buf; @@ -199,18 +199,18 @@ class PrimFuncSpecializer : public StmtExprMutator { auto new_buf = GetNewBuffer(op->buffer); if (new_buf.same_as(op->buffer)) { - return GetRef(op); + return ffi::GetRef(op); } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->buffer = new_buf; return PrimExpr(n); } } PrimExpr VisitExpr_(const VarNode* op) final { - auto it = var_map_.find(GetRef(op)); + auto it = var_map_.find(ffi::GetRef(op)); if (it == var_map_.end()) { - return GetRef(op); + return ffi::GetRef(op); } else { return it->second; } @@ -242,8 +242,9 @@ class PrimFuncSpecializer : public StmtExprMutator { // of Var-to-PrimExpr remapping. Var data = VisitExpr(buffer->data).as().value_or(buffer->data); - Array shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); - Array strides = + ffi::Array shape = + buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); + ffi::Array strides = buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); @@ -252,7 +253,7 @@ class PrimFuncSpecializer : public StmtExprMutator { buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) { return buffer; } else { - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->data = std::move(data); n->elem_offset = std::move(elem_offset); n->shape = std::move(shape); @@ -304,7 +305,7 @@ class PrimFuncSpecializer : public StmtExprMutator { BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { auto it = buffer_map_.find(buffer_region->buffer); const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer; - Array region = buffer_region->region.Map( + ffi::Array region = buffer_region->region.Map( std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { return buffer_region; @@ -415,11 +416,11 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map>& param_map) { +PrimFunc Specialize(PrimFunc func, const ffi::Map>& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; - const Variant& instance = kv.second; + const ffi::Variant& instance = kv.second; if (auto opt_buffer = instance.as()) { UpdateSpecializeVarMap(func, param, opt_buffer.value(), &var_map); } else if (auto opt_expr = instance.as()) { @@ -433,10 +434,10 @@ PrimFunc Specialize(PrimFunc func, const Map>& pa /**************** FFI ****************/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Specialize", Specialize); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 4b3b4d191510..d57196dc8d62 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -32,7 +32,7 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { StmtNode::RegisterReflection(); LetStmtNode::RegisterReflection(); AttrStmtNode::RegisterReflection(); @@ -51,7 +51,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ MatchBufferRegionNode::RegisterReflection(); BlockNode::RegisterReflection(); BlockRealizeNode::RegisterReflection(); -}); +} // LetStmt LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { @@ -66,7 +66,7 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK_EQ(value.dtype(), var.dtype()); } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); @@ -74,16 +74,16 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.LetStmt", [](Var var, PrimExpr value, Stmt body, Span span) { return LetStmt(var, value, body, span); }); -}); +} // AttrStmt -AttrStmt::AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span) { - auto n = make_object(); +AttrStmt::AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span) { + auto n = ffi::make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); @@ -92,10 +92,10 @@ AttrStmt::AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Sp data_ = std::move(n); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.AttrStmt", - [](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { + [](Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span) { // when node is a POD data type like int or bool, first convert to // primexpr. if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { @@ -103,7 +103,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } return AttrStmt(node, attr_key, value, body, span); }); -}); +} // AssertStmt AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) { @@ -114,7 +114,7 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span spa ICHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); @@ -122,17 +122,18 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span spa data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.AssertStmt", [](PrimExpr condition, StringImm message, Stmt body, Span span) { return AssertStmt(condition, message, body, span); }); -}); +} // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding, Map annotations, Span span) { + ffi::Optional thread_binding, ffi::Map annotations, + ffi::Optional step, Span span) { ICHECK(loop_var.defined()); ICHECK(min.defined()); ICHECK(extent.defined()); @@ -148,8 +149,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, require_scalar_int_dtype(min, "min"); require_scalar_int_dtype(extent, "extent"); - // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them - // without raising errors. + // When extent, min or step is an IntImm but has narrower dtype than loop_var + // we directly promote them without raising errors. auto try_promote_imm_dtype = [&](const PrimExpr& e) { ICHECK(e.dtype().bits() <= loop_var.dtype().bits()) << " Loop variable's dtype (" << loop_var.dtype() @@ -168,7 +169,13 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); - ObjectPtr node = make_object(); + if (step.has_value()) { + require_scalar_int_dtype(*step, "step"); + step = try_promote_imm_dtype(*step); + ICHECK(loop_var.dtype() == (*step).dtype()) << loop_var.dtype() << " vs " << (*step).dtype(); + } + + ObjectPtr node = ffi::make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); @@ -176,19 +183,23 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, node->body = std::move(body); node->thread_binding = std::move(thread_binding); node->annotations = std::move(annotations); + node->step = std::move(step); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +bool ForNode::HasTrivialStep() const { return !step.has_value() || is_one(*step); } + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, - Stmt body, Optional thread_binding, - Optional> annotations, Span span) { + Stmt body, ffi::Optional thread_binding, + ffi::Optional> annotations, + ffi::Optional step, Span span) { return For(loop_var, min, extent, static_cast(kind), body, thread_binding, - annotations.value_or(Map()), span); + annotations.value_or(ffi::Map()), step, span); }); -}); +} std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) switch (type) { @@ -215,32 +226,25 @@ std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) While::While(PrimExpr condition, Stmt body, Span span) { ICHECK(condition.defined()); ICHECK(condition.dtype().is_scalar()); - ICHECK(condition.as() == nullptr) << "The condition should not be trivial."; ICHECK(body.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.While", [](PrimExpr condition, Stmt body, Span span) { return While(condition, body, span); }); -}); +} // Allocate -Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations, Span span) { - CHECK(IsPointerType(buffer_var->type_annotation, dtype) || - (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) - << "The allocated data type (" << dtype - << ") does not match the type annotation of the buffer " << buffer_var << " (" - << buffer_var->type_annotation - << "). The data type should be an element of the pointer type."; +Allocate::Allocate(Var buffer_var, DataType dtype, ffi::Array extents, PrimExpr condition, + Stmt body, ffi::Map annotations, Span span) { for (size_t i = 0; i < extents.size(); ++i) { ICHECK(extents[i].defined()); @@ -250,7 +254,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); @@ -261,7 +265,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim data_ = std::move(node); } -int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { +int64_t AllocateNode::ConstantAllocationSize(const ffi::Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { if (const IntImmNode* int_size = extents[i].as()) { @@ -276,22 +280,23 @@ int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { return static_cast(result); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.Allocate", [](Var buffer_var, DataType type, Array extents, PrimExpr condition, - Stmt body, Map annotations, Span span) { + "tir.Allocate", + [](Var buffer_var, DataType type, ffi::Array extents, PrimExpr condition, Stmt body, + ffi::Map annotations, Span span) { return Allocate(buffer_var, type, extents, condition, body, annotations, span); }); -}); +} // Const // The constructor to create a IRNode with constant data // depending on the type of ObjectRef, it will either // create AllocateConstNode with irmod_storage_idx or data -AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, Map annotations, - Span span) { +AllocateConst::AllocateConst(Var buffer_var, DataType dtype, ffi::Array extents, + ObjectRef data_or_idx, Stmt body, + ffi::Map annotations, Span span) { ICHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" @@ -305,26 +310,26 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext ICHECK(body.defined()); ICHECK(data_or_idx.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); node->body = std::move(body); node->annotations = annotations; node->span = std::move(span); - if (data_or_idx->IsInstance()) { - node->data = Optional(Downcast(data_or_idx)); - node->irmod_storage_idx = Optional(); + if (data_or_idx->IsInstance()) { + node->data = ffi::Optional(Downcast(data_or_idx)); + node->irmod_storage_idx = ffi::Optional(); } else if (data_or_idx->IsInstance()) { - node->data = Optional(); - node->irmod_storage_idx = Optional(Downcast(data_or_idx)); + node->data = ffi::Optional(); + node->irmod_storage_idx = ffi::Optional(Downcast(data_or_idx)); } else { LOG(FATAL) << "Data type not supported: " << data_or_idx->GetTypeKey(); } data_ = std::move(node); } -int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents) { +int64_t AllocateConstNode::ConstantAllocationSize(const ffi::Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { if (const IntImmNode* int_size = extents[i].as()) { @@ -338,35 +343,35 @@ int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents } return static_cast(result); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.AllocateConst", - [](Var buffer_var, DataType dtype, Array extents, ObjectRef data_or_idx, Stmt body, - Optional> annotations, Span span) { + [](Var buffer_var, DataType dtype, ffi::Array extents, ObjectRef data_or_idx, + Stmt body, ffi::Optional> annotations, Span span) { return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations.value_or({}), span); }); -}); +} // DeclBuffer DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->body = std::move(body); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.DeclBuffer", [](Buffer buffer, Stmt body, Span span) { return DeclBuffer(buffer, body, span); }); -}); +} // SeqStmt -SeqStmt::SeqStmt(Array seq, Span span) { +SeqStmt::SeqStmt(ffi::Array seq, Span span) { bool requires_flattening = std::any_of( seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance(); }); @@ -386,24 +391,25 @@ SeqStmt::SeqStmt(Array seq, Span span) { << "Use the node " << seq[0] << "directly, " << "or for dynamic usage, normalize using SeqStmt::Flatten()"; - auto node = make_object(); + auto node = ffi::make_object(); node->seq = std::move(seq); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.SeqStmt", - [](Array seq, Span span) { return SeqStmt(std::move(seq), span); }); -}); + refl::GlobalDef().def( + "tir.SeqStmt", [](ffi::Array seq, Span span) { return SeqStmt(std::move(seq), span); }); +} // IfThenElse -IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case, Span span) { +IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional else_case, + Span span) { ICHECK(condition.defined()); ICHECK(then_case.defined()); // else_case may be null. - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); @@ -411,33 +417,33 @@ IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional else_c data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.IfThenElse", [](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { return IfThenElse(condition, then_case, else_case, span); }); -}); +} // Evaluate Evaluate::Evaluate(PrimExpr value, Span span) { ICHECK(value.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->value = std::move(value); node->span = std::move(span); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.Evaluate", [](PrimExpr value, Span span) { return Evaluate(value, span); }); -}); +} // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate, Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -483,7 +489,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, << " lanes. The number of lanes must match."; DataType predicate_element_dtype = predicate_dtype.element_of(); - ICHECK(predicate_element_dtype.is_bool()) + ICHECK(predicate_element_dtype.is_predicate_dtype()) << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype << "."; } @@ -502,7 +508,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, << "`, but RHS's dtype is `" << value.dtype() << "`"; } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); @@ -511,32 +517,33 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tir.BufferStore", - [](Buffer buffer, PrimExpr value, Array indices, Optional predicate, - Span span) { return BufferStore(buffer, value, indices, predicate, span); }); -}); + refl::GlobalDef().def("tir.BufferStore", + [](Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate, Span span) { + return BufferStore(buffer, value, indices, predicate, span); + }); +} // BufferRealize -BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, +BufferRealize::BufferRealize(Buffer buffer, ffi::Array bounds, PrimExpr condition, Stmt body, Span span) { - data_ = make_object(buffer, bounds, condition, body, span); + data_ = ffi::make_object(buffer, bounds, condition, body, span); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BufferRealize", [](Buffer buffer, Array bounds, + refl::GlobalDef().def("tir.BufferRealize", [](Buffer buffer, ffi::Array bounds, PrimExpr condition, Stmt body, Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); -}); +} // BufferRegion PrimExpr BufferRegionNode::ToPrimExpr() const { // Auto convert to PrimExpr if it is a single point load - Array indices; + ffi::Array indices; indices.reserve(this->region.size()); for (const Range& r : this->region) { if (tvm::tir::is_one(r->extent)) { @@ -544,32 +551,32 @@ PrimExpr BufferRegionNode::ToPrimExpr() const { } else if (r->extent.as()) { indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), r->extent)); } else { - LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << GetRef(this); + LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ffi::GetRef(this); } } return tir::BufferLoad(this->buffer, indices); } -BufferRegion::BufferRegion(Buffer buffer, Array region) { +BufferRegion::BufferRegion(Buffer buffer, ffi::Array region) { CHECK_EQ(buffer->shape.size(), region.size()) << "The dimension between " << buffer << " and region " << region << " mismatched, the buffer is " << buffer; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->region = std::move(region); data_ = std::move(node); } BufferRegion BufferRegion::FullRegion(Buffer buffer) { - Array region; + ffi::Array region; for (PrimExpr extent : buffer->shape) { region.push_back(Range::FromMinExtent(0, extent)); } return BufferRegion(buffer, region); } -BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { - Array region; +BufferRegion BufferRegion::FromPoint(Buffer buffer, ffi::Array indices) { + ffi::Array region; for (const PrimExpr& index : indices) { if (const RampNode* ramp_index = index.as()) { region.push_back( @@ -581,12 +588,12 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { return BufferRegion(buffer, region); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BufferRegion", [](Buffer buffer, Array region) { + refl::GlobalDef().def("tir.BufferRegion", [](Buffer buffer, ffi::Array region) { return BufferRegion(buffer, region); }); -}); +} // MatchBufferRegion MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { @@ -603,8 +610,8 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { // Check data_alignment CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) << "Trying to match buffer to another one with lower alignment requirement " - << " required_alignment=" << buffer->data_alignment - << ", provided_alignment=" << source_buffer->data_alignment; + << " required alignment=" << buffer->data_alignment + << ", provided alignment=" << source_buffer->data_alignment; // Check BufferType. AutoBroadcast is not allowed for now. CHECK(buffer->buffer_type == BufferType::kDefault && @@ -633,24 +640,26 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { // Note that we do not check elem_offset and strides in this function // Construction - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->source = std::move(source); data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.MatchBufferRegion", [](Buffer buffer, BufferRegion source) { return MatchBufferRegion(buffer, source); }); -}); +} // Block -Block::Block(Array iter_vars, Array reads, Array writes, - String name_hint, Stmt body, Optional init, Array alloc_buffers, - Array match_buffers, Map annotations, Span span) { - ObjectPtr node = make_object(); +Block::Block(ffi::Array iter_vars, ffi::Array reads, + ffi::Array writes, ffi::String name_hint, Stmt body, + ffi::Optional init, ffi::Array alloc_buffers, + ffi::Array match_buffers, ffi::Map annotations, + Span span) { + ObjectPtr node = ffi::make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); node->writes = std::move(writes); @@ -664,24 +673,27 @@ Block::Block(Array iter_vars, Array reads, Array iter_vars, Array reads, Array writes, - String name_hint, Stmt body, Optional init, Array alloc_buffers, - Array match_buffers, Map annotations, Span span) { - return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, - annotations, span); - }); -}); + refl::GlobalDef().def("tir.Block", + [](ffi::Array iter_vars, ffi::Array reads, + ffi::Array writes, ffi::String name_hint, Stmt body, + ffi::Optional init, ffi::Array alloc_buffers, + ffi::Array match_buffers, + ffi::Map annotations, Span span) { + return Block(iter_vars, reads, writes, name_hint, body, init, + alloc_buffers, match_buffers, annotations, span); + }); +} // BlockRealize -BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block block, Span span) { +BlockRealize::BlockRealize(ffi::Array values, PrimExpr predicate, Block block, + Span span) { CHECK_EQ(block->iter_vars.size(), values.size()) << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values"; - CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression"; - ObjectPtr node = make_object(); + CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1)) + << "TypeError: Expect Block.predicate to be a bool expression"; + ObjectPtr node = ffi::make_object(); node->iter_values = std::move(values); node->predicate = std::move(predicate); node->block = std::move(block); @@ -689,17 +701,17 @@ BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block blo data_ = std::move(node); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BlockRealize", [](Array iter_values, PrimExpr predicate, + refl::GlobalDef().def("tir.BlockRealize", [](ffi::Array iter_values, PrimExpr predicate, Block block, Span span) { return BlockRealize(iter_values, predicate, block, span); }); -}); +} PrimExpr TypeAnnotation(DataType dtype, Span span) { static auto op = Op::Get("tir.type_annotation"); - return tir::Call(dtype, op, {}, span); + return tir::Call(dtype, op, {}, {}, span); } TVM_TIR_REGISTER_OP("type_annotation") diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index e580f22f6b7f..6eef6cd34414 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -46,6 +46,9 @@ void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitExpr(op->min); this->VisitExpr(op->extent); + if (op->step.has_value()) { + this->VisitExpr(*op->step); + } this->VisitStmt(op->body); } @@ -152,22 +155,22 @@ class StmtMutator::Internal { * \return The mutated array, a new copy can be created. */ template - static Array MutateArray(StmtMutator* self, const Array& arr, F fmutate) { + static ffi::Array MutateArray(StmtMutator* self, const ffi::Array& arr, F fmutate) { if (self->allow_copy_on_write_ && arr.unique()) { // if we allow copy on write, we can directly // call the inplace mutate function. - const_cast&>(arr).MutateByApply(fmutate); + const_cast&>(arr).MutateByApply(fmutate); return arr; } else { bool allow_cow = false; std::swap(allow_cow, self->allow_copy_on_write_); - Array copy = arr.Map(fmutate); + ffi::Array copy = arr.Map(fmutate); std::swap(allow_cow, self->allow_copy_on_write_); return copy; } } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const IterVar& iter_var) { PrimExpr min = self->VisitExpr(iter_var->dom->min); PrimExpr extent = self->VisitExpr(iter_var->dom->extent); @@ -181,17 +184,17 @@ class StmtMutator::Internal { return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); }; return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); }; return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const Range& r) { PrimExpr min = self->VisitExpr(r->min); PrimExpr extent = self->VisitExpr(r->extent); @@ -204,9 +207,9 @@ class StmtMutator::Internal { return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const BufferRegion& buffer_region) { - Array region = Mutate(self, buffer_region->region); + ffi::Array region = Mutate(self, buffer_region->region); if (region.same_as(buffer_region->region)) { return buffer_region; } else { @@ -216,9 +219,10 @@ class StmtMutator::Internal { return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, + const ffi::Array& arr) { auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { - Array region = Mutate(self, match_buffer_region->source->region); + ffi::Array region = Mutate(self, match_buffer_region->source->region); if (region.same_as(match_buffer_region->source->region)) { return match_buffer_region; } else { @@ -234,7 +238,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -247,7 +251,7 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -259,13 +263,19 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); + ffi::Optional step{std::nullopt}; + if (op->step.has_value()) { + step = this->VisitExpr(*op->step); + } Stmt body = this->VisitStmt(op->body); - if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body) && + step.same_as(op->step)) { + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->min = std::move(min); n->extent = std::move(extent); + n->step = std::move(step); n->body = std::move(body); return Stmt(n); } @@ -275,7 +285,7 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); @@ -285,12 +295,12 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) { } Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { - Array extents = Internal::Mutate(this, op->extents); + ffi::Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->extents = std::move(extents); @@ -301,11 +311,11 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { } Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { - Array extents = Internal::Mutate(this, op->extents); + ffi::Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); if (extents.same_as(op->extents) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->extents = std::move(extents); @@ -318,7 +328,7 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { Stmt body = this->VisitStmt(op->body); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->body = std::move(body); @@ -329,13 +339,13 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); - Optional else_case = std::nullopt; + ffi::Optional else_case = std::nullopt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); @@ -347,10 +357,10 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { PrimExpr value = this->VisitExpr(op->value); - Array indices = Internal::Mutate(this, op->indices); + ffi::Array indices = Internal::Mutate(this, op->indices); if (value.same_as(op->value) && indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -365,7 +375,7 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { Stmt body = this->VisitStmt(op->body); if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->bounds = std::move(bounds); @@ -376,9 +386,9 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { } Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { - Array seq = Internal::Mutate(this, op->seq); + ffi::Array seq = Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { - return SeqStmt::Flatten(GetRef(op)); + return SeqStmt::Flatten(ffi::GetRef(op)); } else { auto node = CopyOnWrite(op); node->seq = std::move(seq); @@ -400,10 +410,10 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit } // function to run the visit. auto frunvisit = [&](const SeqStmtNode* op) { - Array seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate) - : Internal::Mutate(this, op->seq); + ffi::Array seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate) + : Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->seq = std::move(seq); @@ -411,7 +421,7 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit } }; if (flatten_before_visit) { - Array seq; + ffi::Array seq; SeqStmt::Flattener flattener(&seq); flattener(0, op->seq); // NOTE: If copy on write is allowed @@ -435,7 +445,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); @@ -448,7 +458,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -457,11 +467,11 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { } Stmt StmtMutator::VisitStmt_(const BlockNode* op) { - Array iter_vars = Internal::Mutate(this, op->iter_vars); - Array reads = Internal::Mutate(this, op->reads); - Array writes = Internal::Mutate(this, op->writes); - Array match_buffers = Internal::Mutate(this, op->match_buffers); - Optional init = std::nullopt; + ffi::Array iter_vars = Internal::Mutate(this, op->iter_vars); + ffi::Array reads = Internal::Mutate(this, op->reads); + ffi::Array writes = Internal::Mutate(this, op->writes); + ffi::Array match_buffers = Internal::Mutate(this, op->match_buffers); + ffi::Optional init = std::nullopt; if (op->init.defined()) { init = VisitStmt(op->init.value()); } @@ -469,7 +479,7 @@ Stmt StmtMutator::VisitStmt_(const BlockNode* op) { if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) && body.same_as(op->body) && init.same_as(op->init) && match_buffers.same_as(op->match_buffers)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->iter_vars = std::move(iter_vars); @@ -483,11 +493,11 @@ Stmt StmtMutator::VisitStmt_(const BlockNode* op) { } Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { - Array v = Internal::Mutate(this, op->iter_values); + ffi::Array v = Internal::Mutate(this, op->iter_values); PrimExpr pred = this->VisitExpr(op->predicate); Stmt block = this->VisitStmt(op->block); if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->iter_values = std::move(v); @@ -575,7 +585,7 @@ class IRTransformer final : public StmtExprMutator { }; Stmt IRTransform(Stmt ir_node, const ffi::Function& f_preorder, const ffi::Function& f_postorder, - Optional> only_enable) { + ffi::Optional> only_enable) { std::unordered_set only_type_index; if (only_enable.defined()) { for (auto s : only_enable.value()) { @@ -588,10 +598,10 @@ Stmt IRTransform(Stmt ir_node, const ffi::Function& f_preorder, const ffi::Funct class IRSubstitute : public StmtExprMutator { public: - explicit IRSubstitute(std::function(const Var&)> vmap) : vmap_(vmap) {} + explicit IRSubstitute(std::function(const Var&)> vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto ret = vmap_(var); if (ret.defined()) { // Allow substitution of void variables with any expression. The TVM script parser @@ -673,13 +683,18 @@ class IRSubstitute : public StmtExprMutator { if (auto mapped_var = vmap_(var_node.value())) { return AttrStmt(mapped_var, op->attr_key, op->value, op->body); } + } else if (auto expr_node = op->node.as()) { + PrimExpr new_expr = VisitExpr(expr_node.value()); + if (!new_expr.same_as(expr_node.value())) { + return AttrStmt(new_expr, op->attr_key, op->value, op->body); + } } return ret; } private: // Caller provided function that defines the variables to be remapped. - std::function(const Var&)> vmap_; + std::function(const Var&)> vmap_; /* \brief Generated map to track buffers being remapped. * @@ -691,11 +706,11 @@ class IRSubstitute : public StmtExprMutator { std::unordered_map buf_remap_; }; -Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { +Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { return IRSubstitute(vmap)(std::move(stmt)); } -PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { +PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { return IRSubstitute(vmap)(std::move(expr)); } @@ -743,14 +758,15 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr, class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { public: - explicit IRSubstituteWithDataTypeLegalization(std::function(const Var&)> vmap) + explicit IRSubstituteWithDataTypeLegalization( + std::function(const Var&)> vmap) : vmap_(vmap) {} using DataTypeLegalizer::VisitExpr_; using DataTypeLegalizer::VisitStmt_; PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto ret = vmap_(var); if (ret.defined()) { return ret.value(); @@ -811,7 +827,7 @@ class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { private: // Caller provided function that defines the variables to be remapped. - std::function(const Var&)> vmap_; + std::function(const Var&)> vmap_; /* \brief Generated map to track buffers being remapped. * @@ -824,16 +840,16 @@ class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { }; Stmt SubstituteWithDataTypeLegalization(Stmt stmt, - std::function(const Var&)> vmap) { + std::function(const Var&)> vmap) { return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt)); } -PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, - std::function(const Var&)> vmap) { +PrimExpr SubstituteWithDataTypeLegalization( + PrimExpr expr, std::function(const Var&)> vmap) { return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.IRTransform", IRTransform) @@ -845,14 +861,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ObjectRef node, ffi::Function f) { tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); }) - .def("tir.Substitute", [](ObjectRef node, Map vmap) -> ObjectRef { + .def("tir.Substitute", [](ObjectRef node, ffi::Map vmap) -> ObjectRef { if (node->IsInstance()) { return Substitute(Downcast(node), vmap); } else { return Substitute(Downcast(node), vmap); } }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index aa3ca1959c5d..b76234ecb856 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -43,7 +43,7 @@ void TIRVisitorWithPath::Visit(const IRModule& mod, AccessPath path) { std::unordered_set externally_exposed; for (const auto& [gvar, func] : mod->functions) { gvars.push_back(gvar); - if (func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { externally_exposed.insert(gvar); } } @@ -193,7 +193,7 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { // `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any // symbolic shapes used within `buffer_view that are not already // defined. - Array arr = Downcast>(op->node); + ffi::Array arr = Downcast>(op->node); ICHECK_EQ(arr.size(), 2U); Buffer buffer_view = Downcast(arr[0]); Buffer orig_buffer = Downcast(arr[1]); @@ -203,8 +203,11 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { context.push_back(std::move(var)); } - } else if (auto expr = op->node.as()) { - Visit(expr.value(), path->Attr("node")); + } else if (op->node != nullptr) { + auto expr = op->node.as(); + if (expr) { + Visit(expr.value(), path->Attr("node")); + } } Visit(op->body, path->Attr("body")); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 0ff9da33eb6d..65673d1f2b34 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -85,7 +85,7 @@ class TIRVisitorWithPath // Utility to visit an array of nodes template - inline void Visit(const Array& arr, ffi::reflection::AccessPath path) { + inline void Visit(const ffi::Array& arr, ffi::reflection::AccessPath path) { for (size_t i = 0; i < arr.size(); i++) { Visit(arr[i], path->ArrayItem(i)); } @@ -93,7 +93,7 @@ class TIRVisitorWithPath // Utility to visit an optional node nodes template - inline void Visit(const Optional& opt, ffi::reflection::AccessPath path) { + inline void Visit(const ffi::Optional& opt, ffi::reflection::AccessPath path) { if (opt) { Visit(opt.value(), path); } @@ -229,7 +229,7 @@ class TIRVisitorWithPath } }; auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def]( - const Array& arr, + const ffi::Array& arr, ffi::reflection::AccessPath path) { for (size_t i = 0; i < arr.size(); i++) { try_visit_implicit_var_def(arr[i], path->ArrayItem(i)); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index aafe6277e24d..68b494d41144 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -43,7 +43,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", ffi::Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); @@ -82,9 +82,7 @@ class PrimFuncPassNode : public PassNode { * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } - - static constexpr const char* _type_key = "tir.PrimFuncPass"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.PrimFuncPass", PrimFuncPassNode, PassNode); }; class PrimFuncPass : public Pass { @@ -97,12 +95,12 @@ class PrimFuncPass : public Pass { TVM_DLL PrimFuncPass(std::function pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimFuncPass, Pass, PrimFuncPassNode); }; PrimFuncPass::PrimFuncPass(std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -141,14 +139,15 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) } Pass CreatePrimFuncPass(std::function pass_func, - int opt_level, String name, tvm::Array required, bool traceable) { + int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return PrimFuncPass(std::move(pass_func), pass_info); } -TVM_FFI_STATIC_INIT_BLOCK({ PrimFuncPassNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { PrimFuncPassNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.transform.CreatePrimFuncPass", @@ -159,7 +158,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }; return PrimFuncPass(wrapped_pass_func, pass_info); }); -}); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 12c7c8d33c7f..6ce2ae09e2da 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -52,6 +52,14 @@ TIR_DEFINE_BUILTIN_FUNC(thread_return) .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) .set_num_inputs(0); +TIR_DEFINE_BUILTIN_FUNC(continue_loop) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) + .set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(break_loop) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) + .set_num_inputs(0); + TIR_DEFINE_BUILTIN_FUNC(likely) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation)) @@ -115,6 +123,9 @@ TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr( TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(isfinite).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kPure)); + TIR_DEFINE_BUILTIN_FUNC(popcount) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) @@ -203,11 +214,11 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) // When num_inputs are not set, the function is assumed to be variable length. TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_packed"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_packed"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_cpacked"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_cpacked"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -222,12 +233,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_invariant) TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_packed_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_packed_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_cpacked_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_cpacked_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 9ced6f556cb0..2b4ccf7a1ad8 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -114,14 +114,14 @@ Type GetTypeFromRuntimeDataType(const DataType& dtype) { PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) { return tir::Call( t, tir::builtin::large_uint_imm(), - {make_const(DataType::UInt(32), low, span), make_const(DataType::UInt(32), high, span)}, + {make_const(DataType::UInt(32), low, span), make_const(DataType::UInt(32), high, span)}, {}, span); } // Q-multiplication PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span) { return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(), - {x, y, q, s}, span); + {x, y, q, s}, {}, span); } void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*) @@ -214,6 +214,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } else if (ltype.is_float4() && !rtype.is_float4()) { // Cast int->float4 for rhs when lhs is a float4 rhs = cast(ltype, rhs); + } else if (ltype.is_bool() && (rtype.is_int() || rtype.is_uint())) { + // Cast bool to int for lhs when rhs is a int or uint + lhs = cast(rtype, lhs); + } else if ((ltype.is_int() || ltype.is_uint()) && rtype.is_bool()) { + // Cast bool to int for rhs when lhs is a int or uint + rhs = cast(ltype, rhs); } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (ltype.bits() < rtype.bits()) { @@ -243,22 +249,29 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) PrimExpr ret(PrimExpr value, Span span) { CHECK(value.defined()); - return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); + return tir::Call(value.dtype(), tir::builtin::ret(), {value}, {}, span); } -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.ret", ret); -}); - PrimExpr thread_return(Span span) { - return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span); + return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, {}, span); +} + +PrimExpr continue_loop(Span span) { + return tir::Call(DataType::Void(), tir::builtin::continue_loop(), {}, {}, span); +} + +PrimExpr break_loop(Span span) { + return tir::Call(DataType::Void(), tir::builtin::break_loop(), {}, {}, span); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.thread_return", thread_return); -}); + refl::GlobalDef() + .def("tir.ret", ret) + .def("tir.thread_return", thread_return) + .def("tir.continue_loop", continue_loop) + .def("tir.break_loop", break_loop); +}; // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { @@ -288,6 +301,8 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 16) { return FloatImm(dtype, 65504.0, span); } + } else if (dtype.is_tfloat32()) { + return FloatImm(dtype, std::numeric_limits::max(), span); } else if (dtype.is_bfloat16()) { return FloatImm(dtype, std::numeric_limits::max(), span); } else if (dtype.is_float8()) { @@ -323,14 +338,7 @@ PrimExpr max_value(const DataType& dtype, Span span) { PrimExpr min_value(const DataType& dtype, Span span) { using namespace tir; ICHECK_EQ(dtype.lanes(), 1); - if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) { - // TODO(tkonolige): need to convert all registered min functions to use the span. - auto f = datatype::GetMinFunc(dtype.code()); - ICHECK(f) << "No minimum function registered for custom dtype " << (unsigned int)dtype.code(); - // TODO(@hypercubestart) Document this change (and others associated with the overflowing - // floatimm min bug) - return (*f)(dtype.bits()).cast(); - } else if (dtype.is_int()) { + if (dtype.is_int()) { if (dtype.bits() == 64) { return IntImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.bits() < 64) { @@ -348,6 +356,9 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 16) { return FloatImm(dtype, -65504.0, span); } + } + else if (dtype.is_tfloat32()) { + return FloatImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.is_bfloat16()) { return FloatImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.is_float8()) { @@ -484,7 +495,7 @@ PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { value.dtype().bytes() * value.dtype().lanes() == t.bytes() * t.lanes())) << "Reinterpret requires size match " << t << " vs " << value.dtype(); } - return tir::Call(t, tir::builtin::reinterpret(), {value}, span); + return tir::Call(t, tir::builtin::reinterpret(), {value}, {}, span); } // operator+ @@ -614,7 +625,7 @@ PrimExpr max(PrimExpr a, PrimExpr b, Span span) { // if_then_else PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { - ICHECK(cond.dtype() == DataType::Bool(1)) + ICHECK(cond.dtype() == DataType::Bool()) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value, span); if (const IntImmNode* op = cond.as()) { @@ -626,13 +637,13 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, } return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), - {cond, true_value, false_value}, span); + {cond, true_value, false_value}, {}, span); } // likely PrimExpr likely(PrimExpr cond, Span span) { if (is_const_int(cond)) return cond; - return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, span); + return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, {}, span); } // operator> @@ -691,10 +702,10 @@ void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const cha << rhs << " of type " << rhs.dtype(); } -void type_check_integer_args(const PrimExpr& arg, const char* op) { - ICHECK(arg.dtype().is_int() || arg.dtype().is_uint()) - << "Expected integer argument for " << op << ", but received " << arg << " of type " - << arg.dtype(); +void type_check_int_or_bool_args(const PrimExpr& arg, const char* op) { + ICHECK(arg.dtype().is_int() || arg.dtype().is_uint() || arg.dtype().is_bool()) + << "Expected integer or boolean argument for " << op << ", but received " << arg + << " of type " << arg.dtype(); } void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { @@ -705,6 +716,15 @@ void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const cha << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type " << rhs.dtype(); } + +void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { + ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || lhs.dtype().is_bool()) + << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type " + << lhs.dtype(); + ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || rhs.dtype().is_bool()) + << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type " + << rhs.dtype(); +} } // namespace PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); } @@ -749,7 +769,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) { } }); - return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, {}, span); } // shift left @@ -768,58 +788,58 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, {}, span); } // bitwise and PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); } PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) { - type_check_integer_args(a, b, "& operator (bitwise AND)"); + type_check_int_or_bool_args(a, b, "& operator (bitwise AND)"); BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, {}, span); } // bitwise_or PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); } PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) { - type_check_integer_args(a, b, "| operator (bitwise OR)"); + type_check_int_or_bool_args(a, b, "| operator (bitwise OR)"); BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, {}, span); } // bitwise_xor PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); } PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) { - type_check_integer_args(a, b, "^ operator (bitwise XOR)"); + type_check_int_or_bool_args(a, b, "^ operator (bitwise XOR)"); BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, {}, span); } // bitwise_not PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); } PrimExpr bitwise_neg(PrimExpr a, Span span) { - type_check_integer_args(a, "~ operator (bitwise NOT)"); - return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); + type_check_int_or_bool_args(a, "~ operator (bitwise NOT)"); + return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, {}, span); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.bitwise_not", [](PrimExpr a, Span span) { return bitwise_neg(a, span); }); -}); +} // pow PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { @@ -852,7 +872,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { } static auto op = Op::Get("tir.pow"); - return tir::Call(x.dtype(), op, {x, y}, span); + return tir::Call(x.dtype(), op, {x, y}, {}, span); } TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr("TVectorizable", true); @@ -866,14 +886,14 @@ PrimExpr abs(PrimExpr x, Span span) { return IntImm(x.dtype(), std::abs(px->value), px->span); } return tir::Select(x >= make_zero(x.dtype()), x, -x, span); - } else if (x.dtype().is_float() || x.dtype().is_bfloat()) { + } else if (x.dtype().is_float() || x.dtype().is_bfloat() || x.dtype().is_tfloat()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value), fx->span); } static auto op = Op::Get("tir.fabs"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } else if (x.dtype().is_uint()) { return x; } else { @@ -898,9 +918,9 @@ PrimExpr isnan(PrimExpr x, Span span) { } static auto op = Op::Get("tir.isnan"); if (x.dtype().bits() == 16) { - return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, span); + return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, {}, span); } else { - return tir::Call(t, op, {x}, span); + return tir::Call(t, op, {x}, {}, span); } } else { LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; @@ -921,56 +941,73 @@ PrimExpr isinf(PrimExpr x, Span span) { } // isfinite -PrimExpr isfinite(PrimExpr x, Span span) { return !isinf(x, span) && !isnan(x, span); } +PrimExpr isfinite(PrimExpr x, Span span) { + DataType t = DataType::Bool(x.dtype().lanes()); + if (x.dtype().is_int() || x.dtype().is_uint()) { + return make_const(t, true, span); + } else if (x.dtype().is_float()) { + using tir::FloatImmNode; + const FloatImmNode* fx = x.as(); + if (fx) { + return make_const(t, std::isfinite(fx->value), fx->span); + } + if (x.dtype().bits() == 32 || x.dtype().bits() == 64) { + return tir::Call(t, builtin::isfinite(), {x}, {}, span); + } + return !isinf(x, span) && !isnan(x, span); + } else { + LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it..."; + } +} -PrimExpr sum(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr sum(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Add(x, y, span); PrimExpr identity_element = make_zero(source.dtype(), span); tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } -PrimExpr all(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr all(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { type_check_boolean_args(source, "tvm::all"); Var x("x", source.dtype(), span), y("y", source.dtype()); PrimExpr result = tir::And(x, y, span); PrimExpr identity_element = make_const(source.dtype(), true, span); tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } -PrimExpr any(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr any(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { type_check_boolean_args(source, "tvm::any"); Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Or(x, y, span); PrimExpr identity_element = make_const(source.dtype(), false, span); tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } -PrimExpr max(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr max(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Max(x, y, span); PrimExpr identity_element = min_value(source.dtype(), span); tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } -PrimExpr min(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr min(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Min(x, y, span); PrimExpr identity_element = max_value(source.dtype(), span); tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } -PrimExpr prod(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr prod(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Mul(x, y, span); PrimExpr identity_element = make_const(source.dtype(), 1, span); tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } // fmod @@ -978,70 +1015,70 @@ PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); ICHECK(x.dtype().is_float()) << "fmod only applies to float"; static auto op = Op::Get("tir.fmod"); - return tir::Call(x.dtype(), op, {x, y}, span); + return tir::Call(x.dtype(), op, {x, y}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("fmod"); // floor PrimExpr floor(PrimExpr x, Span span) { - if (x.dtype().is_int() || x.dtype().is_uint()) { + if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span); static auto op = Op::Get("tir.floor"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr("TVectorizable", true); // ceil PrimExpr ceil(PrimExpr x, Span span) { - if (x.dtype().is_int() || x.dtype().is_uint()) { + if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span); static auto op = Op::Get("tir.ceil"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr("TVectorizable", true); // round PrimExpr round(PrimExpr x, Span span) { - if (x.dtype().is_int() || x.dtype().is_uint()) { + if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); static auto op = Op::Get("tir.round"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr("TVectorizable", true); // nearbyint PrimExpr nearbyint(PrimExpr x, Span span) { - if (x.dtype().is_int() || x.dtype().is_uint()) { + if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); static auto op = Op::Get("tir.nearbyint"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint"); // trunc PrimExpr trunc(PrimExpr x, Span span) { - if (x.dtype().is_int() || x.dtype().is_uint()) { + if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } using tir::FloatImmNode; @@ -1051,7 +1088,7 @@ PrimExpr trunc(PrimExpr x, Span span) { fx->span); } static auto op = Op::Get("tir.trunc"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr("TVectorizable", true); @@ -1127,7 +1164,7 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("node._const", @@ -1158,7 +1195,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("tir.trunc", tvm::trunc) .def("tir._cast", tvm::cast) .def("tir.reinterpret", tvm::reinterpret); -}); +} // operator overloading, smarter than make #define DEF_MAKE_BINARY_OP(Node, Func) \ @@ -1169,15 +1206,25 @@ TVM_FFI_STATIC_INIT_BLOCK({ bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ if (lhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + auto arg1 = args[1].cast(); \ + if(arg1.dtype().is_uint()) { \ + *ret = Func(make_const(arg1.dtype(), args[0].cast()), arg1, args[2].cast()); \ + } else { \ + *ret = Func(make_const(arg1.dtype(), args[0].cast()), arg1, args[2].cast()); \ + } \ } else if (rhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + auto arg0 = args[0].cast(); \ + if(arg0.dtype().is_uint()) { \ + *ret = Func(arg0, make_const(arg0.dtype(), args[1].cast()), args[2].cast()); \ + } else { \ + *ret = Func(arg0, make_const(arg0.dtype(), args[1].cast()), args[2].cast()); \ + } \ } else { \ *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ } \ }) -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir._OpIfThenElse", @@ -1214,7 +1261,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .DEF_MAKE_BIT_OP(bitwise_xor, bitwise_xor) .DEF_MAKE_BIT_OP(left_shift, left_shift) // NOLINT(*) .DEF_MAKE_BIT_OP(right_shift, right_shift); -}); +} PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) { auto plus_4 = make_const(DataType::Float(bits), 4.f); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 25d09ff931ea..1285c2c5f0ab 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -230,7 +230,7 @@ bool IsWriteCache(const StmtSRef& block_sref); * \param analyzer The analyzer * \return A boolean flag indicating if the binding is affine */ -bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, +bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& loop_var_ranges, arith::Analyzer* analyzer); /*! @@ -251,7 +251,7 @@ void CheckAffineBinding(const ScheduleState& self, Block block); * \throw ScheduleError If the input block does not have an affine binding */ void CheckPartialAffineBinding(const ScheduleState& self, Block block, - const Optional& high_exclusive); + const ffi::Optional& high_exclusive); /*! * \brief Extracts the ranges of loop variables in a path of the sref tree @@ -263,17 +263,17 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, * - if the storage scope is shared, it will look for threadIdx.x/y/z * \return The loop domain */ -Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, - const Optional& high_exclusive = std::nullopt, - const runtime::StorageScope& extra_relax_scope = // - runtime::StorageScope{runtime::StorageRank::kGlobal, ""}); +ffi::Map LoopDomainOfSRefTreePath( + const StmtSRef& low_inclusive, const ffi::Optional& high_exclusive = std::nullopt, + const runtime::StorageScope& extra_relax_scope = // + runtime::StorageScope{runtime::StorageRank::kGlobal, ""}); /*! * \brief Returns the block var binding * \param realize The BlockRealize to be analyzed * \return The block var binding */ -Map GetBindings(const BlockRealize& realize); +ffi::Map GetBindings(const BlockRealize& realize); /*! * \brief Get the vars involved in the bindings of data parallel block vars and reduction block @@ -316,14 +316,15 @@ void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& bloc * \param parent_sref The StmtSRef that points to the parent block/loop * \return A list of StmtSRefs of leaf block */ -Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref); +ffi::Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref); /*! * \brief Gets the BlockRealize of the leaf blocks of a scope where a specific block/loop is in * \param parent_sref The StmtSRef that points to the parent block/loop * \return A list of leaf BlockRealize */ -Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); +ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); /*! * \brief Get the BlockRealize of the single child block of the block or loop specified by @@ -357,7 +358,7 @@ IterVarType GetLoopIterType(const StmtSRef& loop_sref); * \return The lowest common ancestor of the input block srefs or loop srefs * \note The input array is required to have at least one sref */ -StmtSRef GetSRefLowestCommonAncestor(const Array& srefs); +StmtSRef GetSRefLowestCommonAncestor(const ffi::Array& srefs); /*! * \brief Checks if the given block has been applied by multi-level tiling. We check this by @@ -374,8 +375,8 @@ bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); * \return All the feasible compute-at locations of the input block, given as an array of loop srefs * and an array of their indices among the outer loops of the input block */ -std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, - const StmtSRef& block_sref); +std::pair, std::vector> CollectComputeLocation( + const ScheduleState& self, const StmtSRef& block_sref); /******** Producer-consumer relation ********/ @@ -385,7 +386,7 @@ std::pair, std::vector> CollectComputeLocation(const Schedu * \param scope The block scope where the given block is in * \return The producer blocks of the specified block */ -Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope); +ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope); /*! * \brief Get the consumer blocks to the given block under the given scope @@ -393,7 +394,7 @@ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope * \param scope The block scope where the given block is in * \return The consumer blocks of the specified block */ -Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); +ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); /*! * \brief Get the list of output blocks within the given scope @@ -403,7 +404,7 @@ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope * \return A list of all blocks that write to some output buffer * block */ -Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block); +ffi::Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block); /*! * \brief A solution to split a ordered list of subtrees into two parts, @@ -431,8 +432,9 @@ struct ProducerConsumerSplit { * \throw ScheduleError is not valid split is found */ static ProducerConsumerSplit Find( - const ScheduleState& state, const Array& subtrees, - const Array& producer_block_srefs, const Array& consumer_block_srefs, + const ScheduleState& state, const ffi::Array& subtrees, + const ffi::Array& producer_block_srefs, + const ffi::Array& consumer_block_srefs, std::unordered_map* block2realize); }; @@ -469,8 +471,8 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl * \return The defining site of the buffer and whether the buffer is allocated (otherwise the * buffer is from match_buffer). */ -std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, - const Buffer& buffer); +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer); /******** Reduction Block Related ********/ @@ -481,8 +483,8 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ * \return The extracted init values and BufferStore updates * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block */ -std::pair, Array> GetInitValuesAndUpdatesFromReductionBlock( - const Optional& self, Block block); +std::pair, ffi::Array> GetInitValuesAndUpdatesFromReductionBlock( + const ffi::Optional& self, Block block); /*! * \brief Check whether the input array of IterVars only contains data-parallel and reduction block @@ -491,7 +493,7 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct * \return A boolean indicating whether the input array of IterVars only contains data-parallel and * reduction block iters */ -bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters); +bool ContainsOnlyDataParAndReductionBlockIter(const ffi::Array& iters); /*! * \brief Check whether the block's reduction block iters are not used to index the block's output @@ -511,9 +513,9 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block); * \return The corresponding CommReducer, combiner LHS values and combiner RHS values * \throw ScheduleError If no corresponding commutative reducer can be matched */ -std::tuple, Array> GetReducerAndCombinerLhsRhs( - const Optional& self, const Array& identities, - const Array& combiners); +std::tuple, ffi::Array> GetReducerAndCombinerLhsRhs( + const ffi::Optional& self, const ffi::Array& identities, + const ffi::Array& combiners); /******** Commutative Reducer ********/ @@ -522,7 +524,8 @@ std::tuple, Array> GetReducerAndCombinerL * \return The list of the registered reducer-getter functions * \sa ReducerRegistry */ -std::vector(Array)>> GetReducerGetters(); +std::vector(ffi::Array)>> +GetReducerGetters(); /*! * \brief Given the input identities and the combiner BufferStores of a reduction, extract the @@ -534,8 +537,9 @@ std::vector(Array)>> GetReduc * \param rhs The extracted RHS values of the reducer * \return A boolean indicating whether a corresponding commutative reducer is found */ -bool FromIdentityCombiner(const Array& identities, const Array& combiners, - CommReducer* result_reducer, Array* lhs, Array* rhs); +bool FromIdentityCombiner(const ffi::Array& identities, + const ffi::Array& combiners, CommReducer* result_reducer, + ffi::Array* lhs, ffi::Array* rhs); /******** Misc ********/ @@ -545,7 +549,7 @@ bool FromIdentityCombiner(const Array& identities, const Array SuggestIndexMap(const Buffer& buffer, const Array& indices, - const Array& loops, const PrimExpr& predicate, - arith::Analyzer* analyzer); +ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array& indices, + const ffi::Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer); /*! * \brief Checks if the given AST contains the specific operators @@ -605,7 +609,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& * \param ops The list of operators to be checked * \return A boolean indicating whether the AST contains the specific operators */ -bool HasOp(const Stmt& stmt, const Array& ops); +bool HasOp(const Stmt& stmt, const ffi::Array& ops); /*! * \brief Checks if the given AST statement contains if-then-else, including @@ -697,10 +701,11 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // * \param dom_high_exclusive The highest node in the sref tree path * \return An n-dimensional integer set */ -Array AnalyzeRegionUpperBound(const BufferRegion& region, const PrimExpr& predicate, - const StmtSRef& dom_low_inclusive, - const StmtSRef& dom_high_exclusive, - arith::Analyzer* analyzer); +ffi::Array AnalyzeRegionUpperBound(const BufferRegion& region, + const PrimExpr& predicate, + const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive, + arith::Analyzer* analyzer); /*! * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) @@ -712,10 +717,11 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, const P * \param analyzer The analyzer * \return An n-dimensional integer set */ -Array AnalyzeRegionLowerBound(const BufferRegion& region, const PrimExpr& predicate, - const StmtSRef& dom_low_inclusive, - const StmtSRef& dom_high_exclusive, - arith::Analyzer* analyzer); +ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, + const PrimExpr& predicate, + const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive, + arith::Analyzer* analyzer); /*! * \brief Simplify non-trivial expressions @@ -733,13 +739,13 @@ PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) class TensorizeInfoNode : public Object { public: /*! \brief Maps loops in a target block to the ones in an intrinsic description */ - Map loop_map; + ffi::Map loop_map; /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ - Map desc_loop_indexer; + ffi::Map desc_loop_indexer; /*! \brief Optional padded extents of the block iters when padding is needed to match the * intrinsic description */ - Optional> block_iter_paddings; + ffi::Optional> block_iter_paddings; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -748,14 +754,15 @@ class TensorizeInfoNode : public Object { .def_ro("desc_loop_indexer", &TensorizeInfoNode::desc_loop_indexer) .def_ro("block_iter_paddings", &TensorizeInfoNode::block_iter_paddings); } - - static constexpr const char* _type_key = "tir.schedule.TensorizeInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.schedule.TensorizeInfo", TensorizeInfoNode, Object); }; class TensorizeInfo : public ObjectRef { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); + explicit TensorizeInfo(ObjectPtr data) : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TensorizeInfo, ObjectRef, TensorizeInfoNode); }; /*! @@ -766,26 +773,27 @@ class TensorizeInfo : public ObjectRef { * \param allow_padding Whether to allow padding the block iters to match the intrinsic description * \return TensorizeInfo structure if a valid mapping is found, std::nullopt otherwise */ -Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func, bool allow_padding); +ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func, + bool allow_padding); /*!\brief Necessary information used to perform transformations for tensorization */ class AutoTensorizeMappingInfoNode : public Object { public: /*! \brief Possible mappings to apply to block iters */ - Array mappings; + ffi::Array mappings; /* Additional information from AutoTensorizeComparator */ /*! \brief Mapping from LHS buffer to RHS buffer */ - Map lhs_buffer_map; + ffi::Map lhs_buffer_map; /*! \brief Buffer indices on RHS */ - Map> rhs_buffer_indices; + ffi::Map> rhs_buffer_indices; /*! \brief Block iters on LHS */ - Array lhs_iters; + ffi::Array lhs_iters; /*! \brief Block iters on RHS */ - Array rhs_iters; + ffi::Array rhs_iters; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -796,15 +804,18 @@ class AutoTensorizeMappingInfoNode : public Object { .def_ro("lhs_iters", &AutoTensorizeMappingInfoNode::lhs_iters) .def_ro("rhs_iters", &AutoTensorizeMappingInfoNode::rhs_iters); } - - static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.schedule.AutoTensorizeMappingInfo", + AutoTensorizeMappingInfoNode, Object); }; class AutoTensorizeMappingInfo : public ObjectRef { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef, - AutoTensorizeMappingInfoNode); + explicit AutoTensorizeMappingInfo(ObjectPtr data) + : ObjectRef(data) { + TVM_FFI_ICHECK(data != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AutoTensorizeMappingInfo, ObjectRef, + AutoTensorizeMappingInfoNode); }; /*! @@ -818,9 +829,9 @@ class AutoTensorizeMappingInfo : public ObjectRef { * tensorized. We will need to apply the suggested layout transformations and then match against the * tensor intrinsics. */ -Optional GetAutoTensorizeMappingInfo(const ScheduleState& self, - const StmtSRef& block_sref, - const PrimFunc& desc_func); +ffi::Optional GetAutoTensorizeMappingInfo(const ScheduleState& self, + const StmtSRef& block_sref, + const PrimFunc& desc_func); /*! * \brief Perform basic checks for auto tensorization applicability, such as the structure of diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 91c63c3469bb..75cbd5f3e4c1 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -24,10 +24,10 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { TensorizeInfoNode::RegisterReflection(); AutoTensorizeMappingInfoNode::RegisterReflection(); -}); +} /******** IR Module ********/ @@ -49,7 +49,7 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl } LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " "statement:\n" - << GetRef(root_block); + << ffi::GetRef(root_block); throw; } @@ -61,13 +61,13 @@ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, public: explicit RootBlockError(IRModule mod) : mod_(mod) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The primitive does not operate on the root block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The primitive does not operate on the root block"; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; }; @@ -75,10 +75,10 @@ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, public: explicit NotStagePipelineError(IRModule mod, Block block) : mod_(mod), block_(block) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The scope root is not a stage pipeline"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return R"(The scope {0} is not a stage pipeline. Definition of a scope that is a stage pipeline: - The region cover property holds for every of its child blocks @@ -87,7 +87,7 @@ Definition of a scope that is a stage pipeline: - All the statements in the scope are schedulable statements, i.e. Block and For )"; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; @@ -100,8 +100,8 @@ Definition of a scope that is a stage pipeline: const StmtSRefNode* subtree = sref.get(); for (; p != nullptr; subtree = p, p = p->parent) { if (p->stmt->IsInstance()) { - scope_root_sref = GetRef(p); - scope_root_subtree = GetRef(subtree); + scope_root_sref = ffi::GetRef(p); + scope_root_subtree = ffi::GetRef(subtree); break; } } @@ -114,7 +114,7 @@ Definition of a scope that is a stage pipeline: bool stage_pipeline = self->GetBlockInfo(scope_root_sref).stage_pipeline; if (stage_pipeline == false) { const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref); - throw NotStagePipelineError(self->mod, GetRef(block)); + throw NotStagePipelineError(self->mod, ffi::GetRef(block)); } } return scope_root_sref; @@ -123,9 +123,9 @@ Definition of a scope that is a stage pipeline: ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { struct Collector : public StmtVisitor { void VisitStmt_(const BlockRealizeNode* realize) final { - result.realizes.push_back(GetRef(realize)); - const Array& iter_vars = realize->block->iter_vars; - const Array& iter_values = realize->iter_values; + result.realizes.push_back(ffi::GetRef(realize)); + const ffi::Array& iter_vars = realize->block->iter_vars; + const ffi::Array& iter_values = realize->iter_values; ICHECK_EQ(iter_vars.size(), iter_values.size()); int n = realize->iter_values.size(); for (int i = 0; i < n; ++i) { @@ -175,7 +175,7 @@ void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) { */ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, const StmtSRef& block_sref) { - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; CheckSRefHigherOrEqual(scope_root_sref, block_sref); const BlockNode* maybe_root_block = scope_root_sref->StmtAs(); if (maybe_root_block) { @@ -183,7 +183,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, buffer_writers = scope->buffer_writers; } else { // Collect all child blocks of root sub-tree, and merge their buffer writers. - Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); + ffi::Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); for (const StmtSRef& child_block_sref : child_block_srefs) { BlockScope child_scope = self->GetBlockScope(child_block_sref); for (const auto& it : child_scope->buffer_writers) { @@ -275,15 +275,15 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, public: explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} - String FastErrorString() const final { return "ScheduleError: Incomplete block"; } - String DetailRenderTemplate() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Incomplete block"; } + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block {0} is not a complete block - it violates condition #" << violated_cond_; os << ".\n" << kCompleteBlockDefinition; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; int violated_cond_; @@ -292,7 +292,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw IncompleteBlockError(self->mod, GetRef(block), error_code); + throw IncompleteBlockError(self->mod, ffi::GetRef(block), error_code); } } @@ -327,7 +327,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc return 4; } // Cond 5. The reduction block vars are not used to index the output buffers. - return ReductionIterNotIndexOutputBuffer(GetRef(block)) ? 0 : 5; + return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)) ? 0 : 5; } bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, @@ -335,13 +335,13 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.schedule.IsReductionBlock", [](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv) { return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); }); -}); +} void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { @@ -349,15 +349,15 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, public: explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} - String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } - String DetailRenderTemplate() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_; os << ".\n" << kReductionBlockDefinition; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; int violated_cond_; @@ -366,7 +366,7 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotReductionBlockError(self->mod, GetRef(block), error_code); + throw NotReductionBlockError(self->mod, ffi::GetRef(block), error_code); } } @@ -382,10 +382,10 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl complete_block_error_code_(complete_block_error_code), reduction_block_error_code_(reduction_block_error_code) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Not a complete or reduction block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block {0} is not a complete block - it violates condition #" << complete_block_error_code_; @@ -396,7 +396,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -413,8 +413,8 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl return; } const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotCompleteOrReductionBlockError(self->mod, GetRef(block), complete_block_error_code, - reduction_block_error_code); + throw NotCompleteOrReductionBlockError(self->mod, ffi::GetRef(block), + complete_block_error_code, reduction_block_error_code); } void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) { @@ -429,12 +429,12 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt local_reduction_block_code_(local_reduction_block_code) { ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " "because some of its child block on SRef tree is neither a local complete block nor a " "local reduction block."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because " "its child block {1} on SRef tree is neither a local complete block nor a local " @@ -448,7 +448,9 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + ffi::Array LocationsOfInterest() const final { + return {subtree_root_, violate_block_}; + } IRModule mod_; Stmt subtree_root_; @@ -457,14 +459,14 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt int local_reduction_block_code_; }; - Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); + ffi::Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); for (const StmtSRef& block_sref : child_block_srefs) { int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); if (local_complete_block_code != 0 && local_reduction_block_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), - GetRef(block), local_complete_block_code, + throw NotCompactDataFlowError(self->mod, ffi::GetRef(subtree_root->stmt), + ffi::GetRef(block), local_complete_block_code, local_reduction_block_code); } } @@ -492,19 +494,19 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, class OutputBlockError : public ScheduleError { public: explicit OutputBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot operate on an output block"; } - String DetailRenderTemplate() const final { return "The block {0} is an output block"; } + ffi::String DetailRenderTemplate() const final { return "The block {0} is an output block"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; if (IsOutputBlock(self, block_sref, scope_root_sref)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw OutputBlockError(self->mod, GetRef(block)); + throw OutputBlockError(self->mod, ffi::GetRef(block)); } } @@ -545,7 +547,7 @@ bool IsWriteCache(const StmtSRef& block_sref) { /******** Binding ********/ -bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, +bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& loop_var_ranges, arith::Analyzer* analyzer) { if (loop_var_ranges.empty()) { return true; @@ -561,7 +563,7 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va return false; } for (const arith::IterSumExpr& sum_expr : res->indices) { - const Array& args = sum_expr->args; + const ffi::Array& args = sum_expr->args; if (!args.empty() && !is_one(args[0]->scale)) { return false; } @@ -570,16 +572,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va } void CheckPartialAffineBinding(const ScheduleState& self, Block block, - const Optional& high_exclusive) { + const ffi::Optional& high_exclusive) { class NotAffineBindingError : public ScheduleError { public: - explicit NotAffineBindingError(IRModule mod, Block block, Optional high_exclusive) + explicit NotAffineBindingError(IRModule mod, Block block, + ffi::Optional high_exclusive) : mod_(std::move(mod)), block_(std::move(block)) { if (high_exclusive.defined()) { high_exclusive_loop_ = high_exclusive.value()->StmtAs(); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; if (high_exclusive_loop_) { ss << "ScheduleError: The block is required to have an partial affine binding under " @@ -589,7 +592,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, } return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; if (high_exclusive_loop_) { ss << "The block {0} is required to have an partial affine binding under " @@ -600,7 +603,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, return ss.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; const ForNode* high_exclusive_loop_{nullptr}; @@ -614,8 +617,8 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, if (block_sref->parent && high_exclusive.defined()) { // if it is not of global affine binding, check affineness under high_exclusive, arith::Analyzer analyzer; - Map dom_map = - LoopDomainOfSRefTreePath(GetRef(block_sref->parent), high_exclusive); + ffi::Map dom_map = + LoopDomainOfSRefTreePath(ffi::GetRef(block_sref->parent), high_exclusive); if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) { return; } @@ -633,18 +636,18 @@ void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& bloc explicit NotTrivialBindingError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The binding values of the block are not variables of outer loops."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The binding values of the {0} are not variables of outer loops."; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -652,14 +655,14 @@ void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& bloc }; if (!IsTrivialBinding(self, block_sref)) { - throw NotTrivialBindingError(self->mod, GetRef(block_sref->StmtAs())); + throw NotTrivialBindingError(self->mod, ffi::GetRef(block_sref->StmtAs())); } } -Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, - const Optional& high_exclusive, - const runtime::StorageScope& extra_relax_scope) { - Map result; +ffi::Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, + const ffi::Optional& high_exclusive, + const runtime::StorageScope& extra_relax_scope) { + ffi::Map result; const StmtSRefNode* p = low_inclusive.get(); const StmtSRefNode* limit = static_cast(high_exclusive.get()); for (; p != limit; p = p->parent) { @@ -673,7 +676,7 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, for (; p; p = p->parent) { if (const ForNode* loop = p->StmtAs()) { if (loop->kind == ForKind::kThreadBinding) { - const String& thread_tag = loop->thread_binding.value()->thread_tag; + const ffi::String& thread_tag = loop->thread_binding.value()->thread_tag; if (CanRelaxStorageUnderThread(extra_relax_scope, runtime::ThreadScope::Create(thread_tag))) { result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -685,12 +688,12 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, return result; } -Map GetBindings(const BlockRealize& realize) { +ffi::Map GetBindings(const BlockRealize& realize) { const BlockNode* block = realize->block.get(); - const Array& all_lhs = block->iter_vars; - const Array& all_rhs = realize->iter_values; + const ffi::Array& all_lhs = block->iter_vars; + const ffi::Array& all_rhs = realize->iter_values; ICHECK_EQ(all_lhs.size(), all_rhs.size()); - Map result; + ffi::Map result; for (int i = 0, n = all_lhs.size(); i < n; ++i) { const IterVar& lhs = all_lhs[i]; const PrimExpr& rhs = all_rhs[i]; @@ -724,7 +727,7 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, if (set == nullptr) { continue; } - Array vars_in_binding = UndefinedVars(iter_value); + ffi::Array vars_in_binding = UndefinedVars(iter_value); for (const Var& var : vars_in_binding) { set->insert(var.get()); } @@ -742,32 +745,32 @@ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sre explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The primitive only supports loop starting with 0"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The loop {0} does not start with 0, which is not supported"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; }; const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!analyzer->CanProve(loop->min == 0)) { - throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + throw LoopNotStartWithZeroError(self->mod, ffi::GetRef(loop)); } } /******** Block-loop relation ********/ -Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, - const StmtSRef& parent_sref) { - Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); - Array child_block_srefs; +ffi::Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { + ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + ffi::Array child_block_srefs; child_block_srefs.reserve(child_block_realize.size()); for (BlockRealize realize : child_block_realize) { @@ -776,19 +779,19 @@ Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, return child_block_srefs; } -Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { +ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { - static Array Collect(const Stmt& stmt) { + static ffi::Array Collect(const Stmt& stmt) { Collector collector; collector(stmt); return std::move(collector.result_); } void VisitStmt_(const BlockRealizeNode* block_realize) final { - result_.push_back(GetRef(block_realize)); + result_.push_back(ffi::GetRef(block_realize)); } - Array result_; + ffi::Array result_; }; if (parent_sref->stmt->IsInstance()) { @@ -807,31 +810,31 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self class NonSingleChildBlockError : public ScheduleError { public: explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref) - : mod_(std::move(mod)), stmt_(GetRef(sref->stmt)) { + : mod_(std::move(mod)), stmt_(ffi::GetRef(sref->stmt)) { sref_type_ = stmt_.as() != nullptr ? "block" : "loop"; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream os; os << "ScheduleError: The " << sref_type_ << " is required to have only one child block"; return os.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The " << sref_type_ << " {0} is required to have only one child block"; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {stmt_}; } + ffi::Array LocationsOfInterest() const final { return {stmt_}; } IRModule mod_; Stmt stmt_; - String sref_type_; + ffi::String sref_type_; }; - Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); if (child_block_realize.size() != 1) { throw NonSingleChildBlockError(self->mod, parent_sref); } @@ -867,19 +870,19 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr return Downcast(func->body); } else { BlockRealizeFinder finder(block); - finder(GetRef(block_sref->parent->stmt)); + finder(ffi::GetRef(block_sref->parent->stmt)); ICHECK(finder.result != nullptr) - << "InternalError: Cannot find the BlockRealize of block " << GetRef(block); - return GetRef(finder.result); + << "InternalError: Cannot find the BlockRealize of block " << ffi::GetRef(block); + return ffi::GetRef(finder.result); } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.schedule.GetBlockRealize", [](Schedule sch, BlockRV block_rv) { return GetBlockRealize(sch->state(), sch->GetSRef(block_rv)); }); -}); +} IterVarType GetLoopIterType(const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); @@ -928,7 +931,7 @@ IterVarType GetLoopIterType(const StmtSRef& loop_sref) { } } -StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { +StmtSRef GetSRefLowestCommonAncestor(const ffi::Array& srefs) { CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref"; std::unordered_map sref_visited_cnt; @@ -945,16 +948,17 @@ StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { p = p->parent; } ICHECK(p != nullptr); - return GetRef(p); + return ffi::GetRef(p); } bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { - return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).has_value(); + return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure) + .has_value(); } -std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, - const StmtSRef& block_sref) { - Array location_srefs; +std::pair, std::vector> CollectComputeLocation( + const ScheduleState& self, const StmtSRef& block_sref) { + ffi::Array location_srefs; std::vector location_indices; // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can @@ -967,7 +971,7 @@ std::pair, std::vector> CollectComputeLocation(const Schedu location_indices.push_back(-1); // Step 2. If the block has no consumer, there is no more candidate. - Array consumers = GetConsumers(self, block_sref); + ffi::Array consumers = GetConsumers(self, block_sref); if (consumers.empty()) { return std::make_pair(location_srefs, location_indices); } @@ -975,14 +979,14 @@ std::pair, std::vector> CollectComputeLocation(const Schedu // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If // such a loop cannot be found, there is no more candidate and we just return. StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers) - : GetRef(consumers[0]->parent); + : ffi::GetRef(consumers[0]->parent); if (loop_boundary->StmtAs() == nullptr) { return std::make_pair(location_srefs, location_indices); } // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position // of the boundary loop reveals the number of possible additional candidates. - Array loop_srefs = GetLoops(consumers[0]); + ffi::Array loop_srefs = GetLoops(consumers[0]); size_t lca_pos = std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin(); ICHECK_LT(lca_pos, loop_srefs.size()); @@ -1035,9 +1039,9 @@ std::pair, std::vector> CollectComputeLocation(const Schedu /******** Producer-consumer relation ********/ -Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { - Array edges = scope->GetDepsByDst(block_sref); - Array results; +ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { + ffi::Array edges = scope->GetDepsByDst(block_sref); + ffi::Array results; std::unordered_set result_set; results.reserve(edges.size()); for (const Dependency& edge : edges) { @@ -1050,9 +1054,9 @@ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope return results; } -Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) { - Array edges = scope->GetDepsBySrc(block_sref); - Array results; +ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) { + ffi::Array edges = scope->GetDepsBySrc(block_sref); + ffi::Array results; std::unordered_set result_set; results.reserve(edges.size()); for (const Dependency& edge : edges) { @@ -1065,7 +1069,7 @@ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope return results; } -Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) { +ffi::Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) { struct OutputBlockCollector : public StmtVisitor { explicit OutputBlockCollector(const ScheduleState& self) : self_(self) {} @@ -1084,7 +1088,7 @@ Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scop } const ScheduleState& self_; - Array results_; + ffi::Array results_; }; OutputBlockCollector collector(self); collector(scope_block->body); @@ -1093,8 +1097,9 @@ Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scop } ProducerConsumerSplit ProducerConsumerSplit::Find( - const ScheduleState& self, const Array& subtrees, - const Array& producer_block_srefs, const Array& consumer_block_srefs, + const ScheduleState& self, const ffi::Array& subtrees, + const ffi::Array& producer_block_srefs, + const ffi::Array& consumer_block_srefs, std::unordered_map* block2realize) { class InsertionPointNotFoundError : public ScheduleError { public: @@ -1104,12 +1109,12 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( last_producer_position_(last_producer_position), first_consumer_position_(first_consumer_position) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot find the insertion point that satisfies the producer-consumer " "constraint"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Cannot find the insertion point that satisfies the producer-consumer constraint. In " "0-based indexing, the last producer appears in subtree " + std::to_string(last_producer_position_) + @@ -1119,7 +1124,7 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1202,7 +1207,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl buffer_index_(buffer_index), index_type_(index_type) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { if (index_type_ == BufferIndexType::kWrite) { return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " "range " @@ -1216,7 +1221,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl } } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; size_t num = index_type_ == BufferIndexType::kWrite ? block_->writes.size() : block_->reads.size(); @@ -1228,7 +1233,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1237,7 +1242,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl BufferIndexType index_type_; }; - const Array& access_region = + const ffi::Array& access_region = index_type == BufferIndexType::kWrite ? block->writes : block->reads; if (n < 0 || static_cast(access_region.size()) <= n) { @@ -1251,8 +1256,8 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, return GetNthAccessBufferRegion(self, block, n, index_type)->buffer; } -std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, - const Buffer& buffer) { +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer) { // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or // match_buffers. const StmtSRefNode* defining_site_sref = block_sref.get(); @@ -1266,13 +1271,13 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ // Try to find the buffer in `allloc_buffers` for (const Buffer& alloc_buffer : block->alloc_buffers) { if (buffer.same_as(alloc_buffer)) { - return {GetRef(defining_site_sref), true}; + return {ffi::GetRef(defining_site_sref), true}; } } // We do not allow the buffer being defined in `match_buffer`. for (const MatchBufferRegion match_buffer : block->match_buffers) { if (buffer.same_as(match_buffer)) { - return {GetRef(defining_site_sref), false}; + return {ffi::GetRef(defining_site_sref), false}; } } defining_site_sref = defining_site_sref->parent; @@ -1288,7 +1293,7 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { const StmtSRefNode* p = sref.get(); for (; p->parent != nullptr; p = p->parent) { } - return GetRef(p); + return ffi::GetRef(p); } void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, @@ -1307,7 +1312,7 @@ void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, /******** Misc ********/ -bool HasOp(const Stmt& stmt, const Array& ops) { +bool HasOp(const Stmt& stmt, const ffi::Array& ops) { std::unordered_set op_set; op_set.reserve(ops.size()); for (const Op& op : ops) { @@ -1397,7 +1402,7 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri } // Case 2. Read index cannot be recognized as `var +/- const` // where `var` is a write index and `const` is an optional constant shift - Optional opt_const = std::nullopt; + ffi::Optional opt_const = std::nullopt; const VarNode* var = static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); if (var == nullptr || !var2idx.count(var)) { @@ -1440,26 +1445,26 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri /******** Storage Scope ********/ -void CheckStorageScope(const ScheduleState& self, String storage_scope) { +void CheckStorageScope(const ScheduleState& self, ffi::String storage_scope) { class InvalidStorageScopeError : public ScheduleError { public: - explicit InvalidStorageScopeError(IRModule mod, String storage_scope) + explicit InvalidStorageScopeError(IRModule mod, ffi::String storage_scope) : mod_(std::move(mod)), storage_scope_(std::move(storage_scope)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input storage scope is invalid"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The input storage scope \"" + storage_scope_ + "\" is invalid."; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: IRModule mod_; - String storage_scope_; + ffi::String storage_scope_; }; try { @@ -1481,8 +1486,8 @@ bool IsSpatial(const StmtSRef& block_sref) { bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { TVM_SREF_TO_BLOCK(block_sref); - Array loops = GetLoops(block_sref); - Array binds = GetBlockRealize(self, block_sref)->iter_values; + ffi::Array loops = GetLoops(block_sref); + ffi::Array binds = GetBlockRealize(self, block_sref)->iter_values; if (loops.size() != binds.size()) { return false; } @@ -1495,12 +1500,12 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { return true; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.schedule.IsTrivialBinding", [](Schedule sch, BlockRV block_rv) { return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv)); }); -}); +} bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { if (HasBeenMultiLevelTiled(block_sref)) { @@ -1532,7 +1537,7 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref read_buffers.reserve(block->reads.size()); for (const BufferRegion& buffer_region : block->reads) { const BufferNode* buffer = buffer_region->buffer.get(); - const Array& regions = buffer_region->region; + const ffi::Array& regions = buffer_region->region; // Step 2.1. Duplication of read buffers are not allowed if (read_buffers.insert(buffer).second == false) { return false; @@ -1584,7 +1589,7 @@ bool IsSpatialPrimFunc(const PrimFunc& func) { std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - Array loops = tir::GetLoops(block_sref); + ffi::Array loops = tir::GetLoops(block_sref); int64_t cum_space_len = 1, cum_reduce_len = 1; /* * Return (-1, -1) if @@ -1619,7 +1624,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // int64_t max_parallel_extent, // int64_t max_parallel_basic) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Array loops = tir::GetLoops(block_sref); + ffi::Array loops = tir::GetLoops(block_sref); // Cond 1. The block must have at lease one write buffer if (block->writes.size() == 0) { @@ -1742,10 +1747,10 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, return info; } -Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func, - bool allow_padding) { +ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func, + bool allow_padding) { arith::Analyzer analyzer; const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars @@ -1773,7 +1778,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const std::vector& desc_loops = desc_info.desc_loops; const std::unordered_set& desc_loop_vars = desc_info.desc_loop_vars; const BlockRealizeNode* desc_block = desc_info.desc_block; - ObjectPtr ret = make_object(); + ObjectPtr ret = ffi::make_object(); const int n_block_vars = block->iter_values.size(); const int n_desc_vars = desc_block->iter_values.size(); const int offset = n_block_vars - n_desc_vars; @@ -1876,19 +1881,19 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, } } - ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + ret->loop_map.Set(block_loop_sref, ffi::GetRef(desc_loop)); break; } } for (int i = 0, n = desc_loops.size(); i < n; ++i) { - ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + ret->desc_loop_indexer.Set(ffi::GetRef(desc_loops[i]), Integer(i)); } if (!block_index_to_padding.empty()) { if (!allow_padding) { return std::nullopt; } - Array paddings; + ffi::Array paddings; for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) { const IterVar& iter_var = block->block->iter_vars[i]; if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) { @@ -1903,7 +1908,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, return TensorizeInfo(ret); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.IsSpatialPrimFunc", IsSpatialPrimFunc) @@ -1911,15 +1916,15 @@ TVM_FFI_STATIC_INIT_BLOCK({ PrimFunc desc_func, bool allow_padding) { return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); }); -}); +} /******** Auto Tensorization ********/ /*! \brief IndexMap proposer for layout transformation in auto tensorization. */ class AutoTensorizeMappingProposer { public: - static Array ProposeMappings(const AutoTensorizeComparator* extractor, - arith::Analyzer* analyzer) { + static ffi::Array ProposeMappings(const AutoTensorizeComparator* extractor, + arith::Analyzer* analyzer) { AutoTensorizeMappingProposer proposer(extractor, analyzer); proposer.CollectFeasibleSet(); return proposer.ProposeAllFuseMapping(); @@ -2013,7 +2018,7 @@ class AutoTensorizeMappingProposer { for (const auto& kv : rhs_buffer_masks) { const VarNode* rhs_var = kv.first; const BufferMask& mask = kv.second; - mask_to_rhs_vars[mask].insert(GetRef(rhs_var)); + mask_to_rhs_vars[mask].insert(ffi::GetRef(rhs_var)); } std::unordered_map rhs_var_iter_type; for (const auto& iter : extractor_->rhs_iters_) { @@ -2029,7 +2034,7 @@ class AutoTensorizeMappingProposer { } } - Array ProposeAllFuseMapping() { + ffi::Array ProposeAllFuseMapping() { // Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to // the same iter on RHS, they will be fused in the original order in LHS block iters. We will // generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped @@ -2037,12 +2042,12 @@ class AutoTensorizeMappingProposer { // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn // the parameters of the result index map, each parameter corresponds to a LHS iter - Array index_map_src; + ffi::Array index_map_src; // the outputs of the result index map - Array index_map_tgt; + ffi::Array index_map_tgt; // Step 1: Collect extents of LHS iters and prepare the initial indices of the IndexMap - Map lhs_iter_extents; + ffi::Map lhs_iter_extents; for (const auto& iter : extractor_->lhs_iters_) { lhs_iter_extents.Set(iter->var, iter->dom->extent); index_map_src.push_back(iter->var.copy_with_suffix("")); @@ -2050,7 +2055,7 @@ class AutoTensorizeMappingProposer { // Step 2: Each iter on RHS has a group of corresponding iters on LHS. Initialize the fusion // result for each group of iters on LHS. - Map fused_lhs_iters; + ffi::Map fused_lhs_iters; for (const auto& iter : extractor_->rhs_iters_) { fused_lhs_iters.Set(iter->var, 0); } @@ -2114,19 +2119,20 @@ bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& return CheckAutoTensorizeApplicable(sch->state(), sch->GetSRef(block_rv), desc_func, &extractor); } -Optional GetAutoTensorizeMappingInfo(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { +ffi::Optional GetAutoTensorizeMappingInfo( + const tir::ScheduleState& self, const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { AutoTensorizeComparator extractor(self->mod); if (!CheckAutoTensorizeApplicable(self, block_sref, desc_func, &extractor)) { return std::nullopt; } arith::Analyzer analyzer; - Array mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); + ffi::Array mappings = + AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); if (mappings.empty()) { return std::nullopt; } - ObjectPtr ret = make_object(); + ObjectPtr ret = ffi::make_object(); ret->mappings = std::move(mappings); ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_); ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_); @@ -2135,7 +2141,7 @@ Optional GetAutoTensorizeMappingInfo(const tir::Schedu return AutoTensorizeMappingInfo(ret); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.GetAutoTensorizeMappingInfo", @@ -2149,17 +2155,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto block_sref = sch->GetSRef(block); return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); }) - .def("tir.schedule.GetLoopIterType", [](Schedule sch, LoopRV loop) -> String { - IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); - if (kind == kDataPar) { - return "S"; - } else if (kind == kCommReduce) { - return "R"; - } else { - return "O"; - } - }); -}); + .def("tir.schedule.GetLoopIterType", + [](Schedule sch, LoopRV loop) -> ffi::String { + IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); + if (kind == kDataPar) { + return "S"; + } else if (kind == kCommReduce) { + return "R"; + } else { + return "O"; + } + }) + .def("tir.schedule.HasIfThenElse", + [](const Stmt& stmt) -> bool { return HasIfThenElse(stmt); }); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index f6dc0067a800..ddc15ab5e592 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -28,7 +28,7 @@ namespace tir { * \param buffer The buffer * \return The strides */ -Array GetStrides(const Buffer& buffer) { +ffi::Array GetStrides(const Buffer& buffer) { if (!buffer->strides.empty()) { ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); return buffer->strides; @@ -37,7 +37,7 @@ Array GetStrides(const Buffer& buffer) { if (ndim == 0) { return {}; } - Array strides(ndim, PrimExpr{nullptr}); + ffi::Array strides(ndim, PrimExpr{nullptr}); PrimExpr stride = make_const(buffer->DefaultIndexType(), 1); for (int i = ndim - 1; i >= 0; --i) { strides.Set(i, stride); @@ -75,9 +75,9 @@ class SplitExprCollector { * \return The collected split expressions */ static std::vector Collect(const PrimExpr& index, - const Map& input_iters, // - const PrimExpr& predicate, // - arith::IterMapLevel check_level, // + const ffi::Map& input_iters, // + const PrimExpr& predicate, // + arith::IterMapLevel check_level, // arith::Analyzer* analyzer) { arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, predicate, check_level, analyzer); @@ -106,7 +106,7 @@ class SplitExprCollector { failed_ = true; return; } - exprs_.push_back(SplitExpr{GetRef(var), *lower_factor, *extent}); + exprs_.push_back(SplitExpr{ffi::GetRef(var), *lower_factor, *extent}); } else if (auto iter_sum_expr = expr->source->source.as()) { Visit(iter_sum_expr.value()); } else { @@ -126,13 +126,13 @@ class SplitExprCollector { std::vector exprs_; }; -Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, - const Array& loops, const PrimExpr& predicate, - arith::Analyzer* analyzer) { +ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array& indices, + const ffi::Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer) { int ndim = buffer->shape.size(); int n_loops = loops.size(); // Step 1. Collect the domains and indices of loop variables - Map input_iters; + ffi::Map input_iters; std::unordered_map var2id; var2id.reserve(n_loops); for (int i = 0; i < n_loops; ++i) { @@ -142,7 +142,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& } // Step 2. Calculate a functor that flattens a multi-dimensional index auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()]( - const Array& indices) -> PrimExpr { + const ffi::Array& indices) -> PrimExpr { PrimExpr flatten_index = make_const(dtype, 0); for (int i = 0; i < ndim; ++i) { flatten_index = flatten_index + strides[i] * indices[i]; @@ -179,7 +179,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& &order, // & shape = buffer->shape, // analyzer // - ](Array indices) -> Array { + ](ffi::Array indices) -> ffi::Array { ICHECK_EQ(indices.size(), shape.size()); for (int i = 0, n = indices.size(); i < n; ++i) { analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); @@ -198,7 +198,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& } std::reverse(split.begin(), split.end()); // Step 5.3. Reorder the indexing pattern according to `order` - Array results; + ffi::Array results; results.reserve(ndim); for (int i = 0; i < ndim; ++i) { results.push_back(split[order[i]]); @@ -207,11 +207,11 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& }; // Step 6: Create the inverse index mapping. auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape, - analyzer](Array indices) -> Array { + analyzer](ffi::Array indices) -> ffi::Array { ICHECK_EQ(indices.size(), split_exprs.size()); // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3. // After the inverse permutation, indices[i] corresponds to split_exprs[i] - Array inv_permuted_indices; + ffi::Array inv_permuted_indices; inv_permuted_indices.reserve(indices.size()); for (int i = 0, n = indices.size(); i < n; ++i) { const Var& index = indices[inverse_order[i]]; @@ -227,27 +227,28 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& stride *= split_exprs[i].extent; } // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1. - Array result; + ffi::Array result; result.reserve(shape.size()); for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i])); flattened_index = floordiv(flattened_index, shape[i]); result.push_back(index); } - return Array(result.rbegin(), result.rend()); + return ffi::Array(result.rbegin(), result.rend()); }; IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse); return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.SuggestIndexMap", [](Buffer buffer, Array indices, - Array loops, PrimExpr predicate) { - arith::Analyzer analyzer; - return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); - }); -}); + refl::GlobalDef().def( + "tir.schedule.SuggestIndexMap", + [](Buffer buffer, ffi::Array indices, ffi::Array loops, PrimExpr predicate) { + arith::Analyzer analyzer; + return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + }); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index d85be933820c..085a4a33de87 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -49,7 +49,7 @@ namespace tir { */ class PatternMatcher : public ExprVisitor { public: - explicit PatternMatcher(Array pattern) : pattern_(std::move(pattern)) {} + explicit PatternMatcher(ffi::Array pattern) : pattern_(std::move(pattern)) {} void VisitExpr_(const VarNode* op) final { auto it = filled_map_.find(op); @@ -258,7 +258,7 @@ class PatternMatcher : public ExprVisitor { } } - void Match(const Array& exprs_to_match) { + void Match(const ffi::Array& exprs_to_match) { this->match_success_ = true; this->filled_map_.clear(); @@ -281,7 +281,7 @@ class PatternMatcher : public ExprVisitor { private: bool match_success_{true}; - Array pattern_; + ffi::Array pattern_; PrimExpr expr_to_match_; std::unordered_map filled_map_; }; @@ -303,19 +303,19 @@ static const char* kRFactorCrossThreadReductionApplicableBlockDef = 11) The buffers written by the block should have same shape 12) The indices of all BufferStores in the reduction block should be the same)"; -void ErrorRFactorCrossThreadReductionNotApplicable(const Optional& self, Block block, - int violated_cond) { +void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional& self, + Block block, int violated_cond) { class RFactorNotApplicableError : public ScheduleError { public: explicit RFactorNotApplicableError(IRModule mod, Block block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: RFactor cannot be applied to the block since the block does not meet " "the requirements"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "RFactor cannot be applied to block {0}, because the block violates condition #" << violated_cond_ << ".\n" @@ -324,7 +324,7 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const Optional } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -352,11 +352,12 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const Optional * \param buf2index A mapping from reduction buffers to their indices of the reduction order * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block */ -void ExtractReductionUpdates(const Optional& self, Block block, - const LetStmtNode* let, int n_buffers, Array* updates, +void ExtractReductionUpdates(const ffi::Optional& self, Block block, + const LetStmtNode* let, int n_buffers, + ffi::Array* updates, std::unordered_map* buf2index) { std::unordered_map var2index; - Array let_values; + ffi::Array let_values; let_values.reserve(n_buffers); updates->resize(n_buffers); @@ -390,7 +391,8 @@ void ExtractReductionUpdates(const Optional& self, Block block, if (p_seq == nullptr && p_buf_store == nullptr) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); } - Array seq = p_seq != nullptr ? p_seq->seq : Array{GetRef(p_buf_store)}; + ffi::Array seq = + p_seq != nullptr ? p_seq->seq : ffi::Array{ffi::GetRef(p_buf_store)}; if (static_cast(seq.size()) != n_buffers) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6); } @@ -426,10 +428,10 @@ void ExtractReductionUpdates(const Optional& self, Block block, } } -std::pair, Array> GetInitValuesAndUpdatesFromReductionBlock( - const Optional& self, Block block) { - Array inits; - Array updates; +std::pair, ffi::Array> GetInitValuesAndUpdatesFromReductionBlock( + const ffi::Optional& self, Block block) { + ffi::Array inits; + ffi::Array updates; // Step 1. Extract the BufferStores serving as block inits. if (auto init = block->init.as()) { @@ -455,7 +457,7 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct int n_buffers = inits.size(); std::unordered_map buf2index; if (const auto* update = block->body.as()) { - updates.push_back(GetRef(update)); + updates.push_back(ffi::GetRef(update)); buf2index[update->buffer.get()] = 0; } else { const auto* let = block->body.as(); @@ -465,15 +467,15 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct // Step 3. Set the init values according to the buffer order in `updates`, with the help of the // mapping `buf2index`. - Array init_values; + ffi::Array init_values; init_values.resize(n_buffers); // - Check all buffers have the same shape // - Check all indices of the BufferStores are the same // - Check buffers written in the block init and the block body can match // - Check buffers do not duplicate - const Array& expected_shape = updates[0]->buffer->shape; - const Array& expected_indices = updates[0]->indices; + const ffi::Array& expected_shape = updates[0]->buffer->shape; + const ffi::Array& expected_indices = updates[0]->indices; ICHECK_EQ(expected_shape.size(), expected_indices.size()); int n_dim = expected_indices.size(); arith::Analyzer ana; @@ -511,7 +513,7 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct return std::make_pair(init_values, updates); } -bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters) { +bool ContainsOnlyDataParAndReductionBlockIter(const ffi::Array& iters) { for (const IterVar& iter_var : iters) { if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { return false; @@ -589,18 +591,18 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) { class NoMatchedReducerError : public ScheduleError { public: - explicit NoMatchedReducerError(IRModule mod, Array identities, - Array combiners) + explicit NoMatchedReducerError(IRModule mod, ffi::Array identities, + ffi::Array combiners) : mod_(std::move(mod)), identities_(std::move(identities)), combiners_(std::move(combiners)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " "block. So rfactor and cross-thread reduction cannot be applied."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "No matched reducer for identity " << identities_ << " and combiner " << combiners_ << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " @@ -609,18 +611,18 @@ class NoMatchedReducerError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; - Array identities_; - Array combiners_; + ffi::Array identities_; + ffi::Array combiners_; }; -std::tuple, Array> GetReducerAndCombinerLhsRhs( - const Optional& self, const Array& identities, - const Array& combiners) { +std::tuple, ffi::Array> GetReducerAndCombinerLhsRhs( + const ffi::Optional& self, const ffi::Array& identities, + const ffi::Array& combiners) { CommReducer reducer{nullptr}; - Array combiner_lhs, combiner_rhs; + ffi::Array combiner_lhs, combiner_rhs; bool matched = FromIdentityCombiner(identities, combiners, &reducer, &combiner_lhs, &combiner_rhs); if (!matched) { @@ -636,9 +638,10 @@ std::tuple, Array> GetReducerAndCombinerL /******** Commutative Reducer ********/ -bool MatchReducer(const CommReducer& reducer, const Array& identities, - const Array& combined_values, const Array& buf_loads, - Array* lhs, Array* rhs) { +bool MatchReducer(const CommReducer& reducer, const ffi::Array& identities, + const ffi::Array& combined_values, + const ffi::Array& buf_loads, ffi::Array* lhs, + ffi::Array* rhs) { ExprDeepEqual equal; ICHECK_EQ(identities.size(), combined_values.size()); int n_buffers = identities.size(); @@ -650,7 +653,7 @@ bool MatchReducer(const CommReducer& reducer, const Array& identities, PatternMatcher pattern_matcher(reducer->result); pattern_matcher.Match(combined_values); - Array lhs_tmp, rhs_tmp; + ffi::Array lhs_tmp, rhs_tmp; lhs_tmp.reserve(n_buffers); rhs_tmp.reserve(n_buffers); if (!pattern_matcher.Success()) { @@ -671,11 +674,12 @@ bool MatchReducer(const CommReducer& reducer, const Array& identities, return true; } -bool FromIdentityCombiner(const Array& identities, const Array& combiners, - CommReducer* result_reducer, Array* lhs, Array* rhs) { +bool FromIdentityCombiner(const ffi::Array& identities, + const ffi::Array& combiners, CommReducer* result_reducer, + ffi::Array* lhs, ffi::Array* rhs) { int n = identities.size(); - Array buf_loads; - Array stored_values; + ffi::Array buf_loads; + ffi::Array stored_values; buf_loads.reserve(n); stored_values.reserve(n); @@ -685,9 +689,9 @@ bool FromIdentityCombiner(const Array& identities, const Array(Array)>& reducer_getter : + for (const ffi::TypedFunction(ffi::Array)>& reducer_getter : GetReducerGetters()) { - Optional reducer = reducer_getter(identities); + ffi::Optional reducer = reducer_getter(identities); if (!reducer.defined()) { continue; } diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index 4e3f04e0f389..f9a09552c21c 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -56,19 +56,20 @@ class SRefTreeVerifier : public StmtVisitor { } ICHECK(self_->stmt2ref.count(block)) << "InternalError: A BlockNode should appear in sref map, but it didn't\n" - << GetRef(block); + << ffi::GetRef(block); ++n_sref_visited_; ++n_block_sref_visited_; const StmtSRef& sref = self_->stmt2ref.at(block); ICHECK(self_->block_info.count(sref)) << "InternalError: Cannot find scope information of the BlockNode:\n" - << GetRef(block); + << ffi::GetRef(block); ICHECK(sref->parent == ancestors_.back()) << "InternalError: Parent information mismatch for BlockNode:\n" - << GetRef(block) << "\nIts parent is supposed to be:\n" - << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" - << (sref->parent ? Optional(GetRef(sref->parent->stmt)) - : Optional(std::nullopt)); + << ffi::GetRef(block) << "\nIts parent is supposed to be:\n" + << ffi::GetRef(ancestors_.back()->stmt) + << "\nHowever, its parent is incorrect and is:\n" + << (sref->parent ? ffi::Optional(ffi::GetRef(sref->parent->stmt)) + : ffi::Optional(std::nullopt)); ancestors_.push_back(sref.operator->()); if (block->init.defined()) { ++init_block_depth_; @@ -88,16 +89,17 @@ class SRefTreeVerifier : public StmtVisitor { } ICHECK(self_->stmt2ref.count(loop)) << "InternalError: A ForNode should appear in sref map, but it didn't\n" - << GetRef(loop); + << ffi::GetRef(loop); ++n_sref_visited_; const StmtSRef& sref = self_->stmt2ref.at(loop); - Optional stmt = std::nullopt; + ffi::Optional stmt = std::nullopt; ICHECK(sref->parent == ancestors_.back()) << "InternalError: Parent information mismatch for ForNode:\n" - << GetRef(loop) << "\nIts parent is supposed to be:\n" - << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" - << (sref->parent ? Optional(GetRef(sref->parent->stmt)) - : Optional(std::nullopt)); + << ffi::GetRef(loop) << "\nIts parent is supposed to be:\n" + << ffi::GetRef(ancestors_.back()->stmt) + << "\nHowever, its parent is incorrect and is:\n" + << (sref->parent ? ffi::Optional(ffi::GetRef(sref->parent->stmt)) + : ffi::Optional(std::nullopt)); ancestors_.push_back(sref.operator->()); StmtVisitor::VisitStmt_(loop); ancestors_.pop_back(); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 6f7e682d6c7a..eae4c64a15a7 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -26,7 +26,7 @@ namespace tir { Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; @@ -56,7 +56,7 @@ class ScheduleCopier { TSymbolTable* new_symbol_table) { const ScheduleState& src_state = self->state_; ScheduleCopier copier(src_state); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->mod = src_state->mod; n->block_info = copier.Copy(src_state->block_info); n->stmt2ref = copier.Copy(src_state->stmt2ref); @@ -98,9 +98,9 @@ class ScheduleCopier { return old2new_[sref] = StmtSRef(nullptr, nullptr, -1); } - /*! \brief Copy Array */ - Array Copy(const Array& list) { - Array result; + /*! \brief Copy ffi::Array */ + ffi::Array Copy(const ffi::Array& list) { + ffi::Array result; result.reserve(list.size()); for (const StmtSRef& elem : list) { result.push_back(Copy(elem)); @@ -108,9 +108,9 @@ class ScheduleCopier { return result; } - /*! \brief Copy Array */ - Array Copy(const Array& list) { - Array result; + /*! \brief Copy ffi::Array */ + ffi::Array Copy(const ffi::Array& list) { + ffi::Array result; result.reserve(list.size()); for (const Dependency& elem : list) { result.push_back(Dependency(Copy(elem->src), Copy(elem->dst), elem->kind)); @@ -118,9 +118,9 @@ class ScheduleCopier { return result; } - /*! \brief Copy SMap> */ - SMap> Copy(const SMap>& map) { - SMap> result; + /*! \brief Copy SMap> */ + SMap> Copy(const SMap>& map) { + SMap> result; result.reserve(map.size()); for (const auto& kv : map) { result[Copy(kv.first)] = Copy(kv.second); @@ -128,9 +128,9 @@ class ScheduleCopier { return result; } - /*! \brief Copy SMap> */ - SMap> Copy(const SMap>& map) { - SMap> result; + /*! \brief Copy SMap> */ + SMap> Copy(const SMap>& map) { + SMap> result; result.reserve(map.size()); for (const auto& kv : map) { result[kv.first] = Copy(kv.second); @@ -145,7 +145,7 @@ class ScheduleCopier { const StmtSRef& old_sref = kv.first; const BlockInfo& old_info = kv.second; BlockInfo new_info = old_info; - ObjectPtr scope = make_object(); + ObjectPtr scope = ffi::make_object(); scope->src2deps = Copy(old_info.scope->src2deps); scope->dst2deps = Copy(old_info.scope->dst2deps); scope->buffer_writers = Copy(old_info.scope->buffer_writers); @@ -184,7 +184,7 @@ class ScheduleCopier { std::unordered_map old2new_; }; -void ConcreteScheduleNode::WorkOn(const String& func_name) { +void ConcreteScheduleNode::WorkOn(const ffi::String& func_name) { this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name); } @@ -194,7 +194,7 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb } Schedule ConcreteScheduleNode::Copy() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->func_working_on_ = this->func_working_on_; n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); @@ -233,18 +233,18 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } -ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV ConcreteScheduleNode::SampleCategorical(const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); throw; } -Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, - int max_innermost_factor, - Optional> decision) { +ffi::Array ConcreteScheduleNode::SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); // use None RV object to denotes auto-infer tile factors. return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, @@ -254,9 +254,9 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int throw; } -Array ConcreteScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, int n, - int partition_pos, int innerpart_factor, - Optional> decision) { +ffi::Array ConcreteScheduleNode::SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SamplePartitionedTile(&this->rand_state_, this->GetSRef(loop_rv), n, partition_pos, innerpart_factor, &decision)); @@ -265,7 +265,7 @@ Array ConcreteScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, } LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, - Optional decision) { + ffi::Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV( tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); @@ -275,22 +275,25 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional& func_name) { +BlockRV ConcreteScheduleNode::GetBlock(const ffi::String& name, + const ffi::Optional& func_name) { class NotSingleResult : public ScheduleError { public: - explicit NotSingleResult(String name, IRModule mod, const Array& blocks) + explicit NotSingleResult(ffi::String name, IRModule mod, const ffi::Array& blocks) : name_(name), mod_(mod), blocks_{} { blocks_.reserve(blocks.size()); for (const StmtSRef& block_sref : blocks) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - blocks_.push_back(GetRef(block)); + blocks_.push_back(ffi::GetRef(block)); } } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } + ffi::Array LocationsOfInterest() const final { + return {blocks_.begin(), blocks_.end()}; + } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (blocks_.empty()) { return "Cannot find a block with the name: " + name_; } else { @@ -298,7 +301,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional blocks_; + ffi::Array blocks_; }; GlobalVar gv = NullValue(); if (func_name.has_value()) { @@ -320,7 +323,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional blocks = tir::GetBlocks(this->state_, name, gv); + ffi::Array blocks = tir::GetBlocks(this->state_, name, gv); if (blocks.size() != 1) { TVM_TIR_SCHEDULE_BEGIN(); throw NotSingleResult(name, this->state_->mod, blocks); @@ -329,12 +332,12 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional(blocks[0]); } -Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } -Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { - Array result; +ffi::Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); @@ -342,8 +345,8 @@ Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { return result; } -Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { - Array result; +ffi::Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); @@ -351,21 +354,21 @@ Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { return result; } -Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::GetProducers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_); throw; } -Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_); throw; } -Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { +ffi::Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); @@ -374,9 +377,9 @@ Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_ /******** Schedule: Transform loops ********/ -LoopRV ConcreteScheduleNode::Merge(const Array& loop_rvs) { +LoopRV ConcreteScheduleNode::Merge(const ffi::Array& loop_rvs) { CHECK(loop_rvs.size() > 1) << "ValueError: 'merge' requires at least 2 loop(s)"; - Array loop_srefs = this->GetSRefs(loop_rvs); + ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Merge(state_, loop_srefs); @@ -385,9 +388,9 @@ LoopRV ConcreteScheduleNode::Merge(const Array& loop_rvs) { return CreateRV(result); } -LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs, bool preserve_unit_iters) { +LoopRV ConcreteScheduleNode::Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) { CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; - Array loop_srefs = this->GetSRefs(loop_rvs); + ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Fuse(state_, loop_srefs, preserve_unit_iters); @@ -400,16 +403,16 @@ class NotSingleInferFactorError : public ScheduleError { public: explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: only one factor can be specified as -1 or none"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Only one factor can be specified as -1 or none"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; }; @@ -419,7 +422,7 @@ class WrongFactorError : public ScheduleError { explicit WrongFactorError(IRModule mod, For loop, bool product) : mod_(mod), loop_(std::move(loop)), product_(product) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { if (product_) return "ScheduleError: The product of factors is not larger than or equal to the extent of " "loop"; @@ -427,7 +430,7 @@ class WrongFactorError : public ScheduleError { return "ScheduleError: The sum of factors is larger than or equal to the extent of loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (product_) return "The product of factors is not larger than or equal to the extent of loop {0}"; else @@ -435,7 +438,7 @@ class WrongFactorError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -447,18 +450,18 @@ class NonPositiveFactorError : public ScheduleError { explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx) : mod_(std::move(mod)), factor_(factor), idx_(idx) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: All the constant factors are required to be positive. However, some " "constant input factor is zero or negative."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "All the constant factors are required to be positive. However, the factor at position " << idx_ << " is " << factor_; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -466,17 +469,17 @@ class NonPositiveFactorError : public ScheduleError { size_t idx_; }; -Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters, bool disable_predication) { +ffi::Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters, bool disable_predication) { // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - Array factors; + ffi::Array factors; factors.reserve(factor_rvs.size()); int infer_index = -1; PrimExpr tot_length = 1; - Array results; + ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); // infer factor if needed and check validity of factors for (size_t i = 0; i < factor_rvs.size(); i++) { @@ -502,7 +505,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, factors.Set(infer_index, this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { - throw WrongFactorError(state_->mod, GetRef(loop), true); + throw WrongFactorError(state_->mod, ffi::GetRef(loop), true); } results = tir::Split(state_, loop_sref, factors, preserve_unit_iters, disable_predication); TVM_TIR_SCHEDULE_END("split", this->error_render_level_); @@ -510,24 +513,24 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, return CreateRV(results); } -Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters) { +ffi::Array ConcreteScheduleNode::LoopPartition( + const LoopRV& loop_rv, const ffi::Array>& factor_rvs, + bool preserve_unit_iters) { class SymbolicShapeError : public ScheduleError { public: explicit SymbolicShapeError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The min and extent values of the loop are required to be known at " "compile time. However, dynamic shape has been detected."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Detected dynamic shape in either min or extent of a loop {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -536,14 +539,14 @@ Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, // Prepare for the loop_partitioning StmtSRef loop_sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - Array factors; + ffi::Array factors; factors.reserve(factor_rvs.size()); int infer_index = -1; PrimExpr tot_length = 0; - Array results; + ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); if (!is_const_number(loop->min) || !is_const_number(loop->extent)) { - throw SymbolicShapeError(state_->mod, GetRef(loop)); + throw SymbolicShapeError(state_->mod, ffi::GetRef(loop)); } // infer factor if needed and check validity of factors for (size_t i = 0; i < factor_rvs.size(); i++) { @@ -566,7 +569,7 @@ Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, } } if (this->analyzer_->CanProve(tot_length >= loop->extent)) { - throw WrongFactorError(state_->mod, GetRef(loop), false); + throw WrongFactorError(state_->mod, ffi::GetRef(loop), false); } if (infer_index != -1) { // if there is a 'None' in the factor list, 'None' becomes the difference between the extent and @@ -585,7 +588,7 @@ Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, return CreateRV(results); } -void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { +void ConcreteScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { TVM_TIR_SCHEDULE_BEGIN(); tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); TVM_TIR_SCHEDULE_END("reorder", this->error_render_level_); @@ -593,7 +596,7 @@ void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { } void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, - const Array new_order) { + const ffi::Array new_order) { TVM_TIR_SCHEDULE_BEGIN(); tir::ReorderBlockIterVar(state_, GetSRef(block_rv), new_order); TVM_TIR_SCHEDULE_END("reorder_block_iter_var", this->error_render_level_); @@ -601,7 +604,7 @@ void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, } LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { - LoopRV result{nullptr}; + LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::AddUnitLoop(state_, GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); @@ -610,7 +613,7 @@ LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { } LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { - LoopRV result{nullptr}; + LoopRV result{ffi::UnsafeInit()}; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::AddUnitLoop(state_, GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); @@ -634,7 +637,7 @@ void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_); } -void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { +void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) { if (thread_axis == "vthread") { LOG(WARNING) << "`vthread` is legacy behavior and is going to be deprecated. Please use " "`vthread.x`, `vthread.y` and `vthread.z` instead"; @@ -655,11 +658,11 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { /******** Schedule: Insert cache stages ********/ BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. - Array consumer_block_refs = {}; + ffi::Array consumer_block_refs = {}; for (BlockRV block : consumer_blocks) { consumer_block_refs.push_back(this->GetSRef(block)); } @@ -672,11 +675,11 @@ BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer } BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. - Array consumer_block_refs = {}; + ffi::Array consumer_block_refs = {}; for (BlockRV block : consumer_blocks) { consumer_block_refs.push_back(this->GetSRef(block)); } @@ -689,7 +692,7 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff } BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -701,7 +704,7 @@ BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read } BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -712,27 +715,29 @@ BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int wri return CreateRV(result); } -Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) { - Array results; +ffi::Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, + int write_buffer_index, + const ffi::String& storage_scope) { + ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_); this->state_->DebugVerify(); - Array return_blocks; + ffi::Array return_blocks; return_blocks.push_back(CreateRV(results[0])); return_blocks.push_back(CreateRV(results[1])); return return_blocks; } -Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, - const String& storage_scope, int cse_thresh) { - Array result; +ffi::Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, + const ffi::String& storage_scope, + int cse_thresh) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh); TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_); this->state_->DebugVerify(); - Array return_blocks; + ffi::Array return_blocks; for (const StmtSRef& blockrv : result) { return_blocks.push_back(CreateRV(blockrv)); } @@ -740,10 +745,10 @@ Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, } BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) { + BufferIndexType buffer_index_type, bool skip_simplify) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); + result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, skip_simplify); TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -752,7 +757,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, /******** Schedule: Data movement ********/ BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int read_buffer_index, const String& storage_scope) { + int read_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, @@ -763,7 +768,7 @@ BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block } BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int write_buffer_index, const String& storage_scope) { + int write_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, @@ -827,6 +832,15 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } +void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, + const BlockRV& epilogue_block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv), + this->GetSRef(epilogue_block_rv)); + TVM_TIR_SCHEDULE_END("fuse-reduction-epilogue", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Block Annotation ********/ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, @@ -838,7 +852,7 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde } void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, - const String& storage_scope) { + const ffi::String& storage_scope) { TVM_TIR_SCHEDULE_BEGIN(); tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("set-scope", this->error_render_level_); @@ -846,7 +860,7 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, } void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, - const String& dtype) { + const ffi::String& dtype) { TVM_TIR_SCHEDULE_BEGIN(); tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_); @@ -883,7 +897,8 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit return CreateRV(result); } -BlockRV ConcreteScheduleNode::Blockize(const Array& blocks, bool preserve_unit_iters) { +BlockRV ConcreteScheduleNode::Blockize(const ffi::Array& blocks, + bool preserve_unit_iters) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters); @@ -892,7 +907,7 @@ BlockRV ConcreteScheduleNode::Blockize(const Array& blocks, bool preser return CreateRV(result); } -void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(), @@ -901,7 +916,7 @@ void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } -void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, +void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), @@ -929,8 +944,8 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { if (const auto* expr = ann_val.as()) { ICHECK(!expr->IsInstance()) - << "TypeError: String is expected, but gets StringImm"; - auto res_expr = this->Get(GetRef(expr)); + << "TypeError: ffi::String is expected, but gets StringImm"; + auto res_expr = this->Get(ffi::GetRef(expr)); // prefer to return int/float literals for annotations if (auto opt_intimm = res_expr.as()) { return (*std::move(opt_intimm))->value; @@ -941,7 +956,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { return res_expr; } if (const auto* arr = ann_val.as()) { - Array result; + ffi::Array result; result.reserve(arr->size()); for (size_t i = 0; i < arr->size(); i++) { result.push_back(CheckAndGetAnnotationValue(arr->at(i))); @@ -949,7 +964,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { return result; } if (const auto* dict = ann_val.as()) { - Map result; + ffi::Map result; for (auto it = dict->begin(); it != dict->end(); ++it) { const auto& key = it->first; auto value = CheckAndGetAnnotationValue(it->second); @@ -958,7 +973,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } else if (auto opt_str = key.try_cast()) { result.Set(opt_str.value(), value); } else { - LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm"; + LOG(FATAL) << "TypeError: annotation dict key expect to be ffi::String or StringImm"; } } return result; @@ -969,7 +984,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { TVM_FFI_UNREACHABLE(); } -void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, +void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val)); @@ -977,14 +992,14 @@ void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } -void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { +void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) { TVM_TIR_SCHEDULE_BEGIN(); tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); } -void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, +void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); tir::Annotate(state_, this->GetSRef(block_rv), ann_key, @@ -993,7 +1008,7 @@ void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_k TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } -void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { +void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) { TVM_TIR_SCHEDULE_BEGIN(); tir::Unannotate(state_, this->GetSRef(block_rv), ann_key); this->state_->DebugVerify(); @@ -1004,10 +1019,10 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value, + const ffi::Optional& pad_value, bool assume_injective_transform) { TVM_TIR_SCHEDULE_BEGIN(); - auto f_subst = [&](const Var& var) -> Optional { + auto f_subst = [&](const Var& var) -> ffi::Optional { if (auto opt_expr = symbol_table_.Get(var)) { return Downcast(opt_expr.value()); } else { @@ -1031,7 +1046,7 @@ void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv, void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) { + const ffi::Array& axis_separators) { TVM_TIR_SCHEDULE_BEGIN(); tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, axis_separators); @@ -1050,7 +1065,7 @@ BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const Lo return CreateRV(result); } -void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array& padding) { +void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) { TVM_TIR_SCHEDULE_BEGIN(); tir::PadEinsum(state_, this->GetSRef(block_rv), padding); TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_); @@ -1068,8 +1083,9 @@ void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buff /******** Schedule: Misc ********/ -void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) { +void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, + const ffi::String& buf_type, + const ffi::Array& buf_index_array) { TVM_TIR_SCHEDULE_BEGIN(); tir::UnsafeHideBufferAccess(state_, this->GetSRef(block_rv), buf_type, buf_index_array); TVM_TIR_SCHEDULE_END("hide-buffer-access", this->error_render_level_); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 5f3f0c8b61f1..64d27fc10a1d 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -33,13 +33,13 @@ class ConcreteScheduleNode : public ScheduleNode { friend class ScheduleCopier; public: - using TSymbolTable = Map; + using TSymbolTable = ffi::Map; protected: /*! \brief The internal state of scheduling */ ScheduleState state_; /*! \brief The function to be worked on. */ - Optional func_working_on_; + ffi::Optional func_working_on_; /*! \brief The level of error rendering */ ScheduleErrorRenderLevel error_render_level_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ @@ -51,16 +51,17 @@ class ConcreteScheduleNode : public ScheduleNode { public: static void RegisterReflection() { - // No fields to register as they are not visited + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } virtual ~ConcreteScheduleNode() = default; public: ScheduleState state() const final { return state_; } - Optional trace() const override { return std::nullopt; } - Optional func_working_on() const final { return func_working_on_; } - void WorkOn(const String& func_name) final; + ffi::Optional trace() const override { return std::nullopt; } + ffi::Optional func_working_on() const final { return func_working_on_; } + void WorkOn(const ffi::String& func_name) final; Schedule Copy() override; void Seed(support::LinearCongruentialEngine::TRandState seed) final; support::LinearCongruentialEngine::TRandState ForkSeed() final; @@ -73,8 +74,8 @@ class ConcreteScheduleNode : public ScheduleNode { inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; inline bool HasBlock(const BlockRV& block_rv) const final; - inline Array GetSRefs(const Array& rvs) const; - inline Array GetSRefs(const Array& rvs) const; + inline ffi::Array GetSRefs(const ffi::Array& rvs) const; + inline ffi::Array GetSRefs(const ffi::Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } @@ -82,59 +83,63 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = std::nullopt) override; - Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = std::nullopt) override; - Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, - int innerpart_factor, - Optional> decision = std::nullopt) override; + ExprRV SampleCategorical(const ffi::Array& candidates, const ffi::Array& probs, + ffi::Optional decision = std::nullopt) override; + ffi::Array SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision = std::nullopt) override; + ffi::Array SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision = std::nullopt) override; LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = std::nullopt) override; + ffi::Optional decision = std::nullopt) override; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const String& name, const Optional& func_name) override; - Array GetLoops(const BlockRV& block_rv) override; - Array GetChildBlocks(const BlockRV& block_rv) override; - Array GetChildBlocks(const LoopRV& loop_rv) override; - Array GetProducers(const BlockRV& block_rv) override; - Array GetConsumers(const BlockRV& block_rv) override; - Array GetOutputBlocks(const BlockRV& scope_block_rv) override; + BlockRV GetBlock(const ffi::String& name, const ffi::Optional& func_name) override; + ffi::Array GetLoops(const BlockRV& block_rv) override; + ffi::Array GetChildBlocks(const BlockRV& block_rv) override; + ffi::Array GetChildBlocks(const LoopRV& loop_rv) override; + ffi::Array GetProducers(const BlockRV& block_rv) override; + ffi::Array GetConsumers(const BlockRV& block_rv) override; + ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) override; /******** Schedule: Transform loops ********/ - LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; - LoopRV Merge(const Array& loop_rvs) override; - Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters, bool disable_predication) override; - Array LoopPartition(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters) override; - void Reorder(const Array& ordered_loop_rvs) override; - void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) override; + LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) override; + LoopRV Merge(const ffi::Array& loop_rvs) override; + ffi::Array Split(const LoopRV& loop_rv, const ffi::Array>& factors, + bool preserve_unit_iters, bool disable_predication) override; + ffi::Array LoopPartition(const LoopRV& loop_rv, + const ffi::Array>& factors, + bool preserve_unit_iters) override; + void Reorder(const ffi::Array& ordered_loop_rvs) override; + void ReorderBlockIterVar(const BlockRV& block_rv, const ffi::Array new_order) override; LoopRV AddUnitLoop(const BlockRV& block_rv) override; LoopRV AddUnitLoop(const LoopRV& loop_rv) override; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) override; void Vectorize(const LoopRV& loop_rv) override; - void Bind(const LoopRV& loop_rv, const String& thread_axis) override; + void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) override; void Unroll(const LoopRV& loop_rv) override; /******** Schedule: Insert cache stages ********/ - BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) override; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) override; + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) override; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) override; BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) override; + const ffi::String& storage_scope, const IndexMap& index_map) override; BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) override; - Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) override; - Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, - int cse_thresh) override; + const ffi::String& storage_scope, const IndexMap& index_map) override; + ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) override; + ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, + int cse_thresh) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) override; + BufferIndexType buffer_index_type, bool skip_simplify) override; /******** Schedule: Data movement ********/ BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) override; + const ffi::String& storage_scope) override; BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) override; + const ffi::String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) override; @@ -142,41 +147,46 @@ class ConcreteScheduleNode : public ScheduleNode { int index = -1) override; void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; + void FuseReductionEpilogue(const BlockRV& reduction_block, + const BlockRV& epilogue_block) override; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override; - void PadEinsum(const BlockRV& block_rv, const Array& padding) override; + void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) override; /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; - void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; - void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override; + void SetScope(const BlockRV& block_rv, int buffer_index, + const ffi::String& storage_scope) override; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const ffi::String& dtype) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; - BlockRV Blockize(const Array& blocks, bool preserve_unit_iters) override; - void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; - void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override; + BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) override; + void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, + bool preserve_unit_iters) override; + void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, + bool preserve_unit_iters) override; /******** Schedule: Annotation ********/ - void Annotate(const LoopRV& loop_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; - void Annotate(const BlockRV& block_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const BlockRV& block_rv, const String& ann_key) override; + void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) override; + void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const Optional& pad_value, + const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform = false) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) override; + const ffi::Array& axis_separators) override; /******** Schedule: Padding decomposition ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; /******** Schedule: Buffer transformation ********/ void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} - void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) override; + void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + const ffi::Array& buf_index_array) override; void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) override; @@ -195,7 +205,7 @@ class ConcreteScheduleNode : public ScheduleNode { * \return The new random variables created */ template - inline Array CreateRV(const Array& srefs); + inline ffi::Array CreateRV(const ffi::Array& srefs); /*! * \brief Add an sref as a random variable into the symbol table * \tparam T The type of the random variable @@ -217,8 +227,8 @@ class ConcreteScheduleNode : public ScheduleNode { * Which is convention of certain primitives. * \return The new random variables created */ - inline Array CreateRV(const std::vector& value, - bool convert_negone_to_none = false); + inline ffi::Array CreateRV(const std::vector& value, + bool convert_negone_to_none = false); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); /*! @@ -237,17 +247,17 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); const BlockNode* block = TVM_SREF_TO_BLOCK(sref); - return GetRef(block); + return ffi::GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(sref); - return GetRef(loop); + return ffi::GetRef(loop); } inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { - PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional { + PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> ffi::Optional { auto it = this->symbol_table_.find(var); if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; @@ -286,7 +296,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { if (sref->stmt == nullptr) { LOG(FATAL) << "ValueError: The block no longer exists in the IRModule"; } - return GetRef(sref); + return ffi::GetRef(sref); } inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { @@ -311,12 +321,13 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { if (sref->stmt == nullptr) { LOG(FATAL) << "ValueError: The loop no longer exists in the IRModule"; } - return GetRef(sref); + return ffi::GetRef(sref); } template -inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Array& rvs) { - Array result; +inline ffi::Array GetSRefsHelper(const ConcreteScheduleNode* sch, + const ffi::Array& rvs) { + ffi::Array result; result.reserve(rvs.size()); for (const T& rv : rvs) { result.push_back(sch->GetSRef(rv)); @@ -324,19 +335,19 @@ inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Arr return result; } -inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { +inline ffi::Array ConcreteScheduleNode::GetSRefs(const ffi::Array& rvs) const { return GetSRefsHelper(this, rvs); } -inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { +inline ffi::Array ConcreteScheduleNode::GetSRefs(const ffi::Array& rvs) const { return GetSRefsHelper(this, rvs); } /******** Adding/Removing elements in the symbol table ********/ template -inline Array ConcreteScheduleNode::CreateRV(const Array& srefs) { - Array result; +inline ffi::Array ConcreteScheduleNode::CreateRV(const ffi::Array& srefs) { + ffi::Array result; result.reserve(srefs.size()); for (const StmtSRef& sref : srefs) { T rv; @@ -359,9 +370,9 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return rv; } -inline Array ConcreteScheduleNode::CreateRV(const std::vector& value, - bool convert_negone_to_none) { - Array results; +inline ffi::Array ConcreteScheduleNode::CreateRV(const std::vector& value, + bool convert_negone_to_none) { + ffi::Array results; results.reserve(value.size()); for (int64_t v : value) { if (convert_negone_to_none && v == -1) { diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 479fc34c75af..ce882ebbc9c7 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -21,13 +21,13 @@ namespace tvm { namespace tir { -String ScheduleError::RenderReport(const String& primitive) const { +ffi::String ScheduleError::RenderReport(const ffi::String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; // get locations of interest - Array locs = LocationsOfInterest(); - std::unordered_map loc_obj_to_name; + ffi::Array locs = LocationsOfInterest(); + std::unordered_map loc_obj_to_name; int n_locs = locs.size(); std::string msg = DetailRenderTemplate(); PrinterConfig cfg; diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h index f579a54bbc81..39c9cc203fcf 100644 --- a/src/tir/schedule/error.h +++ b/src/tir/schedule/error.h @@ -30,11 +30,12 @@ namespace tir { class ScheduleError : public tvm::runtime::Error { public: /*! \brief Base constructor */ - ScheduleError() : tvm::runtime::Error("ScheduleError", "", TVM_FFI_TRACEBACK_HERE) {} + ScheduleError() + : tvm::runtime::Error("ScheduleError", "", TVMFFIBacktrace(nullptr, 0, nullptr, 0)) {} /*! \brief The error occurred in this IRModule */ virtual IRModule mod() const = 0; /*! \brief The locations of interest that we want to point out */ - virtual Array LocationsOfInterest() const = 0; + virtual ffi::Array LocationsOfInterest() const = 0; /*! * \brief Returns an error string template for rendering, corresponds to the "detail" mode. * \sa ScheduleErrorRenderLevel @@ -44,14 +45,14 @@ class ScheduleError : public tvm::runtime::Error { * now it only printed out all the locations in plain text, but in the future, we may want to mark * the IR with underscores and attach names to each location of interest. */ - virtual String DetailRenderTemplate() const = 0; + virtual ffi::String DetailRenderTemplate() const = 0; /*! * \brief Returns an error string without needing to render, corresponds to the "fast" mode * \sa ScheduleErrorRenderLevel */ - virtual String FastErrorString() const = 0; + virtual ffi::String FastErrorString() const = 0; /*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */ - String RenderReport(const String& primitive) const; + ffi::String RenderReport(const ffi::String& primitive) const; }; class LoopPositionError : public ScheduleError { @@ -62,11 +63,11 @@ class LoopPositionError : public ScheduleError { block_(std::move(block)), primitive_(primitive) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: " + primitive_ + " expect the loop to be an ancestor of block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: The input loop {0} of " << primitive_ << " is required to be an ancestor of block {1}."; @@ -74,7 +75,7 @@ class LoopPositionError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_, block_}; } + ffi::Array LocationsOfInterest() const final { return {loop_, block_}; } IRModule mod_; For loop_; diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 0e911580338a..02c866e0b605 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -23,19 +23,19 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { InstructionKindNode::RegisterReflection(); InstructionNode::RegisterReflection(); -}); +} bool InstructionKindNode::IsPostproc() const { static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); return this == inst_enter_postproc.get(); } -Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, - Array outputs) { - ObjectPtr n = make_object(); +Instruction::Instruction(InstructionKind kind, ffi::Array inputs, ffi::Array attrs, + ffi::Array outputs) { + ObjectPtr n = ffi::make_object(); n->kind = std::move(kind); n->inputs = std::move(inputs); n->attrs = std::move(attrs); @@ -45,17 +45,17 @@ Instruction::Instruction(InstructionKind kind, Array inputs, Array att using InstructionKindRegistry = AttrRegistry; -InstructionKind InstructionKind::Get(const String& name) { +InstructionKind InstructionKind::Get(const ffi::String& name) { const InstructionKindRegEntry* reg = InstructionKindRegistry::Global()->Get(name); ICHECK(reg != nullptr) << "AttributeError: Instruction kind " << name << " is not registered"; return reg->inst_kind_; } InstructionKindRegEntry::InstructionKindRegEntry(uint32_t reg_index) { - this->inst_kind_ = InstructionKind(make_object()); + this->inst_kind_ = InstructionKind(ffi::make_object()); } -InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const String& name) { +InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const ffi::String& name) { return InstructionKindRegistry::Global()->RegisterOrGet(name); } @@ -65,29 +65,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { const auto* self = obj.as(); ICHECK_NOTNULL(self); - Array inputs; + ffi::Array inputs; inputs.reserve(self->inputs.size()); for (const Any& obj : self->inputs) { if (obj == nullptr) { - inputs.push_back(String("None")); + inputs.push_back(ffi::String("None")); } else if (auto opt_str = obj.as()) { - inputs.push_back(String('"' + (*opt_str).operator std::string() + '"')); + inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); } else if (obj.as() || obj.as()) { - inputs.push_back(String("_")); + inputs.push_back(ffi::String("_")); } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { inputs.push_back(obj); } else if (obj.as() || obj.as()) { inputs.push_back(obj); } else if (const auto* expr = obj.as()) { - PrimExpr new_expr = - Substitute(GetRef(expr), [](const Var& var) -> Optional { - ObjectPtr new_var = make_object(*var.get()); + PrimExpr new_expr = Substitute( + ffi::GetRef(expr), [](const Var& var) -> ffi::Optional { + ObjectPtr new_var = ffi::make_object(*var.get()); new_var->name_hint = "_"; return Var(new_var); }); std::ostringstream os; os << new_expr; - inputs.push_back(String(os.str())); + inputs.push_back(ffi::String(os.str())); } else if (obj.as()) { inputs.push_back(obj); } else { @@ -99,19 +99,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /*inputs=*/inputs, /*attrs=*/self->attrs, /*decision=*/Any(nullptr), - /*outputs=*/Array(self->outputs.size(), String("_"))); + /*outputs=*/ffi::Array(self->outputs.size(), ffi::String("_"))); }); /**************** FFI ****************/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.InstructionKindGet", InstructionKind::Get) .def("tir.schedule.Instruction", - [](InstructionKind kind, Array inputs, Array attrs, Array outputs) - -> Instruction { return Instruction(kind, inputs, attrs, outputs); }); -}); + [](InstructionKind kind, ffi::Array inputs, ffi::Array attrs, + ffi::Array outputs) -> Instruction { + return Instruction(kind, inputs, attrs, outputs); + }); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index bff619ca49cc..93a1dd77ab64 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -44,25 +44,25 @@ namespace tir { * static constexpr bool kIsPure = false; * * // Convertible to `InstructionKindNode::FInstructionApply` - * static Array ApplyToSchedule( + * static ffi::Array ApplyToSchedule( * const tir::Schedule& sch, - * const Array& inputs, - * const Array& attrs, - * const Optional& decision); + * const ffi::Array& inputs, + * const ffi::Array& attrs, + * const ffi::Optional& decision); * * // Convertible to `InstructionKindNode::FInstructionAsPython` - * static String AsPython( - * const Array& inputs, - * const Array& attrs, - * const Optional& decision, - * const Array& outputs); + * static ffi::String AsPython( + * const ffi::Array& inputs, + * const ffi::Array& attrs, + * const ffi::Optional& decision, + * const ffi::Array& outputs); * * // Convertible to `InstructionKindNode::FInstructionAttrsAsJSON` * static ObjectRef AttrsAsJSON( - * const Array& attrs); + * const ffi::Array& attrs); * * // Convertible to `InstructionKindNode::FInstructionAttrsFromJSON` - * static Array AttrsFromJSON( + * static ffi::Array AttrsFromJSON( * const ObjectRef& attrs_record); * }; * @@ -108,12 +108,12 @@ namespace tir { * // - The next `kNumInputs` arguments are input random variables * // - The next `kNumAttrs` arguments are attributes * // - The next argument is decision, if `kNumDecisions == 1` - * static Array UnpackedApplyToSchedule( + * static ffi::Array UnpackedApplyToSchedule( * Schedule sch, * LoopRV loop_rv, * Integer n, * Integer max_innermost_factor, - * Optional> decision) { + * ffi::Optional> decision) { * return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); * } * @@ -123,12 +123,12 @@ namespace tir { * // - The next `kNumInputs` arguments are names of input random variables * // - The next `kNumAttrs` arguments are attributes * // - The next argument is decision, if `kNumDecisions == 1` - * static String UnpackedAsPython( - * Array outputs, - * String loop_rv, + * static ffi::String UnpackedAsPython( + * ffi::Array outputs, + * ffi::String loop_rv, * Integer n, * Integer max_innermost_factor, - * Optional> decision) { + * ffi::Optional> decision) { * PythonAPICall py("sample_perfect_tile"); * py.Input("loop", loop_rv); * py.Input("n", n->value); @@ -152,16 +152,16 @@ struct UnpackedInstTraits { * `TTraits::UnpackedApplyToSchedule` * \sa InstructionKindNode::f_apply_to_schedule */ - static Array ApplyToSchedule(const Schedule& sch, const Array& inputs, - const Array& attrs, const Any& decision); + static ffi::Array ApplyToSchedule(const Schedule& sch, const ffi::Array& inputs, + const ffi::Array& attrs, const Any& decision); /*! * \brief Unpack the arguments in the calling convention, and feed them into * `TTraits::UnpackedAsPython` * \sa InstructionKindNode::f_as_python */ - static String AsPython(const Array& inputs, const Array& attrs, const Any& decision, - const Array& outputs); + static ffi::String AsPython(const ffi::Array& inputs, const ffi::Array& attrs, + const Any& decision, const ffi::Array& outputs); /*! \brief No customized serializer by default */ static constexpr std::nullptr_t AttrsAsJSON = nullptr; @@ -171,12 +171,12 @@ struct UnpackedInstTraits { protected: template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs); + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs); template - static TVM_ALWAYS_INLINE void _SetAttrs(AnyView* packed_args, const Array& attrs); + static TVM_ALWAYS_INLINE void _SetAttrs(AnyView* packed_args, const ffi::Array& attrs); template static TVM_ALWAYS_INLINE void _SetDecision(AnyView* packed_args, const Any& decision); - static TVM_ALWAYS_INLINE Array _ConvertOutputs(const ffi::Any& rv); + static TVM_ALWAYS_INLINE ffi::Array _ConvertOutputs(const ffi::Any& rv); }; /*! @@ -190,32 +190,33 @@ class PythonAPICall { * \brief Constructor * \param method_name The name of the schedule API to be called */ - explicit PythonAPICall(String method_name) : method_name_(method_name), output_(std::nullopt) {} + explicit PythonAPICall(ffi::String method_name) + : method_name_(method_name), output_(std::nullopt) {} /*! \brief Add an integer input */ - inline void Input(String arg_name, int arg); + inline void Input(ffi::String arg_name, int arg); /*! \brief Add an integer input */ - inline void Input(String arg_name, int64_t arg); + inline void Input(ffi::String arg_name, int64_t arg); /*! \brief Add a bool input */ - inline void Input(String arg_name, bool arg); + inline void Input(ffi::String arg_name, bool arg); /*! \brief Add a double input */ - inline void Input(String arg_name, double arg); + inline void Input(ffi::String arg_name, double arg); /*! \brief Add an input random variable */ - inline void Input(String arg_name, String arg); + inline void Input(ffi::String arg_name, ffi::String arg); /*! \brief Add an input random variable */ - inline void Input(String arg_name, std::string arg); + inline void Input(ffi::String arg_name, std::string arg); /*! \brief Add an input, dispatched to different implementations according to the object's type */ - inline void Input(String arg_name, Any arg); + inline void Input(ffi::String arg_name, Any arg); /*! \brief Add the decision */ inline void Decision(Any decision); /*! * \brief Add a single output random variable * \param unit_array An array containing only one element */ - inline void SingleOutput(Array unit_array); + inline void SingleOutput(ffi::Array unit_array); /*! \brief Add a list of output random variables */ - inline void OutputList(Array outputs); + inline void OutputList(ffi::Array outputs); /*! \returns The schedule API call in python syntax */ - inline String Str() const; + inline ffi::String Str() const; private: /*! \brief Converts a TVM object to python string and print to the output stream */ @@ -223,13 +224,13 @@ class PythonAPICall { private: /*! \brief The name of the API to call */ - String method_name_; + ffi::String method_name_; /*! \brief The output of the instruction */ - Optional output_; + ffi::Optional output_; /*! \brief The names of input arguments */ - std::vector arg_names_; + std::vector arg_names_; /*! \brief The values of input arguments */ - std::vector args_; + std::vector args_; }; /********** implementation details **********/ @@ -272,7 +273,7 @@ template struct _IsTVMArray : std::false_type {}; template -struct _IsTVMArray> : std::true_type {}; +struct _IsTVMArray> : std::true_type {}; template struct _IsSingleObject @@ -297,10 +298,10 @@ static constexpr int IsSingleObject = _IsSingleObject>::valu }; // namespace details template -Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, - const Array& inputs, - const Array& attrs, - const Any& decision) { +ffi::Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, + const ffi::Array& inputs, + const ffi::Array& attrs, + const Any& decision) { using method_type = decltype(TTraits::UnpackedApplyToSchedule); using return_type = details::ReturnType; // static_assert(details::ArgumentAreAllObjects, @@ -329,8 +330,9 @@ Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, } template -String UnpackedInstTraits::AsPython(const Array& inputs, const Array& attrs, - const Any& decision, const Array& outputs) { +ffi::String UnpackedInstTraits::AsPython(const ffi::Array& inputs, + const ffi::Array& attrs, const Any& decision, + const ffi::Array& outputs) { using method_type = decltype(TTraits::UnpackedAsPython); using return_type = details::ReturnType; // static_assert(details::ArgumentAreAllObjects, @@ -355,13 +357,13 @@ String UnpackedInstTraits::AsPython(const Array& inputs, const Arr }); ffi::Any rv; pf.CallPacked(ffi::PackedArgs(packed_args, kNumArgs), &rv); - return rv.cast(); + return rv.cast(); } template template TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetInputs(AnyView* packed_args, - const Array& inputs) { + const ffi::Array& inputs) { constexpr size_t kNumInputs = TTraits::kNumInputs; ICHECK_EQ(kNumInputs, inputs.size()) << "ValueError: Incorrect kNumInputs for instruction: " << TTraits::kName; @@ -373,7 +375,7 @@ TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetInputs(AnyView* packed_a template template TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetAttrs(AnyView* packed_args, - const Array& attrs) { + const ffi::Array& attrs) { constexpr size_t kNumAttrs = TTraits::kNumAttrs; ICHECK_EQ(kNumAttrs, attrs.size()) << "ValueError: Incorrect kNumAttrs for instruction: " << TTraits::kName; @@ -396,7 +398,7 @@ TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetDecision(AnyView* packed } template -TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs(const ffi::Any& rv) { +TVM_ALWAYS_INLINE ffi::Array UnpackedInstTraits::_ConvertOutputs(const ffi::Any& rv) { using method_type = decltype(TTraits::UnpackedApplyToSchedule); using return_type = details::ReturnType; constexpr int is_array = details::IsTVMArray; @@ -409,7 +411,7 @@ TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs(const } else if (is_single_obj) { return {rv}; } else if (is_array) { - return rv.cast>(); + return rv.cast>(); } } @@ -466,17 +468,17 @@ inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) { } } -void PythonAPICall::Input(String arg_name, int arg) { +void PythonAPICall::Input(ffi::String arg_name, int arg) { arg_names_.emplace_back(std::move(arg_name)); args_.push_back(std::to_string(arg)); } -void PythonAPICall::Input(String arg_name, int64_t arg) { +void PythonAPICall::Input(ffi::String arg_name, int64_t arg) { arg_names_.emplace_back(std::move(arg_name)); args_.push_back(std::to_string(arg)); } -void PythonAPICall::Input(String arg_name, bool arg) { +void PythonAPICall::Input(ffi::String arg_name, bool arg) { static const char* true_str = "True"; static const char* false_str = "False"; arg_names_.emplace_back(std::move(arg_name)); @@ -487,7 +489,7 @@ void PythonAPICall::Input(String arg_name, bool arg) { } } -void PythonAPICall::Input(String arg_name, double arg) { +void PythonAPICall::Input(ffi::String arg_name, double arg) { arg_names_.emplace_back(std::move(arg_name)); std::ostringstream os; os.precision(17); @@ -495,17 +497,17 @@ void PythonAPICall::Input(String arg_name, double arg) { args_.push_back(os.str()); } -void PythonAPICall::Input(String arg_name, String arg) { +void PythonAPICall::Input(ffi::String arg_name, ffi::String arg) { arg_names_.emplace_back(std::move(arg_name)); args_.emplace_back(std::move(arg)); } -void PythonAPICall::Input(String arg_name, std::string arg) { +void PythonAPICall::Input(ffi::String arg_name, std::string arg) { arg_names_.emplace_back(std::move(arg_name)); args_.emplace_back(std::move(arg)); } -void PythonAPICall::Input(String arg_name, Any arg) { +void PythonAPICall::Input(ffi::String arg_name, Any arg) { arg_names_.emplace_back(std::move(arg_name)); std::ostringstream os; AsPythonString(arg, os); @@ -518,12 +520,12 @@ void PythonAPICall::Decision(Any decision) { } } -void PythonAPICall::SingleOutput(Array unit_array) { +void PythonAPICall::SingleOutput(ffi::Array unit_array) { ICHECK_EQ(unit_array.size(), 1); this->output_ = unit_array[0]; } -void PythonAPICall::OutputList(Array outputs) { +void PythonAPICall::OutputList(ffi::Array outputs) { if (outputs.empty()) { return; } @@ -539,7 +541,7 @@ void PythonAPICall::OutputList(Array outputs) { this->output_ = os.str(); } -String PythonAPICall::Str() const { +ffi::String PythonAPICall::Str() const { std::ostringstream os; if (output_.has_value()) { os << output_.value() << " = "; diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 71b646855d50..bef35387cbaa 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -37,11 +37,11 @@ class TensorIntrinMismatchError : public ScheduleError { ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The stmt doesn't match the tensor intrin."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The stmt {0} doesn't match the tensor intrin\nThe pattern attempting to be matched:\n" << lhs_stmt_ << "\nDoes not match the tensorize description:\n" @@ -54,7 +54,7 @@ class TensorIntrinMismatchError : public ScheduleError { IRModule mod() const final { return lhs_mod_; } - Array LocationsOfInterest() const final { return {lhs_stmt_}; } + ffi::Array LocationsOfInterest() const final { return {lhs_stmt_}; } private: IRModule lhs_mod_; @@ -309,7 +309,7 @@ bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - auto lhs = GetRef(op); + auto lhs = ffi::GetRef(op); if (lhs.same_as(other)) return true; if (op->dtype.code() != rhs->dtype.code()) { if (assert_mode_) { @@ -348,8 +348,8 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { return true; } -bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, - const std::pair& rhs) { +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { if (lhs.first != rhs.first) { if (assert_mode_) { std::ostringstream os; @@ -376,8 +376,8 @@ bool TensorizeComparator::CompareAnnotation(const std::pair& l return true; } -bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, - const Map& rhs) { +bool TensorizeComparator::CompareAnnotationMap(const ffi::Map& lhs, + const ffi::Map& rhs) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) { if (assert_mode_) { @@ -389,14 +389,15 @@ bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, return false; } - auto sort_map = [](const Map& map) -> std::vector> { - std::vector> ret(map.begin(), map.end()); + auto sort_map = [](const ffi::Map& map) + -> std::vector> { + std::vector> ret(map.begin(), map.end()); sort(ret.begin(), ret.end(), [](const auto& a, const auto& b) { return a.first < b.first; }); return ret; }; - std::vector> lhs_array = sort_map(lhs); - std::vector> rhs_array = sort_map(rhs); + std::vector> lhs_array = sort_map(lhs); + std::vector> rhs_array = sort_map(rhs); for (size_t i = 0; i < lhs.size(); ++i) { if (!CompareAnnotation(lhs_array[i], rhs_array[i])) { @@ -582,7 +583,8 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { } template -bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { +bool TensorizeComparator::CompareArray(const ffi::Array& lhs, const ffi::Array& rhs, + F Self::*cmp) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) { if (assert_mode_) { @@ -704,7 +706,7 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { lhs_indices.push_back(SimplifyNonTrivialExpr(index, &analyzer_)); } - auto is_scalar_access = [](const Array& indices, PrimExpr index) { + auto is_scalar_access = [](const ffi::Array& indices, PrimExpr index) { // Check if the indexing is of the form C[0] if (indices.size() > 1) return false; auto int_imm = index.template as(); @@ -722,8 +724,8 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { if (it_rhs == rhs_buffer_indices_map_.end()) { return false; } - auto indices_check = [&](const Array& indices, - const Array& old_indices) -> bool { + auto indices_check = [&](const ffi::Array& indices, + const ffi::Array& old_indices) -> bool { if (indices.size() != old_indices.size()) { return false; } diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index a15de7b97a91..665d093b2fa4 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -86,13 +86,14 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool DefEqual(const Var& lhs, const Var& rhs); virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); - bool CompareAnnotation(const std::pair& lhs, - const std::pair& rhs); - bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const ffi::Map& lhs, + const ffi::Map& rhs); template bool CompareBufferAccess(const T* lhs, const T* rhs); template - bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp); + bool CompareArray(const ffi::Array& lhs, const ffi::Array& rhs, F Self::*cmp); bool CompareRange(const Range& lhs, const Range& rhs); bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); void EmitError(const std::string& error_message); @@ -151,17 +152,17 @@ class AutoTensorizeComparator : public TensorizeComparator { /*! \brief Block iters in the RHS stmt. */ std::vector rhs_iters_; /*! \brief The buffer and its access indices in the LHS stmt. */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> lhs_buffer_indices_map_; /*! \brief The buffer and its access indices in the RHS stmt. */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> rhs_buffer_indices_map_; /*! \brief Map from LHS buffer to RHS buffer */ std::unordered_map lhs_buffer_map_; private: /*! \brief The domain of the inner block iters. */ - Map inner_iter_dom_map_; + ffi::Map inner_iter_dom_map_; }; } // namespace tir diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index de8fe7238ea7..b031266211ed 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -55,8 +55,9 @@ std::vector SampleWithoutReplacement( * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision); + const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional* decision); /*! * \brief Create a sampling function that does multinomial sampling. * \param rand_state The random state. @@ -98,7 +99,7 @@ TVM_DLL std::vector SamplePerfectTile( TVM_DLL std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, - Optional>* decision); + ffi::Optional>* decision); /*! * \brief Sample the factors to a partitioned tile for a specific loop * @@ -136,7 +137,7 @@ TVM_DLL std::vector SamplePartitionedTile( TVM_DLL std::vector SamplePartitionedTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t partition_pos, - int32_t innerpart_factor, Optional>* decision); + int32_t innerpart_factor, ffi::Optional>* decision); /*! * \brief Sample a compute-at location of the given block * \param self The schedule state @@ -147,7 +148,7 @@ TVM_DLL std::vector SamplePartitionedTile( */ TVM_DLL tir::StmtSRef SampleComputeLocation( tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, - const tir::StmtSRef& block_sref, Optional* decision); + const tir::StmtSRef& block_sref, ffi::Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! @@ -157,35 +158,36 @@ TVM_DLL tir::StmtSRef SampleComputeLocation( * \param gvar The function to be retrieved * \return A list of blocks with the specific name */ -Array GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv); +ffi::Array GetBlocks(const ScheduleState& self, const ffi::String& name, + const GlobalVar& gv); /*! * \brief Gets the parent loops of the block in its scope, from outer to inner * \param self The schedule state * \param block_sref The query block * \return A list of loops above the given block in its scope, from outer to inner */ -Array GetLoops(const StmtSRef& block_sref); +ffi::Array GetLoops(const StmtSRef& block_sref); /*! * \brief Get the leaf blocks of a specific block/loop * \param self The schedule state * \param parent_sref The query block/loop * \return A list of leaf blocks inside a specific block/loop */ -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); /*! * \brief Get the producers of a specific block * \param self The schedule state * \param block_sref The block in the query * \return A list of blocks, the producers of the given block */ -Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref); +ffi::Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref); /*! * \brief Get the consumers of a specific block * \param self The schedule state * \param block_rv The block in the query * \return A list of blocks, the consumers of the given block */ -Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); +ffi::Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); /*! * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written @@ -194,7 +196,7 @@ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sr * \return A list of all blocks that write to some output buffer * block */ -Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref); +ffi::Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: @@ -210,9 +212,9 @@ Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope * Warning: enabling this feature may result in incorrect code generation if not used * carefully. \return An array of srefs to the loops after splitting */ -TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters, - bool disable_predication); +TVM_DLL ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, bool preserve_unit_iters, + bool disable_predication); /*! * Partition a loop into a list of consecutive loops. It requires: @@ -223,8 +225,9 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return An array of srefs to the loops after partitioning */ -TVM_DLL Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters); +TVM_DLL ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, + bool preserve_unit_iters); /*! * \brief Merge a list of loops into one. The loops under their LCA requires: @@ -236,7 +239,7 @@ TVM_DLL Array LoopPartition(ScheduleState self, const StmtSRef& loop_s * \param loop_srefs An array of srefs to the loops to be merged * \return The new loop after merge */ -TVM_DLL StmtSRef Merge(ScheduleState self, const Array& loop_srefs); +TVM_DLL StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs); /*! * \brief Fuse a list of consecutive loops into one. It requires: @@ -249,7 +252,7 @@ TVM_DLL StmtSRef Merge(ScheduleState self, const Array& loop_srefs); * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The sref to the fused loop */ -TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, +TVM_DLL StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, bool preserve_unit_loops); /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. @@ -264,7 +267,7 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, * \param self The state of the schedule * \param ordered_loop_srefs An array of srefs which indicates the new order of loops */ -TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); +TVM_DLL void Reorder(ScheduleState self, const ffi::Array& ordered_loop_srefs); /*! * \brief Reorder itervars inside a block. @@ -273,7 +276,7 @@ TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_sre * \param new_order The new itervar order. */ TVM_DLL void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, - const Array& new_order); + const ffi::Array& new_order); /*! * \brief Create a new unit loop on top of the specific block or loop. @@ -320,7 +323,7 @@ TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref); * \param loop_sref The sref of the loop to be bound to the thread axis * \param thread_axis The thread axis to be bound to the loop */ -TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis); +TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const ffi::String& thread_axis); /*! * \brief Unroll the input loop. It requires nothing * \param self The state of the schedule @@ -340,7 +343,8 @@ TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); * \return The cache stage block. */ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, const Array consumer_blocks = {}); + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block that writes the target buffer. @@ -353,8 +357,8 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r * \return The cache stage block. */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks = {}); + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}); /*! * \brief Create a block that reads a buffer region into a read cache. It requires: * 1) There is at most one block who writes the buffer in the scope. @@ -369,7 +373,7 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * \return The cache stage block. */ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope, + int read_buffer_index, const ffi::String& storage_scope, const IndexMap& index_map); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: @@ -385,7 +389,7 @@ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref * \return The cache stage block. */ TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, - int write_buffer_index, const String& storage_scope, + int write_buffer_index, const ffi::String& storage_scope, const IndexMap& index_map); /*! @@ -398,8 +402,8 @@ TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sre * \param storage_scope The target storage scope * \return The cache stage blocks, cache read block together with cache write block. */ -TVM_DLL Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope); +TVM_DLL ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const ffi::String& storage_scope); /*! * \brief Create a block to cache precomputed index for later use. * if there is no index computation, keep unchanged. @@ -408,8 +412,8 @@ TVM_DLL Array CacheInplace(ScheduleState self, const StmtSRef& block_s * \param cse_thresh The repeat threshold that determines a common sub expr * \return The cache stage block. */ -TVM_DLL Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, - const String& storage_scope, int cse_thresh); +TVM_DLL ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, + const ffi::String& storage_scope, int cse_thresh); /*! *! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. @@ -424,15 +428,15 @@ TVM_DLL Array CacheIndex(ScheduleState self, const StmtSRef& block_sre * \return The reindex stage block. */ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type); + BufferIndexType buffer_index_type, bool skip_simplify = false); /******** Schedule: Data movement ********/ TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope); + int read_buffer_index, const ffi::String& storage_scope); TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int write_buffer_index, const String& storage_scope); + int write_buffer_index, const ffi::String& storage_scope); /******** Schedule: Compute location ********/ /*! @@ -505,6 +509,14 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); * \param block_sref The sref to the block to be inlined to its producer */ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); +/*! + * \brief Fuse an epilogue block into a reduction block + * \param self The state of the schedule + * \param reduction_block_sref The sref to the reduction block + * \param epilogue_block_sref The sref to the epilogue block to be fused + */ +TVM_DLL void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref); /******** Schedule: Reduction ********/ /*! * \brief Decompose a reduction block into two separate blocks. @@ -561,7 +573,7 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu * \param storage_scope The storage scope to be set */ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& storage_scope); + const ffi::String& storage_scope); /*! * \brief Set the data type of a buffer, where the buffer is specified by a block and a * write-index @@ -573,7 +585,7 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer * \param dtype The data type to be set */ TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& dtype); + const ffi::String& dtype); /*! * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read * or write index @@ -584,7 +596,7 @@ TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int */ TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators); + const ffi::Array& axis_separators); /******** Schedule: Blockize & Tensorize ********/ @@ -604,7 +616,7 @@ TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool pr * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new block */ -TVM_DLL StmtSRef Blockize(ScheduleState self, const Array& blocks, +TVM_DLL StmtSRef Blockize(ScheduleState self, const ffi::Array& blocks, bool preserve_unit_iters); /*! @@ -625,7 +637,7 @@ TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, * \param ann_key The annotation key * \param ann_val The annotation value */ -TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, +TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key, const Any& ann_val); /*! * \brief Unannotate a block/loop's annotation with key ann_key @@ -633,7 +645,7 @@ TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& an * \param sref The block/loop to be unannotated * \param ann_key The annotation key */ -TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); +TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key); /******** Schedule: Layout transformation ********/ /*! @@ -656,7 +668,8 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& */ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value, bool assume_injective_transform); + const ffi::Optional& pad_value, + bool assume_injective_transform); /*! * \brief Apply a transformation represented by IndexMap to block @@ -688,7 +701,7 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref * \param padding The padding for each block iter. */ TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref, - const Array& padding); + const ffi::Array& padding); /******** Schedule: Buffer transformation ********/ /*! * \brief Compute the target buffer via rolling buffering. @@ -715,7 +728,8 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w * \param buf_index_array The array of buffer indices we hide access. */ TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, - const String& buf_type, const Array& buf_index_array); + const ffi::String& buf_type, + const ffi::Array& buf_index_array); /*! * \brief Annotate the read or write region of a specific buffer in a block diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index e00ac2a5bba9..c398a46418a6 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -21,9 +21,10 @@ namespace tvm { namespace tir { -void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, const Any& ann_val) { +void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key, + const Any& ann_val) { // Extract annotation - const Map* annotations = nullptr; + const ffi::Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; } else if (const auto* block = sref->StmtAs()) { @@ -36,27 +37,27 @@ void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, c return; } // Add the new annotation - Map new_ann(*annotations); + ffi::Map new_ann(*annotations); new_ann.Set(ann_key, ann_val); // Create the new stmt if (const auto* loop = sref->StmtAs()) { - ObjectPtr n = make_object(*loop); + ObjectPtr n = ffi::make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); } else if (const auto* block = sref->StmtAs()) { - ObjectPtr n = make_object(*block); + ObjectPtr n = ffi::make_object(*block); n->annotations = std::move(new_ann); Block p(n); - self->Replace(sref, p, {{GetRef(block), p}}); + self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; } } -void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) { +void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key) { // Extract annotation - const Map* annotations = nullptr; + const ffi::Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; } else if (const auto* block = sref->StmtAs()) { @@ -67,18 +68,18 @@ void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) // Remove the annotation ICHECK(annotations->find(ann_key) != annotations->end()) << "IndexError: Cannot find annotation key: " << ann_key; - Map new_ann(*annotations); + ffi::Map new_ann(*annotations); new_ann.erase(ann_key); // Create the new stmt if (const auto* loop = sref->StmtAs()) { - ObjectPtr n = make_object(*loop); + ObjectPtr n = ffi::make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); } else if (const auto* block = sref->StmtAs()) { - ObjectPtr n = make_object(*block); + ObjectPtr n = ffi::make_object(*block); n->annotations = std::move(new_ann); Block p(n); - self->Replace(sref, p, {{GetRef(block), p}}); + self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; @@ -95,7 +96,7 @@ struct AnnotateTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, Any ann_val, - String ann_key) { + ffi::String ann_key) { if (auto block = block_or_loop_rv.as()) { return sch->Annotate(block.value(), ann_key, ann_val); } @@ -106,8 +107,8 @@ struct AnnotateTraits : public UnpackedInstTraits { throw; } - static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, Any ann_val, - String ann_key) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef block_or_loop_rv, + Any ann_val, ffi::String ann_key) { PythonAPICall py("annotate"); py.Input("block_or_loop", block_or_loop_rv); py.Input("ann_key", ann_key); @@ -128,7 +129,8 @@ struct UnannotateTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) { + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, + ffi::String ann_key) { if (auto block = block_or_loop_rv.as()) { return sch->Unannotate(block.value(), ann_key); } @@ -139,8 +141,8 @@ struct UnannotateTraits : public UnpackedInstTraits { throw; } - static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, - String ann_key) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef block_or_loop_rv, + ffi::String ann_key) { PythonAPICall py("unannotate"); py.Input("block_or_loop", block_or_loop_rv); py.Input("ann_key", ann_key); diff --git a/src/tir/schedule/primitive/annotate_buffer_access.cc b/src/tir/schedule/primitive/annotate_buffer_access.cc index ce767339ee50..84672dede70d 100644 --- a/src/tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/tir/schedule/primitive/annotate_buffer_access.cc @@ -33,7 +33,7 @@ class AnnotateRegionRewriter : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - Array regions = + ffi::Array regions = buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads; ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative"; ICHECK_LT(buffer_index_, static_cast(regions.size())) << "Buffer index out of range"; @@ -47,12 +47,13 @@ class AnnotateRegionRewriter : public StmtExprMutator { } // Annotate the block with explicit_read_region or explicit_write_region - Map new_annotations = n->annotations; - String annotation_key = buffer_index_type_ == BufferIndexType::kWrite - ? attr::explicit_write_region - : attr::explicit_read_region; + ffi::Map new_annotations = n->annotations; + ffi::String annotation_key = buffer_index_type_ == BufferIndexType::kWrite + ? attr::explicit_write_region + : attr::explicit_read_region; if (new_annotations.count(annotation_key)) { - Array buffer_indices = Downcast>(new_annotations[annotation_key]); + ffi::Array buffer_indices = + Downcast>(new_annotations[annotation_key]); bool found = false; for (const Integer& index : buffer_indices) { if (index->value == buffer_index_) { @@ -65,7 +66,7 @@ class AnnotateRegionRewriter : public StmtExprMutator { new_annotations.Set(annotation_key, buffer_indices); } } else { - new_annotations.Set(annotation_key, Array{Integer(buffer_index_)}); + new_annotations.Set(annotation_key, ffi::Array{Integer(buffer_index_)}); } n->annotations = std::move(new_annotations); @@ -82,16 +83,17 @@ class AnnotateRegionRewriter : public StmtExprMutator { void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, buffer_index_type); + Buffer buffer = + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, buffer_index_type); arith::Analyzer analyzer; - Array block_iter_vars; + ffi::Array block_iter_vars; for (const IterVar& iter_var : block->iter_vars) { block_iter_vars.push_back(iter_var->var); } - Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ffi::Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; - Array new_ranges; + ffi::Array new_ranges; for (size_t i = 0; i < new_indices.size(); i += 2) { // (begin, end) represents a region new_ranges.push_back(Range::FromMinExtent( @@ -101,9 +103,9 @@ void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int bu BufferRegion new_region(buffer, new_ranges); AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type); - Stmt new_stmt = mutator(GetRef(block_sref->stmt)); + Stmt new_stmt = mutator(ffi::GetRef(block_sref->stmt)); - self->Replace(block_sref, new_stmt, {{GetRef(block), Downcast(new_stmt)}}); + self->Replace(block_sref, new_stmt, {{ffi::GetRef(block), Downcast(new_stmt)}}); } struct AnnotateBufferAccessTraits : public UnpackedInstTraits { @@ -122,7 +124,7 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraitsinitial_indices.size(); ++i) { @@ -139,11 +141,12 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraits outputs, String block, Integer buffer_index, - Integer buffer_index_type, IndexMap index_map) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer buffer_index, Integer buffer_index_type, + IndexMap index_map) { PythonAPICall py("annotate_buffer_access"); py.Input("block", block); py.Input("buffer_index", buffer_index->value); @@ -151,7 +154,7 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraits(buffer_index_type->value)) << "\""; - py.Input("buf_type", String(os.str())); + py.Input("buf_type", ffi::String(os.str())); py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map)); return py.Str(); diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 0e2a055d7afe..2bf62d409e2d 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -30,13 +30,13 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input `axis` is out of range. It is required to be in range " "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set " "storage alignment."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; int ndim = static_cast(buffer_->shape.size()); os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim @@ -47,7 +47,7 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) { int ndim = static_cast(buffer->shape.size()); @@ -71,12 +71,12 @@ class NonAllocatedBufferError : public ScheduleError { public: explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input buffer is not allocated by a block. This means the buffer is " " either a function parameter or defined in `match_buffer` of a block."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The input buffer " << buffer_->name << " is not allocated by a block. This means the buffer is either a function parameter or " @@ -94,7 +94,7 @@ class NonAllocatedBufferError : public ScheduleError { return defining_site_sref.value(); } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: @@ -107,12 +107,12 @@ class StorageAlignInvalidFactorError : public ScheduleError { explicit StorageAlignInvalidFactorError(IRModule mod, int factor) : mod_(std::move(mod)), factor_(factor) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input `factor` of storage_align is expected to be a positive " "number."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The input `factor` of storage_align is expected to be a positive number. However, the " "input `factor` is " @@ -126,7 +126,7 @@ class StorageAlignInvalidFactorError : public ScheduleError { } } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: @@ -139,12 +139,12 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The block annotation for storage align is expected to be an array of " "4-integer-tuples (buffer_index, axis, factor, offset)."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block annotation for storage align is expected to be an array of 4-integer-tuples " "(buffer_index, axis, factor, offset). However, the block annotation with key " @@ -168,7 +168,7 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { return storage_align_annotation; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod() const final { return mod_; } private: @@ -194,7 +194,7 @@ class StorageScopeMutator : private ReplaceBufferMutator { * \return The new block after the mutation */ static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, - const String& storage_scope, Map* block_sref_reuse) { + const ffi::String& storage_scope, ffi::Map* block_sref_reuse) { Buffer new_buffer = WithScope(old_buffer, storage_scope); StorageScopeMutator mutator(old_buffer, new_buffer, storage_scope, block_sref_reuse); Stmt new_block = mutator.VisitStmt(allocate_site); @@ -202,8 +202,8 @@ class StorageScopeMutator : private ReplaceBufferMutator { } private: - StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope, - Map* block_sref_reuse) + StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, ffi::String storage_scope, + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse) {} MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { @@ -222,8 +222,8 @@ class StorageScopeMutator : private ReplaceBufferMutator { void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = - GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, BufferIndexType::kWrite); + Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, + BufferIndexType::kWrite); StorageAlignInvalidFactorError::Check(self->mod, factor); axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); @@ -231,7 +231,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind // Step 1: Get existing or create new annotation value. StorageAlignAnnotation storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, - GetRef(block_ptr)); + ffi::GetRef(block_ptr)); // Step 2: Update the annotation value bool found = false; @@ -250,14 +250,14 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind // Step 3: Replace the block with the new annotation Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); - self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); + self->Replace(block_sref, new_block, {{ffi::GetRef(block_ptr), new_block}}); } void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& storage_scope) { + const ffi::String& storage_scope) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); // Step 1. If `storage_scope` equals the original storage scope of the buffer, just return. if (buffer.scope() == storage_scope) { @@ -274,9 +274,9 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given // storage scope. In the meanwhile, collect the block sref reuse information. - Map block_reuse_map; - Block new_block = StorageScopeMutator::Mutate(GetRef(alloc_site), buffer, storage_scope, - &block_reuse_map); + ffi::Map block_reuse_map; + Block new_block = StorageScopeMutator::Mutate(ffi::GetRef(alloc_site), buffer, + storage_scope, &block_reuse_map); self->Replace(alloc_site_sref, new_block, block_reuse_map); } @@ -294,7 +294,7 @@ class DTypeMutator : private ReplaceBufferMutator { * \return The new block after the mutation */ static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype, - Map* block_sref_reuse) { + ffi::Map* block_sref_reuse) { Buffer new_buffer = WithDType(old_buffer, dtype); DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse); Stmt new_block = mutator.VisitStmt(allocate_site); @@ -303,7 +303,7 @@ class DTypeMutator : private ReplaceBufferMutator { private: DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse), src_dtype_(old_buffer->dtype), tgt_dtype_(dtype) {} @@ -343,11 +343,11 @@ class DTypeMutator : private ReplaceBufferMutator { }; void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& dtype) { + const ffi::String& dtype) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); - DataType target_dtype(StringToDLDataType(dtype)); + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); + DataType target_dtype(ffi::StringToDLDataType(dtype)); // Step 1. If `dtype` equals the original data type, just return. if (buffer->dtype == target_dtype) { @@ -361,9 +361,9 @@ void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_i // Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given // dtype, and insert data type conversions. - Map block_reuse_map; + ffi::Map block_reuse_map; Block new_block = - DTypeMutator::Mutate(GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); + DTypeMutator::Mutate(ffi::GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); self->Replace(alloc_site_sref, new_block, block_reuse_map); } @@ -384,8 +384,9 @@ struct StorageAlignTraits : public UnpackedInstTraits { offset->value); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - Integer axis, Integer factor, Integer offset) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, Integer axis, Integer factor, + Integer offset) { PythonAPICall py("storage_align"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); @@ -409,12 +410,12 @@ struct SetScopeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, - String storage_scope) { + ffi::String storage_scope) { return sch->SetScope(block_rv, buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, ffi::String storage_scope) { PythonAPICall py("set_scope"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); @@ -436,12 +437,12 @@ struct UnsafeSetDTypeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, - String dtype) { + ffi::String dtype) { return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - String dtype) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, ffi::String dtype) { PythonAPICall py("unsafe_set_dtype"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 4828701bb571..2ae32ea66a6a 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -52,18 +52,18 @@ class SubspaceNotDivisibleError : public ScheduleError { scope_loop_(std::move(scope_loop)), inner_block_(std::move(inner_block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The bindings of the inner block can not be blockized."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops " "starting at {1}."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } + ffi::Array LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } private: IRModule mod_; @@ -86,17 +86,17 @@ class SubspaceNotDivisibleError : public ScheduleError { * \param inner_iters The iters of the inner space * \return The result of the subspace division. */ -Array> TrivialSubspaceDivision(const Array& iter_vars, - const Array& bindings, - const PrimExpr& predicate, - const Array& outer_iters, - const Array& inner_iters) { +ffi::Array> TrivialSubspaceDivision( + const ffi::Array& iter_vars, const ffi::Array& bindings, + const PrimExpr& predicate, const ffi::Array& outer_iters, + const ffi::Array& inner_iters) { if (!is_one(predicate)) return {}; - Array> res; + ffi::Array> res; std::unordered_set outer_loop_vars; std::unordered_set inner_loop_vars; - auto make_uses_var = [](const Array& vars) -> std::function { + auto make_uses_var = + [](const ffi::Array& vars) -> std::function { std::unordered_set var_set; var_set.reserve(vars.size()); for (const Var& var : vars) { @@ -154,15 +154,16 @@ Array> TrivialSubspaceDivision(const Array& iter * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \param loop_sref_as_outer Whether loop_sref is divided into outer or inner */ -Array> SubspaceDivide(const BlockRealize& realize, - const StmtSRef& block_sref, // - const StmtSRef& loop_sref, // - std::vector* loops, - arith::Analyzer* analyzer, bool preserve_unit_iters, - bool loop_sref_as_outer = false) { - Array inner_vars; - Array outer_vars; - Map loop_var_domain; +ffi::Array> SubspaceDivide(const BlockRealize& realize, + const StmtSRef& block_sref, // + const StmtSRef& loop_sref, // + std::vector* loops, + arith::Analyzer* analyzer, + bool preserve_unit_iters, + bool loop_sref_as_outer = false) { + ffi::Array inner_vars; + ffi::Array outer_vars; + ffi::Map loop_var_domain; bool inner = true; for (StmtSRefNode* sref = block_sref->parent; // sref && sref->stmt->IsInstance(); // @@ -179,7 +180,7 @@ Array> SubspaceDivide(const BlockRealize& realize, inner = false; } } - Array> result = + ffi::Array> result = arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate, arith::IterMapLevel::Surjective, analyzer, /*simplify_trivial_iterators=*/!preserve_unit_iters); @@ -203,17 +204,18 @@ Array> SubspaceDivide(const BlockRealize& realize, * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return A substitution plan to the iterators in the original inner block. */ -Map DeriveBlockBinding(const Array& iter_vars, // - const Array>& division, // - Array* outer_iter_vars, // - Array* outer_bindings, // - Array* inner_iter_vars, // - Array* inner_bindings, // - bool preserve_unit_iters, bool reuse_outer = false) { +ffi::Map DeriveBlockBinding( + const ffi::Array& iter_vars, // + const ffi::Array>& division, // + ffi::Array* outer_iter_vars, // + ffi::Array* outer_bindings, // + ffi::Array* inner_iter_vars, // + ffi::Array* inner_bindings, // + bool preserve_unit_iters, bool reuse_outer = false) { using arith::IterMapExpr; using arith::IterMapExprNode; using arith::NormalizeIterMapToExpr; - Map block_var_subst; + ffi::Map block_var_subst; ICHECK_EQ(iter_vars.size() + 1, division.size()); arith::Analyzer ana; for (int i = 0, n = iter_vars.size(); i < n; ++i) { @@ -282,15 +284,15 @@ Map DeriveBlockBinding(const Array& iter_vars, * \return The inner block created. */ BlockRealize GenerateInner(bool is_write_reduction, - const Array& iter_vars, // - const Array& iter_values, // - const PrimExpr& predicate, // + const ffi::Array& iter_vars, // + const ffi::Array& iter_values, // + const PrimExpr& predicate, // Block block) { BlockNode* n = block.CopyOnWrite(); n->iter_vars = iter_vars; n->init = std::nullopt; if (is_write_reduction) { - Array reads; + ffi::Array reads; reads.reserve(block->writes.size() + block->reads.size()); reads.insert(reads.end(), block->writes.begin(), block->writes.end()); reads.insert(reads.end(), block->reads.begin(), block->reads.end()); @@ -308,15 +310,15 @@ BlockRealize GenerateInner(bool is_write_reduction, * \return The subtree of the init block and its outer loops. */ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize, - const std::vector& loops, String block_name) { + const std::vector& loops, ffi::String block_name) { const Block& inner_block = inner_realize->block; - Map subst_map; + ffi::Map subst_map; // Step 1: Create new block vars for the block inside the init stmt of outer block // A iter is used in the block if // 1) It is data parallel // 2) It is used in the original init block - Array iter_vars; - Array iter_values; + ffi::Array iter_vars; + ffi::Array iter_values; ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size()); int n = inner_block->iter_vars.size(); iter_vars.reserve(n); @@ -326,7 +328,7 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize const PrimExpr& iter_value = inner_realize->iter_values[i]; if (old_iter_var->iter_type == IterVarType::kDataPar && UsesVar(block_init, old_iter_var->var)) { - ObjectPtr new_iter_var = make_object(*old_iter_var.get()); + ObjectPtr new_iter_var = ffi::make_object(*old_iter_var.get()); new_iter_var->var = new_iter_var->var.copy_with_suffix("_init"); subst_map.Set(old_iter_var->var, new_iter_var->var); iter_vars.push_back(IterVar(new_iter_var)); @@ -354,7 +356,7 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize } } if (is_init_loop) { - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->loop_var = loop->loop_var.copy_with_suffix(""); new_loop->body = std::move(stmt); subst_map.Set(loop->loop_var, new_loop->loop_var); @@ -373,10 +375,10 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize * \param analyzer The analyzer for arithmetic simplification. * \return The substituted stmt. */ -Stmt Substitute(const Stmt& stmt, const Map& sub, - Map* block_sref_reuse, arith::Analyzer* analyzer) { +Stmt Substitute(const Stmt& stmt, const ffi::Map& sub, + ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) { struct Replacer : public StmtExprMutator { - explicit Replacer(const Map& sub, Map* block_sref_reuse, + explicit Replacer(const ffi::Map& sub, ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {} @@ -389,14 +391,14 @@ Stmt Substitute(const Stmt& stmt, const Map& sub, } PrimExpr VisitExpr_(const VarNode* op) final { - if (Optional e = sub_.Get(GetRef(op))) { + if (ffi::Optional e = sub_.Get(ffi::GetRef(op))) { return e.value(); } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const BlockNode* op) final { - Block src = GetRef(op); + Block src = ffi::GetRef(op); Block tgt = Downcast(StmtExprMutator::VisitStmt_(op)); if (!src.same_as(tgt)) { block_sref_reuse_->Set(src, tgt); @@ -404,8 +406,8 @@ Stmt Substitute(const Stmt& stmt, const Map& sub, return tgt; } - const Map& sub_; - Map* block_sref_reuse_; + const ffi::Map& sub_; + ffi::Map* block_sref_reuse_; arith::Analyzer* analyzer_; }; return Replacer(sub, block_sref_reuse, analyzer)(stmt); @@ -417,16 +419,16 @@ Stmt Substitute(const Stmt& stmt, const Map& sub, * \param dom_map The variables to be relaxed * \return The relaxed regions */ -Array EvalSetRegions(const Array& regions, - const Map& dom_map) { - Array results; +ffi::Array EvalSetRegions(const ffi::Array& regions, + const ffi::Map& dom_map) { + ffi::Array results; results.reserve(regions.size()); for (const BufferRegion& buffer_region : regions) { const Buffer& buffer = buffer_region->buffer; - Array relaxed = arith::EvalSet(buffer_region->region, dom_map); + ffi::Array relaxed = arith::EvalSet(buffer_region->region, dom_map); ICHECK_EQ(relaxed.size(), buffer->shape.size()); int ndim = buffer->shape.size(); - Array new_region; + ffi::Array new_region; new_region.reserve(ndim); for (int i = 0; i < ndim; ++i) { new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i]))); @@ -441,23 +443,24 @@ Array EvalSetRegions(const Array& regions, * \param regions The input regions for the union. * \return The union regions */ -Array UnionRegions(const Array& regions) { - typedef std::vector> ranges_t; +ffi::Array UnionRegions(const ffi::Array& regions) { + typedef std::vector> ranges_t; std::unordered_map intset_map; for (const BufferRegion& buffer_region : regions) { const Buffer& buffer = buffer_region->buffer; if (intset_map.find(buffer) == intset_map.end()) { - intset_map[buffer] = {buffer->shape.size(), Array()}; + intset_map[buffer] = {buffer->shape.size(), ffi::Array()}; } - std::vector> dim_range(buffer->shape.size(), Array()); + std::vector> dim_range(buffer->shape.size(), + ffi::Array()); for (size_t dim = 0; dim < buffer->shape.size(); ++dim) { intset_map[buffer][dim].push_back(arith::IntSet::FromRange(buffer_region->region[dim])); } } - Array results; + ffi::Array results; for (const auto& it : intset_map) { const Buffer& buffer = it.first; - Array regions; + ffi::Array regions; for (size_t dim = 0; dim < buffer->shape.size(); ++dim) { const arith::IntSet intset = arith::Union(it.second[dim]); regions.push_back({intset.min(), intset.max() + 1}); @@ -475,7 +478,7 @@ Array UnionRegions(const Array& regions) { */ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { for (const ForNode* loop : loops) { - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->body = std::move(stmt); stmt = For(new_loop); } @@ -483,7 +486,7 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { } BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, - Map* block_sref_reuse, arith::Analyzer* analyzer, + ffi::Map* block_sref_reuse, arith::Analyzer* analyzer, bool preserve_unit_iters) { TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. @@ -492,25 +495,25 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, StmtSRef block_sref = self->stmt2ref.at(block.get()); // Step 2: Derive subspace division std::vector loops; - Array> division = + ffi::Array> division = SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer, preserve_unit_iters); if (division.empty()) { - throw SubspaceNotDivisibleError(self->mod, GetRef(loops.back()), block); + throw SubspaceNotDivisibleError(self->mod, ffi::GetRef(loops.back()), block); } PrimExpr outer_predicate = division.back()[0]->extent; PrimExpr inner_predicate = division.back()[1]->extent; // Step 3. Derive block bindings for both outer and inner block. - Array outer_iter_vars; - Array inner_iter_vars; - Array outer_bindings; - Array inner_bindings; - Map block_var_subst = // + ffi::Array outer_iter_vars; + ffi::Array inner_iter_vars; + ffi::Array outer_bindings; + ffi::Array inner_bindings; + ffi::Map block_var_subst = // DeriveBlockBinding(block->iter_vars, division, // &outer_iter_vars, &outer_bindings, // &inner_iter_vars, &inner_bindings, // preserve_unit_iters); // Step 4: Do var substitution to adjust to the new block bindings - Map inner_iter_dom; + ffi::Map inner_iter_dom; for (const IterVar& iter : inner_iter_vars) { inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom)); analyzer->Bind(iter->var, iter->dom); @@ -549,12 +552,12 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, block_subst->init.defined() // ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, block_subst->name_hint + "_init") - : Optional(std::nullopt))); + : ffi::Optional(std::nullopt))); } StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) { arith::Analyzer analyzer; - Map block_sref_reuse; + ffi::Map block_sref_reuse; BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters); self->Replace(loop_sref, blockized, block_sref_reuse); @@ -566,34 +569,34 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_u return result; } -BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& block_srefs, - const StmtSRef& lca, Map* block_sref_reuse, +BlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array& block_srefs, + const StmtSRef& lca, ffi::Map* block_sref_reuse, bool preserve_unit_iters) { - Array seq_body; + ffi::Array seq_body; PrimExpr outer_predicate{nullptr}; - Array outer_iter_vars{nullptr}; - Array outer_bindings{nullptr}; - Array read_regions; - Array write_regions; + ffi::Array outer_iter_vars{nullptr}; + ffi::Array outer_bindings{nullptr}; + ffi::Array read_regions; + ffi::Array write_regions; std::string outer_block_name = "outer_"; - Map loop_var_subst; + ffi::Map loop_var_subst; arith::Analyzer analyzer; for (const auto& block_sref : block_srefs) { auto block_realize = GetBlockRealize(self, block_sref); auto block = block_realize->block; // Step 1: Derive subspace division std::vector loops; - Array> division = SubspaceDivide(block_realize, block_sref, lca, &loops, - &analyzer, preserve_unit_iters, true); + ffi::Array> division = SubspaceDivide( + block_realize, block_sref, lca, &loops, &analyzer, preserve_unit_iters, true); if (division.empty()) { - throw SubspaceNotDivisibleError(self->mod, GetRef(loops.back()), block); + throw SubspaceNotDivisibleError(self->mod, ffi::GetRef(loops.back()), block); } outer_predicate = division.back()[0]->extent; PrimExpr inner_predicate = division.back()[1]->extent; // Step 2. Derive block bindings for both outer and inner block. - Array inner_iter_vars; - Array inner_bindings; - Map block_var_subst = // + ffi::Array inner_iter_vars; + ffi::Array inner_bindings; + ffi::Map block_var_subst = // DeriveBlockBinding(block->iter_vars, division, // &outer_iter_vars, &outer_bindings, // &inner_iter_vars, &inner_bindings, // @@ -604,7 +607,7 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& bl loop_var_subst.Set(Downcast(outer_bindings[i]), outer_iter_vars[i]->var); } } - Map inner_iter_dom; + ffi::Map inner_iter_dom; for (const IterVar& iter : inner_iter_vars) { Range dom = Substitute(iter->dom, loop_var_subst); inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(dom)); @@ -637,7 +640,7 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& bl block_sref_reuse->Set(block, inner_realize->block); Stmt stmt = inner_realize; for (const ForNode* loop : loops) { - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->body = std::move(stmt); new_loop->extent = Substitute(new_loop->extent, loop_var_subst); stmt = For(new_loop); @@ -654,19 +657,19 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& bl /*writes=*/UnionRegions(write_regions), /*name_hint=*/outer_block_name, /*body=*/SeqStmt(seq_body), - /*init=*/Optional(std::nullopt))); + /*init=*/ffi::Optional(std::nullopt))); } class BlockizeRewriter : public StmtMutator { public: - static Stmt Rewrite(const StmtSRef& lca, const Array& blocks, + static Stmt Rewrite(const StmtSRef& lca, const ffi::Array& blocks, const BlockRealize& blockized) { BlockizeRewriter rewriter(lca, blocks, blockized); - return rewriter(GetRef(lca->stmt)); + return rewriter(ffi::GetRef(lca->stmt)); } private: - explicit BlockizeRewriter(const StmtSRef& lca, const Array& blocks, + explicit BlockizeRewriter(const StmtSRef& lca, const ffi::Array& blocks, const BlockRealize& blockized) : lca_(lca), blocks_(blocks), blockized_(blockized) {} @@ -676,7 +679,7 @@ class BlockizeRewriter : public StmtMutator { int idx_start = -1; int last_found_idx = -1; size_t cur_idx = 0; - Array new_seq; + ffi::Array new_seq; for (const Stmt& it : seq->seq) { target_in_ = false; Stmt stmt = StmtMutator::VisitStmt(it); @@ -700,7 +703,7 @@ class BlockizeRewriter : public StmtMutator { Stmt VisitStmt_(const ForNode* loop) final { if (loop == lca_->stmt) { return For(loop->loop_var, loop->min, loop->extent, loop->kind, RewriteSeq(loop->body), - loop->thread_binding, loop->annotations, loop->span); + loop->thread_binding, loop->annotations, loop->step, loop->span); } return StmtMutator::VisitStmt_(loop); } @@ -717,17 +720,18 @@ class BlockizeRewriter : public StmtMutator { break; } } - return GetRef(block); + return ffi::GetRef(block); } StmtSRef lca_; - Array blocks_; + ffi::Array blocks_; BlockRealize blockized_; bool target_in_ = false; }; -StmtSRef Blockize(ScheduleState self, const Array& blocks, bool preserve_unit_iters) { - Map block_sref_reuse; +StmtSRef Blockize(ScheduleState self, const ffi::Array& blocks, + bool preserve_unit_iters) { + ffi::Map block_sref_reuse; auto lca = GetSRefLowestCommonAncestor(blocks); BlockRealize blockized = BlockizeBlocks(self, blocks, lca, &block_sref_reuse, preserve_unit_iters); @@ -743,17 +747,17 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed BlockRealize block_realize{nullptr}; - Optional old_block = std::nullopt; + ffi::Optional old_block = std::nullopt; if (sref->stmt->IsInstance()) { block_realize = GetBlockRealize(self, sref); old_block = block_realize->block; } else if (sref->stmt->IsInstance()) { arith::Analyzer analyzer; - Map block_sref_reuse; + ffi::Map block_sref_reuse; block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters); } else { LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: " - << GetRef(sref->stmt); + << ffi::GetRef(sref->stmt); throw; } @@ -762,7 +766,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int PrimFunc intrin_impl = DeepCopy(intrin->impl); int index_dtype_bits = -1; - auto f_update_max_dtype_bits_from_region = [&](const Array& buffer_regions) { + auto f_update_max_dtype_bits_from_region = [&](const ffi::Array& buffer_regions) { for (const BufferRegion& buffer_region : buffer_regions) { for (const auto& range : buffer_region->region) { index_dtype_bits = std::max(index_dtype_bits, range->min.dtype().bits()); @@ -794,7 +798,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int ICHECK(comparator.rhs_buffer_map_.count(desc)); impl2cur[impl] = comparator.rhs_buffer_map_[desc]; } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> impl2region; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> impl2region; Block impl_block = Downcast(intrin_impl->body)->block; for (const BufferRegion& read : impl_block->reads) { impl2region.emplace(read->buffer, read->region); @@ -804,16 +808,16 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor // intrin to make them subregions of the buffer in the original IR. - Array match_buffer_regions; + ffi::Array match_buffer_regions; match_buffer_regions.reserve(intrin_impl->params.size()); for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) { const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]); const Buffer& cur = impl2cur.at(impl); - const Array& old_region = impl2region.at(impl); + const ffi::Array& old_region = impl2region.at(impl); const std::vector& indices_base = comparator.buffer_indices_.at(cur); int offset = static_cast(indices_base.size()) - static_cast(old_region.size()); ICHECK(offset >= 0); - Array new_region; + ffi::Array new_region; new_region.reserve(cur->shape.size()); for (int i = 0; i < offset; i++) { PrimExpr min = indices_base[i]; @@ -867,14 +871,14 @@ struct BlockizeTraits : public UnpackedInstTraits { static BlockRV UnpackedApplyToSchedule(Schedule sch, ObjectRef target, Bool preserve_unit_iters) { if (auto loop = target.as()) { return sch->Blockize(loop.value(), preserve_unit_iters.operator bool()); - } else if (auto blocks = target.as>()) { + } else if (auto blocks = target.as>()) { return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); } LOG(FATAL) << "TypeError: expect Loop or list of Blocks, but gets:" << target->GetTypeKey(); } - static String UnpackedAsPython(Array outputs, ObjectRef target, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef target, + Bool preserve_unit_iters) { PythonAPICall py("blockize"); py.Input("target", target); py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); @@ -895,7 +899,7 @@ struct TensorizeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin, + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ffi::String intrin, Bool preserve_unit_iters) { if (auto block = block_or_loop_rv.as()) { sch->Tensorize(block.value(), intrin, preserve_unit_iters.operator bool()); @@ -907,8 +911,8 @@ struct TensorizeTraits : public UnpackedInstTraits { } } - static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_or_loop_rv, + ffi::String intrin, Bool preserve_unit_iters) { PythonAPICall py("tensorize"); py.Input("block_or_loop", block_or_loop_rv); py.Input("tensor_intrin", intrin); diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/tir/schedule/primitive/cache_index.cc index 9ea47def4c31..156f2ae4c59c 100644 --- a/src/tir/schedule/primitive/cache_index.cc +++ b/src/tir/schedule/primitive/cache_index.cc @@ -38,17 +38,17 @@ struct IndexInfo { /*! \brief The expr to be precomputed */ std::vector index_exprs; /*! \brief The range of the loop vars relating to index computation */ - Map range_map; + ffi::Map range_map; /*! \brief The binding table of the block var and the loop var */ - Map var_binding; + ffi::Map var_binding; /*! \brief The block var of the target block */ - std::vector> origin_block_vars; + std::vector> origin_block_vars; /*! \brief The index to insert the cache stage. */ size_t loc_pos; /*! \brief The cache stage to be inserted. */ Stmt cache_stage; /*! \brief The map used for ScheduleStateNode::Replace. */ - Map block_reuse; + ffi::Map block_reuse; }; /*! @@ -79,7 +79,7 @@ class IndexInfoCollector : public StmtExprVisitor { static void Collect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, IndexInfo* info) { IndexInfoCollector collector(self, block_sref, scope_sref, info->cse_thresh); - collector(GetRef(scope_sref->stmt)); + collector(ffi::GetRef(scope_sref->stmt)); info->loc_pos = collector.loc_pos_; info->index_exprs = collector.exprs_; info->range_map = collector.range_map_; @@ -150,7 +150,7 @@ class IndexInfoCollector : public StmtExprVisitor { // Analyze sub expr candidates ComputationTable table_syntactic_comp_done_by_stmt = - ComputationsDoneBy::GetComputationsDoneBy(GetRef(store), IsEligibleComputation, + ComputationsDoneBy::GetComputationsDoneBy(ffi::GetRef(store), IsEligibleComputation, [](const PrimExpr& expr) { return true; }); std::vector> semantic_comp_done_by_stmt = SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, true); @@ -211,7 +211,7 @@ class IndexInfoCollector : public StmtExprVisitor { /*! \brief The flag indicating the right scope to update seq pos */ bool update_seq_pos_{false}; /*! \brief Record the ranges of iter vars */ - Map range_map_; + ffi::Map range_map_; }; /*! @@ -220,9 +220,9 @@ class IndexInfoCollector : public StmtExprVisitor { * \param storage_scope The storage scope of the cached buffer (only used in naming here) * \returns A block indicating the body of the loop nesting. */ -Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { - Array blocks; - Array bodies; +ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& storage_scope) { + ffi::Array blocks; + ffi::Array bodies; bodies.reserve(info->index_exprs.size()); info->cache_buffer.reserve(info->index_exprs.size()); @@ -235,7 +235,7 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { PostOrderVisit(index_expr, [&info, &expr_index](const ObjectRef& node) { if (node->IsInstance()) { Var iter_var = Downcast(node); - const Array& origin_block_var = info->origin_block_vars[expr_index]; + const ffi::Array& origin_block_var = info->origin_block_vars[expr_index]; auto find_result = std::find_if(origin_block_var.begin(), origin_block_var.end(), [&](Var it) { return it.get() == iter_var.get(); }); if (find_result == origin_block_var.end()) { @@ -262,7 +262,7 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { DataType data_type = index_expr.dtype(); Var index_buffer_var("index_var_" + std::to_string(expr_index), PointerType(PrimType(data_type), storage_scope)); - Array buffer_shape; + ffi::Array buffer_shape; for (const Var& it : info->origin_block_vars[expr_index]) { buffer_shape.push_back( arith::EvalSet(info->var_binding.at(it), arith::AsIntSet(info->range_map)).max() + 1); @@ -272,7 +272,7 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { // Create loop vars and block vars' binding_value std::vector loop_vars; - Map replace_table; + ffi::Map replace_table; for (const Var& it : iter_vars) { DataType data_type = DetermineDatatype(arith::IntSet::FromRange(info->range_map.at(it))); Var loop_var("ax" + std::to_string(replace_table.size()), data_type); @@ -285,12 +285,12 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { iter_values.push_back(Substitute(info->var_binding.at(it), replace_table)); } // block variables - Array block_vars; + ffi::Array block_vars; // block access region for write buffers Region access_region; // indices used in block body - Array access_indices; - Map block_var_map; + ffi::Array access_indices; + ffi::Map block_var_map; // Create block vars, block's accessed region and accessing indices for (size_t i = 0; i < info->origin_block_vars[expr_index].size(); i++) { const Var& block_var = info->origin_block_vars[expr_index][i]; @@ -348,15 +348,15 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { */ Stmt InsertIndexStage(const Stmt& stmt, int pos, const Stmt& stage) { if (const auto* seq_stmt = stmt.as()) { - ObjectPtr result = make_object(*seq_stmt); + ObjectPtr result = ffi::make_object(*seq_stmt); result->seq.insert(result->seq.begin() + pos, stage); return SeqStmt(result); } if (pos == 0) { - return SeqStmt::Flatten>({stage, stmt}); + return SeqStmt::Flatten>({stage, stmt}); } ICHECK_EQ(pos, 1); - return SeqStmt::Flatten>({stmt, stage}); + return SeqStmt::Flatten>({stmt, stage}); } /*! \brief Mutator for CacheIndex. */ @@ -370,14 +370,14 @@ class CacheIndexRewriter : public StmtExprMutator { */ static Stmt Rewrite(const StmtSRef& scope_sref, IndexInfo* info) { CacheIndexRewriter rewriter(scope_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: explicit CacheIndexRewriter(const StmtSRef& scope_sref, IndexInfo* info) : scope_sref_(scope_sref), info_(info) { cache_indices_.reserve(info_->origin_block_vars.size()); - for (const Array& group_it : info_->origin_block_vars) { + for (const ffi::Array& group_it : info_->origin_block_vars) { cache_indices_.push_back({}); for (const Var& it : group_it) { cache_indices_.back().push_back(it); @@ -386,7 +386,7 @@ class CacheIndexRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); // Mutate the body visiting_target_block = static_cast(block == info_->target_block->stmt); Block stmt = Downcast(StmtMutator::VisitStmt_(block)); @@ -395,7 +395,7 @@ class CacheIndexRewriter : public StmtExprMutator { // Check if it is the block corresponding to the parent scope if (block == scope_sref_->stmt) { // If so, put buffer allocation and insert cache stages on the parent scope - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertIndexStage(n->body, info_->loc_pos, info_->cache_stage); for (const Buffer& it : info_->cache_buffer) { n->alloc_buffers.push_back(it); @@ -431,13 +431,13 @@ class CacheIndexRewriter : public StmtExprMutator { /*! \brief The info for inserting cache stage */ IndexInfo* info_; /*! \brief The indices for the cache buffer */ - std::vector> cache_indices_; + std::vector> cache_indices_; /*! \brief Indicating whether cache stage is inserted, only do index replacement afterwards*/ bool visiting_target_block{false}; }; -Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, - const String& storage_scope, int cse_thresh) { +ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, + const ffi::String& storage_scope, int cse_thresh) { /*! * Check: * - The index is in the array of block reading region @@ -460,14 +460,14 @@ Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, // Step 2. Create cache stages and rewrite the stmt. BlockRealize realize = GetBlockRealize(self, block_sref); info.var_binding = GetBindings(realize); - Array cache_stages = MakeIndexCacheStage(&info, storage_scope); + ffi::Array cache_stages = MakeIndexCacheStage(&info, storage_scope); Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); bool old_stage_pipeline = self->block_info[block_sref].stage_pipeline; // Step 3. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); - Array result_block_srefs; + ffi::Array result_block_srefs; for (const Block& it : cache_stages) { StmtSRef result_block_sref = self->stmt2ref.at(it.get()); result_block_srefs.push_back(result_block_sref); @@ -478,7 +478,7 @@ Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, affine_binding = true; } else { arith::Analyzer analyzer; - StmtSRef parent_sref = GetRef(result_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(result_block_sref->parent); affine_binding = IsAffineBinding(/*realize=*/GetBlockRealize(self, result_block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), /*analyzer=*/&analyzer); @@ -503,13 +503,14 @@ struct CacheIndexTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, String storage_scope, - Integer cse_thresh) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, + ffi::String storage_scope, + Integer cse_thresh) { return sch->CacheIndex(block, storage_scope, cse_thresh->value); } - static String UnpackedAsPython(Array outputs, String block, String storage_scope, - Integer cse_thresh) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::String storage_scope, Integer cse_thresh) { PythonAPICall py("cache_index"); py.Input("block", block); py.Input("storage_scope", storage_scope); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 38cafbe1515e..9a883c11359b 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -30,35 +30,35 @@ namespace tir { class NotSingleWriteBlock : public ScheduleError { public: - explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array write_blocks) + explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, ffi::Array write_blocks) : mod_(std::move(mod)), buffer_(std::move(buffer)) { ICHECK_GT(write_blocks.size(), 1); write_blocks_.reserve(write_blocks.size()); for (const StmtSRef& block_sref : write_blocks) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - write_blocks_.push_back(GetRef(block)); + write_blocks_.push_back(ffi::GetRef(block)); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer is allowed to be written by single block."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { size_t k = write_blocks_.size(); return "The buffer " + buffer_->name + " is expected to be written by single block, but got " + std::to_string(k) + " blocks who write it."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { return {write_blocks_.begin(), write_blocks_.end()}; } private: IRModule mod_; Buffer buffer_; - Array write_blocks_; + ffi::Array write_blocks_; }; /******** Helper Functions/Classes ********/ @@ -70,7 +70,7 @@ struct CacheStageInfo { /*! \brief The buffer to be written. */ Buffer write_buffer; /*! \brief The buffer allocation to be inserted into the block signature. */ - Optional alloc; + ffi::Optional alloc; /*! \brief The AST node whose body is where the cache stage should be inserted. */ StmtSRef loc_sref; /*! \brief The index to insert the cache_read/cache_write stage. */ @@ -78,7 +78,7 @@ struct CacheStageInfo { /*! \brief The cache_read/cache_write stage to be inserted. */ Stmt cache_stage; /*! \brief The map used for ScheduleStateNode::Replace. */ - Map block_reuse; + ffi::Map block_reuse; /*! \brief A set of blocks that will consume the new cache. */ std::unordered_set consumer_blocks; /*! \brief cache region for the buffer to be cached */ @@ -86,9 +86,9 @@ struct CacheStageInfo { }; /*! \brief Return the buffer region related with the buffer */ -Optional GetBufferRegionFromBuffer(const Array& buffer_regions, - const Buffer& buffer) { - Optional res = std::nullopt; +ffi::Optional GetBufferRegionFromBuffer( + const ffi::Array& buffer_regions, const Buffer& buffer) { + ffi::Optional res = std::nullopt; for (const auto& region : buffer_regions) { if (region->buffer.same_as(buffer)) { ICHECK(!res.defined()); @@ -100,13 +100,13 @@ Optional GetBufferRegionFromBuffer(const Array& buff struct ReindexCacheStageInfo : CacheStageInfo { /* Indices used to access the allocated cache buffer. */ - Array indices; + ffi::Array indices; /* Touched loop variable related information. */ - Array loop_vars; - Array loop_ranges; + ffi::Array loop_vars; + ffi::Array loop_ranges; /* Touched block variable related information. */ - Array block_iter_vars; - Array block_iter_values; + ffi::Array block_iter_vars; + ffi::Array block_iter_values; }; /* \brief The schedule error that accessed buffer region is not a single point for @@ -119,26 +119,26 @@ class NotSinglePointAccess : public ScheduleError { primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer region accessed inside the block is not a single point."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The buffer region " << cache_region_ << " accessed inside block {0} is not a single point, which violates" << " the prerequisite of " << primitive_name_ << " primitive."; - return String(os.str()); + return ffi::String(os.str()); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; Block block_; BufferRegion cache_region_; - String primitive_name_; + ffi::String primitive_name_; }; /*! @@ -151,15 +151,15 @@ class NotSinglePointAccess : public ScheduleError { */ template Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageInfo* info, - const String& storage_scope) { + const ffi::String& storage_scope) { // loop variables std::vector loop_vars; // block variables - Array block_vars; + ffi::Array block_vars; // bindings in block realize std::vector iter_values; // Create loop vars and block vars' binding_value - Map var_map; + ffi::Map var_map; for (size_t i = 0; i < info->loop_vars.size(); ++i) { Var original_var = info->loop_vars[i]; Var loop_var(original_var->name_hint, original_var.dtype()); @@ -180,15 +180,15 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI // block access region for read/write buffers Region read_access_region, write_access_region; - Array read_access_indices, write_access_indices; + ffi::Array read_access_indices, write_access_indices; // Compute read/write region and read/write access indices. - Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; + ffi::Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; Region& old_region = (is_cache_read) ? read_access_region : write_access_region; for (const Range& range : cache_region->region) { old_indices.push_back(Substitute(range->min, var_map)); old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); } - Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; + ffi::Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; Region& new_region = (is_cache_read) ? write_access_region : read_access_region; for (const PrimExpr& idx : info->indices) { new_indices.push_back(Substitute((idx), var_map)); @@ -237,7 +237,7 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI * \returns A block indicating the body of the loop nesting. */ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, - const String& storage_scope, bool cache_full_region = true) { + const ffi::String& storage_scope, bool cache_full_region = true) { // loop variables std::vector loop_vars; // bindings in block realize @@ -249,13 +249,13 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, iter_values.push_back(cache_full_region ? (axis_range->min + loop_var) : loop_var); } // block variables - Array block_vars; + ffi::Array block_vars; // block access region for read/write buffers Region read_access_region; Region write_access_region; // indices used in block body - Array read_access_indices; - Array write_access_indices; + ffi::Array read_access_indices; + ffi::Array write_access_indices; // Create block vars, block's accessed region and accessing indices for (int i = 0; i < static_cast(cache_region->buffer->shape.size()); ++i) { Range axis_range = cache_region->region[i]; @@ -344,14 +344,14 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, */ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, const std::unordered_set& covered, - const Array& original_indices, int buffer_index, + const ffi::Array& original_indices, int buffer_index, BufferIndexType buffer_index_type) { // iters of the reindex block - Array new_block_iters; + ffi::Array new_block_iters; // the substitution map from the original block iter to the iters of the reindex block std::unordered_map block_var_replace_map; // indices to access the reindex buffer and the target buffer - Array reindex_indices, target_indices; + ffi::Array reindex_indices, target_indices; // Step 1: Create block iters, access regions of the reindex block, and accessing indices to the // reindex buffer. @@ -383,8 +383,8 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, // The src and the dst region and indices of the data copy Region src_region{nullptr}; Region dst_region{nullptr}; - Array src_indices{nullptr}; - Array dst_indices{nullptr}; + ffi::Array src_indices{nullptr}; + ffi::Array dst_indices{nullptr}; if (buffer_index_type == BufferIndexType::kWrite) { src_indices = reindex_indices; @@ -444,7 +444,7 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) return true; } arith::Analyzer analyzer; - StmtSRef parent_sref = GetRef(block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(block_sref->parent); return IsAffineBinding(/*realize=*/GetBlockRealize(self, block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), /*analyzer=*/&analyzer); @@ -477,7 +477,7 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { } if (const auto* seq_stmt = body.as()) { - Array seq = seq_stmt->seq; + ffi::Array seq = seq_stmt->seq; ICHECK_LE(pos, seq.size()) << "Cannot insert at position " << pos << " into sequence of length " << seq.size(); seq.insert(seq.begin() + pos, stage); @@ -506,14 +506,14 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { * or `std::nullopt` if no block writes it in the scope. * \throw NotSingleWriteBlock if there are more than one interested block. */ -Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, - const Buffer& buffer) { +ffi::Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, + const Buffer& buffer) { BlockScope scope = self->GetBlockScope(scope_sref); auto it = scope->buffer_writers.find(buffer); if (it == scope->buffer_writers.end()) { return std::nullopt; } else { - const Array& block_srefs = it->second; + const ffi::Array& block_srefs = it->second; ICHECK(!block_srefs.empty()); if (block_srefs.size() > 1) { throw NotSingleWriteBlock(self->mod, buffer, block_srefs); @@ -570,11 +570,11 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_re const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive, const StmtSRef& dom_high_exclusive) { BlockRealize realize = GetBlockRealize(self, block_sref); - Map binding = GetBindings(realize); + ffi::Map binding = GetBindings(realize); const Buffer& buffer = buffer_region->buffer; arith::Analyzer analyzer; BufferRegion subst_region = BufferRegion(buffer, Substitute(buffer_region->region, binding)); - Array int_sets = AnalyzeRegionUpperBound( + ffi::Array int_sets = AnalyzeRegionUpperBound( /*region=*/subst_region, /*predicate=*/realize->predicate, /*dom_low_inclusive=*/dom_low_inclusive, @@ -632,7 +632,7 @@ class CacheLocDetector : public StmtVisitor { if (!related_blocks.empty()) { CacheLocDetector detector(self, block_sref, scope_sref, related_blocks); - detector(GetRef(scope_sref->stmt)); + detector(ffi::GetRef(scope_sref->stmt)); info->loc_sref = detector.loc_sref_; info->loc_pos = detector.loc_pos_; } else { @@ -761,7 +761,7 @@ class CacheInplaceLocDetector : public StmtVisitor { static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { CacheInplaceLocDetector detector(self, block_sref, scope_sref); - detector(GetRef(scope_sref->stmt)); + detector(ffi::GetRef(scope_sref->stmt)); info->loc_sref = detector.loc_sref_; info->loc_pos = detector.loc_pos_; } @@ -851,7 +851,7 @@ class CacheReadRewriter : public StmtExprMutator { static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info, bool cache_full_region = true) { CacheReadRewriter rewriter(scope_sref, info, cache_full_region); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -868,12 +868,12 @@ class CacheReadRewriter : public StmtExprMutator { return ret; }; - update_access_regions = [this, update_region](Array regions) { + update_access_regions = [this, update_region](ffi::Array regions) { if (cache_full_region_) { return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); } - Array ret; + ffi::Array ret; for (const BufferRegion& region : regions) { if (region->buffer.same_as(info_->read_buffer)) { ret.push_back(BufferRegion(info_->write_buffer, @@ -884,12 +884,12 @@ class CacheReadRewriter : public StmtExprMutator { } return ret; }; - update_match_buffers = [this, update_region](Array match_buffers) { + update_match_buffers = [this, update_region](ffi::Array match_buffers) { if (cache_full_region_) { return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); } - Array ret; + ffi::Array ret; for (const MatchBufferRegion& match_buffer : match_buffers) { if (match_buffer->source->buffer.same_as(info_->read_buffer)) { ret.push_back(MatchBufferRegion( @@ -909,7 +909,7 @@ class CacheReadRewriter : public StmtExprMutator { // Check the insertion point if (loop == info_->loc_sref->stmt) { // Insert cache stage into the loop if it is the right place - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Stmt(n); } @@ -917,14 +917,14 @@ class CacheReadRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) override { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); // Check if this block is one of the specified consumers. // If no consumer blocks are specified, all blocks should be considered consumers. bool is_consumer = info_->consumer_blocks.empty(); // Otherwise check if this is one of the specified blocks. for (StmtSRef consumer_sref : info_->consumer_blocks) { const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = GetRef(consumer_node); + Block consumer_block = ffi::GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { is_consumer = true; } @@ -941,14 +941,14 @@ class CacheReadRewriter : public StmtExprMutator { // Check the insertion point if (block == info_->loc_sref->stmt) { // Insert cache stage into the block if it is the right place - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Block(n); } // Check if it is the block corresponding to the parent scope if (block == scope_sref_->stmt) { // If so, put buffer allocation on the parent scope - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); @@ -959,10 +959,10 @@ class CacheReadRewriter : public StmtExprMutator { // Only make this change if the block is one of the specified consumers. if (is_consumer) { // Use the updated block stmt - Array reads = update_access_regions(stmt->reads); - Array match_buffers = update_match_buffers(stmt->match_buffers); + ffi::Array reads = update_access_regions(stmt->reads); + ffi::Array match_buffers = update_match_buffers(stmt->match_buffers); if (!reads.same_as(stmt->reads) || !match_buffers.same_as(stmt->match_buffers)) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); stmt = Block(n); @@ -973,7 +973,7 @@ class CacheReadRewriter : public StmtExprMutator { return stmt; } - Array RewriteIndices(const Array& indices) { + ffi::Array RewriteIndices(const ffi::Array& indices) { std::vector ret; for (size_t i = 0; i < indices.size(); ++i) { ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); @@ -983,7 +983,7 @@ class CacheReadRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->write_buffer; if (!cache_full_region_) { n->indices = RewriteIndices(load->indices); @@ -997,7 +997,7 @@ class CacheReadRewriter : public StmtExprMutator { if (op == info_->read_buffer->data.get()) { return info_->write_buffer->data; } - return GetRef(op); + return ffi::GetRef(op); } private: @@ -1008,9 +1008,9 @@ class CacheReadRewriter : public StmtExprMutator { /*! \brief Whether the most recently visited block is a specified consumer. */ bool current_block_consumes; /*! \brief function to update read/write region of block being cache read.*/ - std::function(Array)> update_access_regions; + std::function(ffi::Array)> update_access_regions; /*! \brief function to update match buffers of block being cache read.*/ - std::function(Array)> update_match_buffers; + std::function(ffi::Array)> update_match_buffers; /*! * \brief A boolean indicating if the cache buffer is allocated with * full region or compact region. @@ -1033,18 +1033,18 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { */ static Stmt Rewrite(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) { ReindexCacheReadRewriter rewriter(scope_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: explicit ReindexCacheReadRewriter(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) : CacheReadRewriter(scope_sref, info) { new_indices_ = info->indices; - update_access_regions = [&](Array reads) { - Array new_reads; + update_access_regions = [&](ffi::Array reads) { + ffi::Array new_reads; for (const BufferRegion& buf_region : reads) { if (buf_region->buffer.same_as(info_->read_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1055,12 +1055,12 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { } return new_reads; }; - update_match_buffers = [&](const Array match_buffers) { - Array new_match_buffers; + update_match_buffers = [&](const ffi::Array match_buffers) { + ffi::Array new_match_buffers; for (const MatchBufferRegion& match_buffer_region : match_buffers) { BufferRegion source = match_buffer_region->source; if (source->buffer.same_as(info_->read_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1076,7 +1076,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { PrimExpr VisitExpr_(const BufferLoadNode* load) final { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->write_buffer; n->indices = new_indices_; return PrimExpr(n); @@ -1085,7 +1085,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { } /*! \brief The indices to use for new buffer. */ - Array new_indices_; + ffi::Array new_indices_; }; class ReindexCacheWriteRewriter; @@ -1105,7 +1105,7 @@ class CacheWriteRewriter : public StmtExprMutator { static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, CacheStageInfo* info, bool cache_full_region = true) { CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info, cache_full_region); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -1125,12 +1125,12 @@ class CacheWriteRewriter : public StmtExprMutator { return ret; }; - update_access_regions = [this, update_region](Array regions) { + update_access_regions = [this, update_region](ffi::Array regions) { if (cache_full_region_) { return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); } - Array ret; + ffi::Array ret; for (const BufferRegion& region : regions) { if (region->buffer.same_as(info_->write_buffer)) { ret.push_back(BufferRegion(info_->read_buffer, @@ -1141,12 +1141,12 @@ class CacheWriteRewriter : public StmtExprMutator { } return ret; }; - update_match_buffers = [this, update_region](Array match_buffers) { + update_match_buffers = [this, update_region](ffi::Array match_buffers) { if (cache_full_region_) { return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); } - Array ret; + ffi::Array ret; for (const MatchBufferRegion& match_buffer : match_buffers) { if (match_buffer->source->buffer.same_as(info_->write_buffer)) { ret.push_back(MatchBufferRegion( @@ -1166,7 +1166,7 @@ class CacheWriteRewriter : public StmtExprMutator { // Check the insertion point if (loop == info_->loc_sref->stmt) { // Insert cache stage into the loop if it is the right place - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Stmt(n); } @@ -1174,17 +1174,17 @@ class CacheWriteRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) override { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); // Check if this block is one of the specified cache consumers. // update the read buffer to the cache. for (StmtSRef consumer_sref : info_->consumer_blocks) { const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = GetRef(consumer_node); + Block consumer_block = ffi::GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { - Array writes = update_access_regions(block->writes); - Array reads = update_access_regions(block->reads); - Array match_buffers = update_match_buffers(block->match_buffers); + ffi::Array writes = update_access_regions(block->writes); + ffi::Array reads = update_access_regions(block->reads); + ffi::Array match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { auto n = CopyOnWrite(block); @@ -1213,13 +1213,13 @@ class CacheWriteRewriter : public StmtExprMutator { // Find the insertion point if (block == info_->loc_sref->stmt) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Block(n); } // Put buffer allocation on the parent scope if (block == scope_sref_->stmt) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); @@ -1232,7 +1232,7 @@ class CacheWriteRewriter : public StmtExprMutator { auto match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); @@ -1243,7 +1243,7 @@ class CacheWriteRewriter : public StmtExprMutator { return stmt; } - Array RewriteIndices(const Array& indices) { + ffi::Array RewriteIndices(const ffi::Array& indices) { std::vector ret; for (size_t i = 0; i < indices.size(); ++i) { ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); @@ -1267,7 +1267,7 @@ class CacheWriteRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->write_buffer)) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->read_buffer; if (!cache_full_region_) { n->indices = RewriteIndices(n->indices); @@ -1281,7 +1281,7 @@ class CacheWriteRewriter : public StmtExprMutator { if (op == info_->write_buffer->data.get()) { return info_->read_buffer->data; } - return GetRef(op); + return ffi::GetRef(op); } private: @@ -1294,9 +1294,9 @@ class CacheWriteRewriter : public StmtExprMutator { /*! \brief Whether the current node is under the given block. */ bool under_writer_block_{false}; /*! \brief function to update read/write region of block being cache write.*/ - std::function(Array)> update_access_regions; + std::function(ffi::Array)> update_access_regions; /*! \brief function to update match buffers of block being cache write.*/ - std::function(Array)> update_match_buffers; + std::function(ffi::Array)> update_match_buffers; /*! * \brief A boolean indicating if the cache buffer is allocated with * full region or compact region. @@ -1321,7 +1321,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, ReindexCacheStageInfo* info) { ReindexCacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -1329,11 +1329,11 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { ReindexCacheStageInfo* info) : CacheWriteRewriter(scope_sref, writer_block_sref, info) { new_indices_ = info->indices; - update_access_regions = [&](Array reads) { - Array new_reads; + update_access_regions = [&](ffi::Array reads) { + ffi::Array new_reads; for (const BufferRegion& buf_region : reads) { if (buf_region->buffer.same_as(info_->write_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1344,12 +1344,12 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { } return new_reads; }; - update_match_buffers = [&](const Array match_buffers) { - Array new_match_buffers; + update_match_buffers = [&](const ffi::Array match_buffers) { + ffi::Array new_match_buffers; for (const MatchBufferRegion& match_buffer_region : match_buffers) { BufferRegion source = match_buffer_region->source; if (source->buffer.same_as(info_->write_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1377,7 +1377,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { PrimExpr VisitExpr_(const BufferLoadNode* load) final { if (load->buffer.same_as(info_->write_buffer)) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->read_buffer; n->indices = new_indices_; return PrimExpr(n); @@ -1386,7 +1386,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { } /*! \brief The indices to use for new buffer. */ - Array new_indices_; + ffi::Array new_indices_; }; /*! @@ -1396,10 +1396,10 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { * \param covered Set of block iter vars covered by the buffer access indices * \return The new buffer with target shape. */ -Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_iters, +Buffer CreateReindexBuffer(const Buffer& buffer, const ffi::Array& block_iters, const std::unordered_set& covered) { - ObjectPtr new_buffer = make_object(*buffer.get()); - ObjectPtr new_var = make_object(*buffer->data.get()); + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); + ObjectPtr new_var = ffi::make_object(*buffer->data.get()); std::vector new_shape; std::vector new_strides; for (const auto& iter : block_iters) { @@ -1421,14 +1421,16 @@ Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_ite class NotLeafBlockError : public ScheduleError { public: NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target block is not a leaf block."; } - String DetailRenderTemplate() const final { return "The target block {0} is not a leaf block."; } + ffi::String DetailRenderTemplate() const final { + return "The target block {0} is not a leaf block."; + } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; @@ -1444,12 +1446,12 @@ class InvalidBufferAccessError : public ScheduleError { InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind kind) : mod_(std::move(mod)), buffer_(std::move(buffer)), block_(std::move(block)), kind_(kind) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target buffer should be accessed via BufferLoad or BufferStore. The " "indices should be the same if there are multiple accesses to the target buffer."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The target buffer " << buffer_->name << " should be accessed in the leaf block {0} via BufferLoad or BufferStore. The indices " @@ -1464,7 +1466,7 @@ class InvalidBufferAccessError : public ScheduleError { return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1476,7 +1478,8 @@ class InvalidBufferAccessError : public ScheduleError { /*! \brief Collect the related Load/Store to reindex */ class ReIndexCollector : public StmtExprVisitor { public: - static Array Collect(const IRModule& mod, const Buffer& buffer, const Block& block) { + static ffi::Array Collect(const IRModule& mod, const Buffer& buffer, + const Block& block) { ReIndexCollector collector(mod, buffer, block); collector(block->body); if (!collector.buffer_access_indices_.defined()) { @@ -1509,7 +1512,7 @@ class ReIndexCollector : public StmtExprVisitor { } } - void CheckAndUpdateBufferAccessIndices(const Array indices) { + void CheckAndUpdateBufferAccessIndices(const ffi::Array indices) { if (!buffer_access_indices_.defined()) { buffer_access_indices_ = indices; return; @@ -1534,7 +1537,7 @@ class ReIndexCollector : public StmtExprVisitor { /*! \brief The block to visit */ Block block_; /*! \brief The indices of buffer acess to rewrite */ - Optional> buffer_access_indices_; + ffi::Optional> buffer_access_indices_; }; /*! \brief Mutator of ReIndex */ @@ -1543,7 +1546,7 @@ class ReIndexRewriter : public StmtExprMutator { static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, CacheStageInfo* info, const std::unordered_set& covered) { ReIndexRewriter rewriter(block_sref, info, covered); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -1555,12 +1558,12 @@ class ReIndexRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); if (is_scope_) { is_scope_ = false; Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); // Insert cache stage into the loop - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); @@ -1587,7 +1590,7 @@ class ReIndexRewriter : public StmtExprMutator { BufferRegion{new_buffer_, region_}); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); @@ -1632,7 +1635,7 @@ class ReIndexRewriter : public StmtExprMutator { /*! \brief The reindex buffer */ Buffer new_buffer_; /*! \brief The new indices */ - Array indices_; + ffi::Array indices_; /*! \brief The new region */ Region region_; }; @@ -1642,15 +1645,15 @@ void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer rea public: explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The scope root's region cover is not complete."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return R"(The scope {0} 's region cover is not complete. The region cover property require to hold for every of its child blocks )"; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; @@ -1661,7 +1664,7 @@ The region cover property require to hold for every of its child blocks if (region->buffer.same_as(read_buffer)) { if (!self->block_info.at(child_block_sref).region_cover) { const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); - throw NotRegionCoverError(self->mod, GetRef(block)); + throw NotRegionCoverError(self->mod, ffi::GetRef(block)); } } } @@ -1671,7 +1674,7 @@ The region cover property require to hold for every of its child blocks /******** Implementation ********/ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, const Array consumer_blocks) { + const ffi::String& storage_scope, const ffi::Array consumer_blocks) { /*! * Check: * - The index is in the array of block reading region @@ -1688,8 +1691,8 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 1. Check index, getting the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer read_buffer = - GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); + Buffer read_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, + BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check required region cover for cache_read CheckRegionCover(self, scope_sref, read_buffer); @@ -1709,13 +1712,14 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 3. Update cache stage info. BufferRegion cache_region{nullptr}; - if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + if (ffi::Optional _write_block_sref = + GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); // Find the producing region BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); - StmtSRef parent_sref = GetRef(write_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(write_block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); @@ -1724,7 +1728,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Case 2. The buffer is the input block for the scope. info.loc_sref = scope_sref; info.loc_pos = 0; - if (Optional region = + if (ffi::Optional region = GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) { cache_region = region.value(); } else { @@ -1764,7 +1768,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff } StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, const Array consumer_blocks) { + const ffi::String& storage_scope, const ffi::Array consumer_blocks) { /*! * Check: * - The index is in the array of block reading region @@ -1781,8 +1785,8 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 1. Checking index, getting the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer write_buffer = - GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kWrite); + Buffer write_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), write_buffer_index, + BufferIndexType::kWrite); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Step 2. Creating CacheStageInfo @@ -1803,7 +1807,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 4. Find the producing region and insert position BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value(); - StmtSRef parent_sref = GetRef(block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, block_sref, scope_sref, &info); BufferRegion cache_region = @@ -1841,12 +1845,12 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } -Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { +ffi::Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { std::vector result; for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); parent = parent->parent) { if (parent == top_sref.get()) break; - result.push_back(GetRef(parent)); + result.push_back(ffi::GetRef(parent)); } return {result.rbegin(), result.rend()}; } @@ -1858,8 +1862,9 @@ Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& t class ReindexCacheReadWriteNotMatchError : public ScheduleError { public: ReindexCacheReadWriteNotMatchError(IRModule mod, Block block, Var var, - Array old_indices, Array new_indices, - bool is_cache_read, bool appears_in_old) + ffi::Array old_indices, + ffi::Array new_indices, bool is_cache_read, + bool appears_in_old) : mod_(std::move(mod)), block_(std::move(block)), var_(std::move(var)) { primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; if (appears_in_old) { @@ -1870,26 +1875,26 @@ class ReindexCacheReadWriteNotMatchError : public ScheduleError { other_indices_ = std::move(old_indices); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: the block itervars appeared in lhs and rhs of reindex cache stage do " "not match."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::stringstream s; s << "Error when applying " << primitive_name_ << " on block {0}, the block itervar " << var_ << " appears in " << appears_indices_ << ", but not in " << other_indices_ << "."; - return String(s.str()); + return ffi::String(s.str()); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - String primitive_name_; + ffi::String primitive_name_; Block block_; Var var_; - Array appears_indices_; - Array other_indices_; + ffi::Array appears_indices_; + ffi::Array other_indices_; }; /*! @@ -1908,21 +1913,21 @@ class ReindexCacheReadWriteNotMatchError : public ScheduleError { template void CollectReindexCacheStageInfoAndCreateBuffer( ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& block_sref, - const String& storage_scope, const IndexMap& index_map, const Block& block, + const ffi::String& storage_scope, const IndexMap& index_map, const Block& block, const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { arith::Analyzer analyzer; - Array block_iter_vars, block_shape; + ffi::Array block_iter_vars, block_shape; for (const IterVar& iter_var : block->iter_vars) { block_iter_vars.push_back(iter_var); block_shape.push_back(iter_var->dom->extent); } - Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); - Array new_shape = index_map->MapShape(block_shape, &analyzer); + ffi::Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ffi::Array new_shape = index_map->MapShape(block_shape, &analyzer); info->indices = new_indices; // Step 5. Update CacheTouchedInfo VarUseDefAnalyzer collector_old(/*defined_vars=*/{}); - Array old_indices; + ffi::Array old_indices; for (const Range& range : cache_region->region) { collector_old(range->min); old_indices.push_back(range->min); @@ -1959,8 +1964,8 @@ void CollectReindexCacheStageInfoAndCreateBuffer( } // Create new buffer - ObjectPtr new_buffer = make_object(*old_buffer.get()); - ObjectPtr new_var = make_object(*old_buffer->data.get()); + ObjectPtr new_buffer = ffi::make_object(*old_buffer.get()); + ObjectPtr new_var = ffi::make_object(*old_buffer->data.get()); const auto* ptr_type = TVM_TYPE_AS(old_buffer->data->type_annotation, PointerTypeNode); new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); @@ -1992,7 +1997,7 @@ void CheckSinglePoint(ScheduleState self, const Block& block, const BufferRegion } StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) { + const ffi::String& storage_scope, const IndexMap& index_map) { /*! * Check: * - The index is in the array of block reading region @@ -2008,7 +2013,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re CheckStorageScope(self, storage_scope); // Step 1. Check index, getting the target buffer and the parent scope - Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + Block block = ffi::GetRef(TVM_SREF_TO_BLOCK(block_sref)); BlockRealize realize = GetBlockRealize(self, block_sref); Buffer read_buffer = GetNthAccessBuffer(self, block, read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); @@ -2019,15 +2024,16 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re info.consumer_blocks.insert(block_sref); // Step 3. Update cache stage info. - Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); + ffi::Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); ICHECK(maybe_region.defined()) << read_buffer << " should appear in the block's read region: " << block->reads; BufferRegion cache_region = maybe_region.value(); - if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + if (ffi::Optional _write_block_sref = + GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); // Find the producing region - StmtSRef parent_sref = GetRef(write_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(write_block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); } else { @@ -2062,7 +2068,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re } StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) { + const ffi::String& storage_scope, const IndexMap& index_map) { /*! * Check: * - The index is in the array of block reading region @@ -2078,7 +2084,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w CheckStorageScope(self, storage_scope); // Step 1. Checking index, getting the target buffer and the parent scope - Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + Block block = ffi::GetRef(TVM_SREF_TO_BLOCK(block_sref)); BlockRealize realize = GetBlockRealize(self, block_sref); Buffer write_buffer = GetNthAccessBuffer(self, block, write_buffer_index, BufferIndexType::kWrite); @@ -2092,9 +2098,9 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); // Step 4. Find the producing region and insert position - Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); + ffi::Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); ICHECK(maybe_region.defined()) << write_buffer << " should appear in the block's write region"; - StmtSRef parent_sref = GetRef(block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, block_sref, scope_sref, &info); BufferRegion cache_region = maybe_region.value(); @@ -2130,23 +2136,23 @@ class NotReadWriteError : public ScheduleError { public: NotReadWriteError(IRModule mod, Block block, Buffer buffer) : mod_(std::move(mod)), block_(std::move(block)), buffer_(std::move(buffer)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target block does not both read & write target buffer."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The target block {0} does not both read & write target buffer {1}."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_, buffer_}; } + ffi::Array LocationsOfInterest() const final { return {block_, buffer_}; } IRModule mod_; Block block_; Buffer buffer_; }; -Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope) { +ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const ffi::String& storage_scope) { /*! * Do cache read then cache write */ @@ -2156,8 +2162,8 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int // Check 1. Check index, get the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = - GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); + Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, + BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check 3. Check required region cover for cache_read @@ -2165,13 +2171,13 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int // Check 4. Check if target block both read & write target buffer. const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref); - Optional read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer); - Optional write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer); + ffi::Optional read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer); + ffi::Optional write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer); if (!read_region.defined() || !write_region.defined()) { - throw NotReadWriteError(self->mod, GetRef(rw_block), buffer); + throw NotReadWriteError(self->mod, ffi::GetRef(rw_block), buffer); } - Array results_block_sref; + ffi::Array results_block_sref; Buffer new_buffer = WithScope(buffer, storage_scope); // Do cache read @@ -2235,22 +2241,25 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int } StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type) { + BufferIndexType buffer_index_type, bool skip_simplify) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - Block block = GetRef(block_ptr); + Block block = ffi::GetRef(block_ptr); Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); arith::Analyzer analyzer; // Step 1. Collect the original indices and check there's only single pattern of related // Load/Store and the buffer is not accessed opaquely - Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); + ffi::Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); // Simplify the indices if possible - for (const IterVar& iter : block->iter_vars) { - analyzer.Bind(iter->var, iter->dom); + if (!skip_simplify){ + // skip simplification in case to preserve unit loops. + for (const IterVar& iter : block->iter_vars) { + analyzer.Bind(iter->var, iter->dom); + } + original_indices.MutateByApply( + [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); } - original_indices.MutateByApply( - [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); // Collect block iters appearing in the original_indices std::unordered_set covered; @@ -2319,13 +2328,14 @@ struct CacheReadTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, - Array consumer_blocks, Integer read_buffer_index, - String storage_scope) { + ffi::Array consumer_blocks, + Integer read_buffer_index, ffi::String storage_scope) { return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks); } - static String UnpackedAsPython(Array outputs, String block, Array consumer_blocks, - Integer read_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array consumer_blocks, + Integer read_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2352,13 +2362,14 @@ struct CacheWriteTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, - Array consumer_blocks, Integer write_buffer_index, - String storage_scope) { + ffi::Array consumer_blocks, + Integer write_buffer_index, ffi::String storage_scope) { return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks); } - static String UnpackedAsPython(Array outputs, String block, Array consumer_blocks, - Integer write_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array consumer_blocks, + Integer write_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); @@ -2384,13 +2395,14 @@ struct CacheInplaceTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, - Integer read_buffer_index, String storage_scope) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, + Integer read_buffer_index, + ffi::String storage_scope) { return sch->CacheInplace(block, read_buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, - String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer read_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_inplace"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2409,23 +2421,26 @@ struct ReIndexTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 1; - static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumAttrs = 3; static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, - Integer buffer_index_type) { + Integer buffer_index_type, Bool skip_simplify) { return sch->ReIndex(block, buffer_index.IntValue(), - static_cast(buffer_index_type->value)); + static_cast(buffer_index_type->value), + skip_simplify.operator bool()); } - static String UnpackedAsPython(Array outputs, String block, Integer buffer_index, - Integer buffer_index_type) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer buffer_index, Integer buffer_index_type, + Bool skip_simplify) { PythonAPICall py("reindex"); py.Input("block", block); std::ostringstream os; os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) << "\", " << buffer_index << ")"; - py.Input("buffer", String(os.str())); + py.Input("buffer", ffi::String(os.str())); + py.Input("skip_simplify", skip_simplify.operator bool()); py.SingleOutput(outputs); return py.Str(); } @@ -2444,12 +2459,13 @@ struct ReindexCacheReadTraits : public UnpackedInstTraitsReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map); } - static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, - Integer read_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + IndexMap index_map, Integer read_buffer_index, + ffi::String storage_scope) { PythonAPICall py("reindex_cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2473,12 +2489,13 @@ struct ReindexCacheWriteTraits : public UnpackedInstTraitsReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map); } - static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, - Integer write_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + IndexMap index_map, Integer write_buffer_index, + ffi::String storage_scope) { PythonAPICall py("reindex_cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 0075fee18f4c..cd56ff8b9ddf 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -33,21 +33,21 @@ template class NotAllRequiredBlocksAreVisitedError : public ScheduleError { public: explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited, - const Array& required) + const ffi::Array& required) : mod_(mod), num_not_visited_(num_not_visited) { required_.reserve(required.size()); for (const StmtSRef& block_sref : required) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - required_.push_back(GetRef(block)); + required_.push_back(ffi::GetRef(block)); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Not all required blocks are under the loop scope"; } - String DetailRenderTemplate() const final { - String relation = is_consumer ? "consumer(s)" : "producer(s)"; + ffi::String DetailRenderTemplate() const final { + ffi::String relation = is_consumer ? "consumer(s)" : "producer(s)"; std::ostringstream os; os << "The primitive requires all the " << relation << " of the given block to be present under the target loop. However, there are " @@ -61,14 +61,14 @@ class NotAllRequiredBlocksAreVisitedError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { return {required_.begin(), required_.end()}; } private: IRModule mod_; int num_not_visited_; - Array required_; + ffi::Array required_; }; /*! @@ -96,22 +96,22 @@ class NotInSameScopeError : public ScheduleError { } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Expected the block and loop to be under the same block scope, and loop " "not to be the ancestor of block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, " "and loop not to be the ancestor of block"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_, loop_}; } + ffi::Array LocationsOfInterest() const final { return {block_, loop_}; } private: explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) : mod_(mod), - block_(GetRef(block_sref->StmtAs())), - loop_(GetRef(loop_sref->StmtAs())) {} + block_(ffi::GetRef(block_sref->StmtAs())), + loop_(ffi::GetRef(loop_sref->StmtAs())) {} IRModule mod_; Block block_; @@ -138,8 +138,9 @@ class NotInSameScopeError : public ScheduleError { * \throws ScheduleError if there is no such insertion point found */ template -int FindInsertionPoint(const ScheduleState& self, const Array& subtrees, - const Array& producer_srefs, const Array& consumer_srefs, +int FindInsertionPoint(const ScheduleState& self, const ffi::Array& subtrees, + const ffi::Array& producer_srefs, + const ffi::Array& consumer_srefs, std::unordered_map* block2realize, int index) { ProducerConsumerSplit split = @@ -254,9 +255,9 @@ class ScopeReconstructor : private StmtMutator { void MakeNewLoop(int insert_position, std::vector iter_doms, arith::Analyzer* analyzer, bool preserve_unit_loops) { int n_iters = iter_doms.size(); - Array loop_vars; - Array loop_extents; - Array iter_values; + ffi::Array loop_vars; + ffi::Array loop_extents; + ffi::Array iter_values; loop_vars.reserve(n_iters); loop_extents.reserve(n_iters); iter_values.reserve(n_iters); @@ -302,9 +303,9 @@ class ScopeReconstructor : private StmtMutator { /*ForKind=*/ForKind::kSerial, /*body=*/std::move(new_subtree)); } - Array subtrees = AsArray(loop_->body); + ffi::Array subtrees = AsArray(loop_->body); subtrees.insert(subtrees.begin() + insert_position, std::move(new_subtree)); - ObjectPtr new_loop = make_object(*loop_.get()); + ObjectPtr new_loop = ffi::make_object(*loop_.get()); new_loop->body = SeqStmt(std::move(subtrees)); this->new_loop_ = For(std::move(new_loop)); } @@ -312,7 +313,7 @@ class ScopeReconstructor : private StmtMutator { private: Stmt VisitStmt_(const BlockNode* block) final { if (block != scope_root_.get()) { - return GetRef(block); + return ffi::GetRef(block); } if (block == rm_src_stmt_.get()) { block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode); @@ -358,19 +359,19 @@ class ScopeReconstructor : private StmtMutator { * \param relaxed Where the calculation result is stored */ template -void RelaxBufferRegions(const Map& binding, - const Array& buffer_regions, +void RelaxBufferRegions(const ffi::Map& binding, + const ffi::Array& buffer_regions, const StmtSRef& relax_path_low_inclusive, const StmtSRef& relax_path_high_exclusive, std::unordered_map>* relaxed) { runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, ""}; // We cache the variable domains runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal; - Optional> var_dom = std::nullopt; + ffi::Optional> var_dom = std::nullopt; // Enumerate every buffer region for (const BufferRegion& buffer_region : buffer_regions) { const Buffer& buffer = buffer_region->buffer; - const Array& region = buffer_region->region; + const ffi::Array& region = buffer_region->region; // Skip the buffer regions we are not interested in auto it = relaxed->find(buffer.get()); if (it == relaxed->end()) { @@ -389,7 +390,7 @@ void RelaxBufferRegions(const Map& binding, /*extra_relax_scope=*/scope)); } // Relax the region - Array relaxed_region = + ffi::Array relaxed_region = arith::EvalSet(Substitute(region, binding), var_dom.value()); relaxed_regions.push_back({relaxed_region.begin(), relaxed_region.end()}); } @@ -412,7 +413,7 @@ std::pair SolveBlockVarDomain(const arith::IntSet& prov PrimExpr required_min = analyzer->Simplify(required.min()); PrimExpr required_max = analyzer->Simplify(required.max()); arith::IntSet var_dom, var_bound; - Optional var; + ffi::Optional var; arith::PVar p_v; arith::PVar p_e; if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) { @@ -506,9 +507,10 @@ void UpdateBlockVarDomainDimwise( } /*! \brief Helper function to implement intset version of `InverseAffineIterMap`. */ -Map InverseAffineIterMap(const Array& iter_map, - const NDIntSet& outputs, arith::Analyzer* analyzer) { - Array min_point, max_point; +ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, + const NDIntSet& outputs, + arith::Analyzer* analyzer) { + ffi::Array min_point, max_point; min_point.reserve(outputs.size()); max_point.reserve(outputs.size()); for (const auto& intset : outputs) { @@ -518,7 +520,7 @@ Map InverseAffineIterMap(const Array& it } auto rev_min = InverseAffineIterMap(iter_map, min_point); auto rev_max = InverseAffineIterMap(iter_map, max_point); - Map dom_map; + ffi::Map dom_map; for (const auto& kv : rev_min) { const Var& var = kv.first; auto it = rev_max.find(var); @@ -543,7 +545,7 @@ Map InverseAffineIterMap(const Array& it * \param iter_doms The result iteration domains to be updated * \returns bool. Denotes whether update success */ -bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& iter_vars, +bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const ffi::Array& iter_vars, const NDIntSet& provided_region, const NDIntSet& required_region, arith::Analyzer* analyzer, std::unordered_map* iter_doms) { @@ -552,12 +554,12 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& if (!intset.CanProveSinglePoint(analyzer)) return false; } // calculate forward mapping (block vars -> provided region point) - Map dom_map; + ffi::Map dom_map; for (const IterVar& iter_var : iter_vars) { dom_map.Set(iter_var->var, iter_var->dom); } size_t ndim = buffer->shape.size(); - Array provide_indices; + ffi::Array provide_indices; provide_indices.reserve(ndim); for (size_t i = 0; i < ndim; ++i) { provide_indices.push_back(provided_region[i].min()); @@ -573,8 +575,10 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& required_bound.push_back( arith::IntSet::Interval(make_zero(buffer->shape[i]->dtype), max(buffer->shape[i] - 1, 0))); } - Map var_dom = InverseAffineIterMap(res->indices, required_region, analyzer); - Map var_bound = InverseAffineIterMap(res->indices, required_bound, analyzer); + ffi::Map var_dom = + InverseAffineIterMap(res->indices, required_region, analyzer); + ffi::Map var_bound = + InverseAffineIterMap(res->indices, required_bound, analyzer); for (const auto& kv : var_dom) { const Var& var = kv.first; auto it = var_bound.find(var); @@ -593,7 +597,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& * \return A list of iteration domain info corresponding to the given list of block vars */ std::vector CalculateBlockVarDomain( - const Array& iter_vars, + const ffi::Array& iter_vars, std::unordered_map> provided_regions, std::unordered_map> required_regions, arith::Analyzer* analyzer) { @@ -657,16 +661,16 @@ template void CalculateProvidedRequiredRegions( const BlockNode* block, const StmtSRef& loop_sref, std::unordered_map block2realize, - Array producer_srefs, Array consumer_srefs, + ffi::Array producer_srefs, ffi::Array consumer_srefs, std::unordered_map>* provided_regions, std::unordered_map>* required_regions) { // Step 1. Calculate the region provided by a single execution instance of `block` - const Array& provided_buffers = is_compute_at ? block->writes : block->reads; + const ffi::Array& provided_buffers = is_compute_at ? block->writes : block->reads; provided_regions->reserve(provided_buffers.size()); required_regions->reserve(provided_buffers.size()); for (const BufferRegion& provided_buffer_region : provided_buffers) { const BufferNode* buffer = provided_buffer_region->buffer.get(); - const Array& region = provided_buffer_region->region; + const ffi::Array& region = provided_buffer_region->region; (*provided_regions)[buffer].push_back(support::NDIntSetFromRegion(region)); (*required_regions)[buffer].clear(); } @@ -675,9 +679,9 @@ void CalculateProvidedRequiredRegions( const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref); ICHECK(block2realize.count(required_block)); RelaxBufferRegions( - /*binding=*/GetBindings(GetRef(block2realize.at(required_block))), + /*binding=*/GetBindings(ffi::GetRef(block2realize.at(required_block))), /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes, - /*relax_path_low_inclusive=*/GetRef(required_block_sref->parent), + /*relax_path_low_inclusive=*/ffi::GetRef(required_block_sref->parent), /*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions); } } @@ -695,11 +699,11 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s // Check condition 1) : scope stage pipeline StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - Block scope_root = GetRef(scope_root_sref->StmtAs()); + Block scope_root = ffi::GetRef(scope_root_sref->StmtAs()); AddShapeVarBounds(self, scope_root_sref.get(), analyzer); BlockScope scope = self->GetBlockScope(scope_root_sref); - Array producer_srefs = GetProducers(block_sref, scope); - Array consumer_srefs = GetConsumers(block_sref, scope); + ffi::Array producer_srefs = GetProducers(block_sref, scope); + ffi::Array consumer_srefs = GetConsumers(block_sref, scope); // Check condition 2) : `block` is a complete or reduction block CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref); // Check condition 3): `block` and `loop` are under the same scope, @@ -711,7 +715,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s CheckNotOutputBlock(self, block_sref, scope_root_sref); } // Step 2. Plan for the removal of `block` - ScopeReconstructor reconstructor(scope_root, GetRef(block), GetRef(loop)); + ScopeReconstructor reconstructor(scope_root, ffi::GetRef(block), ffi::GetRef(loop)); LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_); // Step 3. Find the insertion point under `loop` // Check condition 5): all the required block are under the given loop @@ -755,7 +759,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s BlockInfo& block_info = self->block_info[block_sref]; block_info.affine_binding = IsAffineBinding( /*realize=*/reconstructor.new_block_realize_, - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(ffi::GetRef(block_sref->parent)), /*analyzer=*/analyzer); } @@ -813,8 +817,8 @@ struct ComputeAtTraits : public UnpackedInstTraits { return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value); } - static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, - Bool preserve_unit_loops, IntImm index) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + ffi::String loop_rv, Bool preserve_unit_loops, IntImm index) { PythonAPICall py("compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); @@ -842,8 +846,8 @@ struct ReverseComputeAtTraits : public UnpackedInstTraitsvalue); } - static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, - Bool preserve_unit_loops, IntImm index) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + ffi::String loop_rv, Bool preserve_unit_loops, IntImm index) { PythonAPICall py("reverse_compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 4e037158d98a..cc3785d5c103 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -36,14 +36,16 @@ class HasInitBlock : public ScheduleError { public: explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {} - String FastErrorString() const final { return "ScheduleError: The block has init statement"; } + ffi::String FastErrorString() const final { + return "ScheduleError: The block has init statement"; + } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "ScheduleError: The block has init statement: {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } static void Check(const IRModule& mod, const Block& block) { if (block->init.defined()) { @@ -61,12 +63,12 @@ class NotSingleReadWriteBuffer : public ScheduleError { explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) : mod_(mod), is_read_(is_read), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return is_read_ ? "ScheduleError: The block is allowed to read only a single buffer region" : "ScheduleError: The block is allowed to write only a single buffer region"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (is_read_) { int k = block_->reads.size(); return "The block is only allowed to read a single buffer region, but it reads " + @@ -79,7 +81,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; bool is_read_; @@ -87,7 +89,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { static Buffer GetSingleRead(const ScheduleState& self, const Block& block, const StmtSRef& scope_root_sref) { - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers; const BufferNode* read_buffer = nullptr; for (const BufferRegion& read_region : block->reads) { @@ -95,7 +97,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { if (buffer == read_buffer) { continue; } - if (buffer_writers.count(GetRef(buffer)) > 0) { + if (buffer_writers.count(ffi::GetRef(buffer)) > 0) { if (read_buffer != nullptr) { throw NotSingleReadWriteBuffer(self->mod, true, block); } @@ -105,7 +107,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { if (read_buffer == nullptr) { throw NotSingleReadWriteBuffer(self->mod, true, block); } - return GetRef(read_buffer); + return ffi::GetRef(read_buffer); } static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { @@ -121,17 +123,17 @@ class BodyAnalysisError : public ScheduleError { explicit BodyAnalysisError(bool is_reverse, IRModule mod, Block block) : is_reverse_(is_reverse), mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The block cannot be inlined because its body pattern does not meet the " "condition for inlining"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return is_reverse_ ? kErrBodyReverseInline : kErrBodyInline; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } bool is_reverse_; IRModule mod_; @@ -143,20 +145,20 @@ class NonSingleProducerError : public ScheduleError { explicit NonSingleProducerError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The consumer block to be inlined is required to have only a single " "producer block, and the producer block should be a complete block who has only a " "single consumer"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The consumer block {0} to be inlined is required to have only a single " "producer block, and the producer block should be a complete block who has only a " "single consumer"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -174,7 +176,7 @@ class NonSingleProducerError : public ScheduleError { const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); Buffer consumer_buffer = NotSingleReadWriteBuffer::GetSingleRead( - self, GetRef(consumer_block), scope_root_sref); + self, ffi::GetRef(consumer_block), scope_root_sref); class ProducerFinder : public StmtVisitor { public: static std::vector GetProducer(const ScheduleState& self, @@ -211,9 +213,9 @@ class NonSingleProducerError : public ScheduleError { // Check if the producer block is a complete block StmtSRef producer_block_sref = self_->stmt2ref.at(node); if (!IsCompleteBlock(self_, producer_block_sref, scope_root_sref_)) { - throw NonSingleProducerError(self_->mod, GetRef(node)); + throw NonSingleProducerError(self_->mod, ffi::GetRef(node)); } - producer_across_scope_.back().push_back(GetRef(node)); + producer_across_scope_.back().push_back(ffi::GetRef(node)); break; } } @@ -224,9 +226,9 @@ class NonSingleProducerError : public ScheduleError { std::vector> producer_across_scope_; }; std::vector producer_across_scope = ProducerFinder::GetProducer( - self, scope_root_sref, consumer_buffer, GetRef(scope_block)); + self, scope_root_sref, consumer_buffer, ffi::GetRef(scope_block)); if (producer_across_scope.size() != 1) { - throw NonSingleProducerError(self->mod, GetRef(consumer_block)); + throw NonSingleProducerError(self->mod, ffi::GetRef(consumer_block)); } return self->stmt2ref.at(producer_across_scope[0].get()); } @@ -237,21 +239,21 @@ class OpaqueAccessError : public ScheduleError { explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) : mod_(mod), scope_root_(nullptr) { const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); - this->scope_root_ = GetRef(scope_root); + this->scope_root_ = ffi::GetRef(scope_root); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer to be inlined has opaque access (e.g. `B.data`), or its " "subregion is matched into other blocks"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The buffer to be inlined has opaque access (e.g. `B.data`), or its " "subregion is matched into other blocks: {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {scope_root_}; } + ffi::Array LocationsOfInterest() const final { return {scope_root_}; } IRModule mod_; Block scope_root_; @@ -263,11 +265,11 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { PrimExpr new_predicate) : mod_(mod), producer_(producer), new_predicate_(new_predicate) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The producer block has a non-trivial predicate."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: The producer block {0} has a non-trivial predicate " << producer_->predicate << " that cannot be implied by the synthesized predicate " @@ -276,7 +278,7 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {producer_}; } + ffi::Array LocationsOfInterest() const final { return {producer_}; } IRModule mod_; BlockRealize producer_; @@ -315,7 +317,7 @@ class BaseInliner : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* block) { CheckMatchBufferRegion(block); AddBuffersInBlockSignature(block); - Block src_block = GetRef(block); + Block src_block = ffi::GetRef(block); if (src_block.same_as(src_stmt)) { block = tgt_stmt.as(); ICHECK(block != nullptr); @@ -358,7 +360,7 @@ class BaseInliner : public StmtExprMutator { */ Block UpdateBuffersInBlockSignature(Block block, bool is_scope_root) { // Step 1. Update `BlockNode::alloc_buffers` - Array alloc_buffers; + ffi::Array alloc_buffers; if (is_scope_root) { alloc_buffers.reserve(block->alloc_buffers.size()); for (const Buffer& alloc_buffer : block->alloc_buffers) { @@ -370,14 +372,15 @@ class BaseInliner : public StmtExprMutator { alloc_buffers = std::move(block->alloc_buffers); } // Step 2. Update `BlockNode::reads` and `BlockNode::writes` - Array reads = std::move(block->reads); - Array writes = std::move(block->writes); + ffi::Array reads = std::move(block->reads); + ffi::Array writes = std::move(block->writes); auto f_access_inline_buffer = [this](const BufferRegion& access) { return access->buffer.same_as(this->inlined_buffer_); }; if (!is_scope_root && (std::any_of(reads.begin(), reads.end(), f_access_inline_buffer) || std::any_of(writes.begin(), writes.end(), f_access_inline_buffer))) { - Array> inspected = GetBlockReadWriteRegion(block, buffer_var_map_); + ffi::Array> inspected = + GetBlockReadWriteRegion(block, buffer_var_map_); reads = inspected[0]; writes = inspected[1]; } @@ -422,7 +425,7 @@ class BaseInliner : public StmtExprMutator { /*! \brief The scope root */ StmtSRef scope_root_sref_{nullptr}; /*! \brief Maps a buffer's data field to itself */ - Map buffer_var_map_; + ffi::Map buffer_var_map_; /*! \brief The indices used for indexing the buffer to be inlined */ std::vector idx_vars_; /*! \brief The mapping to substitute index variables to PrimExprs */ @@ -438,7 +441,7 @@ class BaseInliner : public StmtExprMutator { /*! \brief The Stmt to be replaced to when removing the leaf block */ Stmt tgt_stmt{nullptr}; /*! \brief The reuse mapping of block srefs */ - Map block_reuse; + ffi::Map block_reuse; /*! \brief Indicates if there is any opaque access of the inlined buffer */ bool has_opaque_access{false}; }; @@ -489,7 +492,7 @@ class ComputeInliner : public BaseInliner { // If the mapping for store indices is non-trivial // check bijective mapping from producer iter var to store indices - Map producer_iter_doms; + ffi::Map producer_iter_doms; for (const auto& iter : producer_block->iter_vars) { producer_iter_doms.Set(iter->var, iter->dom); } @@ -509,7 +512,7 @@ class ComputeInliner : public BaseInliner { idx_vars_[i] = Var("ph_" + std::to_string(i), inlined_store_->indices[i].dtype()); } auto inverse_iter_map = arith::InverseAffineIterMap( - res->indices, Array(idx_vars_.begin(), idx_vars_.end())); + res->indices, ffi::Array(idx_vars_.begin(), idx_vars_.end())); for (const auto& iter : producer_block->iter_vars) { if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) { // fallback mapping for constant iters @@ -541,7 +544,7 @@ class ComputeInliner : public BaseInliner { * \brief Set the mapping of index substitution `self->idx_sub_` * \param indices The expressions that the corresponding index variables are replaced to */ - void SetIndexSubstitution(const Array& indices) { + void SetIndexSubstitution(const ffi::Array& indices) { ICHECK_EQ(indices.size(), idx_vars_.size()); int n = idx_vars_.size(); for (int i = 0; i < n; ++i) { @@ -573,7 +576,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr VisitExpr_(const VarNode* var) final { auto it = self_->idx_sub_.find(var); if (it == self_->idx_sub_.end()) { - return GetRef(var); + return ffi::GetRef(var); } return (*it).second; } @@ -594,7 +597,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr VisitExpr_(const VarNode* var) final { auto it = self_->idx_sub_.find(var); if (it == self_->idx_sub_.end()) { - return GetRef(var); + return ffi::GetRef(var); } return (*it).second; } @@ -644,7 +647,7 @@ class ReverseComputeInliner : public BaseInliner { } // Collect block iter domains and update the substition map - Map consumer_iter_doms; + ffi::Map consumer_iter_doms; for (const auto& iter_var : consumer_block->iter_vars) { consumer_iter_doms.Set(iter_var->var, iter_var->dom); // Set default mapping for unit iters @@ -708,7 +711,7 @@ class ReverseComputeInliner : public BaseInliner { /*! \brief Generate the predicate after inlining based on the consumer predicate */ BlockRealize BuildInlinedConsumerPredicate(BlockRealize producer_block_realize) { // Bind the producer block iter domains for simplification - Map subst_map; + ffi::Map subst_map; Block producer_block = producer_block_realize->block; for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) { const IterVar& iter = producer_block->iter_vars[i]; @@ -748,7 +751,7 @@ class ReverseComputeInliner : public BaseInliner { auto n = producer_block_realize.CopyOnWrite(); n->block = producer_block; n->predicate = analyzer_.Simplify(outer_predicate); - return GetRef(n); + return ffi::GetRef(n); } Stmt VisitStmt_(const BlockRealizeNode* op) final { @@ -774,7 +777,7 @@ class ReverseComputeInliner : public BaseInliner { * \return Whether the consumer block iter domains are covered */ bool CheckConsumerCovered() { - Map producer_iter_doms; + ffi::Map producer_iter_doms; for (const IterVar& iter_var : producer_block_->iter_vars) { producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom)); } @@ -800,7 +803,7 @@ class ReverseComputeInliner : public BaseInliner { * the result. It will be later used to transform the BufferStore indices of the producer. * \param producer_indices The BufferStore indices of the producer. */ - void CreateInverseMapping(const Array producer_indices) { + void CreateInverseMapping(const ffi::Array producer_indices) { auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, producer_indices); for (const auto& pair : inverse_iter_map) { idx_sub_[pair.first.get()] = pair.second; @@ -811,7 +814,7 @@ class ReverseComputeInliner : public BaseInliner { // "producer->value" may contain the buffer that is inlined in cases of reduction, // so we need to resolve the recursion first producer_rhs_ = RecursionResolver(this)(producer->value); - return Substituter(this)(GetRef(inlined_store_)); + return Substituter(this)(ffi::GetRef(inlined_store_)); } /*! @@ -847,7 +850,7 @@ class ReverseComputeInliner : public BaseInliner { * \param expected_ndim The expected ndim of the access * \return A boolean flag indicating if the check is successful */ - bool UpdateAndCheckIndexExprs(const Array& indices) { + bool UpdateAndCheckIndexExprs(const ffi::Array& indices) { if (buffer_load_indices_.empty()) { buffer_load_indices_ = indices; } else if (!std::equal(buffer_load_indices_.begin(), buffer_load_indices_.end(), @@ -861,9 +864,9 @@ class ReverseComputeInliner : public BaseInliner { /*! \brief The RHS value of the producer's BufferStore statement */ PrimExpr producer_rhs_{nullptr}; /*! \brief The indices of the consumer's BufferLoad */ - Array buffer_load_indices_; + ffi::Array buffer_load_indices_; /*! \brief The IterMap representing the indices of the consumer's BufferLoad */ - Array buffer_load_iter_map_{nullptr}; + ffi::Array buffer_load_iter_map_{nullptr}; /*! \brief The producer block */ const BlockNode* producer_block_{nullptr}; /* \brief The consumer block */ @@ -879,7 +882,7 @@ class ReverseComputeInliner : public BaseInliner { void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, bool check_only = false) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref); - Block producer_block = GetRef(_producer_block); + Block producer_block = ffi::GetRef(_producer_block); HasInitBlock::Check(self->mod, producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block @@ -897,7 +900,7 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, LeafBlockRemovalPlan(self, producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); // Step 5. Create an AST where the leaf `producer_block_sref` points to is removed, // and update other blocks who read from the removed block - Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + Stmt tgt_stmt = inliner(ffi::GetRef(scope_root_sref->stmt)); if (inliner.has_opaque_access) { throw OpaqueAccessError(self->mod, scope_root_sref); } @@ -924,7 +927,7 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, bool check_only = false) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); - Block consumer_block = GetRef(_consumer_block); + Block consumer_block = ffi::GetRef(_consumer_block); BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref); HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block @@ -949,7 +952,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block LeafBlockRemovalPlan(self, consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); // Step 6. Create an AST where the leaf `consumer_block_sref` points to is removed, // and update other blocks who read from the removed block - Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + Stmt tgt_stmt = inliner(ffi::GetRef(scope_root_sref->stmt)); if (inliner.has_opaque_access) { throw OpaqueAccessError(self->mod, scope_root_sref); } @@ -963,7 +966,8 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block BlockInfo& block_info = self->block_info[producer_block_sref]; block_info.affine_binding = IsAffineBinding( /*realize=*/GetBlockRealize(self, producer_block_sref), - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(producer_block_sref->parent)), + /*loop_var_ranges=*/ + LoopDomainOfSRefTreePath(ffi::GetRef(producer_block_sref->parent)), /*analyzer=*/&analyzer); } @@ -980,6 +984,467 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre ReverseComputeInlineImpl(self, consumer_block_sref); } +/*! + * \brief Helper to fuse epilogue block into reduction block + * Analyzes epilogue pattern and transforms reduction init/update + */ +class ReductionEpilogueFuser : public BaseInliner { + public: + explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, + const BlockRealize& epilogue_block_realize, + const StmtSRef& scope_root_sref) + : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), + reduction_block_(reduction_block), + epilogue_block_(epilogue_block_realize->block.get()) {} + + bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize); + + // Step 2: Create single fused reduction block + Block CreateFusedReductionBlock(const BlockNode* reduction_block, + const BlockRealizeNode* reduction_realize); + + private: + bool AnalyzeEpiloguePattern(const PrimExpr& value); + bool IsReductionBlock(const BlockNode* block); + void ExtractEpilogueInfo(); + // Helper function to extract BufferLoad nodes from BufferStore + static std::vector ExtractBufferLoad(const Buffer& buffer, + const BufferStoreNode* from) { + struct Extractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.get() == buffer) { + result.push_back(load); + } + ExprVisitor::VisitExpr_(load); + } + const BufferNode* buffer; + std::vector result; + } extractor; + extractor.buffer = buffer.get(); + for (const PrimExpr& expr : from->indices) { + extractor(expr); + } + extractor(from->value); + return std::move(extractor.result); + } + + const BlockNode* reduction_block_; + const BlockNode* epilogue_block_; + PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C + Buffer epilogue_output_buffer_{nullptr}; // Output buffer D + ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] + BufferRegion epilogue_output_region_{nullptr}; // Write region of D + Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C + BufferRegion epilogue_addend_region_{nullptr}; // Read region of C +}; + +bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { + // 1. Validate predicate + if (!is_one(epilogue_block_realize->predicate)) { + // Failure: Predicate in epilogue block is not supported + return false; + } + + // 2. Check if epilogue body is BufferStore + if (inlined_store_ == nullptr) { + // Failure: epilogue block body is not BufferStore + return false; + } + + // 3. Check if epilogue reads from reduction buffer + std::vector loads = ExtractBufferLoad(inlined_buffer_, inlined_store_); + if (loads.size() == 0) { + // Failure: no BufferLoad from the reduction buffer + return false; + } + + // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] + if (!AnalyzeEpiloguePattern(inlined_store_->value)) { + // Failure: epilogue is not a simple addition pattern + return false; + } + + // 5. Check if producer is a reduction block + if (!IsReductionBlock(reduction_block_)) { + // Failure: producer is not a reduction block + return false; + } + + // 6. Extract epilogue information (output buffer, indices, regions, etc.) + ExtractEpilogueInfo(); + + return true; +} + +bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { + // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] + if (const auto* add = value.as()) { + const auto* load_a = add->a.as(); + const auto* load_b = add->b.as(); + + bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); + bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); + + // Ensure exactly one operand is from the reduction buffer + if (a_is_target != b_is_target) { + epilogue_addend_ = a_is_target ? add->b : add->a; + return true; + } + } + + return false; +} + +bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) { + // Check if block has reduction iter vars + for (const IterVar& iter : block->iter_vars) { + if (iter->iter_type == kCommReduce) { + return true; + } + } + return false; +} + +void ReductionEpilogueFuser::ExtractEpilogueInfo() { + // Extract epilogue output buffer and indices + epilogue_output_buffer_ = inlined_store_->buffer; + epilogue_output_indices_ = inlined_store_->indices; + + // Extract epilogue output region from epilogue block writes + for (const BufferRegion& write : epilogue_block_->writes) { + if (write->buffer.same_as(epilogue_output_buffer_)) { + epilogue_output_region_ = write; + break; + } + } + + // Extract epilogue addend buffer and region from epilogue_addend_ + if (const auto* load = epilogue_addend_.as()) { + epilogue_addend_buffer_ = load->buffer; + // Find the read region from epilogue block reads + for (const BufferRegion& read : epilogue_block_->reads) { + if (read->buffer.same_as(epilogue_addend_buffer_)) { + epilogue_addend_region_ = read; + break; + } + } + } +} + +Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reduction_block, + const BlockRealizeNode* reduction_realize) { + ObjectPtr new_block = ffi::make_object(*reduction_block); + + // 1. Map epilogue block vars to reduction block vars + std::vector reduction_data_vars; + for (const IterVar& iter_var : reduction_block->iter_vars) { + if (iter_var->iter_type == IterVarType::kDataPar) { + reduction_data_vars.push_back(iter_var->var); + } + } + std::vector epilogue_data_vars; + for (const IterVar& iter_var : epilogue_block_->iter_vars) { + if (iter_var->iter_type == IterVarType::kDataPar) { + epilogue_data_vars.push_back(iter_var->var); + } + } + + ICHECK_EQ(reduction_data_vars.size(), epilogue_data_vars.size()) + << "ValueError: The number of data parallel iter vars must be the same in the reduction " + "and epilogue blocks."; + + std::unordered_map var_map; + for (size_t i = 0; i < reduction_data_vars.size(); ++i) { + var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; + } + + // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj] + BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), + Substitute(epilogue_output_indices_, var_map)); + new_block->init = new_init_store; + + // 3. Replace output buffer from temp to D in body + class BufferReplacer : public StmtExprMutator { + public: + BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + if (store->buffer.same_as(old_buffer_)) { + return BufferStore(new_buffer_, store->value, store->indices); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(old_buffer_)) { + return BufferLoad(new_buffer_, load->indices); + } + return load; + } + + private: + Buffer old_buffer_; + Buffer new_buffer_; + }; + + BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_); + new_block->body = replacer(reduction_block->body); + + // 4. Update write regions + ffi::Array new_writes; + for (const BufferRegion& write : reduction_block->writes) { + if (write->buffer.same_as(inlined_buffer_)) { + new_writes.push_back( + BufferRegion(epilogue_output_buffer_, Substitute(write->region, var_map))); + } else { + new_writes.push_back(write); + } + } + new_block->writes = new_writes; + + // 5. Update read regions (C first, then A, B) + ffi::Array new_reads; + std::unordered_set read_bufs; + + // Add C buffer read first (used in init) + if (epilogue_addend_buffer_.defined()) { + new_reads.push_back(BufferRegion(epilogue_addend_buffer_, + Substitute(epilogue_addend_region_->region, var_map))); + read_bufs.insert(epilogue_addend_buffer_.get()); + } + + // Add existing read regions (A, B, etc.) + for (const BufferRegion& read : reduction_block->reads) { + if (!read->buffer.same_as(inlined_buffer_)) { + // Only add non-temp buffers + if (read_bufs.find(read->buffer.get()) == read_bufs.end()) { + new_reads.push_back(read); + read_bufs.insert(read->buffer.get()); + } + } + } + + new_block->reads = new_reads; + + return Block(new_block); +} + +/*! + * \brief Check if a buffer is still referenced by other blocks in the scope + */ +static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) { + class BufferUsageChecker : public StmtVisitor { + public: + explicit BufferUsageChecker(const Buffer& buffer) : buffer_(buffer) {} + + bool CheckStmt(const Stmt& stmt) { + found_usage_ = false; + VisitStmt(stmt); + return found_usage_; + } + + private: + void VisitStmt_(const BlockRealizeNode* op) final { + if (found_usage_) return; + + if (!op || !op->block.defined()) { + StmtVisitor::VisitStmt_(op); + return; + } + + const BlockNode* block = op->block.get(); + if (!block) { + StmtVisitor::VisitStmt_(op); + return; + } + + // Check reads + for (const BufferRegion& read : block->reads) { + if (read->buffer.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + // Check writes + for (const BufferRegion& write : block->writes) { + if (write->buffer.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + // Continue visiting nested blocks + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BlockNode* op) final { + if (found_usage_) return; + if (!op) return; + + // Check alloc_buffers + for (const Buffer& buf : op->alloc_buffers) { + if (buf.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + StmtVisitor::VisitStmt_(op); + } + + const Buffer& buffer_; + bool found_usage_{false}; + }; + + if (!scope_root->body.defined()) { + return false; + } + + BufferUsageChecker checker(buffer); + return checker.CheckStmt(scope_root->body); +} + +/*! + * \brief Helper class to replace reduction and epilogue blocks with a single fused block + */ +class SingleBlockFusionReplacer : public StmtMutator { + public: + static Block Replace(Block old_scope_root, Block new_fused_block, Block old_reduction_block, + Block old_epilogue_block, Buffer reduction_buffer) { + SingleBlockFusionReplacer replacer(std::move(new_fused_block), std::move(old_reduction_block), + std::move(old_epilogue_block), std::move(reduction_buffer)); + Block result = Downcast(replacer(std::move(old_scope_root))); + + // Check if reduction_buffer is still referenced by other blocks + bool buffer_still_used = CheckBufferStillUsed(result, reduction_buffer); + + // Remove intermediate temp buffer only if it's not used by other blocks + if (!buffer_still_used) { + BlockNode* p = result.CopyOnWrite(); + ffi::Array new_alloc_buffers; + for (const Buffer& buf : p->alloc_buffers) { + if (!buf.same_as(reduction_buffer)) { + new_alloc_buffers.push_back(buf); + } + } + p->alloc_buffers = new_alloc_buffers; + } + + return result; + } + + private: + explicit SingleBlockFusionReplacer(Block new_fused_block, Block old_reduction_block, + Block old_epilogue_block, Buffer reduction_buffer) + : new_fused_block_(std::move(new_fused_block)), + old_reduction_block_(std::move(old_reduction_block)), + old_epilogue_block_(std::move(old_epilogue_block)), + reduction_buffer_(std::move(reduction_buffer)) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt mutated_body = StmtMutator::VisitStmt(loop->body); + // Remove empty loops (containing only Evaluate(0)) + if (mutated_body.as()) { + return mutated_body; // Return Evaluate(0) to be removed by SeqStmt + } + + return For(loop->loop_var, loop->min, loop->extent, loop->kind, mutated_body, + loop->thread_binding, loop->annotations); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block.same_as(old_reduction_block_)) { + // Replace reduction block with new fused block + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = new_fused_block_; + return BlockRealize(new_realize); + } else if (realize->block.same_as(old_epilogue_block_)) { + // Remove epilogue block completely + return Evaluate(0); + } + return StmtMutator::VisitStmt_(realize); + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + ffi::Array new_stmts; + for (const Stmt& stmt : seq->seq) { + Stmt new_stmt = VisitStmt(stmt); + // Remove Evaluate(0) + if (!new_stmt.as()) { + new_stmts.push_back(new_stmt); + } + } + return SeqStmt::Flatten(new_stmts); + } + + private: + Block new_fused_block_; + Block old_reduction_block_; + Block old_epilogue_block_; + Buffer reduction_buffer_; +}; + +void FuseReductionEpilogueImpl(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref, bool check_only = false) { + const BlockNode* _reduction_block = TVM_SREF_TO_BLOCK(reduction_block_sref); + const BlockNode* _epilogue_block = TVM_SREF_TO_BLOCK(epilogue_block_sref); + + Block reduction_block = ffi::GetRef(_reduction_block); + Block epilogue_block = ffi::GetRef(_epilogue_block); + BlockRealize epilogue_block_realize = GetBlockRealize(self, epilogue_block_sref); + + // Step 1. Get the scope block + StmtSRef scope_root_sref = + GetScopeRoot(self, epilogue_block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Get the reduction buffer (intermediate buffer) + Buffer reduction_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, reduction_block); + + // Step 3. Check completeness and reduction block properties + CheckReductionBlock(self, reduction_block_sref, scope_root_sref); + CheckCompleteBlock(self, epilogue_block_sref, scope_root_sref); + CheckNotOutputBlock(self, reduction_block_sref, scope_root_sref); + + // Step 4. Analyze the epilogue pattern + ReductionEpilogueFuser fuser(reduction_buffer, _reduction_block, epilogue_block_realize, + scope_root_sref); + if (!fuser.BodyPatternAllowFusion(epilogue_block_realize)) { + throw BodyAnalysisError(true, self->mod, epilogue_block); + } + + if (check_only) { + return; + } + + // Step 5. Create single fused reduction block + BlockRealize reduction_realize = GetBlockRealize(self, reduction_block_sref); + Block fused_block = fuser.CreateFusedReductionBlock(_reduction_block, reduction_realize.get()); + + // Step 6. Transform and replace IR + const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); + + Block new_scope_root = + SingleBlockFusionReplacer::Replace(ffi::GetRef(old_scope_root), fused_block, + reduction_block, epilogue_block, reduction_buffer); + + // Step 7. Update schedule state + ffi::Map block_reuse; + block_reuse.Set(ffi::GetRef(old_scope_root), new_scope_root); + block_reuse.Set(reduction_block, fused_block); + self->Replace(scope_root_sref, new_scope_root, block_reuse); + + // Step 8. Update BlockInfo + self->UpdateScopeBlockInfo(GetBlockRealize(self, scope_root_sref)); +} + +void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref) { + FuseReductionEpilogueImpl(self, reduction_block_sref, epilogue_block_sref); +} + /******** InstructionKind Registration ********/ struct ComputeInlineTraits : public UnpackedInstTraits { @@ -995,7 +1460,7 @@ struct ComputeInlineTraits : public UnpackedInstTraits { return sch->ComputeInline(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("compute_inline"); py.Input("block", block_rv); return py.Str(); @@ -1018,7 +1483,7 @@ struct ReverseComputeInlineTraits : public UnpackedInstTraitsReverseComputeInline(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("reverse_compute_inline"); py.Input("block", block_rv); return py.Str(); @@ -1031,5 +1496,34 @@ struct ReverseComputeInlineTraits : public UnpackedInstTraits { + static constexpr const char* kName = "FuseReductionEpilogue"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV reduction_block_rv, + BlockRV epilogue_block_rv) { + return sch->FuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); + } + + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::String reduction_block_rv, + ffi::String epilogue_block_rv) { + PythonAPICall py("fuse_reduction_epilogue"); + py.Input("reduction_block", reduction_block_rv); + py.Input("epilogue_block", epilogue_block_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(FuseReductionEpilogueTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index d848dad28f27..7e61fd4eb20a 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -27,7 +27,7 @@ namespace tir { /*! \brief Information used to create new padding block */ struct PaddingBlockInfo { /*! \brief In-bound block iter regions, wrt loop vars. */ - Array in_bound_region; + ffi::Array in_bound_region; /*! \brief In-bound value, wrt block iter vars. */ PrimExpr in_bound_value; /*! \brief Condition of in-bound write, wrt loop vars. */ @@ -41,12 +41,12 @@ class PaddingPatternMatchError : public ScheduleError { PaddingPatternMatchError(IRModule mod, Block block, const std::string& error_msg) : mod_(std::move(mod)), block_(std::move(block)), error_msg_(error_msg) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: decompose_padding expect the block to match padding pattern\n " + error_msg_; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: decompose_padding expect the block {0} to match padding pattern\n " << error_msg_; @@ -54,7 +54,7 @@ class PaddingPatternMatchError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -68,7 +68,7 @@ class PaddingPatternMatchError : public ScheduleError { class PaddingInfoAnalyzer { public: static PaddingBlockInfo CheckAndGetPaddingInfo(IRModule mod, const BlockRealizeNode* realize, - const Map& dom_map, + const ffi::Map& dom_map, arith::Analyzer* analyzer) { PaddingInfoAnalyzer padding_analyzer(analyzer); if (!padding_analyzer.MatchPadding(realize, dom_map)) { @@ -81,7 +81,7 @@ class PaddingInfoAnalyzer { explicit PaddingInfoAnalyzer(arith::Analyzer* analyzer) : analyzer_(analyzer) {} /*! \brief Detect padding pattern and update result. */ - bool MatchPadding(const BlockRealizeNode* realize, const Map& dom_map) { + bool MatchPadding(const BlockRealizeNode* realize, const ffi::Map& dom_map) { // Step 1. Check match padding computation pattern. // A[...] = T.if_then_else(predicate, B[...], imm) Block block = realize->block; @@ -120,7 +120,7 @@ class PaddingInfoAnalyzer { SetError("The in-bound predicate is trivial"); return false; } - Array in_bound_region = this->EstimateInBoundRegion( + ffi::Array in_bound_region = this->EstimateInBoundRegion( /*iter_values=*/realize->iter_values, /*dom_map=*/dom_map, /*in_bound_predicate=*/in_bound_predicate); if (in_bound_region.empty()) { @@ -157,10 +157,10 @@ class PaddingInfoAnalyzer { } /*! \brief Return iteration region of block vars where the padding predicate evals to true. */ - Array EstimateInBoundRegion(const Array& iter_values, - const Map& dom_map, - const PrimExpr& in_bound_predicate) { - Array region; + ffi::Array EstimateInBoundRegion(const ffi::Array& iter_values, + const ffi::Map& dom_map, + const PrimExpr& in_bound_predicate) { + ffi::Array region; auto res = arith::DetectIterMap(iter_values, dom_map, in_bound_predicate, arith::IterMapLevel::Surjective, analyzer_); @@ -196,12 +196,12 @@ class PaddingInfoAnalyzer { /*! \brief Create block to fill constant pad values into full region */ static std::pair CreateConstBlock(const BlockRealizeNode* realize, const PaddingBlockInfo& info, - const Array& loops, + const ffi::Array& loops, const Stmt& highest_pos_inclusive, arith::Analyzer* analyzer) { const Block& block = realize->block; - Array new_iter_vars; - Map repl_dict; + ffi::Array new_iter_vars; + ffi::Map repl_dict; // create new block itervars for (size_t i = 0; i < block->iter_vars.size(); ++i) { @@ -231,7 +231,7 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re /*name_hint=*/block->name_hint + "_pad_const", /*body=*/std::move(store)); // create new loop vars - Array new_loop_vars; + ffi::Array new_loop_vars; for (const For& loop : loops) { Var new_var = loop->loop_var.copy_with_suffix(""); new_loop_vars.push_back(new_var); @@ -242,7 +242,7 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re } // create new block realize node - Array new_iter_values; + ffi::Array new_iter_values; for (size_t i = 0; i < realize->iter_values.size(); ++i) { new_iter_values.push_back(rewrite_expr(realize->iter_values[i])); } @@ -265,15 +265,15 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re static std::pair CreateInBoundBlock(const BlockRealizeNode* realize, const PaddingBlockInfo& info, - const Array& loops, + const ffi::Array& loops, const Stmt& highest_pos_inclusive, arith::Analyzer* analyzer) { const Block& block = realize->block; - Array new_iter_vars; - Map repl_dict; + ffi::Array new_iter_vars; + ffi::Map repl_dict; // record loop ranges to be mutated - Map new_loop_ranges; + ffi::Map new_loop_ranges; for (const For& loop : loops) { new_loop_ranges.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); if (loop.same_as(highest_pos_inclusive)) { @@ -282,7 +282,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* } // create new block iter vars and iter bindings - Array new_iter_binding; + ffi::Array new_iter_binding; for (size_t i = 0; i < info.in_bound_region.size(); ++i) { // add new block itervar const IterVar& origin_itervar = block->iter_vars[i]; @@ -318,7 +318,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* }; // create new read/write region for in-bound accesses - Array reads, writes; + ffi::Array reads, writes; for (const BufferRegion& read : block->reads) { reads.push_back(BufferRegion(read->buffer, rewrite_region(read->region))); } @@ -343,7 +343,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* PrimExpr min = it == new_loop_ranges.end() ? loop->min : (*it).second->min; PrimExpr extent = it == new_loop_ranges.end() ? loop->extent : (*it).second->extent; nest_stmt_root = For(loop->loop_var, min, extent, loop->kind, nest_stmt_root, - loop->thread_binding, loop->annotations, loop->span); + loop->thread_binding, loop->annotations, loop->step, loop->span); if (loop.same_as(highest_pos_inclusive)) { break; } @@ -413,7 +413,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Condition Checks and Information Collection const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); - Map dom_map; + ffi::Map dom_map; arith::Analyzer analyzer; // Check 1. check the block is complete. @@ -423,14 +423,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Check 2. Check loop_sref is an ancestor of block_sref. Also collect // - the highest loop position (inclusive) to insert const pad value filling code above. // - the highest loop position (inclusive) to replace with in-bound value filling code. - Array loop_srefs = GetLoops(block_sref); - Array loops; + ffi::Array loop_srefs = GetLoops(block_sref); + ffi::Array loops; bool found_const_filling_pos = false; bool found_in_bound_filling_pos = false; - For const_filling_pos = GetRef(loop_sref->StmtAs()); + For const_filling_pos = ffi::GetRef(loop_sref->StmtAs()); For in_bound_filling_pos{nullptr}; for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { - For cur_loop = GetRef((*it)->StmtAs()); + For cur_loop = ffi::GetRef((*it)->StmtAs()); Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); dom_map.Set(cur_loop->loop_var, range); analyzer.Bind(cur_loop->loop_var, range); @@ -454,7 +454,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, } ICHECK(in_bound_filling_pos.defined()); if (!found_const_filling_pos) { - throw LoopPositionError(self->mod, const_filling_pos, GetRef(block), + throw LoopPositionError(self->mod, const_filling_pos, ffi::GetRef(block), "decompose_padding"); } @@ -473,7 +473,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, &analyzer); // Step 2. Execute IR replacement. - Block old_scope_root_block = GetRef(scope_root_sref->StmtAs()); + Block old_scope_root_block = ffi::GetRef(scope_root_sref->StmtAs()); Block new_scope_root = DecomposePaddingBlockReplacer::Replace(old_scope_root_block, replace_desc); if (check_only) { return block_sref; @@ -482,7 +482,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Step 3. Update schedule states. self->Replace(scope_root_sref, new_scope_root, {{old_scope_root_block, new_scope_root}, - {GetRef(block), replace_desc.in_bound_filling_block->block}}); + {ffi::GetRef(block), replace_desc.in_bound_filling_block->block}}); auto new_block_sref = self->stmt2ref.at(replace_desc.const_filling_block->block.get()); // Set block info of created const pad value filling block @@ -533,13 +533,13 @@ bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, /******** FFI ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.schedule.CanDecomposePadding", [](Schedule self, BlockRV block_rv, LoopRV loop_rv) { return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); }); -}); +} /******** InstructionKind Registration ********/ @@ -556,7 +556,8 @@ struct DecomposPaddingTraits : public UnpackedInstTraits return sch->DecomposePadding(block_rv, loop_rv); } - static String UnpackedAsPython(Array outputs, String block_rv, LoopRV loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + LoopRV loop_rv) { PythonAPICall py("decompose_padding"); py.Input("block", block_rv); py.Input("loop", loop_rv); diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 6dd1eafcc076..de550979c18f 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -29,13 +29,13 @@ class WrongBlockIterTypeError : public ScheduleError { ? "parallel" : (for_kind == ForKind::kVectorized ? "vectorize" : "bind"); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream os; os << "ScheduleError: The \"" << op_str_ << "\" cannot be fulfilled with regard to some of its underlying block"; return os.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; if (op_str_ != "bind") { os << "The \"" << op_str_ @@ -52,7 +52,7 @@ class WrongBlockIterTypeError : public ScheduleError { return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; std::string op_str_; Var loop_var_; @@ -127,8 +127,8 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind if (!self->stmt2ref.count(realize->block.get())) { return false; } - CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), - thread_scope); + CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, + ffi::GetRef(realize), thread_scope); } return true; }); @@ -144,7 +144,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind * `for_kind` is `kThreadBinding` */ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, - Optional thread_axis) { + ffi::Optional thread_axis) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); /* @@ -163,12 +163,12 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. - CheckParallelizability(self, GetRef(loop), for_kind, + CheckParallelizability(self, ffi::GetRef(loop), for_kind, thread_axis.has_value() ? runtime::ThreadScope::Create(thread_axis.value()) : runtime::ThreadScope{-1, -1}); // Step 3. Loop update and IR replacement - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->kind = for_kind; if (thread_axis.has_value()) { new_loop->thread_binding = IterVar(/*dom=*/Range(nullptr), // @@ -189,13 +189,13 @@ void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { ParallelizeComputation(self, loop_sref, ForKind::kVectorized, std::nullopt); } -void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis) { +void Bind(ScheduleState self, const StmtSRef& loop_sref, const ffi::String& thread_axis) { ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis); } void Unroll(ScheduleState self, const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->kind = ForKind::kUnrolled; new_loop->thread_binding = std::nullopt; self->Replace(loop_sref, For(new_loop), {}); @@ -216,7 +216,7 @@ struct ParallelTraits : public UnpackedInstTraits { return sch->Parallel(loop_rv); } - static String UnpackedAsPython(Array outputs, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv) { PythonAPICall py("parallel"); py.Input("loop", loop_rv); return py.Str(); @@ -239,7 +239,7 @@ struct VectorizeTraits : public UnpackedInstTraits { return sch->Vectorize(loop_rv); } - static String UnpackedAsPython(Array outputs, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv) { PythonAPICall py("vectorize"); py.Input("loop", loop_rv); return py.Str(); @@ -258,11 +258,12 @@ struct BindTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String thread) { + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, ffi::String thread) { return sch->Bind(loop_rv, thread); } - static String UnpackedAsPython(Array outputs, String loop_rv, String thread) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + ffi::String thread) { PythonAPICall py("bind"); py.Input("loop", loop_rv); py.Input("thread_axis", thread); @@ -284,7 +285,7 @@ struct UnrollTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { return sch->Unroll(loop_rv); } - static String UnpackedAsPython(Array outputs, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv) { PythonAPICall py("unroll"); py.Input("loop", loop_rv); return py.Str(); diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 588770d968ef..0ad1d82ee0df 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -22,9 +22,11 @@ namespace tvm { namespace tir { -Array GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv) { +ffi::Array GetBlocks(const ScheduleState& self, const ffi::String& name, + const GlobalVar& gv) { struct Finder : public StmtVisitor { - explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} + explicit Finder(const ScheduleState& self, const ffi::String& name) + : self_(self), name_(name) {} void VisitStmt_(const BlockNode* block) override { if (block->name_hint == name_) { @@ -36,8 +38,8 @@ Array GetBlocks(const ScheduleState& self, const String& name, const G } const ScheduleState& self_; - const String& name_; - Array results_; + const ffi::String& name_; + ffi::Array results_; }; BaseFunc func = self->mod->Lookup(gv); @@ -47,16 +49,16 @@ Array GetBlocks(const ScheduleState& self, const String& name, const G return std::move(finder.results_); } -Array GetLoops(const StmtSRef& block_sref) { +ffi::Array GetLoops(const StmtSRef& block_sref) { std::vector result; for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); parent = parent->parent) { - result.push_back(GetRef(parent)); + result.push_back(ffi::GetRef(parent)); } return {result.rbegin(), result.rend()}; } -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { +ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { private: void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } @@ -65,7 +67,7 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent explicit Collector(const ScheduleState& self) : self(self) {} const ScheduleState& self; - Array result; + ffi::Array result; }; Collector collector(self); if (parent_sref->stmt->IsInstance()) { @@ -78,17 +80,17 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent return std::move(collector.result); } -Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { +ffi::Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); return tir::GetProducers(block_sref, self->GetBlockScope(scope_root)); } -Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { +ffi::Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); } -Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { +ffi::Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { const auto* scope_block = TVM_SREF_TO_BLOCK(scope_sref); return tir::GetOutputBlocks(self, scope_block); } @@ -104,11 +106,12 @@ struct GetBlockTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, String name, String func_name) { + static BlockRV UnpackedApplyToSchedule(Schedule sch, ffi::String name, ffi::String func_name) { return sch->GetBlock(name, func_name); } - static String UnpackedAsPython(Array outputs, String name, String func_name) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String name, + ffi::String func_name) { PythonAPICall py("get_block"); py.Input("name", name); py.Input("func_name", func_name); @@ -129,11 +132,11 @@ struct GetLoopsTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetLoops(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_loops"); py.Input("block", block_rv); py.OutputList(outputs); @@ -153,7 +156,7 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { if (auto block = block_or_loop_rv.as()) { return sch->GetChildBlocks(block.value()); } @@ -164,7 +167,8 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { throw; } - static String UnpackedAsPython(Array outputs, String block_or_loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::String block_or_loop_rv) { PythonAPICall py("get_child_blocks"); py.Input("", block_or_loop_rv); py.OutputList(outputs); @@ -184,11 +188,11 @@ struct GetProducersTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetProducers(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_producers"); py.Input("block", block_rv); py.OutputList(outputs); @@ -208,11 +212,11 @@ struct GetConsumersTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetConsumers(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_consumers"); py.Input("block", block_rv); py.OutputList(outputs); @@ -232,11 +236,11 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetOutputBlocks(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_output_blocks"); py.Input("block", block_rv); py.OutputList(outputs); diff --git a/src/tir/schedule/primitive/hide_buffer_access.cc b/src/tir/schedule/primitive/hide_buffer_access.cc index 469dc278e503..f5e92b8ba50b 100644 --- a/src/tir/schedule/primitive/hide_buffer_access.cc +++ b/src/tir/schedule/primitive/hide_buffer_access.cc @@ -27,25 +27,25 @@ namespace tir { namespace { class BufTypeError : public ScheduleError { public: - explicit BufTypeError(IRModule mod, const String& buf_type) + explicit BufTypeError(IRModule mod, const ffi::String& buf_type) : mod_(std::move(mod)), buf_type_(buf_type) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Invalid buffer type for hide_buffer_access schedule."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The buffer type for hide_buffer_access schedule should either be 'read'" " or 'write', got " + buf_type_ + " instead."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; - String buf_type_; + ffi::String buf_type_; }; class InvalidIndexError : public ScheduleError { @@ -53,11 +53,11 @@ class InvalidIndexError : public ScheduleError { explicit InvalidIndexError(IRModule mod, int num_access_regions, int buf_idx) : mod_(std::move(mod)), num_access_regions_(num_access_regions), buf_idx_(buf_idx) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Invalid buffer index array for hide_buffer_access schedule."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The buffer index array for hide_buffer_access schedule should be a list of integers" " between 0 and " + std::to_string(num_access_regions_ - 1) + ", got " + std::to_string(buf_idx_) + @@ -66,7 +66,7 @@ class InvalidIndexError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -78,8 +78,9 @@ class InvalidIndexError : public ScheduleError { /******** Implementation ********/ -void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, const String& buf_type, - const Array& buf_index_array) { +void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, + const ffi::String& buf_type, + const ffi::Array& buf_index_array) { /*! * Check: * - validity of buf_index_array @@ -107,7 +108,7 @@ void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, cons /* Step 0: Collect new buffer access regions. */ - Array reads, writes; + ffi::Array reads, writes; if (buf_type == "read") { for (size_t i = 0; i < block->reads.size(); ++i) { @@ -129,12 +130,12 @@ void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, cons /* Step 1: Replace old block with the new block */ - auto n = make_object(*block); + auto n = ffi::make_object(*block); n->reads = reads; n->writes = writes; Block new_block = Block(n); - Map blk_map; - blk_map.Set(GetRef(block), new_block); + ffi::Map blk_map; + blk_map.Set(ffi::GetRef(block), new_block); self->Replace(block_sref, new_block, blk_map); } @@ -147,13 +148,13 @@ struct UnsafeHideBufferAccessTraits : public UnpackedInstTraits buf_index_array) { + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, ffi::String buf_type, + ffi::Array buf_index_array) { sch->UnsafeHideBufferAccess(block, buf_type, buf_index_array); } - static String UnpackedAsPython(Array outputs, String block, String buf_type, - Array buf_index_array) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::String buf_type, ffi::Array buf_index_array) { PythonAPICall py("unsafe_hide_buffer_access"); py.Input("block", block); py.Input("buf_type", buf_type); diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 8931c0e71c11..c625d8c153cf 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -75,8 +75,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { // Loops within the analyzed block that should be replaced struct ReplacementPlan { - Map replacements; - Map new_block_to_old; + ffi::Map replacements; + ffi::Map new_block_to_old; }; // The block to be inserted, along with the location at which it @@ -94,7 +94,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, arith::Analyzer* analyzer) { + ffi::Optional pad_value, arith::Analyzer* analyzer) { ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) << "Internal error: Should be caught by ScheduleError checks prior to this point"; TransformLayoutPlanner visitor(old_buffer); @@ -108,7 +108,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BufferStore store; // The block realize that contains the store, if any. - Optional innermost_block_realize; + ffi::Optional innermost_block_realize; // The nested loops whose values contribute to the indices used in // the store. Not all loop variables in the loopnest need to @@ -125,7 +125,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} void VisitStmt_(const ForNode* op) override { - BindLoopVar context(this, GetRef(op)); + BindLoopVar context(this, ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); } @@ -135,7 +135,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } void VisitStmt_(const BlockRealizeNode* op) override { - BindBlockRealize context(this, GetRef(op)); + BindBlockRealize context(this, ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); } @@ -158,7 +158,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } WriteInfo write_info; - write_info.store = GetRef(op); + write_info.store = ffi::GetRef(op); if (loop_dependency_range) { size_t i = loop_dependency_range.value().first; size_t j = loop_dependency_range.value().second; @@ -220,8 +220,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { class BufferStoreReplacer : public StmtExprMutator { public: BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate, - const IndexMap& inverse, const Optional& pad_value, - Map* new_block_to_old, arith::Analyzer* analyzer) + const IndexMap& inverse, const ffi::Optional& pad_value, + ffi::Map* new_block_to_old, arith::Analyzer* analyzer) : info(info), new_buffer(new_buffer), new_indices(inverse->initial_indices), @@ -250,7 +250,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BlockRealize block_realize = info.innermost_block_realize.value(); const auto& block = block_realize->block; - const Array& old_indices = info.store->indices; + const ffi::Array& old_indices = info.store->indices; const auto& old_iter_vars = block->iter_vars; this->new_iter_vars = old_iter_vars; @@ -294,10 +294,10 @@ class TransformLayoutPlanner : private StmtExprVisitor { return Var(ss.str(), var.dtype()); }); - Map + ffi::Map loop_var_to_virtual_var; // For updating padding_predicate in terms of the new indices - Array new_iter_values; // For BlockRealize - Array new_iter_vars; // For Block + ffi::Array new_iter_values; // For BlockRealize + ffi::Array new_iter_vars; // For Block for (size_t i = 0; i < block_index_start; i++) { new_iter_vars.push_back(old_iter_vars[i]); @@ -339,7 +339,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return false; } - const Array& old_indices = info.store->indices; + const ffi::Array& old_indices = info.store->indices; ICHECK_EQ(old_indices.size(), op->indices.size()); ExprDeepEqual expr_equal; @@ -351,9 +351,9 @@ class TransformLayoutPlanner : private StmtExprVisitor { return true; }(); - BufferStore store = GetRef(op); + BufferStore store = ffi::GetRef(op); if (can_replace) { - Array new_index_exprs = + ffi::Array new_index_exprs = new_indices.Map([](const auto& var) -> PrimExpr { return var; }); PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_index_exprs, analyzer)[0]; store = @@ -387,7 +387,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } Stmt VisitStmt_(const BlockNode* op) final { - Block orig = GetRef(op); + Block orig = ffi::GetRef(op); Block mutated = Downcast(StmtExprMutator::VisitStmt_(op)); RecordReplacement(orig, mutated); @@ -395,7 +395,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (auto opt = var_remap.Get(var)) { return opt.value(); } else { @@ -423,21 +423,21 @@ class TransformLayoutPlanner : private StmtExprVisitor { const WriteInfo& info; const Buffer& new_buffer; - Array new_indices; - Array new_iter_vars; - Array new_iter_values; + ffi::Array new_indices; + ffi::Array new_iter_vars; + ffi::Array new_iter_values; PrimExpr padding_predicate; const IndexMap& inverse; - const Optional& pad_value; - Map& new_block_to_old; + const ffi::Optional& pad_value; + ffi::Map& new_block_to_old; bool all_stores_replaced{true}; arith::Analyzer* analyzer; - Map var_remap; + ffi::Map var_remap; }; TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, - PrimExpr padding_predicate, Optional pad_value, + PrimExpr padding_predicate, ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value, analyzer); @@ -458,16 +458,16 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, + ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } - Array iter_vars; - Array iter_values; - Array indices; - Map loop_indices_to_block_indices; + ffi::Array iter_vars; + ffi::Array iter_values; + ffi::Array indices; + ffi::Map loop_indices_to_block_indices; ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { const auto& loop_var = inverse->initial_indices[i]; @@ -503,14 +503,14 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, + ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } - Map new_block_to_old; - auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional { + ffi::Map new_block_to_old; + auto generate_if_then_else_block = [&](const WriteInfo& info) -> ffi::Optional { if (!info.contains_row_major_traversal || !pad_value.defined() || is_zero(padding_predicate)) { return std::nullopt; @@ -534,7 +534,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return stmt; }; - Map loop_replacements; + ffi::Map loop_replacements; for (const auto& info : write_info_) { if (info.dependent_loopnest.size()) { @@ -553,15 +553,15 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, + ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } - Array iter_vars; - Array iter_values; - Array indices; + ffi::Array iter_vars; + ffi::Array iter_values; + ffi::Array indices; ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { const auto& loop_var = inverse->initial_indices[i]; @@ -673,7 +673,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BindBlockRealize& operator=(BindBlockRealize&&) = delete; TransformLayoutPlanner* self_{nullptr}; - Optional cache_; + ffi::Optional cache_; std::vector bound_vars_; }; @@ -707,7 +707,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { * * Used to fill the `WriteInfo::innermost_block_realize` field.. */ - Optional innermost_block_realize_{std::nullopt}; + ffi::Optional innermost_block_realize_{std::nullopt}; /*! \brief The buffer to be replaced */ Buffer old_buffer_; @@ -719,23 +719,23 @@ class TransformLayoutPlanner : private StmtExprVisitor { */ class ReuseBlocksCollector : public tir::StmtVisitor { public: - static Map Collect(Block result, Map new_block_to_old) { + static ffi::Map Collect(Block result, ffi::Map new_block_to_old) { return ReuseBlocksCollector(new_block_to_old).Run(result); } private: /*! \brief Entry point */ - Map Run(const Block result) { + ffi::Map Run(const Block result) { VisitStmt(result); return block_sref_reuse_; } /*! \brief Constructor */ - explicit ReuseBlocksCollector(Map new_block_to_old) + explicit ReuseBlocksCollector(ffi::Map new_block_to_old) : new_block_to_old_(new_block_to_old) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { - Block block_ref = GetRef(block); + Block block_ref = ffi::GetRef(block); auto it = new_block_to_old_.find(block_ref); if (it != new_block_to_old_.end()) { block_sref_reuse_.Set((*it).second, (*it).first); @@ -744,9 +744,9 @@ class ReuseBlocksCollector : public tir::StmtVisitor { } /*! \brief New map to be filled with just blocks from scope block */ - Map block_sref_reuse_; + ffi::Map block_sref_reuse_; /*! \brief All block replacements collected so far */ - Map new_block_to_old_; + ffi::Map new_block_to_old_; }; class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { @@ -760,10 +760,10 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { * \return The new AST rooting at the original parent scope and the map from the old block to the * new block */ - static std::pair> Rewrite( + static std::pair> Rewrite( const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, - const IndexMap& index_map, const Optional& opt_inverse, - const PrimExpr& padding_predicate, const Optional& pad_value) { + const IndexMap& index_map, const ffi::Optional& opt_inverse, + const PrimExpr& padding_predicate, const ffi::Optional& pad_value) { arith::Analyzer analyzer; auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, @@ -778,7 +778,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); } - Map block_sref_reuse = + ffi::Map block_sref_reuse = ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_); return {result, block_sref_reuse}; @@ -800,7 +800,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { } } - void RewriteBufferAccess(Buffer* buffer, Array* indices) { + void RewriteBufferAccess(Buffer* buffer, ffi::Array* indices) { *buffer = new_buffer_; *indices = index_map_->MapIndices(*indices, &index_simplifier_); *indices = this->IterMapSimplifyWithContext(*indices, true); @@ -825,7 +825,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { // replacing `loop` with `{loop, post_proc}`. In this case, avoid // infinite recursion. - For node = GetRef(op); + For node = ffi::GetRef(op); if (auto plan_ptr = std::get_if(&plan_)) { auto it = plan_ptr->replacements.find(node); if (it != plan_ptr->replacements.end()) { @@ -853,8 +853,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { return buffer_store; } - void RewriteAccessRegion(Array* old_access_regions, - const Array& infered_access_regions) { + void RewriteAccessRegion(ffi::Array* old_access_regions, + const ffi::Array& infered_access_regions) { auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { if (buffer_region->buffer.same_as(old_buffer_)) { ICHECK(infered_access_regions.size() == 1); @@ -867,7 +867,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const BlockNode* op) final { Block orig = [&]() { - Block block = GetRef(op); + Block block = ffi::GetRef(op); while (true) { if (auto it = new_block_to_old_.find(block); it != new_block_to_old_.end()) { block = (*it).second; @@ -918,8 +918,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const Buffer& new_buffer_; const IndexMap& index_map_; const TransformLayoutPlanner::TransformPlan& plan_; - Map buffer_data_to_buffer_; - Map new_block_to_old_; + ffi::Map buffer_data_to_buffer_; + ffi::Map new_block_to_old_; arith::Analyzer index_simplifier_; }; @@ -927,19 +927,19 @@ class BufferIsSubregionError : public ScheduleError { public: explicit BufferIsSubregionError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input buffer is defined in `match_buffer` of a block, it is expected" " to be a function parameter or allocated by a block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: The input buffer " << buffer_->name << " is defined in `match_buffer` of " << "a block, it is expected to be a function parameter or allocated by a block."; return os.str(); } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: @@ -952,14 +952,14 @@ class TransformationPaddingIndexMapError : public ScheduleError { TransformationPaddingIndexMapError(IRModule mod, IndexMap pad_value) : mod_(mod), pad_value_(pad_value) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: The IndexMap specifying pad_value has " << pad_value_->final_indices.size() << " outputs, should only have one output"; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; ss << "ScheduleError: Pad value is specified as " << pad_value_ << " which has " << pad_value_->final_indices.size() << " outputs, but should only have one output"; @@ -967,7 +967,7 @@ class TransformationPaddingIndexMapError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -982,13 +982,13 @@ class TransformationPaddingTypeError : public ScheduleError { pad_value_dtype_ = pad_value_->final_indices[0].dtype(); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_dtype_; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; ss << "ScheduleError: Buffer " << buffer_->name << " has elements of type " << buffer_->dtype << ", but the transformation fills padding with " << pad_value_ << ", which is of type " @@ -997,7 +997,7 @@ class TransformationPaddingTypeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1025,26 +1025,26 @@ class TransformationPaddingExpressionError : public ScheduleError { void VisitExpr_(const BufferLoadNode* op) final { if (!op->buffer.same_as(buffer_)) { - illegal_load = GetRef(op); + illegal_load = ffi::GetRef(op); } ExprVisitor::VisitExpr_(op); } const Buffer& buffer_; - Optional illegal_load; + ffi::Optional illegal_load; }; TransformationPaddingExpressionError(IRModule mod, Buffer buffer, IndexMap pad_value, BufferLoad illegal_load) : mod_(mod), buffer_(buffer), pad_value_(pad_value), illegal_load_(illegal_load) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: Pad value may not contain load from " << illegal_load_->buffer->name; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; ss << "ScheduleError: Pad value may only contain BufferLoad from the transformed buffer " << buffer_->name << ", but pad_value " << pad_value_ << " contains expression " @@ -1053,7 +1053,7 @@ class TransformationPaddingExpressionError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; Buffer buffer_; @@ -1070,13 +1070,13 @@ class TransformationIntroducesPaddingError : public ScheduleError { index_map_(std::move(index_map)), padding_predicate_(std::move(padding_predicate)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: Transformation would introduce padding at " << padding_predicate_ << "."; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { arith::Analyzer analyzer; auto new_shape = index_map_->MapShape(buffer_->shape, &analyzer); std::ostringstream os; @@ -1087,7 +1087,7 @@ class TransformationIntroducesPaddingError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1098,12 +1098,12 @@ class TransformationIntroducesPaddingError : public ScheduleError { // Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid // dtype-mismatch issues later. -IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& args) { +IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const ffi::Array& args) { const auto& initial_indices_orig = index_map->initial_indices; ICHECK(args.size() == initial_indices_orig.size()); - Array initial_indices; - Map var_map; + ffi::Array initial_indices; + ffi::Map var_map; std::optional index_dtype = std::nullopt; for (size_t i = 0; i < args.size(); ++i) { @@ -1134,8 +1134,8 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& [&](const Var& var) { return var_map.Get(var); }); } }); - Optional opt_inverse_index_map = - Downcast>(index_map->inverse_index_map); + ffi::Optional opt_inverse_index_map = + Downcast>(index_map->inverse_index_map); if (opt_inverse_index_map.defined()) { opt_inverse_index_map = LegalizeIndexMapDType(opt_inverse_index_map.value(), final_indices); } @@ -1146,13 +1146,13 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map_orig, - const Optional& pad_value, bool assume_injective_transform) { + const ffi::Optional& pad_value, bool assume_injective_transform) { arith::Analyzer analyzer; AddShapeVarBounds(self, block_sref.get(), &analyzer); // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = - GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); + GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer->shape); @@ -1176,11 +1176,11 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - Optional opt_inverse = std::nullopt; + ffi::Optional opt_inverse = std::nullopt; PrimExpr padding_predicate = Bool(false); if (!assume_injective_transform) { std::tie(opt_inverse, padding_predicate) = [&]() { - Array region; + ffi::Array region; for (const auto& dim : old_buffer->shape) { region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim)); } @@ -1200,7 +1200,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block // alloc_buffers. auto [new_stmt, block_sref_reuse] = - TransformLayoutRewriter::Rewrite(GetRef(scope_block), old_buffer, new_buffer, + TransformLayoutRewriter::Rewrite(ffi::GetRef(scope_block), old_buffer, new_buffer, index_map, opt_inverse, padding_predicate, pad_value); Block new_scope_block = Downcast(new_stmt); @@ -1211,7 +1211,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ IRModuleNode* new_mod = self->mod.CopyOnWrite(); ffi::MapObj* new_map = new_mod->functions.CopyOnWrite(); - Map new_buffer_map; + ffi::Map new_buffer_map; for (auto [var, buffer] : old_func->buffer_map) { if (buffer.same_as(old_buffer)) { buffer = new_buffer; @@ -1266,11 +1266,11 @@ class NotBijectiveAffineIndexMapError : public ScheduleError { public: NotBijectiveAffineIndexMapError(IRModule mod, IndexMap index_map) : mod_(std::move(mod)), index_map_(std::move(index_map)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The index map is not bijective affine."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The index map " << index_map_->ToPythonString() << " is not bijective affine."; return os.str(); @@ -1278,7 +1278,7 @@ class NotBijectiveAffineIndexMapError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1295,12 +1295,12 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError { explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block, IndexMap index_map) : mod_(std::move(mod)), block_(std::move(block)), index_map_(std::move(index_map)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The index map can't be applied to block iters because the number of " "parameters mismatch."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The index map " << index_map_->ToPythonString() << " can't be applied to block iters of {0} because the number of parameters mismatch. " @@ -1311,7 +1311,7 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1324,12 +1324,12 @@ class OpaqueNewIterTypeError : public ScheduleError { explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value) : mod_(std::move(mod)), block_(std::move(block)), iter_value_(std::move(iter_value)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot detect the new block iter type because it contains more than one " "type of original iter vars."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "Cannot detect the block iter type for new iter value " << iter_value_ << " in {0} because it contains more than one type of original iter vars."; @@ -1337,7 +1337,7 @@ class OpaqueNewIterTypeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1348,13 +1348,13 @@ class OpaqueNewIterTypeError : public ScheduleError { void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const IndexMap& index_map) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - const Block& block = GetRef(block_ptr); + const Block& block = ffi::GetRef(block_ptr); arith::Analyzer analyzer; AddShapeVarBounds(self, block_sref.get(), &analyzer); // Step 1: Collect outer loops and loop vars - Array loops = GetLoops(block_sref); // outer loops of the block - std::unordered_set loop_vars; // loop vars of the outer loops + ffi::Array loops = GetLoops(block_sref); // outer loops of the block + std::unordered_set loop_vars; // loop vars of the outer loops for (const StmtSRef& loop_sref : loops) { CheckLoopStartsWithZero(self, loop_sref, &analyzer); loop_vars.emplace(loop_sref->StmtAs()->loop_var.get()); @@ -1374,11 +1374,11 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, CheckBlockHasTrivialBinding(self, block_sref); // Step 3: Collect information of block iter vars - Array block_vars; // iter_var->var of each block iter - Map block_iter_dom; // domain of block iter + ffi::Array block_vars; // iter_var->var of each block iter + ffi::Map block_iter_dom; // domain of block iter std::unordered_map block_iter_type; // iter type of block iter - Array + ffi::Array block_iter_range_array; // array of block iter extents in the same order as block iters for (const auto& iter_var : block->iter_vars) { block_vars.push_back(iter_var->var); @@ -1390,15 +1390,16 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 4: Apply the IndexMap to block iters. IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map); - Array transformed_block_iters = index_map->MapIndices(block_vars, &analyzer); - Array new_block_iter_range = index_map->MapShape(block_iter_range_array, &analyzer); + ffi::Array transformed_block_iters = index_map->MapIndices(block_vars, &analyzer); + ffi::Array new_block_iter_range = + index_map->MapShape(block_iter_range_array, &analyzer); // Step 5: Create the new block after transformation. // Step 5.1: Create new block iters. After applying the IndexMap f to block iters ax_0, ..., ax_n, // create block iter each expression in f(ax_0, ..., ax_n). - Array new_block_iters; // new block iters - Array new_block_vars; // iter_var->var of new block iters + ffi::Array new_block_iters; // new block iters + ffi::Array new_block_vars; // iter_var->var of new block iters for (size_t i = 0; i < transformed_block_iters.size(); ++i) { Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype}; new_block_vars.push_back(new_block_var); @@ -1409,7 +1410,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type); } if (iter_type == kOpaque) { - throw OpaqueNewIterTypeError(self->mod, GetRef(block_ptr), transformed_block_iters[i]); + throw OpaqueNewIterTypeError(self->mod, ffi::GetRef(block_ptr), + transformed_block_iters[i]); } auto dtype = new_block_var.dtype(); new_block_iters.push_back(IterVar( @@ -1419,10 +1421,10 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters // in the body. - Map inverse_subst_map; + ffi::Map inverse_subst_map; // Construct the inverse map { - Array initial_ranges; + ffi::Array initial_ranges; for (const PrimExpr& extent : block_iter_range_array) { initial_ranges.push_back(Range::FromMinExtent(make_const(extent.dtype(), 0), extent)); } @@ -1433,20 +1435,20 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, throw NotBijectiveAffineIndexMapError(self->mod, index_map); } // old block vars written in terms of new block vars - Array inversed_new_block_vars = + ffi::Array inversed_new_block_vars = inverse_index_map->MapIndices(new_block_vars, &analyzer); for (int i = 0, n = block_vars.size(); i < n; ++i) { inverse_subst_map.Set(Downcast(block_vars[i]), inversed_new_block_vars[i]); } } - Block new_block = Downcast(Substitute(GetRef(block_ptr), inverse_subst_map)); + Block new_block = Downcast(Substitute(ffi::GetRef(block_ptr), inverse_subst_map)); new_block.CopyOnWrite()->iter_vars = new_block_iters; new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); // Step 5.3: Create outer loops for each new block iter. // Make new loop vars - Array new_loop_vars; + ffi::Array new_loop_vars; for (int i = 0; i < static_cast(new_block_iters.size()); ++i) { new_loop_vars.push_back(Var("ax" + std::to_string(i), new_block_iters[i]->var.dtype())); } @@ -1457,7 +1459,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, new_block_realize->block = new_block; // Generate outer loops - Stmt body = GetRef(new_block_realize); + Stmt body = ffi::GetRef(new_block_realize); for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; --i) { body = For(Downcast(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial, std::move(body)); @@ -1474,14 +1476,14 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, class BufferAxisSeparatorMutator : private ReplaceBufferMutator { public: static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse) { + ffi::Map* block_sref_reuse) { BufferAxisSeparatorMutator mutator(old_buffer, std::move(new_buffer), block_sref_reuse); return Downcast(mutator.VisitStmt(scope_block)); } private: BufferAxisSeparatorMutator(const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, new_buffer, block_sref_reuse) {} MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { @@ -1493,8 +1495,8 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) { new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; } else { - new_target_buffer.CopyOnWrite()->axis_separators = - Array(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0)); + new_target_buffer.CopyOnWrite()->axis_separators = ffi::Array( + new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0)); LOG(WARNING) << "Buffer view " << new_target_buffer << " has different dimensionality than backing buffer " << new_source_buffer << ". The `axis_separators` for " << new_target_buffer << "." @@ -1509,10 +1511,11 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { }; void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const Array& axis_separators) { + BufferIndexType buffer_index_type, + const ffi::Array& axis_separators) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = - GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); + GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); @@ -1527,11 +1530,11 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer Buffer new_buffer = old_buffer; new_buffer.CopyOnWrite()->axis_separators = axis_separators; - Map block_sref_reuse; + ffi::Map block_sref_reuse; // Step 2: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. - Block new_scope_block = BufferAxisSeparatorMutator::Mutate(GetRef(scope_block), old_buffer, - new_buffer, &block_sref_reuse); + Block new_scope_block = BufferAxisSeparatorMutator::Mutate( + ffi::GetRef(scope_block), old_buffer, new_buffer, &block_sref_reuse); if (!defining_site_sref.defined()) { // mutate buffer_map of the PrimFunc GlobalVar g_var; @@ -1566,16 +1569,17 @@ struct TransformLayoutTraits : public UnpackedInstTraits static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map, Integer buffer_index, Integer buffer_index_type, - Optional pad_value, + ffi::Optional pad_value, Bool assume_injective_transform) { return sch->TransformLayout(block_rv, buffer_index.IntValue(), static_cast(buffer_index_type->value), index_map, pad_value, assume_injective_transform.operator bool()); } - static String UnpackedAsPython(Array outputs, String block_rv, IndexMap index_map, - Integer buffer_index, Integer buffer_index_type, - Optional pad_value, Bool assume_injective_transform) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + IndexMap index_map, Integer buffer_index, + Integer buffer_index_type, ffi::Optional pad_value, + Bool assume_injective_transform) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); @@ -1591,13 +1595,13 @@ struct TransformLayoutTraits : public UnpackedInstTraits } public: - static ObjectRef AttrsAsJSON(const Array& attrs) { - Array attrs_record; + static ObjectRef AttrsAsJSON(const ffi::Array& attrs) { + ffi::Array attrs_record; attrs_record.reserve(kNumAttrs); attrs_record.push_back(attrs[0]); attrs_record.push_back(attrs[1]); if (attrs[2] != nullptr) { - attrs_record.push_back(String(::tvm::SaveJSON(attrs[2]))); + attrs_record.push_back(ffi::String(::tvm::SaveJSON(attrs[2]))); } else { attrs_record.push_back(attrs[2]); } @@ -1605,13 +1609,13 @@ struct TransformLayoutTraits : public UnpackedInstTraits return attrs_record; } - static Array AttrsFromJSON(const ObjectRef& attrs_record_) { - Array attrs_record = Downcast>(attrs_record_); - Array attrs; + static ffi::Array AttrsFromJSON(const ObjectRef& attrs_record_) { + ffi::Array attrs_record = Downcast>(attrs_record_); + ffi::Array attrs; attrs.push_back(attrs_record[0]); attrs.push_back(attrs_record[1]); if (attrs_record[2] != nullptr) { - attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); + attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); } else { attrs.push_back(attrs_record[2]); } @@ -1636,7 +1640,8 @@ struct TransformBlockLayoutTraits : public UnpackedInstTraitsTransformBlockLayout(block_rv, index_map); } - static String UnpackedAsPython(Array outputs, String block_rv, IndexMap index_map) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + IndexMap index_map) { PythonAPICall py("transform_block_layout"); py.Input("block", block_rv); py.Input("index_map", index_map->ToPythonString()); @@ -1644,17 +1649,17 @@ struct TransformBlockLayoutTraits : public UnpackedInstTraits& attrs) { - Array attrs_record; + static ObjectRef AttrsAsJSON(const ffi::Array& attrs) { + ffi::Array attrs_record; attrs_record.reserve(kNumAttrs); - attrs_record.push_back(String(::tvm::SaveJSON(attrs[0]))); + attrs_record.push_back(ffi::String(::tvm::SaveJSON(attrs[0]))); return attrs_record; } - static Array AttrsFromJSON(const ObjectRef& attrs_record_) { - Array attrs_record = Downcast>(attrs_record_); - Array attrs; - attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[0]))); + static ffi::Array AttrsFromJSON(const ObjectRef& attrs_record_) { + ffi::Array attrs_record = Downcast>(attrs_record_); + ffi::Array attrs; + attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[0]))); return attrs; } @@ -1672,14 +1677,16 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits axis_separators) { + Integer buffer_index_type, + ffi::Array axis_separators) { return sch->SetAxisSeparator(block_rv, buffer_index.IntValue(), static_cast(buffer_index_type->value), axis_separators); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - Integer buffer_index_type, Array axis_separators) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, Integer buffer_index_type, + ffi::Array axis_separators) { PythonAPICall py("set_axis_separator"); py.Input("block", block_rv); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 7baf4e98b775..3cd364b0fd2b 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -46,14 +46,15 @@ class BlockPredicateAppender : public StmtMutator { /*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { public: - explicit SubstituteVarAndCollectOpaqueBlock(std::function(const Var&)> vmap, - Map* opaque_blocks) + explicit SubstituteVarAndCollectOpaqueBlock( + std::function(const Var&)> vmap, + ffi::Map* opaque_blocks) : vmap_(vmap), opaque_blocks_(opaque_blocks) {} private: PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); - if (Optional ret = vmap_(var)) { + Var var = ffi::GetRef(op); + if (ffi::Optional ret = vmap_(var)) { return tvm::cast(var.dtype(), ret.value()); } else { return var; @@ -69,23 +70,24 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { } /*! \brief The substitute function */ - std::function(const Var&)> vmap_; + std::function(const Var&)> vmap_; /*! \brief The reuse mapping of opaque blocks */ - Map* opaque_blocks_; + ffi::Map* opaque_blocks_; }; /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ class IterMapSimplifyBlockBinding : public StmtExprMutator { public: - explicit IterMapSimplifyBlockBinding(ffi::MapObj* opaque_blocks, Map loop_var2extent, + explicit IterMapSimplifyBlockBinding(ffi::MapObj* opaque_blocks, + ffi::Map loop_var2extent, bool preserve_unit_iters) : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent), preserve_unit_iters_(preserve_unit_iters) {} - static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, + static For SimplifyBindings(Stmt stmt, const ffi::Array& loop_srefs, ffi::MapObj* opaque_blocks, bool preserve_unit_iters) { - Map loop_var2extent; + ffi::Map loop_var2extent; for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(sref); loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -115,7 +117,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { } return realize; } - Array v = + ffi::Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, @@ -123,7 +125,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { /*analyzer=*/&analzyer_, /*simplify_trivial_iterators=*/!preserve_unit_iters_); if (v.same_as(op->iter_values)) { - return GetRef(op); + return ffi::GetRef(op); } else { ObjectPtr n = CopyOnWrite(op); n->iter_values = std::move(v); @@ -134,7 +136,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { /*! \brief The reuse mapping */ ffi::MapObj* opaque_blocks_; /*! \brief The range of loops */ - Map loop_var2extent_; + ffi::Map loop_var2extent_; /*! \brief Internal analyzer */ arith::Analyzer analzyer_; /*! \brief Whether or not to simplify unit iterators */ @@ -161,11 +163,12 @@ class BlockPropertyError : public ScheduleError { void VisitStmt_(const BlockNode* op) final { for (const IterVar& iter_var : op->iter_vars) { if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { - throw BlockPropertyError(state_->mod, GetRef(op)); + throw BlockPropertyError(state_->mod, ffi::GetRef(op)); } - Optional high_exclusive = - top_->parent ? GetRef(top_->parent) : Optional(std::nullopt); - CheckPartialAffineBinding(state_, GetRef(op), high_exclusive); + ffi::Optional high_exclusive = top_->parent + ? ffi::GetRef(top_->parent) + : ffi::Optional(std::nullopt); + CheckPartialAffineBinding(state_, ffi::GetRef(op), high_exclusive); } } const ScheduleState& state_; @@ -173,23 +176,23 @@ class BlockPropertyError : public ScheduleError { }; BlockIterTypeAndAffineBindingChecker checker(self, top); - checker(GetRef(sref->stmt)); + checker(ffi::GetRef(sref->stmt)); } explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The block under the loops to be reordered have block iter type other " "than data-parallel or reduction"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The block {0} under the loops to be reordered have block iter type other than " "data-parallel or reduction"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -200,17 +203,17 @@ class HasAnnotationOrThreadBindingError : public ScheduleError { explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The primitive can't be applied because the loop has annotation or " "thread binding"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The primitive can't be applied because the loop {0} has annotation or thread binding"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -221,17 +224,17 @@ class OuterNotInnerParent : public ScheduleError { explicit OuterNotInnerParent(IRModule mod, For outer, For inner) : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The outer loop is not the parent of the inner loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The loops can't be fused because the outer loop {0} is not the parent of the inner " "loop {1}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {outer_, inner_}; } + ffi::Array LocationsOfInterest() const final { return {outer_, inner_}; } IRModule mod_; For outer_; @@ -243,17 +246,17 @@ class NotOnlyChildError : public ScheduleError { explicit NotOnlyChildError(IRModule mod, For outer, For inner) : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The inner loop is not the only child of outer loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The loops can't be fused because the inner loop {1} is not the only child of outer " "loop {0}."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {outer_, inner_}; } + ffi::Array LocationsOfInterest() const final { return {outer_, inner_}; } IRModule mod_; For outer_; @@ -264,16 +267,16 @@ class NotSingleInferFactorError : public ScheduleError { public: explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: only one factor can be specified as -1 or none"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Only one factor can be specified as -1 or none"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; }; @@ -282,17 +285,17 @@ class WrongFactorProductError : public ScheduleError { public: explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The product of factors is not larger than or equal to the extent of " "loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The product of factors is not larger than or equal to the extent of loop {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -302,16 +305,16 @@ class LoopMultiAppearanceError : public ScheduleError { public: explicit LoopMultiAppearanceError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Some loop appears in the input array for multiple times."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Loop {0} appears in the input array for multiple times."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -321,12 +324,14 @@ class LoopsNotAChainError : public ScheduleError { public: enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; - explicit LoopsNotAChainError(IRModule mod, Optional problematic_loop, ProblemKind kind) + explicit LoopsNotAChainError(IRModule mod, ffi::Optional problematic_loop, ProblemKind kind) : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} - String FastErrorString() const final { return "ScheduleError: the loops are not in a chain"; } + ffi::String FastErrorString() const final { + return "ScheduleError: the loops are not in a chain"; + } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::stringstream ss; ss << "The loops are not in a chain because"; if (kind_ == ProblemKind::kNotUnderAScope) { @@ -338,7 +343,7 @@ class LoopsNotAChainError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { if (kind_ == ProblemKind::kNotUnderAScope) { return {}; } else { @@ -348,17 +353,17 @@ class LoopsNotAChainError : public ScheduleError { } IRModule mod_; - Optional problematic_loop_; + ffi::Optional problematic_loop_; ProblemKind kind_; }; class DependentLoopError : public ScheduleError { public: enum class PrimitiveKind { kFuse, kReorder }; - explicit DependentLoopError(IRModule mod, For loop, String inner_var, PrimitiveKind kind) + explicit DependentLoopError(IRModule mod, For loop, ffi::String inner_var, PrimitiveKind kind) : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)), kind_(kind) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { if (kind_ == PrimitiveKind::kReorder) { return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " "in the new order"; @@ -367,7 +372,7 @@ class DependentLoopError : public ScheduleError { } } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (kind_ == PrimitiveKind::kReorder) { return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + " in the new order"; @@ -377,16 +382,17 @@ class DependentLoopError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; - String inner_var_; + ffi::String inner_var_; PrimitiveKind kind_; }; -Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors, - bool preserve_unit_iters, bool disable_predication) { +ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, bool preserve_unit_iters, + bool disable_predication) { // Invariance // - The total repeat number has not changed for each direct child block with updating predicate. // - The execution order has not changed. (The block executes with the same args and the same @@ -394,7 +400,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array // Step 1. Check correctness const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } // Currently, loops not starting with 0 are not supported arith::Analyzer analyzer; @@ -420,10 +426,10 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array analyzer.Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor))); new_loop_vars.emplace_back(std::move(var)); } - Map opaque_block_reuse; + ffi::Map opaque_block_reuse; Stmt new_stmt = loop->body; new_stmt = SubstituteVarAndCollectOpaqueBlock( - [&](const Var& v) -> Optional { + [&](const Var& v) -> ffi::Optional { if (v.same_as(loop->loop_var)) { return substitute_value; } else { @@ -444,7 +450,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array opaque_block_reuse.CopyOnWrite(), preserve_unit_iters); self->Replace(loop_sref, new_stmt, opaque_block_reuse); - Array result_srefs; + ffi::Array result_srefs; result_srefs.reserve(n); for (int i = 0; i < n; i++) { result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); @@ -458,7 +464,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { public: explicit BufferIndicesMapExtractor(Var loop_var) : loop_var_(loop_var) {} - static Map> Extract(Var loop_var, Block& block) { + static ffi::Map> Extract(Var loop_var, Block& block) { BufferIndicesMapExtractor extractor(loop_var); extractor(std::move(block->body)); return extractor.buffer_indices_map; @@ -466,7 +472,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { private: void VisitStmt_(const BufferStoreNode* store) final { - Array indices; + ffi::Array indices; bool check_ = false; for (size_t i = 0; i < store->indices.size(); i++) { const VarNode* var_node = store->indices[i].as(); @@ -482,7 +488,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* load) final { - Array indices; + ffi::Array indices; bool check_ = false; for (size_t i = 0; i < load->indices.size(); i++) { const VarNode* var_node = load->indices[i].as(); @@ -500,21 +506,21 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { void VisitStmt_(const BlockNode* op) final { StmtVisitor::VisitStmt_(op); } Var loop_var_; - Map> buffer_indices_map; + ffi::Map> buffer_indices_map; }; -Array MutateBufferRegion(Map> buffer_indices_map, - Map index_range_map, - Array region_arr) { +ffi::Array MutateBufferRegion( + ffi::Map> buffer_indices_map, + ffi::Map index_range_map, ffi::Array region_arr) { // Update the region with new Ranges and return new BufferRegion - Array new_region_arr = + ffi::Array new_region_arr = MutateArray(region_arr, [&buffer_indices_map, &index_range_map](const BufferRegion& region) { BufferRegion new_region = region; auto it = buffer_indices_map.find(new_region->buffer->name); if (it == buffer_indices_map.end()) return new_region; - Array old_indices = buffer_indices_map[new_region->buffer->name]; - Array new_ranges; + ffi::Array old_indices = buffer_indices_map[new_region->buffer->name]; + ffi::Array new_ranges; for (size_t i = 0; i < old_indices.size(); i++) { new_ranges.push_back(index_range_map[old_indices[i]]); } @@ -543,7 +549,7 @@ class BlockMutator : public StmtExprMutator { Var iter_var_ = new_block->iter_vars[inner_iter_var_index]->var; inner_iter_var_index = -1; // As we are working on cloned block, we need to create new instances of iter_var - Array new_iter_vars = + ffi::Array new_iter_vars = MutateArray(new_block->iter_vars, [this, &iter_var_](const IterVar& iter) { auto dtype = iter->var.dtype(); // Create new Var instance for each IterVar @@ -565,29 +571,29 @@ class BlockMutator : public StmtExprMutator { } // Get the (iter_var, new Range) map - Map index_range_map; + ffi::Map index_range_map; for (size_t i = 0; i < new_block->iter_vars.size(); i++) { IterVar iter = new_block->iter_vars[i]; index_range_map.Set(iter->var->name_hint, iter->dom); } // Get the (Buffer, indices) map - Map> buffer_indices_map = + ffi::Map> buffer_indices_map = BufferIndicesMapExtractor::Extract(new_loop_var_, new_block); - Array new_writes = + ffi::Array new_writes = MutateBufferRegion(buffer_indices_map, index_range_map, new_block->writes); if (!new_block->writes.same_as(new_writes)) { // Update the writes with new_writes new_block.CopyOnWrite()->writes = std::move(new_writes); } - Array new_reads = + ffi::Array new_reads = MutateBufferRegion(buffer_indices_map, index_range_map, new_block->reads); if (!new_block->reads.same_as(new_reads)) { // Update the reads with new_reads new_block.CopyOnWrite()->reads = std::move(new_reads); } - Map var_map; + ffi::Map var_map; for (size_t i = 0; i < new_block->iter_vars.size(); i++) { var_map.Set(_op->iter_vars[i]->var, new_block->iter_vars[i]->var); } @@ -598,7 +604,7 @@ class BlockMutator : public StmtExprMutator { } Stmt VisitStmt_(const BlockRealizeNode* realize) final { - Array iter_values = realize->iter_values; + ffi::Array iter_values = realize->iter_values; for (size_t i = 0; i < iter_values.size(); i++) { if (new_loop_var_.same_as(iter_values[i])) { // Get the iter_var index corresponding to loop_var iter_value index @@ -627,7 +633,7 @@ class BlockMutator : public StmtExprMutator { int inner_iter_var_index = -1; }; -const String get_block_name(Stmt loop_body) { +const ffi::String get_block_name(Stmt loop_body) { const BlockRealizeNode* blk_realize = loop_body.as(); if (blk_realize == nullptr) { return get_block_name(loop_body.as()->body); @@ -635,11 +641,11 @@ const String get_block_name(Stmt loop_body) { return blk_realize->block->name_hint; } -Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters) { +ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, bool preserve_unit_iters) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } arith::Analyzer analyzer; @@ -653,12 +659,12 @@ Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, dtype = DataType::Int(bits); } - String block_name = get_block_name(loop->body) + "_" + loop->loop_var->name_hint; + ffi::String block_name = get_block_name(loop->body) + "_" + loop->loop_var->name_hint; int n = factors.size(); PrimExpr min_value = loop->min; PrimExpr extent_value; - Array block_partitions; + ffi::Array block_partitions; block_partitions.reserve(n); // Iterate over each pair of factors and create partition @@ -696,7 +702,7 @@ Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, self->block_info[scope_root].affine_binding = scope_block_affine_binding; // Collect the SRef for each partitioned loop and return - Array partition_srefs; + ffi::Array partition_srefs; partition_srefs.reserve(n); for (int i = 0; i < n; i++) { StmtSRef partition_loop_sref = @@ -717,11 +723,11 @@ class LoopReconstructor : private StmtMutator { * \brief Create the new nest loops induced by the given loops */ void MakeNewLoop() { - Array new_loop_vars; - Array new_loop_extents; - Array new_stmts; + ffi::Array new_loop_vars; + ffi::Array new_loop_extents; + ffi::Array new_stmts; for (size_t i = 0; i < loops_.size(); i++) { - Map var_map; + ffi::Map var_map; for (size_t j = 0; j < loops_[i].size(); j++) { if (i == 0) { Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m"); @@ -748,15 +754,16 @@ class LoopReconstructor : private StmtMutator { private: Stmt VisitStmt_(const BlockNode* block) final { if (block != scope_root_.get()) { - return GetRef(block); + return ffi::GetRef(block); } return StmtMutator::VisitStmt_(block); } Stmt VisitStmt_(const ForNode* loop) final { - if (GetRef(loop) == need_remove_loop_.back()) { + if (ffi::GetRef(loop) == need_remove_loop_.back()) { return new_outer_loop_; - } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), GetRef(loop))) { + } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), + ffi::GetRef(loop))) { return Evaluate(0); } return StmtMutator::VisitStmt_(loop); @@ -764,7 +771,7 @@ class LoopReconstructor : private StmtMutator { Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final { auto ret = Downcast(StmtMutator::VisitSeqStmt_(seq_stmt, true)); - Array filtered; + ffi::Array filtered; for (Stmt stmt : ret->seq) { if (!is_no_op(stmt)) { filtered.push_back(std::move(stmt)); @@ -793,7 +800,7 @@ class LoopReconstructor : private StmtMutator { std::vector need_remove_loop_; }; -StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { +StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs) { // Invariance // - The total repeat number has not changed for each direct child block. // - The execution order has not changed. (The block executes with the same @@ -813,10 +820,10 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { for (auto p = sref.get(); p != lca.get(); p = p->parent) { if (auto loop = p->StmtAs()) { if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } - CheckLoopStartsWithZero(self, GetRef(p), &analyzer); - nest_loop_i_loops.push_back(GetRef(loop)); + CheckLoopStartsWithZero(self, ffi::GetRef(p), &analyzer); + nest_loop_i_loops.push_back(ffi::GetRef(loop)); nest_loop_i_extents.push_back(loop->extent); } } @@ -824,7 +831,7 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { const ForNode* outer_loop = nullptr; for (auto iter = nest_loop_i_loops.rbegin(); iter != nest_loop_i_loops.rend(); ++iter) { if (outer_loop && !outer_loop->body.same_as(*iter)) { - throw NotOnlyChildError(self->mod, GetRef(outer_loop), *iter); + throw NotOnlyChildError(self->mod, ffi::GetRef(outer_loop), *iter); } outer_loop = (*iter).get(); } @@ -853,7 +860,7 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { } } // Step 2. Create merged loops and replace the original loops - Block scope_root = GetRef(scope_root_sref->StmtAs()); + Block scope_root = ffi::GetRef(scope_root_sref->StmtAs()); LoopReconstructor reconstructor(scope_root, lca_nest_loops); reconstructor.MakeNewLoop(); Block new_scope_root = Downcast(reconstructor(scope_root)); @@ -862,7 +869,8 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { return self->stmt2ref.at(reconstructor.new_inner_loop_.get()); } -StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preserve_unit_iters) { +StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, + bool preserve_unit_iters) { // Invariance // - The total repeat number has not changed for each direct child block. // - The execution order has not changed. (The block executes with the same @@ -877,14 +885,14 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } if (outer_loop_sref.defined()) { if (sref->parent != outer_loop_sref.get()) { - throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); + throw OuterNotInnerParent(self->mod, ffi::GetRef(outer_loop), ffi::GetRef(loop)); } - if (!outer_loop->body.same_as(GetRef(loop))) { - throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); + if (!outer_loop->body.same_as(ffi::GetRef(loop))) { + throw NotOnlyChildError(self->mod, ffi::GetRef(outer_loop), ffi::GetRef(loop)); } } outer_loop_sref = sref; @@ -899,7 +907,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser return false; }; if (UsesVar(loop->extent, f_contain)) { - throw DependentLoopError(self->mod, GetRef(loop), used_var->name_hint, + throw DependentLoopError(self->mod, ffi::GetRef(loop), used_var->name_hint, DependentLoopError::PrimitiveKind::kFuse); } outer_loop_vars.insert(loop->loop_var.get()); @@ -915,7 +923,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser } suffix += "_fused"; Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits)); - Array substitute_value; + ffi::Array substitute_value; substitute_value.resize(loops.size()); PrimExpr lower = 1; for (int i = static_cast(loops.size()) - 1; i > 0; i--) { @@ -926,8 +934,8 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser } substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); Stmt new_stmt = loops.back()->body; - Map opaque_block_reuse; - auto f_substitute = [&](const Var& v) -> Optional { + ffi::Map opaque_block_reuse; + auto f_substitute = [&](const Var& v) -> ffi::Optional { for (int i = 0; i < n; i++) { if (v.same_as(loops[i]->loop_var)) { return substitute_value[i]; @@ -959,14 +967,14 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser * \throws ScheduleError If there are duplicate loops in the array */ std::unordered_set CollectLoopsIntoSet( - const ScheduleState& self, const Array& ordered_loop_srefs) { + const ScheduleState& self, const ffi::Array& ordered_loop_srefs) { std::unordered_set loop_srefs; loop_srefs.reserve(ordered_loop_srefs.size()); for (const StmtSRef& loop_sref : ordered_loop_srefs) { auto inserted = loop_srefs.insert(loop_sref.get()); if (!inserted.second) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - throw LoopMultiAppearanceError(self->mod, GetRef(loop)); + throw LoopMultiAppearanceError(self->mod, ffi::GetRef(loop)); } } return loop_srefs; @@ -1004,7 +1012,7 @@ std::pair GetBoundaryOfReorderRange( // `bottom`. if (visited.count(v)) { if (v != bottom) { - throw LoopsNotAChainError(self->mod, GetRef(v->stmt), + throw LoopsNotAChainError(self->mod, ffi::GetRef(v->stmt), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } bottom = loop_sref; @@ -1041,7 +1049,7 @@ std::vector GetLoopsInReorderRange(const ScheduleState& sel const ForNode* inner = loop_sref->StmtAs(); ICHECK(outer != nullptr && inner != nullptr); if (outer->body.get() != inner) { - throw LoopsNotAChainError(self->mod, GetRef(outer), + throw LoopsNotAChainError(self->mod, ffi::GetRef(outer), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } chain.push_back(loop_sref); @@ -1062,7 +1070,7 @@ std::vector GetLoopsInReorderRange(const ScheduleState& sel * reordering */ For ConstructNewLoopChain(const ScheduleState& self, std::vector chain, - const Array& ordered_loop_srefs, + const ffi::Array& ordered_loop_srefs, const std::unordered_set& loop_srefs) { std::unordered_set inner_vars; inner_vars.reserve(chain.size()); @@ -1077,7 +1085,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vectorStmtAs(); } ICHECK(copy != nullptr); - ObjectPtr n = make_object(*copy); + ObjectPtr n = ffi::make_object(*copy); if (new_loop.defined()) { n->body = new_loop; } else { @@ -1092,7 +1100,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vectormin, f_contain) || UsesVar(copy->extent, f_contain)) { - throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint, + throw DependentLoopError(self->mod, ffi::GetRef(copy), used_var->name_hint, DependentLoopError::PrimitiveKind::kReorder); } inner_vars.insert(copy->loop_var.get()); @@ -1101,7 +1109,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vector& ordered_loop_srefs) { +void Reorder(ScheduleState self, const ffi::Array& ordered_loop_srefs) { if (ordered_loop_srefs.size() <= 1) { return; } @@ -1124,12 +1132,13 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // Step 5. Replace the original loops with the reordered loops and check that outer loop is // not dependent on inner loop For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs); - self->Replace(GetRef(top), new_loop, {}); + self->Replace(ffi::GetRef(top), new_loop, {}); } StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { if (sref->stmt->IsInstance()) { - For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(sref->stmt)); + For new_loop = + For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, ffi::GetRef(sref->stmt)); self->Replace(sref, new_loop, {}); return self->stmt2ref.at(new_loop.get()); } @@ -1139,8 +1148,8 @@ StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { Stmt VisitStmt_(const BlockRealizeNode* realize) final { if (realize->block.get() == src_block_) { - new_loop_ = - For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(realize)); + new_loop_ = For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, + ffi::GetRef(realize)); return new_loop_; } return StmtMutator::VisitStmt_(realize); @@ -1151,13 +1160,13 @@ StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { }; CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block"; - StmtSRef parent_sref = GetRef(sref->parent); + StmtSRef parent_sref = ffi::GetRef(sref->parent); NewLoopCreator creator(sref->stmt); - Stmt new_stmt = creator(GetRef(parent_sref->stmt)); + Stmt new_stmt = creator(ffi::GetRef(parent_sref->stmt)); if (new_stmt->IsInstance()) { self->Replace(parent_sref, std::move(new_stmt), {}); } else { - Block old_parent_block = GetRef(parent_sref->StmtAs()); + Block old_parent_block = ffi::GetRef(parent_sref->StmtAs()); Block new_parent_block = Downcast(new_stmt); self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}}); } @@ -1176,24 +1185,26 @@ struct SplitTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { thread_local Any loop_rv{nullptr}; - thread_local Array factors{nullptr}; + thread_local ffi::Array factors{nullptr}; loop_rv = inputs[0]; - factors = Array{inputs.begin() + 1, inputs.end()}; + factors = ffi::Array{inputs.begin() + 1, inputs.end()}; packed_args[delta] = loop_rv; packed_args[delta + 1] = factors; } - static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, - Array> factors, - Bool preserve_unit_iters, Bool disable_predication) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, + ffi::Array> factors, + Bool preserve_unit_iters, + Bool disable_predication) { return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool(), disable_predication.operator bool()); } - static String UnpackedAsPython(Array outputs, String loop_rv, Array factors, - Bool preserve_unit_iters, Bool disable_predication) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + ffi::Array factors, Bool preserve_unit_iters, + Bool disable_predication) { PythonAPICall py("split"); py.Input("loop", loop_rv); py.Input("factors", factors); @@ -1217,23 +1228,23 @@ struct LoopPartitionTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { thread_local Any loop_rv{nullptr}; - thread_local Array factors{nullptr}; + thread_local ffi::Array factors{nullptr}; loop_rv = inputs[0]; - factors = Array{inputs.begin() + 1, inputs.end()}; + factors = ffi::Array{inputs.begin() + 1, inputs.end()}; packed_args[delta] = loop_rv; packed_args[delta + 1] = factors; } - static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, - Array> factors, - Bool preserve_unit_iters) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, + ffi::Array> factors, + Bool preserve_unit_iters) { return sch->LoopPartition(loop_rv, factors, preserve_unit_iters.operator bool()); } - static String UnpackedAsPython(Array outputs, String loop_rv, Array factors, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + ffi::Array factors, Bool preserve_unit_iters) { PythonAPICall py("loop_partition"); py.Input("loop", loop_rv); py.Input("factors", factors); @@ -1256,17 +1267,18 @@ struct MergeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { packed_args[delta] = inputs; } - static LoopRV UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + static LoopRV UnpackedApplyToSchedule(Schedule sch, ffi::Array loop_rvs) { return sch->Merge(loop_rvs); } - static String UnpackedAsPython(Array outputs, Array loop_rvs) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::Array loop_rvs) { PythonAPICall py("merge"); - for (const String& loop_rv : loop_rvs) { + for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); } py.SingleOutput(outputs); @@ -1287,19 +1299,19 @@ struct FuseTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { packed_args[delta] = inputs; } - static LoopRV UnpackedApplyToSchedule(Schedule sch, Array loop_rvs, + static LoopRV UnpackedApplyToSchedule(Schedule sch, ffi::Array loop_rvs, Bool preserve_unit_iters) { return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool()); } - static String UnpackedAsPython(Array outputs, Array loop_rvs, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::Array loop_rvs, Bool preserve_unit_iters) { PythonAPICall py("fuse"); - for (const String& loop_rv : loop_rvs) { + for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); } py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); @@ -1321,17 +1333,18 @@ struct ReorderTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { packed_args[delta] = inputs; } - static void UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + static void UnpackedApplyToSchedule(Schedule sch, ffi::Array loop_rvs) { return sch->Reorder(loop_rvs); } - static String UnpackedAsPython(Array outputs, Array loop_rvs) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::Array loop_rvs) { PythonAPICall py("reorder"); - for (const String& loop_rv : loop_rvs) { + for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); } return py.Str(); @@ -1361,7 +1374,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { } } - static String UnpackedAsPython(Array outputs, String rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String rv) { PythonAPICall py("add_unit_loop"); py.Input("block_or_loop", rv); py.SingleOutput(outputs); diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/tir/schedule/primitive/pad_einsum.cc index 5b724b6bd295..f66ee2f63e33 100644 --- a/src/tir/schedule/primitive/pad_einsum.cc +++ b/src/tir/schedule/primitive/pad_einsum.cc @@ -29,8 +29,9 @@ namespace tir { * \param buffer_access The BufferLoad or BufferStore * \return The indices if the indices are all Vars, otherwise std::nullopt */ -Optional> CheckTrivialBufferIndices(const Array& buffer_access) { - Array indices; +ffi::Optional> CheckTrivialBufferIndices( + const ffi::Array& buffer_access) { + ffi::Array indices; for (const PrimExpr& index : buffer_access) { if (index->IsInstance()) { continue; @@ -39,13 +40,13 @@ Optional> CheckTrivialBufferIndices(const Array& buffer_acc if (var == nullptr) { return std::nullopt; } - indices.push_back(GetRef(var)); + indices.push_back(ffi::GetRef(var)); } return indices; } -Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) { - Array indices; +ffi::Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) { + ffi::Array indices; indices.reserve(buffer_region->region.size()); for (const Range& range : buffer_region->region) { if (!tir::is_one(range->extent)) { @@ -55,7 +56,7 @@ Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) continue; } if (const auto* var = range->min.as()) { - indices.push_back(GetRef(var)); + indices.push_back(ffi::GetRef(var)); } else { return std::nullopt; } @@ -66,21 +67,21 @@ Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) /*! \brief The schedule error class when the padding size is invalid. */ class InvalidPaddingError : public ScheduleError { public: - InvalidPaddingError(IRModule mod, Block block, Array padding) + InvalidPaddingError(IRModule mod, Block block, ffi::Array padding) : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {} IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - String FastErrorString() const final { + ffi::Array LocationsOfInterest() const final { return {block_}; } + ffi::String FastErrorString() const final { return "ScheduleError: The padding size for the block is invalid."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The padding for the block {0} are invalid. It should be a list of " << block_->iter_vars.size() << " positive integers. Got " << padding_; return os.str(); } - static void Check(const ScheduleState& self, const Block& block, Array padding) { + static void Check(const ScheduleState& self, const Block& block, ffi::Array padding) { if (padding.size() != block->iter_vars.size()) { throw InvalidPaddingError(self->mod, block, padding); } @@ -94,7 +95,7 @@ class InvalidPaddingError : public ScheduleError { private: IRModule mod_; Block block_; - Array padding_; + ffi::Array padding_; }; /*! \brief The schedule error class when the block body is not an Einsum pattern. */ @@ -104,11 +105,11 @@ class NonEinsumError : public ScheduleError { : mod_(std::move(mod)), block_(std::move(block)) {} IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - String FastErrorString() const final { + ffi::Array LocationsOfInterest() const final { return {block_}; } + ffi::String FastErrorString() const final { return "ScheduleError: The block is not a computation of Einsum pattern."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The block {0} not a computation of Einsum pattern."; } @@ -120,13 +121,13 @@ class NonEinsumError : public ScheduleError { /*! \brief Data structure that represents a Einsum computation. */ struct Einsum { // The output buffer - Array output_buffers; + ffi::Array output_buffers; // The indices of the output buffer - Map> output_indices; + ffi::Map> output_indices; // The input buffers - Array input_buffers; + ffi::Array input_buffers; // The indices of the input buffers - Map> input_indices; + ffi::Map> input_indices; }; struct BufferPadding { @@ -134,10 +135,10 @@ struct BufferPadding { Buffer padded_buffer; static BufferPadding FromBufferRegion(const BufferRegion& buffer_region, - const Map& iter_extents) { + const ffi::Map& iter_extents) { BufferPadding result; result.buffer = buffer_region->buffer; - Array shape; + ffi::Array shape; shape.reserve(buffer_region->region.size()); int ndim = buffer_region->region.size(); for (int i = 0; i < ndim; ++i) { @@ -145,7 +146,7 @@ struct BufferPadding { ICHECK(pos->IsInstance() || pos->IsInstance()); if (pos->IsInstance()) { shape.push_back(IntImm(pos->dtype, 1)); - } else if (Optional extent = iter_extents.Get(Downcast(pos))) { + } else if (ffi::Optional extent = iter_extents.Get(Downcast(pos))) { shape.push_back(extent.value()); } else { shape.push_back(buffer_region->buffer->shape[i]); @@ -156,12 +157,12 @@ struct BufferPadding { return result; } - Stmt MakeCopyBlock(bool is_read, Array* blocks, arith::Analyzer* analyzer) { - Array loop_vars; - Array loop_doms; - Array iter_vars; - Array instance_dom; - Array indices; + Stmt MakeCopyBlock(bool is_read, ffi::Array* blocks, arith::Analyzer* analyzer) { + ffi::Array loop_vars; + ffi::Array loop_doms; + ffi::Array iter_vars; + ffi::Array instance_dom; + ffi::Array indices; int ndim = buffer->shape.size(); for (int i = 0; i < ndim; ++i) { PrimExpr dim{nullptr}; @@ -199,7 +200,8 @@ struct BufferPadding { } Block new_block(iter_vars, {read_region}, {write_region}, padded_buffer->name, std::move(body)); blocks->push_back(new_block); - body = BlockRealize(Array{loop_vars.begin(), loop_vars.end()}, Bool(true), new_block); + body = BlockRealize(ffi::Array{loop_vars.begin(), loop_vars.end()}, Bool(true), + new_block); for (int i = ndim - 1; i >= 0; --i) { body = For(loop_vars[i], loop_doms[i]->min, loop_doms[i]->extent, ForKind::kSerial, std::move(body)); @@ -218,7 +220,7 @@ Einsum ExtractEinsum(const ScheduleState& self, const Block& block) { throw NonEinsumError(self->mod, block); } buffer_used.insert(buffer.get()); - if (Optional> opt_indices = CheckTrivialBufferAccess(block->reads[i])) { + if (ffi::Optional> opt_indices = CheckTrivialBufferAccess(block->reads[i])) { result.input_buffers.push_back(buffer); result.input_indices.Set(buffer, opt_indices.value()); } else { @@ -232,7 +234,7 @@ Einsum ExtractEinsum(const ScheduleState& self, const Block& block) { throw NonEinsumError(self->mod, block); } buffer_used.insert(buffer.get()); - if (Optional> opt_indices = CheckTrivialBufferAccess(block->writes[i])) { + if (ffi::Optional> opt_indices = CheckTrivialBufferAccess(block->writes[i])) { result.output_buffers.push_back(buffer); result.output_indices.Set(buffer, opt_indices.value()); } else { @@ -247,12 +249,12 @@ class BufferNotAllocatedInScopeError : public ScheduleError { explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer) : mod_(std::move(mod)), buffer_(std::move(buffer)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer is not allocated as an intermediate buffer in current " "PrimFunc."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The buffer " << buffer_->name << " is not allocated as an intermediate buffer in current PrimFunc."; @@ -260,7 +262,7 @@ class BufferNotAllocatedInScopeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -273,11 +275,11 @@ class InvalidProducerError : public ScheduleError { explicit InvalidProducerError(IRModule mod, Block producer) : mod_(std::move(mod)), producer_(std::move(producer)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The producer block cannot be padded."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The producer block {0} cannot be padded. It should write to a single buffer and the " "body should be a BufferStore."; @@ -285,7 +287,7 @@ class InvalidProducerError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {producer_}; } + ffi::Array LocationsOfInterest() const final { return {producer_}; } private: IRModule mod_; @@ -296,32 +298,32 @@ class InvalidProducerError : public ScheduleError { class PadEinsumBufferReplacer : public StmtExprMutator { public: Stmt VisitStmt_(const BlockNode* old_block_ptr) final { - Block old_block = GetRef(old_block_ptr); + Block old_block = ffi::GetRef(old_block_ptr); Block block = Downcast(StmtMutator::VisitStmt_(old_block_ptr)); - Array iter_vars; + ffi::Array iter_vars; iter_vars.reserve(block->iter_vars.size()); for (const IterVar& iter_var : block->iter_vars) { - if (Optional new_dom = iter2padded_extents.Get(iter_var->var)) { - ObjectPtr new_iter_var = make_object(*iter_var.get()); + if (ffi::Optional new_dom = iter2padded_extents.Get(iter_var->var)) { + ObjectPtr new_iter_var = ffi::make_object(*iter_var.get()); new_iter_var->dom = Range::FromMinExtent(iter_var->dom->min, new_dom.value()); iter_vars.push_back(IterVar(new_iter_var)); } else { iter_vars.push_back(iter_var); } } - Array reads; + ffi::Array reads; reads.reserve(block->reads.size()); for (const BufferRegion& read : block->reads) { - if (Optional buffer = buffer_map_.Get(read->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(read->buffer)) { reads.push_back(BufferRegion(buffer.value(), read->region)); } else { reads.push_back(read); } } - Array writes; + ffi::Array writes; writes.reserve(block->writes.size()); for (const BufferRegion& write : block->writes) { - if (Optional buffer = buffer_map_.Get(write->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(write->buffer)) { writes.push_back(BufferRegion(buffer.value(), write->region)); } else { writes.push_back(write); @@ -335,10 +337,10 @@ class PadEinsumBufferReplacer : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* old_for_ptr) final { - For old_for = GetRef(old_for_ptr); + For old_for = ffi::GetRef(old_for_ptr); For new_for = Downcast(StmtMutator::VisitStmt_(old_for_ptr)); - if (Optional new_extent = loop_var2padded_extent.Get(new_for->loop_var)) { - ObjectPtr new_for_ptr = make_object(*new_for.get()); + if (ffi::Optional new_extent = loop_var2padded_extent.Get(new_for->loop_var)) { + ObjectPtr new_for_ptr = ffi::make_object(*new_for.get()); new_for_ptr->extent = new_extent.value(); new_for = For(new_for_ptr); } @@ -347,7 +349,7 @@ class PadEinsumBufferReplacer : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* old_store_ptr) final { BufferStore store = Downcast(StmtMutator::VisitStmt_(old_store_ptr)); - if (Optional buffer = buffer_map_.Get(store->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(store->buffer)) { return BufferStore(buffer.value(), store->value, store->indices); } else { return store; @@ -356,29 +358,29 @@ class PadEinsumBufferReplacer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* old_load_ptr) final { BufferLoad load = Downcast(ExprMutator::VisitExpr_(old_load_ptr)); - if (Optional buffer = buffer_map_.Get(load->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(load->buffer)) { return BufferLoad(buffer.value(), load->indices); } else { return load; } } - Map iter2padded_extents; - Map loop_var2padded_extent; - Map buffer_map_; - Map block_sref_reuse_; + ffi::Map iter2padded_extents; + ffi::Map loop_var2padded_extent; + ffi::Map buffer_map_; + ffi::Map block_sref_reuse_; }; -void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array& padding) { +void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array& padding) { arith::Analyzer analyzer; // Step 1: Input checking and error handling const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); BlockRealize realize = GetBlockRealize(self, block_sref); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - InvalidPaddingError::Check(self, GetRef(block), padding); + InvalidPaddingError::Check(self, ffi::GetRef(block), padding); // Step 2. Extract the Einsum pattern - ExtractEinsum(self, GetRef(block)); + ExtractEinsum(self, ffi::GetRef(block)); // Step 3. Figure out the padding needed PadEinsumBufferReplacer replacer; for (int i = 0, n = padding.size(); i < n; ++i) { @@ -388,15 +390,15 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Arrayvar, new_dom); if (const auto* loop_var = realize->iter_values[i].as()) { - replacer.iter2padded_extents.Set(GetRef(loop_var), new_dom); - replacer.loop_var2padded_extent.Set(GetRef(loop_var), new_dom); + replacer.iter2padded_extents.Set(ffi::GetRef(loop_var), new_dom); + replacer.loop_var2padded_extent.Set(ffi::GetRef(loop_var), new_dom); } } } - auto f_needs_padding = [&replacer](const Array& region) { + auto f_needs_padding = [&replacer](const ffi::Array& region) { for (const Range& range : region) { if (const auto* var = range->min.as()) { - if (replacer.iter2padded_extents.count(GetRef(var))) { + if (replacer.iter2padded_extents.count(ffi::GetRef(var))) { return true; } } @@ -404,7 +406,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array scope_body; + ffi::Array scope_body; if (const auto* seq_stmt = scope_block->body.as()) { scope_body = seq_stmt->seq; } else { @@ -426,10 +428,10 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array read_blocks; - Array write_blocks; - Array new_copy_blocks; - Array alloc_buffers; + ffi::Array read_blocks; + ffi::Array write_blocks; + ffi::Array new_copy_blocks; + ffi::Array alloc_buffers; for (const BufferRegion& buffer_region : block->reads) { if (f_needs_padding(buffer_region->region)) { BufferPadding bp = @@ -449,7 +451,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array new_scope_body; + ffi::Array new_scope_body; for (int i = 0; i < static_cast(scope_body.size()); ++i) { if (i != pos) { new_scope_body.push_back(scope_body[i]); @@ -462,12 +464,12 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array n = make_object(*scope_block); + ObjectPtr n = ffi::make_object(*scope_block); n->body = SeqStmt::Flatten(new_scope_body); n->alloc_buffers.insert(n->alloc_buffers.end(), alloc_buffers.begin(), alloc_buffers.end()); new_scope_block = Block(n); } - replacer.block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + replacer.block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); // Step 8. Do replacement and update flags self->Replace(scope_sref, new_scope_block, replacer.block_sref_reuse_); for (const Block& block : new_copy_blocks) { @@ -490,11 +492,12 @@ struct PadEinsumTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array padding) { + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, ffi::Array padding) { sch->PadEinsum(block, padding); } - static String UnpackedAsPython(Array outputs, String block, Array padding) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array padding) { PythonAPICall py("pad_einsum"); py.Input("block", block); py.Input("padding", padding); diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc index 9fdb322a4996..44a0f9bbe284 100644 --- a/src/tir/schedule/primitive/read_write_at.cc +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -26,7 +26,7 @@ namespace tir { using support::NDIntSet; -bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { +bool HasBuffer(const ffi::Array& buffer_regions, const Buffer& buffer) { for (const BufferRegion& buffer_region : buffer_regions) { if (buffer_region->buffer.same_as(buffer)) { return true; @@ -35,14 +35,14 @@ bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) return false; } -void RelaxBufferRegions(const Array& buffer_regions, - const Buffer& buffer, // - const Map& var_dom, // - const Map& bindings, // +void RelaxBufferRegions(const ffi::Array& buffer_regions, + const Buffer& buffer, // + const ffi::Map& var_dom, // + const ffi::Map& bindings, // std::vector* relaxed_regions) { for (const BufferRegion& buffer_region : buffer_regions) { if (buffer_region->buffer.same_as(buffer)) { - Array relaxed_region = + ffi::Array relaxed_region = arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); } @@ -53,7 +53,7 @@ class ScopeReplacer : public StmtMutator { public: static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, const ForNode* new_loop) { - ObjectPtr new_scope_block = make_object(*scope_block); + ObjectPtr new_scope_block = ffi::make_object(*scope_block); new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); new_scope_block->alloc_buffers.push_back(dst); return Block(new_scope_block); @@ -64,11 +64,11 @@ class ScopeReplacer : public StmtMutator { : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } - Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const BlockNode* block) final { return ffi::GetRef(block); } Stmt VisitStmt_(const ForNode* loop) final { if (loop == old_loop_) { found_ = true; - return GetRef(new_loop_); + return ffi::GetRef(new_loop_); } return StmtMutator::VisitStmt_(loop); } @@ -81,14 +81,14 @@ class ScopeReplacer : public StmtMutator { class ReadWriteAtBufferReplacer : public StmtExprMutator { public: explicit ReadWriteAtBufferReplacer(const Buffer& src, const Buffer& dst, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} private: Stmt VisitStmt_(const BufferStoreNode* _store) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); if (store->buffer.same_as(src_)) { - ObjectPtr new_store = make_object(*store.get()); + ObjectPtr new_store = ffi::make_object(*store.get()); new_store->buffer = dst_; return BufferStore(new_store); } @@ -98,7 +98,7 @@ class ReadWriteAtBufferReplacer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* _load) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); if (load->buffer.same_as(src_)) { - ObjectPtr new_load = make_object(*load.get()); + ObjectPtr new_load = ffi::make_object(*load.get()); new_load->buffer = dst_; return BufferLoad(new_load); } @@ -106,9 +106,9 @@ class ReadWriteAtBufferReplacer : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* _block) final { - Block old_block = GetRef(_block); + Block old_block = ffi::GetRef(_block); Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); - ObjectPtr new_block = make_object(*block.get()); + ObjectPtr new_block = ffi::make_object(*block.get()); new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); block_sref_reuse_->Set(old_block, Block(new_block)); @@ -117,16 +117,16 @@ class ReadWriteAtBufferReplacer : public StmtExprMutator { const Buffer& src_; const Buffer& dst_; - Map* block_sref_reuse_; + ffi::Map* block_sref_reuse_; }; struct ReadWriteAtImpl { template static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int buffer_index, const String& storage_scope, - Map annotations) { + int buffer_index, const ffi::String& storage_scope, + ffi::Map annotations) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer src = GetNthAccessBuffer(self, GetRef(block), buffer_index, + Buffer src = GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, is_read ? BufferIndexType::kRead : BufferIndexType::kWrite); Buffer dst = WithScope(src, storage_scope); ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); @@ -139,8 +139,8 @@ struct ReadWriteAtImpl { } private: - static Map GetLoopDomain(const StmtSRefNode* loop_sref) { - Map result; + static ffi::Map GetLoopDomain(const StmtSRefNode* loop_sref) { + ffi::Map result; for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; loop_sref = loop_sref->parent) { result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -153,7 +153,7 @@ struct ReadWriteAtImpl { /*require_stage_pipeline=*/true); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); - block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); return self_->stmt2ref.at(new_block); } @@ -166,8 +166,8 @@ struct ReadWriteAtImpl { } template - std::pair MakeLoopAndBlock(const String& new_block_name_hint) { - Array subtrees = AsArray(loop_->body); + std::pair MakeLoopAndBlock(const ffi::String& new_block_name_hint) { + ffi::Array subtrees = AsArray(loop_->body); int n_subtrees = subtrees.size(); runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); std::vector relaxed_regions; @@ -197,10 +197,10 @@ struct ReadWriteAtImpl { /*buffer=*/src_, /*var_dom=*/ arith::AsIntSet(LoopDomainOfSRefTreePath( - /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*low_inclusive=*/ffi::GetRef(self_->stmt2ref.at(block)->parent), /*high_exclusive=*/loop_sref_, /*extra_relax_scope=*/scope)), - /*bindings=*/GetBindings(GetRef(realize)), + /*bindings=*/GetBindings(ffi::GetRef(realize)), /*relaxed_regions=*/&relaxed_regions); } return false; @@ -236,7 +236,7 @@ struct ReadWriteAtImpl { // Step 3. Calculate `domain`, the domain of buffer access NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); int ndim = relaxed.size(); - Array domain; + ffi::Array domain; domain.reserve(ndim); for (int i = 0; i < ndim; ++i) { const arith::IntSet& int_set = relaxed[i]; @@ -256,42 +256,43 @@ struct ReadWriteAtImpl { ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); subtrees.insert(subtrees.begin() + insert_pos, realize); - ObjectPtr new_loop = make_object(*loop_); + ObjectPtr new_loop = ffi::make_object(*loop_); new_loop->body = SeqStmt(std::move(subtrees)); return {For(new_loop), realize}; } - BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, - const Map& loop_domain, Array domain) const { + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, + const ffi::String& name_hint, const ffi::Map& loop_domain, + ffi::Array domain) const { int n = domain.size(); std::vector loop_vars; loop_vars.reserve(n); for (int i = 0; i < n; ++i) { loop_vars.push_back(Var("ax" + std::to_string(i))); } - Map bindings; - Array iter_vars; - Array iter_values; - Array indices; + ffi::Map bindings; + ffi::Array iter_vars; + ffi::Array iter_values; + ffi::Array indices; iter_vars.reserve(n); iter_values.reserve(n); indices.reserve(n); for (int i = 0; i < n; ++i) { auto f_substitute = [&loop_domain, &bindings, &iter_vars, - &iter_values](const Var& var) -> Optional { + &iter_values](const Var& var) -> ffi::Optional { auto it = bindings.find(var); if (it != bindings.end()) { return (*it).second; } Range range = loop_domain.at(var); - ObjectPtr v = make_object(*var.get()); + ObjectPtr v = ffi::make_object(*var.get()); v->name_hint = "v" + std::to_string(iter_vars.size()); bindings.Set(var, Var(v)); iter_values.push_back(var); iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); return Var(v); }; - ObjectPtr dom = make_object(*domain[i].get()); + ObjectPtr dom = ffi::make_object(*domain[i].get()); dom->min = Substitute(std::move(dom->min), f_substitute); dom->extent = Substitute(std::move(dom->extent), f_substitute); domain.Set(i, Range(dom)); @@ -318,7 +319,7 @@ struct ReadWriteAtImpl { } explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, - const Buffer& dst, Map annotations) + const Buffer& dst, ffi::Map annotations) : self_(self), loop_sref_(loop_sref), loop_(nullptr), @@ -335,19 +336,19 @@ struct ReadWriteAtImpl { const ForNode* loop_; const Buffer& src_; const Buffer& dst_; - Map annotations_; - Map block_sref_reuse_; + ffi::Map annotations_; + ffi::Map block_sref_reuse_; std::unique_ptr analyzer_; }; StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope) { + int read_buffer_index, const ffi::String& storage_scope) { return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, {{tir::attr::auto_copy, true}}); } StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int write_buffer_index, const String& storage_scope) { + int write_buffer_index, const ffi::String& storage_scope) { return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, storage_scope, {{tir::attr::auto_copy, true}}); } @@ -364,14 +365,15 @@ struct ReadAtTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int buffer_index, const String& storage_scope); + int buffer_index, const ffi::String& storage_scope); static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, - Integer read_buffer_index, String storage_scope) { + Integer read_buffer_index, ffi::String storage_scope) { return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String loop, String block, - Integer read_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop, + ffi::String block, Integer read_buffer_index, + ffi::String storage_scope) { PythonAPICall py("read_at"); py.Input("loop", loop); py.Input("block", block); @@ -395,12 +397,13 @@ struct WriteAtTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, - Integer write_buffer_index, String storage_scope) { + Integer write_buffer_index, ffi::String storage_scope) { return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String loop, String block, - Integer write_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop, + ffi::String block, Integer write_buffer_index, + ffi::String storage_scope) { PythonAPICall py("write_at"); py.Input("loop", loop); py.Input("block", block); diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index b46801a0684d..0629757a13d8 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -67,7 +67,7 @@ class DecomposeReductionBlockReplacer : public StmtMutator { p_new_block->name_hint = p_new_block->name_hint + "_update"; p_new_block->init = std::nullopt; // Add write regions back to read regions in update block. - Array new_reads; + ffi::Array new_reads; std::unordered_set read_bufs; for (const BufferRegion& read_access : block->reads) { read_bufs.insert(read_access->buffer.get()); @@ -89,7 +89,7 @@ class DecomposeReductionBlockReplacer : public StmtMutator { } Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array new_stmts; + ffi::Array new_stmts; new_stmts.reserve(seq->seq.size()); for (const Stmt& old_stmt : seq->seq) { new_stmts.push_back(VisitStmt(old_stmt)); @@ -108,7 +108,7 @@ class LoopHeightError : public ScheduleError { public: static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block, const BlockRealizeNode* realize, - const Array& loops, + const ffi::Array& loops, const StmtSRef& loop_sref) { for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { // For each block var of type kCommReduce, check its binding @@ -126,7 +126,7 @@ class LoopHeightError : public ScheduleError { const Var& loop_var = higher_loop->StmtAs()->loop_var; if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - throw LoopHeightError(mod, GetRef(loop), GetRef(block)); + throw LoopHeightError(mod, ffi::GetRef(loop), ffi::GetRef(block)); } } } @@ -135,12 +135,12 @@ class LoopHeightError : public ScheduleError { explicit LoopHeightError(IRModule mod, For loop, Block block) : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops " "related to reduce block var"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops " "related to reduce block var of block {1}"; @@ -148,7 +148,7 @@ class LoopHeightError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_, block_}; } + ffi::Array LocationsOfInterest() const final { return {loop_, block_}; } IRModule mod_; For loop_; @@ -188,14 +188,14 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Get the outer loops from high to low - Array loops = GetLoops(block_sref); + ffi::Array loops = GetLoops(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); if (self->enable_check) { // Cond 0. Check loop_sref is an ancestor of block_sref if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { - throw LoopPositionError(self->mod, GetRef(loop), GetRef(block), + throw LoopPositionError(self->mod, ffi::GetRef(loop), ffi::GetRef(block), "decompose_reduction"); } // Cond 1. Check block is reduction @@ -204,8 +204,8 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); } // IR Manipulation - ObjectPtr init_block = make_object(); - ObjectPtr init_realize = make_object(); + ObjectPtr init_block = ffi::make_object(); + ObjectPtr init_realize = ffi::make_object(); init_block->name_hint = block->name_hint + "_init"; init_block->annotations = block->annotations; init_realize->iter_values = {}; @@ -268,33 +268,32 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, std::unordered_map loop_var_map; Stmt body = BlockRealize(init_realize); for (int i : chosen_loops) { - const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]); + For old_loop = ffi::GetRef(TVM_SREF_TO_FOR(loops[i])); // Create a new equivalent to the chosen loop Var old_loop_var = old_loop->loop_var; Var new_loop_var = old_loop_var.copy_with_suffix("_init"); loop_var_map[old_loop_var] = new_loop_var; - Optional opt_thread_binding = old_loop->thread_binding; + ffi::Optional opt_thread_binding = old_loop->thread_binding; if (opt_thread_binding) { auto thread_binding = opt_thread_binding.value(); auto new_var = thread_binding->var.copy_with_suffix(""); thread_binding.CopyOnWrite()->var = new_var; opt_thread_binding = thread_binding; } - body = For(/*loop_var=*/new_loop_var, - /*min=*/old_loop->min, - /*extent=*/old_loop->extent, - /*kind=*/old_loop->kind, - /*body=*/body, - /*thread_binding=*/opt_thread_binding); + auto new_loop = old_loop.CopyOnWrite(); + new_loop->loop_var = new_loop_var; + new_loop->thread_binding = opt_thread_binding; + new_loop->body = body; + body = ffi::GetRef(new_loop); } body = Substitute(body, loop_var_map); // Step 6. Mutate IR const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); auto [new_scope_root, new_reduction_block] = DecomposeReductionBlockReplacer::Replace( - GetRef(old_scope_root), GetRef(loop), body, GetRef(block)); + ffi::GetRef(old_scope_root), ffi::GetRef(loop), body, ffi::GetRef(block)); self->Replace(scope_root_sref, new_scope_root, - {{GetRef(old_scope_root), new_scope_root}, - {GetRef(block), new_reduction_block}}); + {{ffi::GetRef(old_scope_root), new_scope_root}, + {ffi::GetRef(block), new_reduction_block}}); self->UpdateScopeBlockInfo(new_scope_root); return self->stmt2ref.at(init_block.get()); } @@ -312,112 +311,114 @@ struct ReducerRegistry { : reducer_getters{ CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{x[0] + y[0]}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{x[0] + y[0]}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 0)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, 0)}; }), CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{x[0] * y[0]}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{x[0] * y[0]}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 1)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, 1)}; }), CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{min(x[0], y[0])}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{min(x[0], y[0])}; }, - [](const Array& values) { - return Array{max_value(values[0]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{max_value(values[0]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{max(x[0], y[0])}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{max(x[0], y[0])}; }, - [](const Array& values) { - return Array{min_value(values[0]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{min_value(values[0]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { - return Array{x[0] + y[0], x[1] + y[1]}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{x[0] + y[0], x[1] + y[1]}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 0), - make_const(values[1]->dtype, 0)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, 0), + make_const(values[1]->dtype, 0)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]); PrimExpr val = Select(x[1] >= y[1], x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - min_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + min_value(values[1]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]); PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - min_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + min_value(values[1]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]); PrimExpr val = Select(x[1] <= y[1], x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - max_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + max_value(values[1]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select( Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]); PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - max_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + max_value(values[1]->dtype)}; })} {} static void RegisterReducer( - int n_buffers, ffi::TypedFunction(Array, Array)> combiner_getter, - ffi::TypedFunction(Array)> identity_getter) { + int n_buffers, + ffi::TypedFunction(ffi::Array, ffi::Array)> combiner_getter, + ffi::TypedFunction(ffi::Array)> identity_getter) { ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( n_buffers, std::move(combiner_getter), std::move(identity_getter))); } - static ffi::TypedFunction(Array)> CreateReducerGetter( - int n_buffers, ffi::TypedFunction(Array, Array)> combiner_getter, - ffi::TypedFunction(Array)> identity_getter) { + static ffi::TypedFunction(ffi::Array)> CreateReducerGetter( + int n_buffers, + ffi::TypedFunction(ffi::Array, ffi::Array)> combiner_getter, + ffi::TypedFunction(ffi::Array)> identity_getter) { return [n_buffers, // combiner_getter = std::move(combiner_getter), // identity_getter = std::move(identity_getter) // - ](Array values) -> Optional { + ](ffi::Array values) -> ffi::Optional { if (static_cast(values.size()) != n_buffers) { return std::nullopt; } - Array lhs; - Array rhs; + ffi::Array lhs; + ffi::Array rhs; for (int i = 0; i < n_buffers; ++i) { lhs.push_back(Var("x" + std::to_string(i), values[i]->dtype)); rhs.push_back(Var("y" + std::to_string(i), values[i]->dtype)); @@ -431,10 +432,11 @@ struct ReducerRegistry { return &instance; } - std::vector(Array)>> reducer_getters; + std::vector(ffi::Array)>> reducer_getters; }; -std::vector(Array)>> GetReducerGetters() { +std::vector(ffi::Array)>> +GetReducerGetters() { return ReducerRegistry::Global()->reducer_getters; } @@ -443,12 +445,12 @@ class NotSerialLoopKindError : public ScheduleError { explicit NotSerialLoopKindError(IRModule mod, For loop) : mod_(std::move(mod)), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input loop of rfactor is required to be `kSerial`"; } - String DetailRenderTemplate() const final { - String str_kind = ForKind2String(loop_->kind); + ffi::String DetailRenderTemplate() const final { + ffi::String str_kind = ForKind2String(loop_->kind); std::ostringstream os; os << "ScheduleError: The input loop {0} of rfactor is required to be `Serial`. However, the " "kind of {0} is `" @@ -457,7 +459,7 @@ class NotSerialLoopKindError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -468,12 +470,12 @@ class FactorAxisOutOfRangeError : public ScheduleError { explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) : mod_(std::move(mod)), buffer_(std::move(buffer)), factor_axis_(factor_axis) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input `factor_axis` is out of range. It is required to be in range " "[-(ndim + 1), ndim] where `ndim` is the number of dimensions of the write buffer"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; int ndim = static_cast(buffer_->shape.size()); os << "The write buffer " << buffer_->name << " has " << ndim @@ -484,7 +486,7 @@ class FactorAxisOutOfRangeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int factor_axis) { int ndim = static_cast(buffer->shape.size()); @@ -515,7 +517,7 @@ class LoopPropertyError : public ScheduleError { explicit LoopPropertyError(IRModule mod, For loop, ErrorType error_type) : mod_(std::move(mod)), loop_(std::move(loop)), error_type_(error_type) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { switch (error_type_) { case kDataParIterTouchRFactorLoop: return "ScheduleError: The loop to be applied rfactor is required not to be touched by any " @@ -534,7 +536,7 @@ class LoopPropertyError : public ScheduleError { throw; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { switch (error_type_) { case kDataParIterTouchRFactorLoop: return "The loop to be applied rfactor is {0}, which is required not to be touched by any " @@ -554,13 +556,13 @@ class LoopPropertyError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } - static void CheckLoopProperty(const ScheduleState& self, const Array& loops, + static void CheckLoopProperty(const ScheduleState& self, const ffi::Array& loops, const ForNode* rf_loop, const Block& block, const std::unordered_set& data_par_loop_vars, const std::unordered_set& reduce_loop_vars) { - Array children_of_outermost_loop = + ffi::Array children_of_outermost_loop = GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get())); if (!children_of_outermost_loop[0]->block.same_as(block)) { throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop); @@ -601,7 +603,7 @@ class LoopPropertyError : public ScheduleError { * \param loops The loops to be analyzed * \return A mapping from loops to their corresponding loop vars */ -std::unordered_map GetLoopVar2LoopMap(const Array& loops) { +std::unordered_map GetLoopVar2LoopMap(const ffi::Array& loops) { std::unordered_map loop_vars2loop; loop_vars2loop.reserve(loops.size()); for (const For& loop : loops) { @@ -619,16 +621,16 @@ std::unordered_map GetLoopVar2LoopMap(const Array& loo * \param rf_loop The rfactor loop * \return The new created intermediate rfactor buffer */ -Array CreateRFactorBuffers(const Array& buf_stores, int factor_axis, - const ForNode* rf_loop) { - Array rf_buffers; +ffi::Array CreateRFactorBuffers(const ffi::Array& buf_stores, int factor_axis, + const ForNode* rf_loop) { + ffi::Array rf_buffers; rf_buffers.reserve(buf_stores.size()); for (const BufferStore& buf_store : buf_stores) { Buffer buffer = buf_store->buffer; - Array rf_shape = buffer->shape; + ffi::Array rf_shape = buffer->shape; rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); - ObjectPtr n = make_object(*buffer.get()); + ObjectPtr n = ffi::make_object(*buffer.get()); n->shape = rf_shape; n->name = buffer->name + ".rf"; n->data = buffer->data.copy_with_suffix(".rf"); @@ -648,8 +650,8 @@ Array CreateRFactorBuffers(const Array& buf_stores, int fac class BaseBlockCreator { public: explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, - Array old_reduction_updates, CommReducer reducer, - Array rf_buffers, bool is_rf_block) + ffi::Array old_reduction_updates, CommReducer reducer, + ffi::Array rf_buffers, bool is_rf_block) : old_block_realize_(std::move(old_block_realize)), rf_loop_(std::move(rf_loop)), old_reduction_updates_(std::move(old_reduction_updates)), @@ -681,13 +683,13 @@ class BaseBlockCreator { // accesses, and the reduction LHS and RHS of the stored values. PreProcess(); Stmt block_body = Substitute(CreateBlockBody(has_reduce_iter), var_map_); - Optional block_init = CreateBlockInit(has_reduce_iter); + ffi::Optional block_init = CreateBlockInit(has_reduce_iter); if (block_init.defined()) { block_init = Substitute(block_init.value(), var_map_); } CreateReadWriteRegions(); - String new_block_name = old_block_realize_->block->name_hint; + ffi::String new_block_name = old_block_realize_->block->name_hint; PrimExpr predicate = const_true(); if (is_rf_block_) { new_block_name = new_block_name + "_rf"; @@ -713,7 +715,7 @@ class BaseBlockCreator { virtual void CreateReadWriteRegions() = 0; Stmt CreateBlockBody(bool has_reduce_iter) { - Array buf_stores; + ffi::Array buf_stores; buf_stores.reserve(n_buffers_); // Case 1. If the block has no reduction iterator, we just store the RHS values into the @@ -726,14 +728,14 @@ class BaseBlockCreator { } // Case 2. If the reduction is for single buffer, the block body is a single BufferStore. - Array stored_values = (*reducer_.get())(update_lhs_, update_rhs_); + ffi::Array stored_values = (*reducer_.get())(update_lhs_, update_rhs_); if (n_buffers_ == 1) { return BufferStore(update_buffers_[0], stored_values[0], update_indices_[0]); } // Case 3. In case the reduction is for multiple buffers, we should create the reduction with // LetStmt so that the reduction execution generates correct results. - Array let_vars; + ffi::Array let_vars; let_vars.reserve(n_buffers_); for (int i = 0; i < n_buffers_; ++i) { Var var("v_" + update_buffers_[i]->name, PrimType(stored_values[i]->dtype)); @@ -747,12 +749,12 @@ class BaseBlockCreator { return body; } - Optional CreateBlockInit(bool has_reduce_iter) { + ffi::Optional CreateBlockInit(bool has_reduce_iter) { if (!has_reduce_iter) { return std::nullopt; } - Array inits; + ffi::Array inits; inits.reserve(n_buffers_); for (int i = 0; i < n_buffers_; ++i) { inits.push_back( @@ -767,7 +769,7 @@ class BaseBlockCreator { /*! \brief The new created block-realize */ BlockRealize new_block_realize_; /*! \brief The indices used to access the intermediate rfactor buffer */ - Array rf_buf_access_indices_; + ffi::Array rf_buf_access_indices_; protected: /*! \brief The old block-realize */ @@ -777,18 +779,18 @@ class BaseBlockCreator { /*! \brief The rfactor loop */ For rf_loop_; /*! \brief The update BufferStores of the old block */ - Array old_reduction_updates_; + ffi::Array old_reduction_updates_; /*! \brief The matched commutative reducer */ CommReducer reducer_; /*! \brief The intermediate rfactor buffers */ - Array rf_buffers_; + ffi::Array rf_buffers_; /*! \brief The number of rfactor buffers. */ const int n_buffers_; /*! * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced * by the expressions in future substitution for the two blocks */ - Map var_map_; + ffi::Map var_map_; /*! \brief Whether we are creating the rfactor block or the write-back block */ bool is_rf_block_; @@ -797,17 +799,17 @@ class BaseBlockCreator { /*! \brief The new block iter bindings of the new created block-realize */ std::vector iter_values_; /*! \brief The buffers updated in this block */ - Array update_buffers_; + ffi::Array update_buffers_; /*! \brief The indices of the buffers updated in this block, respectively */ - Array> update_indices_; + ffi::Array> update_indices_; /*! \brief The LHS values of the reduction in this block */ - Array update_lhs_; + ffi::Array update_lhs_; /*! \brief THe RHS values of the reduction in this block */ - Array update_rhs_; + ffi::Array update_rhs_; /*! \brief The read regions of the new created block */ - Array read_regions_; + ffi::Array read_regions_; /*! \brief The write regions of the new created block */ - Array write_regions_; + ffi::Array write_regions_; }; /*! @@ -835,10 +837,10 @@ class BaseBlockCreator { class RFactorBlockCreator : public BaseBlockCreator { public: explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, - Array old_reduction_updates, CommReducer reducer, - Array rf_buffers, + ffi::Array old_reduction_updates, CommReducer reducer, + ffi::Array rf_buffers, std::unordered_map loop_vars2loop, - int factor_axis, Array combiner_rhs) + int factor_axis, ffi::Array combiner_rhs) : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), std::move(old_reduction_updates), std::move(reducer), std::move(rf_buffers), true), @@ -872,7 +874,7 @@ class RFactorBlockCreator : public BaseBlockCreator { ICHECK(old_iter->iter_type == kCommReduce); // This block iter is a reduction block iter that touches the rfactor loop. So next we try to // create a new block iter for all loop vars that appear in the old binding. - Array vars_in_old_binding = UndefinedVars(old_binding); + ffi::Array vars_in_old_binding = UndefinedVars(old_binding); for (const Var& var : vars_in_old_binding) { auto it = loop_vars2loop_.find(var.get()); if (it == loop_vars2loop_.end()) { @@ -909,7 +911,7 @@ class RFactorBlockCreator : public BaseBlockCreator { } void CreateReadWriteRegions() final { - Map buffer_map; + ffi::Map buffer_map; for (int i = 0; i < n_buffers_; ++i) { buffer_map.Set(old_reduction_updates_[i]->buffer, rf_buffers_[i]); } @@ -921,11 +923,11 @@ class RFactorBlockCreator : public BaseBlockCreator { } write_regions_.reserve(old_block->writes.size()); for (const BufferRegion& write_region : old_block->writes) { - Array region = write_region->region; + ffi::Array region = write_region->region; region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, make_const(additional_iter_->var.dtype(), 1))); - Optional rf_buffer = buffer_map.Get(write_region->buffer); + ffi::Optional rf_buffer = buffer_map.Get(write_region->buffer); ICHECK(rf_buffer.defined()); write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_))); } @@ -944,7 +946,7 @@ class RFactorBlockCreator : public BaseBlockCreator { /*! \brief The factor_axis specified for rfactor */ int factor_axis_; /*! \brief The RHS values of the reduction in the old block */ - Array combiner_rhs_; + ffi::Array combiner_rhs_; /*! * \brief A mapping which maps loop vars to new created block iters. This map is used to * substitute the loop vars which appear in the bindings of some old block iters with the new @@ -960,10 +962,10 @@ class RFactorBlockCreator : public BaseBlockCreator { class WriteBackBlockCreator : public BaseBlockCreator { public: explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, - Array old_reduction_updates, CommReducer reducer, - Array rf_buffers, IterVar rf_additional_iter, - Array combiner_lhs, - Array rf_buf_access_indices) + ffi::Array old_reduction_updates, CommReducer reducer, + ffi::Array rf_buffers, IterVar rf_additional_iter, + ffi::Array combiner_lhs, + ffi::Array rf_buf_access_indices) : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), std::move(old_reduction_updates), std::move(reducer), std::move(rf_buffers), false), @@ -1009,12 +1011,12 @@ class WriteBackBlockCreator : public BaseBlockCreator { CreateRegion(update_lhs_, false); } - void CreateRegion(const Array& buf_loads, bool is_read) { - Array& buf_regions = is_read ? read_regions_ : write_regions_; + void CreateRegion(const ffi::Array& buf_loads, bool is_read) { + ffi::Array& buf_regions = is_read ? read_regions_ : write_regions_; for (const PrimExpr& expr : buf_loads) { const auto* buf_load = expr.as(); ICHECK(buf_load != nullptr); - Array region; + ffi::Array region; region.reserve(buf_load->indices.size()); for (const PrimExpr& index : buf_load->indices) { region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1))); @@ -1027,7 +1029,7 @@ class WriteBackBlockCreator : public BaseBlockCreator { /*! \brief The new created additional block iter of the rfactor block */ IterVar rf_additional_iter_; /*! \brief The LHS values of the reduction in the old block */ - Array combiner_lhs_; + ffi::Array combiner_lhs_; }; /*! @@ -1037,11 +1039,11 @@ class WriteBackBlockCreator : public BaseBlockCreator { * \param loops The loops to be wrapped over the rfactor block * \return A Stmt which is the wrapping result */ -Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array& loops) { +Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const ffi::Array& loops) { int n_loops = static_cast(loops.size()); // Step 1. Create new loop vars. - Array new_loops; + ffi::Array new_loops; std::unordered_map new_loop_var_map; new_loops.reserve(n_loops); new_loop_var_map.reserve(n_loops); @@ -1051,7 +1053,7 @@ Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array new_bindings; + ffi::Array new_bindings; new_bindings.reserve(rf_block_realize->iter_values.size()); for (const PrimExpr& old_binding : rf_block_realize->iter_values) { new_bindings.push_back(Substitute(old_binding, new_loop_var_map)); @@ -1065,7 +1067,7 @@ Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array= 0; --i) { - ObjectPtr p_loop = make_object(*loops[i].get()); + ObjectPtr p_loop = ffi::make_object(*loops[i].get()); p_loop->loop_var = Downcast(new_loop_var_map[loops[i]->loop_var.get()]); p_loop->body = rf_body; rf_body = For(std::move(p_loop)); @@ -1102,7 +1104,7 @@ class BlockReplacer : public StmtMutator { BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, std::unordered_set reduce_loop_vars, std::unordered_map loop_vars2loop, - const Array& rf_buffers) { + const ffi::Array& rf_buffers) { BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), std::move(wb_block_realize), std::move(old_block_realize), std::move(rf_loop), std::move(reduce_loop_vars), @@ -1133,7 +1135,7 @@ class BlockReplacer : public StmtMutator { // that the scope root block has stage-pipeline property, if this loop is not outside the // reduction block, there's no need to recursively mutate. if (!loop_vars2loop_.count(loop->loop_var.get())) { - return GetRef(loop); + return ffi::GetRef(loop); } // Step 2. Recursively mutate. @@ -1160,7 +1162,7 @@ class BlockReplacer : public StmtMutator { } Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array new_stmts; + ffi::Array new_stmts; new_stmts.reserve(static_cast(seq->seq.size())); for (const Stmt old_stmt : seq->seq) { @@ -1195,7 +1197,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax } const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref); if (rf_loop->kind != ForKind::kSerial) { - throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); + throw NotSerialLoopKindError(self->mod, ffi::GetRef(rf_loop)); } // Step 2. Collect loop vars that are touched by data parallel block iters and reduction block @@ -1206,7 +1208,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // Step 3. Collect the loops of the reduction block. Construct a mapping from loops to // corresponding loop vars. - Array loops = LoopSRefs2Loops(GetLoops(block_sref)); + ffi::Array loops = LoopSRefs2Loops(GetLoops(block_sref)); std::unordered_map loop_vars2loop = GetLoopVar2LoopMap(loops); // Step 4. Check four properties that the loops should have: @@ -1224,11 +1226,11 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs // will be used when constructing the rfactor block. - Array init_values{nullptr}; - Array updates{nullptr}; + ffi::Array init_values{nullptr}; + ffi::Array updates{nullptr}; CommReducer reducer{nullptr}; - Array combiner_lhs{nullptr}; - Array combiner_rhs{nullptr}; + ffi::Array combiner_lhs{nullptr}; + ffi::Array combiner_rhs{nullptr}; std::tie(init_values, updates) = GetInitValuesAndUpdatesFromReductionBlock(self, block); std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(self, init_values, updates); @@ -1246,16 +1248,16 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional // dimension that specified by `factor_axis` and `rf_loop`. - Array rf_buffers = CreateRFactorBuffers(updates, factor_axis, rf_loop); + ffi::Array rf_buffers = CreateRFactorBuffers(updates, factor_axis, rf_loop); // Step 2. Create the rfactor block. - RFactorBlockCreator rf_block_creator(block_realize, GetRef(rf_loop), updates, reducer, + RFactorBlockCreator rf_block_creator(block_realize, ffi::GetRef(rf_loop), updates, reducer, rf_buffers, loop_vars2loop, factor_axis, std::move(combiner_rhs)); rf_block_creator.CreateBlock(); // Step 3. Create the write-back block. - WriteBackBlockCreator wb_block_creator(block_realize, GetRef(rf_loop), updates, reducer, + WriteBackBlockCreator wb_block_creator(block_realize, ffi::GetRef(rf_loop), updates, reducer, rf_buffers, std::move(rf_block_creator.additional_iter_), std::move(combiner_lhs), std::move(rf_block_creator.rf_buf_access_indices_)); @@ -1269,10 +1271,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // ***************************************************** // Step 1. Substitute the old scope root block with the new scope root block. - Block old_scope_root_block = GetRef(scope_root->StmtAs()); + Block old_scope_root_block = ffi::GetRef(scope_root->StmtAs()); Block new_scope_root_block = BlockReplacer::Replace( old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, - GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers); + ffi::GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers); self->Replace( scope_root, new_scope_root_block, {{old_scope_root_block, new_scope_root_block}, {block, wb_block_creator.new_block_}}); @@ -1304,7 +1306,8 @@ struct DecomposeReductionTraits : public UnpackedInstTraitsDecomposeReduction(block_rv, loop_rv); } - static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + ffi::String loop_rv) { PythonAPICall py("decompose_reduction"); py.Input("block", block_rv); py.Input("loop", loop_rv); @@ -1329,7 +1332,8 @@ struct RFactorTraits : public UnpackedInstTraits { return sch->RFactor(loop_rv, factor_axis->value); } - static String UnpackedAsPython(Array outputs, String loop_rv, Integer factor_axis) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + Integer factor_axis) { PythonAPICall py("rfactor"); py.Input("loop", loop_rv); py.Input("factor_axis", factor_axis->value); @@ -1346,7 +1350,7 @@ TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); /******** FFI ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.schedule.RegisterReducer", @@ -1354,7 +1358,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), std::move(identity_getter)); }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/reorder_block_iter_var.cc b/src/tir/schedule/primitive/reorder_block_iter_var.cc index c7967a3ee904..6acc5fa2d924 100644 --- a/src/tir/schedule/primitive/reorder_block_iter_var.cc +++ b/src/tir/schedule/primitive/reorder_block_iter_var.cc @@ -27,29 +27,29 @@ namespace tir { */ class InvalidReorderIndex : public ScheduleError { public: - explicit InvalidReorderIndex(IRModule mod, Block block, Array new_order) + explicit InvalidReorderIndex(IRModule mod, Block block, ffi::Array new_order) : mod_(mod), block_(block), new_order_(new_order) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The specified reorder indices are invalid."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The user provided block itervar index order " << new_order_ << " is not a valid permutation of [0, 1, ..., num_block_iter_vars-1] in block {0}."; - return String(os.str()); + return ffi::String(os.str()); } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; Block block_; - Array new_order_; + ffi::Array new_order_; }; class BlockIterVarRewriter : public StmtMutator { public: - Map block_map; + ffi::Map block_map; explicit BlockIterVarRewriter(const BlockNode* block_n, std::vector order) : order_(std::move(order)), block_to_rewrite(block_n) {} @@ -60,8 +60,8 @@ class BlockIterVarRewriter : public StmtMutator { if (op->block.get() == block_to_rewrite) { auto block_n = CopyOnWrite(op->block.get()); Block block = op->block; - Array new_iter_vars; - Array new_iter_values; + ffi::Array new_iter_vars; + ffi::Array new_iter_values; for (int idx : order_) { new_iter_vars.push_back(block->iter_vars[idx]); new_iter_values.push_back(op->iter_values[idx]); @@ -80,7 +80,7 @@ class BlockIterVarRewriter : public StmtMutator { }; void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, - const Array& new_order) { + const ffi::Array& new_order) { const BlockNode* block_n = TVM_SREF_TO_BLOCK(block_sref); std::vector new_order_vec; for (const Integer& x : new_order) { @@ -95,7 +95,7 @@ void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, return x >= 0 && x < static_cast(num_block_itervars); }); if (!is_full || !is_unique || !is_within_boundary) { - throw InvalidReorderIndex(self->mod, GetRef(block_n), new_order); + throw InvalidReorderIndex(self->mod, ffi::GetRef(block_n), new_order); } // find parent block @@ -103,13 +103,13 @@ void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, const StmtSRefNode* p = block_sref.get()->parent; while (p != nullptr) { if (p->stmt->IsInstance()) { - parent_block_n = TVM_SREF_TO_BLOCK(GetRef(p)); + parent_block_n = TVM_SREF_TO_BLOCK(ffi::GetRef(p)); break; } p = p->parent; } - const StmtSRef parent_block_sref = GetRef(p); - const Block& parent_block = GetRef(parent_block_n); + const StmtSRef parent_block_sref = ffi::GetRef(p); + const Block& parent_block = ffi::GetRef(parent_block_n); // rewrite block and blockrealize BlockIterVarRewriter rewriter(block_n, std::move(new_order_vec)); @@ -127,11 +127,12 @@ struct ReorderBlockIterVarTraits : public UnpackedInstTraits new_order) { + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, ffi::Array new_order) { sch->ReorderBlockIterVar(block, new_order); } - static String UnpackedAsPython(Array outputs, String block, Array new_order) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array new_order) { PythonAPICall py("reorder_block_iter_var"); py.Input("block", block); py.Input("new_order", new_order); diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc index bef5faf92b67..ff030bbef7a2 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -32,14 +32,14 @@ struct RollingBufferInfo { int rolling_axis; PrimExpr rolling_extent; std::vector axis_overlaps; - std::vector> axis_iter_vars; + std::vector> axis_iter_vars; /*! \brief The map used for ScheduleStateNode::Replace. */ - Map block_reuse; + ffi::Map block_reuse; }; BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, - const Map& dom_map) { - Array relaxed_intsets = + const ffi::Map& dom_map) { + ffi::Array relaxed_intsets = arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); Region relaxed_region; relaxed_region.reserve(relaxed_intsets.size()); @@ -55,16 +55,16 @@ class RollingBufferDependencyError : public ScheduleError { explicit RollingBufferDependencyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target block is required to have only RAW dependencies"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The target block {0} is required to have only RAW dependencies"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } /*! * \brief Check if the block has only RAW dependencies. @@ -79,13 +79,13 @@ class RollingBufferDependencyError : public ScheduleError { for (const Dependency& producers : scope->GetDepsByDst(block_sref)) { if (!(producers->kind == DepKind::kRAW)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferDependencyError(self->mod, GetRef(block)); + throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); } } for (const Dependency& consumers : scope->GetDepsBySrc(block_sref)) { if (!(consumers->kind == DepKind::kRAW)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferDependencyError(self->mod, GetRef(block)); + throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); } } } @@ -99,11 +99,11 @@ class RollingBufferMatchError : public ScheduleError { public: RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) : mod_(mod), block_(block), buffer_region_(buffer_region) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" "matching the rolling pattern such as: hh.outer * stride + hh.inner"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The target buffer " << buffer_region_->buffer->name << " with region " << buffer_region_->region @@ -113,7 +113,7 @@ class RollingBufferMatchError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -125,12 +125,12 @@ class RollingBufferInsertionError : public ScheduleError { public: RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) : mod_(mod), buffer_(std::move(buffer)), block_(block) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " "location of the target buffer is not a for loop. "; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " << "the lca of the access location of the target buffer " << buffer_->name @@ -138,7 +138,7 @@ class RollingBufferInsertionError : public ScheduleError { return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -154,7 +154,7 @@ class RollingBufferInfoCollector { RollingBufferInfoCollector collector; if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferMatchError(mod, GetRef(block), buffer_region); + throw RollingBufferMatchError(mod, ffi::GetRef(block), buffer_region); } return collector.info_; } @@ -164,7 +164,7 @@ class RollingBufferInfoCollector { const Buffer& buffer = buffer_region->buffer; const Region& region = buffer_region->region; - std::vector> bound_iter_vars; + std::vector> bound_iter_vars; std::vector bound_overlaps; arith::PVar p_var; @@ -173,7 +173,7 @@ class RollingBufferInfoCollector { auto stride = 0; auto divisor = 1; - Optional iter_var; + ffi::Optional iter_var; if (floordiv((p_var * p_stride), p_divisor).Match(bound->min)) { // Handle the case of fractional strides // They take this form: floordiv(hh.outer, 2) @@ -211,17 +211,17 @@ class RollingBufferInfoCollector { bound_overlaps.push_back(bound_overlap); } - Array loop_srefs = GetLoops(block_sref); + ffi::Array loop_srefs = GetLoops(block_sref); // Pick the outermost iter_var that's mentioned in the bounds // to be the rolling axis - Optional roll_iter_var; + ffi::Optional roll_iter_var; int roll_axis = 0; for (const tir::StmtSRef& loop_sref : loop_srefs) { auto loop_var = loop_sref->StmtAs()->loop_var; - auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional var) { - return var && (var.get() == loop_var.get()); - })}; + auto it{std::find_if( + bound_iter_vars.begin(), bound_iter_vars.end(), + [&](ffi::Optional var) { return var && (var.get() == loop_var.get()); })}; if (it != bound_iter_vars.end()) { auto i = std::distance(bound_iter_vars.begin(), it); roll_iter_var = loop_var; @@ -233,7 +233,7 @@ class RollingBufferInfoCollector { if (!roll_iter_var.defined()) { return false; } - Array new_shape = buffer->shape; + ffi::Array new_shape = buffer->shape; new_shape.Set(roll_axis, region[roll_axis]->extent); Buffer new_buffer = buffer; new_buffer.CopyOnWrite()->shape = new_shape; @@ -255,15 +255,15 @@ class RollingBufferRewriter : public StmtExprMutator { public: static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { RollingBufferRewriter rewriter(scope_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) : scope_sref_(scope_sref), info_(info) {} - void RewriteAccessRegion(Array* old_access_regions, - const Array& infered_access_regions) { + void RewriteAccessRegion(ffi::Array* old_access_regions, + const ffi::Array& infered_access_regions) { auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { if (buffer_region->buffer.same_as(info_->old_buffer)) { ICHECK(infered_access_regions.size() == 1); @@ -274,8 +274,8 @@ class RollingBufferRewriter : public StmtExprMutator { (*old_access_regions).MutateByApply(fmutate); } - void RewriteBufferAccess(Buffer* buffer, Array* indices) const { - Array new_indices; + void RewriteBufferAccess(Buffer* buffer, ffi::Array* indices) const { + ffi::Array new_indices; new_indices.reserve(indices->size()); // First modify the access indices to use modulo arithmetic // for the rolling axis @@ -292,11 +292,11 @@ class RollingBufferRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); BlockNode* n = stmt.CopyOnWrite(); if (block == scope_sref_->stmt) { - Array new_alloc_buffers; + ffi::Array new_alloc_buffers; for (const Buffer& buffer : stmt->alloc_buffers) { if (buffer != info_->old_buffer) { new_alloc_buffers.push_back(buffer); @@ -306,7 +306,7 @@ class RollingBufferRewriter : public StmtExprMutator { } n->alloc_buffers = std::move(new_alloc_buffers); } else { - Array new_iter_vars; + ffi::Array new_iter_vars; for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { auto old_iter_var = stmt->iter_vars[i]; if (static_cast(i) == info_->rolling_axis) { @@ -323,7 +323,7 @@ class RollingBufferRewriter : public StmtExprMutator { new_iter_vars.push_back(old_iter_var); } } - Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + ffi::Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); n->iter_vars = std::move(new_iter_vars); @@ -344,7 +344,8 @@ class RollingBufferRewriter : public StmtExprMutator { auto iter_var = info_->axis_iter_vars[i]; if (iter_var && info_->axis_overlaps[i] > 0) { Var var = iter_var.value(); - const Map dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + const ffi::Map dmap = { + std::make_pair(var, arith::IntSet::Interval(0, 0))}; auto iter_value = realize->iter_values[i]; arith::Analyzer analyzer; auto term_2 = analyzer.int_set(iter_value, dmap).min(); @@ -399,7 +400,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf * indices to circularize the buffer along the rolling dimension. * - Append block predicate to avoid recomputing overlapping elements. */ - Map dom_map; + ffi::Map dom_map; const BlockRealize& realize = GetBlockRealize(self, block_sref); const Block& block = realize->block; @@ -412,8 +413,8 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf RollingBufferDependencyError::Check(self, block_sref, scope_root_sref); // Step 3. Find the lca of the access location of the target buffer and relax the buffer - Array loop_srefs = GetLoops(block_sref); - Array consumers_sref = GetConsumers(self, block_sref); + ffi::Array loop_srefs = GetLoops(block_sref); + ffi::Array consumers_sref = GetConsumers(self, block_sref); consumers_sref.push_back(block_sref); StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref); if (!lca->StmtAs()) { @@ -426,7 +427,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf if (stmt == lca) { break; } - For cur_loop = GetRef(stmt->StmtAs()); + For cur_loop = ffi::GetRef(stmt->StmtAs()); Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range)); } @@ -458,7 +459,8 @@ struct RollingBufferTraits : public UnpackedInstTraits { return sch->RollingBuffer(block, write_buffer_index.IntValue()); } - static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer write_buffer_index) { PythonAPICall py("rolling_buffer"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 1d3cabee1dd6..a8042e0c37eb 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -163,8 +163,8 @@ std::vector SampleWithoutReplacement( } int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision) { + const ffi::Array& candidates, const ffi::Array& probs, + ffi::Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; int32_t i = -1; @@ -309,7 +309,7 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, - Optional>* decision) { + ffi::Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); std::vector result; @@ -370,7 +370,7 @@ TVM_DLL std::vector SamplePartitionedTile( std::vector SamplePartitionedTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t partition_pos, - int32_t innerpart_factor, Optional>* decision) { + int32_t innerpart_factor, ffi::Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); std::vector result; @@ -419,7 +419,7 @@ std::vector SamplePartitionedTile( tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, - const StmtSRef& block_sref, Optional* decision) { + const StmtSRef& block_sref, ffi::Optional* decision) { // Step 1. Collect all possible compute-at locations. auto [location_srefs, location_indices] = CollectComputeLocation(self, block_sref); ICHECK_EQ(location_srefs.size(), location_indices.size()); @@ -460,17 +460,17 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // - Optional decision) { + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + ffi::Array candidates, // + ffi::Array probs, // + ffi::Optional decision) { return sch->SampleCategorical(candidates, probs, decision); } - static String UnpackedAsPython(Array outputs, // - Array candidates, // - Array probs, // - Optional decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, // + ffi::Array candidates, // + ffi::Array probs, // + ffi::Optional decision) { PythonAPICall py("sample_categorical"); py.Input("candidates", candidates); py.Input("probs", probs); @@ -492,14 +492,15 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, - Integer max_innermost_factor, - Optional> decision) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, + Integer max_innermost_factor, + ffi::Optional> decision) { return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); } - static String UnpackedAsPython(Array outputs, String loop_rv, Integer n, - Integer max_innermost_factor, Optional> decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + Integer n, Integer max_innermost_factor, + ffi::Optional> decision) { PythonAPICall py("sample_perfect_tile"); py.Input("loop", loop_rv); py.Input("n", n->value); @@ -522,16 +523,16 @@ struct SamplePartitionedTileTraits : public UnpackedInstTraits UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, - Integer partition_pos, Integer innerpart_factor, - Optional> decision) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, + Integer partition_pos, Integer innerpart_factor, + ffi::Optional> decision) { return sch->SamplePartitionedTile(loop_rv, n->value, partition_pos->value, innerpart_factor->value, decision); } - static String UnpackedAsPython(Array outputs, String loop_rv, Integer n, - Integer partition_pos, Integer innerpart_factor, - Optional> decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + Integer n, Integer partition_pos, Integer innerpart_factor, + ffi::Optional> decision) { PythonAPICall py("sample_partitioned_tile"); py.Input("loop", loop_rv); py.Input("n", n->value); @@ -557,13 +558,13 @@ struct SampleComputeLocationTraits : public UnpackedInstTraits decision) { + ffi::Optional decision) { return sch->SampleComputeLocation(block_rv, decision); } - static String UnpackedAsPython(Array outputs, // - String block_rv, // - Optional decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, // + ffi::String block_rv, // + ffi::Optional decision) { PythonAPICall py("sample_compute_location"); py.Input("block", block_rv); py.Decision(decision); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 86b8675dbf56..d15b43afb965 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -22,16 +22,18 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); BlockRVNode::RegisterReflection(); LoopRVNode::RegisterReflection(); -}); +} /**************** Constructor ****************/ -BlockRV::BlockRV() { this->data_ = make_object(); } +BlockRV::BlockRV() { this->data_ = ffi::make_object(); } -LoopRV::LoopRV() { this->data_ = make_object(); } +LoopRV::LoopRV() { this->data_ = ffi::make_object(); } /**************** GetSRef ****************/ @@ -46,7 +48,7 @@ StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { /**************** FFI ****************/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleGetMod", &ScheduleNode::mod) @@ -57,11 +59,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.ScheduleSeed", &ScheduleNode::Seed) .def_method("tir.schedule.ScheduleForkSeed", &ScheduleNode::ForkSeed) .def_method("tir.schedule.ScheduleWorkOn", &ScheduleNode::WorkOn); -}); +} /**************** (FFI) Constructor ****************/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.BlockRV", []() { return BlockRV(); }) @@ -80,11 +82,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ static_cast(error_render_level), enable_check); }); -}); +} /******** (FFI) Lookup random variables ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.ScheduleGet", @@ -103,7 +105,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ throw; }) .def("tir.schedule.ScheduleGetSRef", - [](Schedule self, ObjectRef obj) -> Optional { + [](Schedule self, ObjectRef obj) -> ffi::Optional { if (auto loop_rv = obj.as()) { return self->GetSRef(loop_rv.value()); } @@ -129,10 +131,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }); -}); +} /******** (FFI) Sampling ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleSampleCategorical", &ScheduleNode::SampleCategorical) @@ -141,9 +143,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ &ScheduleNode::SamplePartitionedTile) .def_method("tir.schedule.ScheduleSampleComputeLocation", &ScheduleNode::SampleComputeLocation); -}); +} /******** (FFI) Get blocks & loops ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleGetBlock", &ScheduleNode::GetBlock) @@ -163,9 +165,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.ScheduleGetProducers", &ScheduleNode::GetProducers) .def_method("tir.schedule.ScheduleGetConsumers", &ScheduleNode::GetConsumers) .def_method("tir.schedule.ScheduleGetOutputBlocks", &ScheduleNode::GetOutputBlocks); -}); +} /******** (FFI) Transform loops ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleMerge", &ScheduleNode::Merge) @@ -185,18 +187,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ throw; } }); -}); +} /******** (FFI) Manipulate ForKind ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleParallel", &ScheduleNode::Parallel) .def_method("tir.schedule.ScheduleVectorize", &ScheduleNode::Vectorize) .def_method("tir.schedule.ScheduleBind", &ScheduleNode::Bind) .def_method("tir.schedule.ScheduleUnroll", &ScheduleNode::Unroll); -}); +} /******** (FFI) Insert cache stages ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleCacheRead", &ScheduleNode::CacheRead) @@ -206,57 +208,59 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.ScheduleCacheInplace", &ScheduleNode::CacheInplace) .def_method("tir.schedule.ScheduleCacheIndex", &ScheduleNode::CacheIndex) .def("tir.schedule.ScheduleReIndex", - [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { + [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, bool skip_simplify) { return self->ReIndex(block_rv, buffer_index, - static_cast(buffer_index_type)); + static_cast(buffer_index_type), skip_simplify); }); -}); +} /******** (FFI) Data movement ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleReadAt", &ScheduleNode::ReadAt) .def_method("tir.schedule.ScheduleWriteAt", &ScheduleNode::WriteAt); -}); +} /******** (FFI) Compute location ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleComputeAt", &ScheduleNode::ComputeAt) .def_method("tir.schedule.ScheduleReverseComputeAt", &ScheduleNode::ReverseComputeAt) .def_method("tir.schedule.ScheduleComputeInline", &ScheduleNode::ComputeInline) - .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline); -}); + .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline) + .def_method("tir.schedule.ScheduleFuseReductionEpilogue", + &ScheduleNode::FuseReductionEpilogue); +} /******** (FFI) Reduction ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleDecomposeReduction", &ScheduleNode::DecomposeReduction) .def_method("tir.schedule.ScheduleRFactor", &ScheduleNode::RFactor); -}); +} /******** (FFI) Block annotation ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleStorageAlign", &ScheduleNode::StorageAlign) .def_method("tir.schedule.ScheduleSetScope", &ScheduleNode::SetScope) .def_method("tir.schedule.ScheduleUnsafeSetDType", &ScheduleNode::UnsafeSetDType); -}); +} /******** (FFI) Blockize & Tensorize ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.ScheduleBlockize", [](Schedule self, ObjectRef target, bool preserve_unit_iters) { if (auto loop_rv = target.as()) { return self->Blockize(loop_rv.value(), preserve_unit_iters); - } else if (auto blocks = target.as>()) { + } else if (auto blocks = target.as>()) { return self->Blockize(blocks.value(), preserve_unit_iters); } LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); }) .def("tir.schedule.ScheduleTensorize", - [](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { + [](Schedule self, ObjectRef rv, ffi::String intrin, bool preserve_unit_iters) { if (auto block_rv = rv.as()) { self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); } else if (auto loop_rv = rv.as()) { @@ -266,14 +270,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ << rv->GetTypeKey() << ". Its value is: " << rv; } }); -}); +} /******** (FFI) Annotation ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.ScheduleAnnotate", - [](Schedule self, ObjectRef rv, const String& ann_key, const Any& ann_val) { + [](Schedule self, ObjectRef rv, const ffi::String& ann_key, const Any& ann_val) { if (auto block_rv = rv.as()) { return self->Annotate(block_rv.value(), ann_key, ann_val); } @@ -285,7 +289,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ throw; }) .def("tir.schedule.ScheduleUnannotate", [](Schedule self, ObjectRef rv, - const String& ann_key) { + const ffi::String& ann_key) { if (auto block_rv = rv.as()) { return self->Unannotate(block_rv.value(), ann_key); } @@ -296,15 +300,15 @@ TVM_FFI_STATIC_INIT_BLOCK({ << ". Its value is: " << rv; throw; }); -}); +} /******** (FFI) Layout transformation ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.ScheduleTransformLayout", [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, - const IndexMap& index_map, const Optional& pad_value, + const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform) { return self->TransformLayout(block_rv, buffer_index, static_cast(buffer_index_type), @@ -313,35 +317,35 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.ScheduleTransformBlockLayout", &ScheduleNode::TransformBlockLayout) .def("tir.schedule.ScheduleSetAxisSeparator", [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, - const Array& axis_separators) { + const ffi::Array& axis_separators) { return self->SetAxisSeparator(block_rv, buffer_index, static_cast(buffer_index_type), axis_separators); }); -}); +} /******** (FFI) Padding decomposition ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleDecomposePadding", &ScheduleNode::DecomposePadding) .def_method("tir.schedule.SchedulePadEinsum", &ScheduleNode::PadEinsum); -}); +} /******** (FFI) Buffer transformation ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_method("tir.schedule.ScheduleRollingBuffer", &ScheduleNode::RollingBuffer); -}); +} /******** (FFI) Misc ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_method("tir.schedule.ScheduleEnterPostproc", &ScheduleNode::EnterPostproc) .def_method("tir.schedule.ScheduleUnsafeHideBufferAccess", &ScheduleNode::UnsafeHideBufferAccess); -}); +} /******** (FFI) Annotate buffer access ********/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.schedule.ScheduleAnnotateBufferAccess", [](Schedule self, const BlockRV& block_rv, int buffer_index, @@ -350,7 +354,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ block_rv, buffer_index, static_cast(buffer_index_type), index_map); }); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ff653502ccaa..c299f52fde55 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -23,7 +23,7 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ ScheduleStateNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ScheduleStateNode::RegisterReflection(); } template using SMap = std::unordered_map; @@ -39,12 +39,12 @@ using SMap = std::unordered_map; * \param dom_high_exclusive The highest node in the sref tree path * \return An n-dimensional integer set */ -Array AnalyzeRegionUpperBound(const BufferRegion& region, // - const PrimExpr& predicate, // - const StmtSRef& dom_low_inclusive, // - const StmtSRef& dom_high_exclusive, // - arith::Analyzer* analyzer) { - Map var_dom = LoopDomainOfSRefTreePath( +ffi::Array AnalyzeRegionUpperBound(const BufferRegion& region, // + const PrimExpr& predicate, // + const StmtSRef& dom_low_inclusive, // + const StmtSRef& dom_high_exclusive, // + arith::Analyzer* analyzer) { + ffi::Map var_dom = LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); @@ -64,22 +64,22 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, * \param analyzer The analyzer * \return An n-dimensional integer set */ -Array AnalyzeRegionLowerBound(const BufferRegion& region, // - const PrimExpr& predicate, // - const StmtSRef& dom_low_inclusive, // - const StmtSRef& dom_high_exclusive, // - arith::Analyzer* analyzer) { - Map var_dom = LoopDomainOfSRefTreePath( +ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, // + const PrimExpr& predicate, // + const StmtSRef& dom_low_inclusive, // + const StmtSRef& dom_high_exclusive, // + arith::Analyzer* analyzer) { + ffi::Map var_dom = LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); - if (Optional> result = EstimateRegionLowerBound( + if (ffi::Optional> result = EstimateRegionLowerBound( /*region=*/region->region, /*var_dom=*/var_dom, /*predicate=*/predicate, /*analyzer=*/analyzer)) { return result.value(); } - return Array(region->buffer->shape.size(), arith::IntSet::Nothing()); + return ffi::Array(region->buffer->shape.size(), arith::IntSet::Nothing()); } /*! @@ -90,9 +90,9 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, * \param analyzer The analyzer * \return A boolean indicating if the produced region could cover the consumed region */ -bool ProducerCoversConsumer(const Array& buffer_shape, - const Array& produced_region, - const Array& consumed_region, +bool ProducerCoversConsumer(const ffi::Array& buffer_shape, + const ffi::Array& produced_region, + const ffi::Array& consumed_region, arith::Analyzer* analyzer) { ICHECK_EQ(buffer_shape.size(), consumed_region.size()); ICHECK_EQ(produced_region.size(), consumed_region.size()); @@ -140,7 +140,7 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new ICHECK(new_stmt->IsInstance() || new_stmt->IsInstance()); const StmtNode* old_stmt = sref->stmt; ICHECK_NE(new_stmt, old_stmt); - self->stmt2ref[new_stmt] = GetRef(sref); + self->stmt2ref[new_stmt] = ffi::GetRef(sref); self->stmt2ref.erase(sref->stmt); sref->stmt = new_stmt; } @@ -177,7 +177,7 @@ class BlockInfoCollector : private StmtVisitor { void MakeBlockInfo(StmtSRef scope_root) { bool is_root_block = srefs_.empty(); // Calculate `BlockInfo::scope` - Array child_block_srefs = std::move(block_frames_.back()); + ffi::Array child_block_srefs = std::move(block_frames_.back()); BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); // Set `affine_binding` if (is_root_block) { @@ -198,26 +198,26 @@ class BlockInfoCollector : private StmtVisitor { } bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef& scope_root, - const Array& child_block_srefs) { + const ffi::Array& child_block_srefs) { const StmtSRefNode* limit = scope_root->parent; bool stage_pipeline = true; // Step 1. Unbind the read/write regions of each child block - std::unordered_map> block_reads_unbound; - std::unordered_map> block_writes_unbound; + std::unordered_map> block_reads_unbound; + std::unordered_map> block_writes_unbound; block_reads_unbound.reserve(child_block_srefs.size()); block_writes_unbound.reserve(child_block_srefs.size()); for (const StmtSRef& block_sref : child_block_srefs) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Map binding = GetBindings(block2realize_.at(block)); + ffi::Map binding = GetBindings(block2realize_.at(block)); // Step 1.1. Unbind read regions - Array reads; + ffi::Array reads; reads.reserve(block->reads.size()); for (const BufferRegion& region : block->reads) { reads.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); } block_reads_unbound.emplace(block_sref.get(), std::move(reads)); // Step 1.2. Unbind write regions - Array writes; + ffi::Array writes; writes.reserve(block->writes.size()); for (const BufferRegion& region : block->writes) { writes.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); @@ -227,7 +227,7 @@ class BlockInfoCollector : private StmtVisitor { // Step 2. For each consumer, check the region cover property for (const auto& kv : info.scope->dst2deps) { const StmtSRef& consumer_block_sref = kv.first; - const Array& deps = kv.second; + const ffi::Array& deps = kv.second; const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); const BlockRealize& consumer_realize = block2realize_.at(consumer_block); bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; @@ -261,14 +261,15 @@ class BlockInfoCollector : private StmtVisitor { // Step 2.3. For each LCA, gather the produced regions, // then check if it could cover the consumed region for (StmtSRef lca = consumer_block_sref; region_cover && lca.get() != limit; - lca = GetRef(lca->parent)) { + lca = ffi::GetRef(lca->parent)) { const std::vector& producer_block_srefs = lca_loc.at(lca.get()); // Skip empty LCA positions if (producer_block_srefs.empty()) { continue; } // For each buffer, record the regions generated under this loop - std::unordered_map>> touched_regions; + std::unordered_map>> + touched_regions; // Step 2.3.1. Find all the regions read by the consumer that we care about for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { const BufferNode* buffer = region->buffer.get(); @@ -277,13 +278,13 @@ class BlockInfoCollector : private StmtVisitor { // Step 2.3.2. Find all the regions written by each producer for (const StmtSRefNode* producer_block_sref : producer_block_srefs) { const BlockRealize& producer_realize = block2realize_.at(producer_block_sref->stmt); - StmtSRef parent_sref = GetRef(producer_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(producer_block_sref->parent); for (const BufferRegion& region : block_writes_unbound.at(producer_block_sref)) { const BufferNode* buffer = region->buffer.get(); auto it = touched_regions.find(buffer); // Skip the regions that is not read by the consumer if (it != touched_regions.end()) { - std::vector>& touched_region = it->second; + std::vector>& touched_region = it->second; // The analysis here is trying to be conservation to rule out false positive cases, // and to make sure region cover property must be satisfied once the flag is on // Therefore, we use lower-bound analysis for producers and upper-bound analysis for @@ -299,14 +300,15 @@ class BlockInfoCollector : private StmtVisitor { } // Step 2.3.3. For each buffer, check the region cover property { - StmtSRef parent_sref = GetRef(consumer_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(consumer_block_sref->parent); for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { const BufferNode* buffer = region->buffer.get(); - const std::vector>& touched_region = touched_regions.at(buffer); + const std::vector>& touched_region = + touched_regions.at(buffer); if (!touched_region.empty()) { - Array produced_region = + ffi::Array produced_region = arith::UnionRegionLowerBound({touched_region.begin(), touched_region.end()}); - Array consumed_region = AnalyzeRegionUpperBound( + ffi::Array consumed_region = AnalyzeRegionUpperBound( /*region=*/region, /*predicate=*/consumer_realize->predicate, /*dom_low_inclusive=*/parent_sref, @@ -337,7 +339,7 @@ class BlockInfoCollector : private StmtVisitor { void VisitStmt_(const BlockRealizeNode* realize) final { block_frames_.emplace_back(); const BlockNode* block = realize->block.get(); - block2realize_.emplace(block, GetRef(realize)); + block2realize_.emplace(block, ffi::GetRef(realize)); // Recursive visit PushSRef(block); VisitStmt(block->body); // `block->init` is not visited @@ -362,7 +364,7 @@ class BlockInfoCollector : private StmtVisitor { /*! \brief The BlockRealize corresponding to blocks */ std::unordered_map block2realize_; /*! \brief The stack frames of blocks in the DFS visit. */ - std::vector> block_frames_; + std::vector> block_frames_; /*! \brief The auxiliary analyzer */ arith::Analyzer analyzer_; }; @@ -371,7 +373,7 @@ class BlockInfoCollector : private StmtVisitor { ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) { CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported"; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); ScheduleStateNode* self = n.get(); // Set `n->mod` n->mod = std::move(mod); @@ -544,7 +546,7 @@ class SRefTreePruner : public StmtVisitor { auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) << "IndexError: Cannot find corresponding StmtSRef for the loop:\n" - << GetRef(op); + << ffi::GetRef(op); StmtSRef& sref = it->second; // Detect reuse const VarNode* loop_var = op->loop_var.get(); @@ -567,7 +569,7 @@ class SRefTreePruner : public StmtVisitor { auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) << "IndexError: Cannot find corresponding StmtSRef for the block:\n" - << GetRef(op); + << ffi::GetRef(op); StmtSRef& sref = it->second; // Detect reuse const auto& sref_reuse = reuse_info_.block_sref_reuse; @@ -617,7 +619,7 @@ class SRefUpdater : public StmtVisitor { private: explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent, const std::unordered_map& reused_srefs) - : self_(GetRef(self)), + : self_(ffi::GetRef(self)), ancestors_{src_stmt_parent}, reused_srefs_(reused_srefs) {} @@ -745,15 +747,15 @@ class ChildReplacer : private StmtMutator { } // Skipping sibling blocks and loops other than `src_stmt_` - Stmt VisitStmt_(const BlockNode* op) final { return GetRef(op); } - Stmt VisitStmt_(const ForNode* op) final { return GetRef(op); } + Stmt VisitStmt_(const BlockNode* op) final { return ffi::GetRef(op); } + Stmt VisitStmt_(const ForNode* op) final { return ffi::GetRef(op); } Stmt VisitStmt_(const SeqStmtNode* op) final { int i = this->seq_index_; int n = static_cast(op->seq.size()); if (0 <= i && i < n) { const Stmt& stmt = op->seq[i]; - Optional new_stmt = std::nullopt; + ffi::Optional new_stmt = std::nullopt; const StmtNode* src_stmt = this->src_stmt_; // `stmt` can be For or BlockRealize // `src_stmt` can be For or Block @@ -767,8 +769,8 @@ class ChildReplacer : private StmtMutator { // Case 2. stmt is BlockRealize, src_stmt is Block if (realize->block.get() == src_stmt) { const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode); - ObjectPtr new_realize = make_object(*realize); - new_realize->block = GetRef(tgt_block); + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = ffi::GetRef(tgt_block); new_stmt = BlockRealize(std::move(new_realize)); } } @@ -814,7 +816,7 @@ class ChildReplacer : private StmtMutator { }; void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, - const Map& _block_sref_reuse) { + const ffi::Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; bool input_correct = @@ -824,7 +826,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ if (!input_correct) { LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey() << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n" - << GetRef(src_stmt) << "\ntgt_stmt:\n" + << ffi::GetRef(src_stmt) << "\ntgt_stmt:\n" << tgt_stmt; } } @@ -834,7 +836,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ } // Reset sref as a new sref so that its content won't be affected by subsequent changes StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index); - Stmt src_stmt = GetRef(src_sref->stmt); + Stmt src_stmt = ffi::GetRef(src_sref->stmt); // Step 1. Create all the nodes needed for the new sref tree. // After this step // 1) all `parent`s are correct @@ -962,18 +964,18 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode); // Make `child_tgt_stmt` the root block const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode); - ObjectPtr new_realize = make_object(*realize); - new_realize->block = GetRef(child_block); + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = ffi::GetRef(child_block); new_func->body = BlockRealize(std::move(new_realize)); // Finally, move the `ref_new_func` back and update `this->mod` new_map->at(g_var) = std::move(ref_new_func); - this->mod = GetRef(new_mod); + this->mod = ffi::GetRef(new_mod); } uint32_t flag = (debug_mask != -1) // ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { - VerifySRefTree(GetRef(this)); + VerifySRefTree(ffi::GetRef(this)); } } @@ -983,10 +985,10 @@ void ScheduleStateNode::DebugVerify() const { ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { - VerifySRefTree(GetRef(this)); + VerifySRefTree(ffi::GetRef(this)); } if (flag & ScheduleDebugMask::kVerifyCachedFlags) { - VerifyCachedFlags(GetRef(this)); + VerifyCachedFlags(ffi::GetRef(this)); } } @@ -997,7 +999,7 @@ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" - << GetRef(block_sref->stmt); + << ffi::GetRef(block_sref->stmt); return it->second; } @@ -1005,7 +1007,7 @@ void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) { BlockInfoCollector::Collect(this, stmt); } -TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { +TVM_DLL ffi::Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { const BlockInfo& info = self->GetBlockInfo(block_sref); return {Bool(info.affine_binding), // Bool(info.region_cover), // @@ -1014,7 +1016,7 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl /**************** FFI ****************/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.ScheduleState", @@ -1024,12 +1026,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.ScheduleStateGetBlockScope", &ScheduleStateNode::GetBlockScope) .def_method("tir.schedule.ScheduleStateReplace", &ScheduleStateNode::Replace) .def("tir.schedule.ScheduleStateGetSRef", - [](ScheduleState self, Stmt stmt) -> Optional { + [](ScheduleState self, Stmt stmt) -> ffi::Optional { auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); + return it != self->stmt2ref.end() ? it->second : ffi::Optional(std::nullopt); }) .def("tir.schedule.ScheduleStateGetCachedFlags", GetCachedFlags); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 5322f85ac1b4..371aa0cb092d 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -23,14 +23,14 @@ namespace tvm { namespace tir { -TVM_FFI_STATIC_INIT_BLOCK({ TraceNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { TraceNode::RegisterReflection(); } /**************** Constructors ****************/ -Trace::Trace() { data_ = make_object(); } +Trace::Trace() { data_ = ffi::make_object(); } -Trace::Trace(Array insts, Map decisions) { - ObjectPtr n = make_object(); +Trace::Trace(ffi::Array insts, ffi::Map decisions) { + ObjectPtr n = ffi::make_object(); n->insts = std::move(insts); n->decisions = std::move(decisions); data_ = std::move(n); @@ -38,7 +38,7 @@ Trace::Trace(Array insts, Map decisions) { /**************** Utilities ****************/ -int GetNumValidInstructions(const Array& insts, bool remove_postproc) { +int GetNumValidInstructions(const ffi::Array& insts, bool remove_postproc) { if (!remove_postproc) { return insts.size(); } @@ -55,11 +55,11 @@ int GetNumValidInstructions(const Array& insts, bool remove_postpro /**************** TranslateInputRVs ****************/ -Array TranslateInputRVs(const Array& inputs, - const std::unordered_map& rv_map) { - Array result; +ffi::Array TranslateInputRVs(const ffi::Array& inputs, + const std::unordered_map& rv_map) { + ffi::Array result; result.reserve(inputs.size()); - auto f_subst_with_rv_map = [&rv_map](const Var& var) -> Optional { + auto f_subst_with_rv_map = [&rv_map](const Var& var) -> ffi::Optional { auto it = rv_map.find(var.get()); if (it == rv_map.end()) { return std::nullopt; @@ -67,7 +67,7 @@ Array TranslateInputRVs(const Array& inputs, const Object* dst = it->second; ICHECK(dst->IsInstance()) << "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey(); - return GetRef(static_cast(dst)); + return ffi::GetRef(static_cast(dst)); }; for (const Any& input : inputs) { @@ -81,12 +81,12 @@ Array TranslateInputRVs(const Array& inputs, input.as()) { // RV: var auto it = rv_map.find(input.as()); ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; - result.push_back(GetRef(it->second)); + result.push_back(ffi::GetRef(it->second)); } else if (auto expr = input.try_cast()) { // RV: Expr result.push_back(Substitute(expr.value(), f_subst_with_rv_map)); } else if (auto index_map = input.as()) { result.push_back(Substitute(index_map.value(), f_subst_with_rv_map)); - } else if (auto arr = input.as>()) { + } else if (auto arr = input.as>()) { // Recursively convert elements of the array into a new list of ObjectRefs. result.push_back(TranslateInputRVs(arr.value(), rv_map)); } else { @@ -99,20 +99,20 @@ Array TranslateInputRVs(const Array& inputs, } // translate rv to string -Array TranslateInputRVs( - const Array& inputs, - const std::unordered_map& rv_names) { - Array results; +ffi::Array TranslateInputRVs( + const ffi::Array& inputs, + const std::unordered_map& rv_names) { + ffi::Array results; results.reserve(inputs.size()); for (const Any& input : inputs) { if (input == nullptr) { // Case 0. nullptr => None - results.push_back(String("None")); + results.push_back(ffi::String("None")); continue; } // string => "content" if (auto opt_str = input.as()) { - results.push_back(String('"' + (*opt_str).operator std::string() + '"')); + results.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); } else if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type and not string results.push_back(input); @@ -132,19 +132,20 @@ Array TranslateInputRVs( results.push_back(input); } else if (input.as()) { // Case 4: array - results.push_back(TranslateInputRVs(Downcast>(Any(input)), rv_names)); + results.push_back(TranslateInputRVs(Downcast>(Any(input)), rv_names)); } else if (input.as()) { // Case 5: dict results.push_back(input); } else if (input.as()) { // // Case 6: IndexMap IndexMap index_map = Downcast(input); - index_map = index_map.RenameVariables([&rv_names](const Var& var) -> Optional { - if (auto it = rv_names.find(var); it != rv_names.end()) { - return it->second; - } - return std::nullopt; - }); + index_map = + index_map.RenameVariables([&rv_names](const Var& var) -> ffi::Optional { + if (auto it = rv_names.find(var); it != rv_names.end()) { + return it->second; + } + return std::nullopt; + }); results.push_back(index_map); } else { LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << input.GetTypeKey(); @@ -154,9 +155,9 @@ Array TranslateInputRVs( return results; } -Array TranslateInputRVs(const Array& inputs, - const std::unordered_map& named_rvs) { - Array results; +ffi::Array TranslateInputRVs(const ffi::Array& inputs, + const std::unordered_map& named_rvs) { + ffi::Array results; results.reserve(inputs.size()); for (const Any& input : inputs) { if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { @@ -171,7 +172,7 @@ Array TranslateInputRVs(const Array& inputs, } // Case 4. array if (input.as()) { - results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); + results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); continue; } // Case 5. dict @@ -189,7 +190,7 @@ Array TranslateInputRVs(const Array& inputs, // Case 6. IndexMap if (obj.as()) { IndexMap index_map = Downcast(obj); - index_map = Substitute(index_map, [&named_rvs](const Var& var) -> Optional { + index_map = Substitute(index_map, [&named_rvs](const Var& var) -> ffi::Optional { auto it = named_rvs.find(var->name_hint); if (it != named_rvs.end()) { return Downcast(it->second); @@ -205,7 +206,7 @@ Array TranslateInputRVs(const Array& inputs, } // Case 2. string if (size >= 2 && name[0] == '"' && name[size - 1] == '"') { - results.push_back(String(std::string(name + 1, size - 2))); + results.push_back(ffi::String(std::string(name + 1, size - 2))); continue; } // Case 0 & 1. None, BlockRV, LoopRV, VarRV @@ -218,7 +219,7 @@ Array TranslateInputRVs(const Array& inputs, /**************** TranslateAddOutputRVs ****************/ -void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, +void TranslateAddOutputRVs(const ffi::Array& old_outputs, const ffi::Array& new_outputs, std::unordered_map* rv_map) { ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); @@ -230,17 +231,17 @@ void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_ } } -Array TranslateAddOutputRVs( - const Array& outputs, - std::unordered_map* rv_names) { - Array results; +ffi::Array TranslateAddOutputRVs( + const ffi::Array& outputs, + std::unordered_map* rv_names) { + ffi::Array results; results.reserve(outputs.size()); for (const Any& output : outputs) { int i = rv_names->size(); ICHECK(!rv_names->count(output.cast())) << "ValueError: The random variable has been produced once: " << rv_names->at(output.cast()); - String result; + ffi::String result; if (output == nullptr) { result = "_"; } else if (output.as()) { @@ -260,12 +261,13 @@ Array TranslateAddOutputRVs( return results; } -void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, +void TranslateAddOutputRVs(const ffi::Array& old_outputs, + const ffi::Array& new_outputs, std::unordered_map* named_rvs) { ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); for (int i = 0; i < n; ++i) { - named_rvs->emplace(Downcast(old_outputs[i]), new_outputs[i].cast()); + named_rvs->emplace(Downcast(old_outputs[i]), new_outputs[i].cast()); } } @@ -282,7 +284,7 @@ void TraceNode::Append(Instruction inst, Any decision) { insts.push_back(std::move(inst)); } -Optional TraceNode::Pop() { +ffi::Optional TraceNode::Pop() { if (insts.empty()) { return std::nullopt; } @@ -298,8 +300,8 @@ Optional TraceNode::Pop() { void TraceNode::ApplyToSchedule( Schedule sch, bool remove_postproc, - ffi::TypedFunction& inputs, // - const Array& attrs, // + ffi::TypedFunction& inputs, // + const ffi::Array& attrs, // const Any& decision)> decision_provider) const { std::unordered_map rv_map; @@ -307,21 +309,21 @@ void TraceNode::ApplyToSchedule( if (remove_postproc && inst->kind->IsPostproc()) { break; } - Array inputs = TranslateInputRVs(inst->inputs, rv_map); - Array attrs = inst->attrs; + ffi::Array inputs = TranslateInputRVs(inst->inputs, rv_map); + ffi::Array attrs = inst->attrs; Any decision = this->GetDecision(inst); if (decision_provider != nullptr) { decision = decision_provider(inst, inputs, attrs, decision); } - Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); + ffi::Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); } } ObjectRef TraceNode::AsJSON(bool remove_postproc) const { - std::unordered_map rv_names; - Array json_insts; - Array json_decisions; + std::unordered_map rv_names; + ffi::Array json_insts; + ffi::Array json_decisions; json_insts.reserve(this->insts.size()); json_decisions.reserve(this->insts.size()); @@ -331,40 +333,40 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const { if (remove_postproc && kind->IsPostproc()) { break; } - json_insts.push_back(Array{ + json_insts.push_back(ffi::Array{ /* 0: inst name */ kind->name, /* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names), /* 2: attrs */ kind->f_attrs_as_json != nullptr ? kind->f_attrs_as_json(inst->attrs) : ObjectRef(inst->attrs), /* 3: outputs */ TranslateAddOutputRVs(inst->outputs, &rv_names), }); - if (auto decision = this->GetDecision(inst).cast>()) { - json_decisions.push_back(Array{ + if (auto decision = this->GetDecision(inst).cast>()) { + json_decisions.push_back(ffi::Array{ /* 0: index */ Integer(i), /* 1: decision */ decision.value(), }); } ++i; } - return Array{ + return ffi::Array{ /* 0: trace */ std::move(json_insts), /* 1: decision */ std::move(json_decisions), }; } -Array TraceNode::AsPython(bool remove_postproc) const { - std::unordered_map rv_names; - Array py_trace; +ffi::Array TraceNode::AsPython(bool remove_postproc) const { + std::unordered_map rv_names; + ffi::Array py_trace; py_trace.reserve(this->insts.size()); for (const Instruction& inst : this->insts) { if (remove_postproc && inst->kind->IsPostproc()) { break; } - Array attrs; + ffi::Array attrs; attrs.reserve(inst->attrs.size()); for (const Any& obj : inst->attrs) { if (auto opt_str = obj.as()) { - attrs.push_back(String('"' + (*opt_str).operator std::string() + '"')); + attrs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); } else { attrs.push_back(obj); } @@ -379,8 +381,8 @@ Array TraceNode::AsPython(bool remove_postproc) const { } void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { - Array json_insts{nullptr}; - Array json_decisions{nullptr}; + ffi::Array json_insts{nullptr}; + ffi::Array json_decisions{nullptr}; // Parse `json` into `json_insts` and `json_decisions` try { const ffi::ArrayObj* arr = json.as(); @@ -388,8 +390,8 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { const auto* arr0 = arr->at(0).as(); const auto* arr1 = arr->at(1).as(); ICHECK(arr0 && arr1); - json_insts = GetRef>(arr0); - json_decisions = GetRef>(arr1); + json_insts = ffi::GetRef>(arr0); + json_decisions = ffi::GetRef>(arr1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: The json entry of a trace should contain two arrays, an array of " "instructions and an array of decisions, but gets: " @@ -421,18 +423,18 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { int i = 0; for (const Any& inst_entry : json_insts) { InstructionKind kind{nullptr}; - Array inputs{nullptr}; - Array attrs{nullptr}; - Array outputs{ObjectPtr{nullptr}}; + ffi::Array inputs{nullptr}; + ffi::Array attrs{nullptr}; + ffi::Array outputs{ObjectPtr{nullptr}}; // Parse the entry try { const auto* arr = inst_entry.as(); ICHECK(arr && arr->size() == 4); ffi::String arr0 = arr->at(0).cast(); kind = InstructionKind::Get(arr0); - inputs = arr->at(1).cast>(); - attrs = arr->at(2).cast>(); - outputs = arr->at(3).cast>(); + inputs = arr->at(1).cast>(); + attrs = arr->at(2).cast>(); + outputs = arr->at(3).cast>(); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " "inputs, attrs, outputs], but gets: " @@ -446,7 +448,7 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { attrs = kind->f_attrs_from_json(attrs); } // Apply to the schedule - Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); + ffi::Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); // Parse outputs TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); ++i; @@ -457,9 +459,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Trace TraceNode::WithDecision(Instruction inst, Any decision, bool remove_postproc) const { int n_insts = GetNumValidInstructions(this->insts, remove_postproc); - Array new_insts = - Array{this->insts.begin(), this->insts.begin() + n_insts}; - Map new_decisions{this->decisions.begin(), this->decisions.end()}; + ffi::Array new_insts = + ffi::Array{this->insts.begin(), this->insts.begin() + n_insts}; + ffi::Map new_decisions{this->decisions.begin(), this->decisions.end()}; new_decisions.Set(std::move(inst), std::move(decision)); return Trace(new_insts, new_decisions); } @@ -512,8 +514,8 @@ Trace TraceNode::Simplified(bool remove_postproc) const { } } } - return Trace(Array(new_insts.rbegin(), new_insts.rend()), - Map(new_decisions)); + return Trace(ffi::Array(new_insts.rbegin(), new_insts.rend()), + ffi::Map(new_decisions)); } /**************** Repr ****************/ @@ -524,9 +526,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ICHECK_NOTNULL(self); p->stream << "# from tvm import tir\n"; p->stream << "def apply_trace(sch: tir.Schedule) -> None:\n"; - Array repr = self->AsPython(/*remove_postproc=*/false); + ffi::Array repr = self->AsPython(/*remove_postproc=*/false); bool is_first = true; - for (const String& line : repr) { + for (const ffi::String& line : repr) { if (is_first) { is_first = false; } else { @@ -553,7 +555,7 @@ struct EnterPostprocTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch) { return sch->EnterPostproc(); } - static String UnpackedAsPython(Array outputs) { + static ffi::String UnpackedAsPython(ffi::Array outputs) { PythonAPICall py("enter_postproc"); return py.Str(); } @@ -566,16 +568,17 @@ TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); /**************** FFI ****************/ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.Trace", - [](Optional> insts, Optional> decisions) { - return Trace(insts.value_or(Array()), decisions.value_or({})); + [](ffi::Optional> insts, + ffi::Optional> decisions) { + return Trace(insts.value_or(ffi::Array()), decisions.value_or({})); }) .def_method("tir.schedule.TraceGetDecision", &TraceNode::GetDecision) .def("tir.schedule.TraceAppend", - [](Trace self, Instruction inst, Optional decision) { + [](Trace self, Instruction inst, ffi::Optional decision) { if (decision.defined()) { return self->Append(inst, decision.value()); } else { @@ -589,7 +592,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.TraceWithDecision", &TraceNode::WithDecision) .def_method("tir.schedule.TraceSimplified", &TraceNode::Simplified) .def("tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b9718c1a5f9c..ad9e65a643cd 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -24,7 +24,7 @@ namespace tir { Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; @@ -41,7 +41,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand } Schedule TracedScheduleNode::Copy() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->func_working_on_ = this->func_working_on_; @@ -53,9 +53,9 @@ Schedule TracedScheduleNode::Copy() { /******** Schedule: Sampling ********/ -ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV TracedScheduleNode::SampleCategorical(const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional decision) { ExprRV result = CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); @@ -67,11 +67,11 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, return result; } -Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, - int max_innermost_factor, - Optional> decision) { +ffi::Array TracedScheduleNode::SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision) { // use None RV object to denotes auto-infer tile factors. - Array results = + ffi::Array results = CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision), /*convert_negone_to_none=*/true); @@ -84,10 +84,10 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n return results; } -Array TracedScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, int n, - int partition_pos, int innerpart_factor, - Optional> decision) { - Array results = CreateRV(tir::SamplePartitionedTile( +ffi::Array TracedScheduleNode::SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision) { + ffi::Array results = CreateRV(tir::SamplePartitionedTile( &this->rand_state_, this->GetSRef(loop_rv), n, partition_pos, innerpart_factor, &decision)); static const InstructionKind& kind = InstructionKind::Get("SamplePartitionedTile"); @@ -101,7 +101,7 @@ Array TracedScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, i } LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, - Optional decision) { + ffi::Optional decision) { LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); @@ -116,7 +116,8 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional& func_name) { +BlockRV TracedScheduleNode::GetBlock(const ffi::String& name, + const ffi::Optional& func_name) { GlobalVar gv = NullValue(); if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); @@ -137,8 +138,8 @@ BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional& return result; } -Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetLoops(block_rv); +ffi::Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetLoops(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetLoops"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -148,8 +149,8 @@ Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); +ffi::Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -159,8 +160,8 @@ Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { - Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); +ffi::Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + ffi::Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -170,8 +171,8 @@ Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { return results; } -Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetProducers(block_rv); +ffi::Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetProducers(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetProducers"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -181,8 +182,8 @@ Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetConsumers(block_rv); +ffi::Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetConsumers(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetConsumers"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -192,8 +193,8 @@ Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { - Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); +ffi::Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { + ffi::Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); static const InstructionKind& kind = InstructionKind::Get("GetOutputBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -205,7 +206,7 @@ Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv /******** Schedule: Transform loops ********/ -LoopRV TracedScheduleNode::Merge(const Array& loop_rvs) { +LoopRV TracedScheduleNode::Merge(const ffi::Array& loop_rvs) { LoopRV result = ConcreteScheduleNode::Merge(loop_rvs); static const InstructionKind& kind = InstructionKind::Get("Merge"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -215,7 +216,7 @@ LoopRV TracedScheduleNode::Merge(const Array& loop_rvs) { return result; } -LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs, bool preserve_unit_loops) { +LoopRV TracedScheduleNode::Fuse(const ffi::Array& loop_rvs, bool preserve_unit_loops) { LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops); static const InstructionKind& kind = InstructionKind::Get("Fuse"); @@ -226,13 +227,13 @@ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs, bool preserve_uni return result; } -Array TracedScheduleNode::Split(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters, bool disable_predication) { - Array results = +ffi::Array TracedScheduleNode::Split(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters, bool disable_predication) { + ffi::Array results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters, disable_predication); - Array inputs; + ffi::Array inputs; inputs.reserve(1 + factor_rvs.size()); inputs.push_back(loop_rv); for (const auto& obj : factor_rvs) { @@ -243,18 +244,18 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, trace_->Append( /*inst=*/Instruction(/*kind=*/kind, /*inputs=*/inputs, - /*attrs=*/Array({preserve_unit_iters, disable_predication}), + /*attrs=*/ffi::Array({preserve_unit_iters, disable_predication}), /*outputs=*/results)); return results; } -Array TracedScheduleNode::LoopPartition(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters) { - Array results = +ffi::Array TracedScheduleNode::LoopPartition( + const LoopRV& loop_rv, const ffi::Array>& factor_rvs, + bool preserve_unit_iters) { + ffi::Array results = ConcreteScheduleNode::LoopPartition(loop_rv, factor_rvs, preserve_unit_iters); - Array inputs; + ffi::Array inputs; inputs.reserve(1 + factor_rvs.size()); inputs.push_back(loop_rv); for (const auto& obj : factor_rvs) { @@ -269,7 +270,7 @@ Array TracedScheduleNode::LoopPartition(const LoopRV& loop_rv, return results; } -void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { +void TracedScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { ConcreteScheduleNode::Reorder(ordered_loop_rvs); static const InstructionKind& kind = InstructionKind::Get("Reorder"); @@ -280,7 +281,7 @@ void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { } void TracedScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, - const Array new_order) { + const ffi::Array new_order) { ConcreteScheduleNode::ReorderBlockIterVar(block_rv, new_order); static const InstructionKind& kind = InstructionKind::Get("ReorderBlockIterVar"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -332,7 +333,7 @@ void TracedScheduleNode::Vectorize(const LoopRV& loop_rv) { /*outputs=*/{})); } -void TracedScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { +void TracedScheduleNode::Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) { ConcreteScheduleNode::Bind(loop_rv, thread_axis); static const InstructionKind& kind = InstructionKind::Get("Bind"); @@ -354,8 +355,8 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { /******** Schedule: Insert cache stages ********/ BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { BlockRV result = ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks); @@ -368,8 +369,8 @@ BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_i } BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope, consumer_blocks); @@ -382,7 +383,7 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer } BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { BlockRV result = ConcreteScheduleNode::ReindexCacheRead(block_rv, read_buffer_index, storage_scope, index_map); @@ -398,7 +399,7 @@ BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_b } BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { BlockRV result = ConcreteScheduleNode::ReindexCacheWrite(block_rv, write_buffer_index, storage_scope, index_map); @@ -413,11 +414,11 @@ BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write return result; } -Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) { - Array result = +ffi::Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) { + ffi::Array result = ConcreteScheduleNode::CacheInplace(block_rv, read_buffer_index, storage_scope); - Array results; + ffi::Array results; for (const BlockRV& r : result) { results.push_back(r); } @@ -429,10 +430,12 @@ Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int rea return result; } -Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, const String& storage_scope, - int cse_thresh) { - Array result = ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh); - Array outputs; +ffi::Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, + const ffi::String& storage_scope, + int cse_thresh) { + ffi::Array result = + ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh); + ffi::Array outputs; for (const BlockRV& r : result) { outputs.push_back(r); } @@ -445,13 +448,13 @@ Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, const Str } BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) { - BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); + BufferIndexType buffer_index_type, bool skip_simplify) { + BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type, skip_simplify); static const InstructionKind& kind = InstructionKind::Get("ReIndex"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type)}, + /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), Bool(skip_simplify)}, /*outputs=*/{result})); return result; } @@ -459,7 +462,7 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, /******** Schedule: Data movement ********/ BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int read_buffer_index, const String& storage_scope) { + int read_buffer_index, const ffi::String& storage_scope) { BlockRV result = ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); @@ -472,7 +475,7 @@ BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_r } BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int write_buffer_index, const String& storage_scope) { + int write_buffer_index, const ffi::String& storage_scope) { BlockRV result = ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); @@ -529,6 +532,17 @@ void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /*outputs=*/{})); } +void TracedScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, + const BlockRV& epilogue_block_rv) { + ConcreteScheduleNode::FuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); + + static const InstructionKind& kind = InstructionKind::Get("FuseReductionEpilogue"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{reduction_block_rv, epilogue_block_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + /******** Schedule: Reduction ********/ BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { @@ -565,7 +579,7 @@ void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, } void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, - const String& storage_scope) { + const ffi::String& storage_scope) { ConcreteScheduleNode::SetScope(block_rv, buffer_index, storage_scope); static const InstructionKind& kind = InstructionKind::Get("SetScope"); trace_->Append(/*inst=*/Instruction( @@ -576,7 +590,7 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, } void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, - const String& dtype) { + const ffi::String& dtype) { ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype); static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType"); trace_->Append(/*inst=*/Instruction( @@ -599,7 +613,7 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_i return new_block; } -BlockRV TracedScheduleNode::Blockize(const Array& blocks, bool preserve_unit_iters) { +BlockRV TracedScheduleNode::Blockize(const ffi::Array& blocks, bool preserve_unit_iters) { BlockRV new_block = ConcreteScheduleNode::Blockize(blocks, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Blockize"); trace_->Append(/*inst=*/Instruction( @@ -610,7 +624,7 @@ BlockRV TracedScheduleNode::Blockize(const Array& blocks, bool preserve return new_block; } -void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) { ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); @@ -621,7 +635,7 @@ void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, /*outputs=*/{})); } -void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, +void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { ConcreteScheduleNode::Tensorize(block_rv, intrin, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); @@ -634,7 +648,7 @@ void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin /******** Schedule: Annotation ********/ -void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, +void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) { ConcreteScheduleNode::Annotate(loop_rv, ann_key, ann_val); static const InstructionKind& kind = InstructionKind::Get("Annotate"); @@ -644,7 +658,7 @@ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, /*outputs=*/{})); } -void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, +void TracedScheduleNode::Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) { ConcreteScheduleNode::Annotate(block_rv, ann_key, ann_val); static const InstructionKind& kind = InstructionKind::Get("Annotate"); @@ -654,7 +668,7 @@ void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key /*outputs=*/{})); } -void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { +void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) { ConcreteScheduleNode::Unannotate(loop_rv, ann_key); static const InstructionKind& kind = InstructionKind::Get("Unannotate"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -663,7 +677,7 @@ void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key /*outputs=*/{})); } -void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { +void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) { ConcreteScheduleNode::Unannotate(block_rv, ann_key); static const InstructionKind& kind = InstructionKind::Get("Unannotate"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -677,7 +691,7 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value, + const ffi::Optional& pad_value, bool assume_injective_transform) { ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map, pad_value, assume_injective_transform); @@ -704,7 +718,7 @@ void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const Ind void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) { + const ffi::Array& axis_separators) { ConcreteScheduleNode::SetAxisSeparator(block_rv, buffer_index, buffer_index_type, axis_separators); static const InstructionKind& kind = InstructionKind::Get("SetAxisSeparator"); @@ -727,7 +741,7 @@ BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const Loop return new_block; } -void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array& padding) { +void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) { ConcreteScheduleNode::PadEinsum(block_rv, padding); static const InstructionKind& kind = InstructionKind::Get("PadEinsum"); trace_->Append(/*inst=*/Instruction( @@ -760,8 +774,9 @@ void TracedScheduleNode::EnterPostproc() { /*outputs=*/{})); } -void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) { +void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, + const ffi::String& buf_type, + const ffi::Array& buf_index_array) { ConcreteScheduleNode::UnsafeHideBufferAccess(block_rv, buf_type, buf_index_array); static const InstructionKind& kind = InstructionKind::Get("UnsafeHideBufferAccess"); trace_->Append(/*inst=*/Instruction( diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 024c3fb873f2..cfe9b83e7cc6 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ #include "./concrete_schedule.h" +#include namespace tvm { namespace tir { @@ -32,70 +33,76 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: static void RegisterReflection() { - // No fields to register as they are not visited + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } ~TracedScheduleNode() = default; public: - Optional trace() const final { return trace_; } + ffi::Optional trace() const final { return trace_; } Schedule Copy() final; public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = std::nullopt) final; - Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = std::nullopt) final; - Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, - int innerpart_factor, - Optional> decision = std::nullopt) final; + ExprRV SampleCategorical(const ffi::Array& candidates, const ffi::Array& probs, + ffi::Optional decision = std::nullopt) final; + ffi::Array SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision = std::nullopt) final; + ffi::Array SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision = std::nullopt) final; LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = std::nullopt) final; + ffi::Optional decision = std::nullopt) final; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const String& name, const Optional& func_name) final; - Array GetLoops(const BlockRV& block_rv) final; - Array GetChildBlocks(const BlockRV& block_rv) final; - Array GetChildBlocks(const LoopRV& loop_rv) final; - Array GetProducers(const BlockRV& block_rv) final; - Array GetConsumers(const BlockRV& block_rv) final; - Array GetOutputBlocks(const BlockRV& scope_block_rv) final; + BlockRV GetBlock(const ffi::String& name, const ffi::Optional& func_name) final; + ffi::Array GetLoops(const BlockRV& block_rv) final; + ffi::Array GetChildBlocks(const BlockRV& block_rv) final; + ffi::Array GetChildBlocks(const LoopRV& loop_rv) final; + ffi::Array GetProducers(const BlockRV& block_rv) final; + ffi::Array GetConsumers(const BlockRV& block_rv) final; + ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) final; /******** Schedule: Transform loops ********/ - LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; - LoopRV Merge(const Array& loop_rvs) final; - Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters, bool disable_predication) final; - Array LoopPartition(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) final; - void Reorder(const Array& ordered_loop_rvs) final; - void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) final; + LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) final; + LoopRV Merge(const ffi::Array& loop_rvs) final; + ffi::Array Split(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters, bool disable_predication) final; + ffi::Array LoopPartition(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters) final; + void Reorder(const ffi::Array& ordered_loop_rvs) final; + void ReorderBlockIterVar(const BlockRV& block_rv, const ffi::Array new_order) final; LoopRV AddUnitLoop(const BlockRV& block_rv) final; LoopRV AddUnitLoop(const LoopRV& loop_rv) final; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) final; void Vectorize(const LoopRV& loop_rv) final; - void Bind(const LoopRV& loop_rv, const String& thread_axis) final; + void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) final; void Unroll(const LoopRV& loop_rv) final; /******** Schedule: Insert cache stages ********/ - BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) final; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) final; + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) final; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) final; BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) final; + const ffi::String& storage_scope, const IndexMap& index_map) final; BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) final; - Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) final; + const ffi::String& storage_scope, const IndexMap& index_map) final; + ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) final; - Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, + BufferIndexType buffer_index_type, bool skip_simplify) final; + ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, int cse_thresh) final; /******** Schedule: Data movement ********/ BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) final; + const ffi::String& storage_scope) final; BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) final; + const ffi::String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) final; @@ -103,41 +110,43 @@ class TracedScheduleNode : public ConcreteScheduleNode { int index = -1) final; void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; + void FuseReductionEpilogue(const BlockRV& reduction_block, const BlockRV& epilogue_block) final; /******** Schedule: Reduction ********/ BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final; BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; - void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; - void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final; + void SetScope(const BlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) final; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const ffi::String& dtype) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; - BlockRV Blockize(const Array& blocks, bool preserve_unit_iters) final; - void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; - void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) final; + BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) final; + void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, + bool preserve_unit_iters) final; + void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) final; /******** Schedule: Annotation ********/ - void Annotate(const LoopRV& loop_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; - void Annotate(const BlockRV& block_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const BlockRV& block_rv, const String& ann_key) override; + void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) override; + void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const Optional& pad_value, + const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) final; + const ffi::Array& axis_separators) final; /******** Schedule: Padding ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; - void PadEinsum(const BlockRV& block_rv, const Array& padding) final; + void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) final; /******** Schedule: Buffer transformation ********/ void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) final; /******** Schedule: Misc ********/ void EnterPostproc() final; - void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) final; + void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + const ffi::Array& buf_index_array) final; void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) final; }; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 256f44e14894..9c3da9f32bea 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -27,18 +27,19 @@ namespace tir { /******** Annotation ********/ -Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { - Map annotations = block->annotations; +Block WithAnnotation(const BlockNode* block, const ffi::String& attr_key, + const ObjectRef& attr_value) { + ffi::Map annotations = block->annotations; annotations.Set(attr_key, attr_value); - ObjectPtr new_block = make_object(*block); + ObjectPtr new_block = ffi::make_object(*block); new_block->annotations = std::move(annotations); return Block(new_block); } /******** Buffer Related ********/ -Buffer WithScope(const Buffer& buffer, const String& scope) { - ObjectPtr new_buffer = make_object(*buffer.get()); - ObjectPtr new_var = make_object(*buffer->data.get()); +Buffer WithScope(const Buffer& buffer, const ffi::String& scope) { + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); + ObjectPtr new_var = ffi::make_object(*buffer->data.get()); const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); new_var->type_annotation = PointerType(ptr_type->element_type, scope); new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); @@ -47,7 +48,7 @@ Buffer WithScope(const Buffer& buffer, const String& scope) { } Buffer WithDType(const Buffer& buffer, const DataType& dtype) { - ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); new_buffer->dtype = dtype; const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); new_buffer->data = @@ -56,11 +57,11 @@ Buffer WithDType(const Buffer& buffer, const DataType& dtype) { return Buffer(new_buffer); } -Array ReplaceBuffer(Array regions, const Buffer& source, - const Buffer& target) { +ffi::Array ReplaceBuffer(ffi::Array regions, const Buffer& source, + const Buffer& target) { regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { if (region->buffer.same_as(source)) { - ObjectPtr n = make_object(*region.get()); + ObjectPtr n = ffi::make_object(*region.get()); n->buffer = target; return BufferRegion(n); } @@ -69,11 +70,11 @@ Array ReplaceBuffer(Array regions, const Buffer& sou return regions; } -Array ReplaceBuffer(Array regions, - const Map& buffer_map) { +ffi::Array ReplaceBuffer(ffi::Array regions, + const ffi::Map& buffer_map) { regions.MutateByApply([&buffer_map](BufferRegion region) -> BufferRegion { if (buffer_map.count(region->buffer)) { - ObjectPtr n = make_object(*region.get()); + ObjectPtr n = ffi::make_object(*region.get()); n->buffer = buffer_map[region->buffer]; return BufferRegion(n); } @@ -82,22 +83,24 @@ Array ReplaceBuffer(Array regions, return regions; } -Array ReplaceBuffer(Array match_buffers, const Buffer& source, - const Buffer& target) { - match_buffers.MutateByApply([&source, - &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { - if (match_buffer->source->buffer.same_as(source)) { - ObjectPtr n = make_object(*match_buffer.get()); - n->source = BufferRegion(target, n->source->region); - return MatchBufferRegion(n); - } - return match_buffer; - }); +ffi::Array ReplaceBuffer(ffi::Array match_buffers, + const Buffer& source, const Buffer& target) { + match_buffers.MutateByApply( + [&source, &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source)) { + ObjectPtr n = + ffi::make_object(*match_buffer.get()); + n->source = BufferRegion(target, n->source->region); + return MatchBufferRegion(n); + } + return match_buffer; + }); return match_buffers; } -Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, - const BufferRegion& target) { +ffi::Array ReplaceBufferRegion(ffi::Array regions, + const Buffer& source_buffer, + const BufferRegion& target) { regions.MutateByApply([&source_buffer, &target](const BufferRegion& region) -> BufferRegion { if (region->buffer.same_as(source_buffer)) { return target; @@ -107,30 +110,31 @@ Array ReplaceBufferRegion(Array regions, const Buffe return regions; } -Array ReplaceBufferRegion(Array match_buffers, - const Buffer& source_buffer, - const BufferRegion& target) { - match_buffers.MutateByApply([&source_buffer, &target]( - const MatchBufferRegion& match_buffer) -> MatchBufferRegion { - if (match_buffer->source->buffer.same_as(source_buffer)) { - ObjectPtr n = make_object(*match_buffer.get()); - n->source = target; - return MatchBufferRegion(n); - } - return match_buffer; - }); +ffi::Array ReplaceBufferRegion(ffi::Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target) { + match_buffers.MutateByApply( + [&source_buffer, &target](const MatchBufferRegion& match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source_buffer)) { + ObjectPtr n = + ffi::make_object(*match_buffer.get()); + n->source = target; + return MatchBufferRegion(n); + } + return match_buffer; + }); return match_buffers; } /******** ReplaceBufferMutator ********/ ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : block_sref_reuse_(block_sref_reuse) { buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer); } -ReplaceBufferMutator::ReplaceBufferMutator(const Map& buffer_map, - Map* block_sref_reuse) +ReplaceBufferMutator::ReplaceBufferMutator(const ffi::Map& buffer_map, + ffi::Map* block_sref_reuse) : block_sref_reuse_(block_sref_reuse) { for (const auto& [old_buffer, new_buffer] : buffer_map) { buffer_var_map_[old_buffer->data.get()] = new_buffer; @@ -139,7 +143,7 @@ ReplaceBufferMutator::ReplaceBufferMutator(const Map& buffer_map PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) { auto it = buffer_var_map_.find(var); - return it != buffer_var_map_.end() ? it->second->data : GetRef(var); + return it != buffer_var_map_.end() ? it->second->data : ffi::GetRef(var); } Stmt ReplaceBufferMutator::VisitStmt_(const BufferStoreNode* op) { @@ -203,12 +207,12 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { }; // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion, - Array match_buffers = block->match_buffers.Map(f_mutate_match_buffer); + ffi::Array match_buffers = block->match_buffers.Map(f_mutate_match_buffer); // Step 2. Mutate the read/write region. - Array reads = block->reads.Map(f_mutate_read_write_region); - Array writes = block->writes.Map(f_mutate_read_write_region); + ffi::Array reads = block->reads.Map(f_mutate_read_write_region); + ffi::Array writes = block->writes.Map(f_mutate_read_write_region); // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block. - Array alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers); + ffi::Array alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers); // Step 4. Recursively mutate the block. Block mutated_block = Downcast(StmtMutator::VisitStmt_(block)); @@ -216,7 +220,7 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { writes.same_as(mutated_block->writes) && alloc_buffers.same_as(mutated_block->alloc_buffers) && match_buffers.same_as(mutated_block->match_buffers)) { - return GetRef(block); + return ffi::GetRef(block); } else { ObjectPtr n = CopyOnWrite(mutated_block.get()); n->reads = std::move(reads); @@ -226,7 +230,7 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { Block new_block(n); if (block_sref_reuse_ != nullptr) { - block_sref_reuse_->Set(GetRef(block), new_block); + block_sref_reuse_->Set(ffi::GetRef(block), new_block); } return new_block; } @@ -241,17 +245,17 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ explicit OnlyLeafError(IRModule mod, Block leaf_block, Block scope_root) : mod_(mod), leaf_block_(leaf_block), scope_root_(scope_root) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot remove the only leaf in the scope"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " "scope will be empty."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } + ffi::Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } IRModule mod_; Block leaf_block_; @@ -295,21 +299,21 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } if (const auto* seq = body.as()) { - ObjectPtr n = make_object(*block); - auto new_seq = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + ObjectPtr n = ffi::make_object(*block); + auto new_seq = RemoveFromSeqStmt(ffi::GetRef(seq), ffi::GetRef(last_stmt)); // Re-attach AllocateConst nodes auto new_body = MergeNest(allocs, new_seq); n->body = new_body; - *src_stmt = GetRef(block); + *src_stmt = ffi::GetRef(block); *tgt_stmt = Stmt(std::move(n)); return; } } if (const auto* loop = sref->StmtAs()) { if (const auto* seq = loop->body.as()) { - ObjectPtr n = make_object(*loop); - n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); - *src_stmt = GetRef(loop); + ObjectPtr n = ffi::make_object(*loop); + n->body = RemoveFromSeqStmt(ffi::GetRef(seq), ffi::GetRef(last_stmt)); + *src_stmt = ffi::GetRef(loop); *tgt_stmt = Stmt(std::move(n)); return; } @@ -317,12 +321,12 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ ICHECK(sref != nullptr && sref->stmt != nullptr); const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref); const auto* scope_block = TVM_SREF_TO_BLOCK(sref); - throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); + throw OnlyLeafError(self->mod, ffi::GetRef(leaf_block), ffi::GetRef(scope_block)); } -Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name, bool allow_padding) { - Optional opt_tensorize_info = +ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const ffi::String& intrin_name, bool allow_padding) { + ffi::Optional opt_tensorize_info = GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name).value()->desc, allow_padding); if (!opt_tensorize_info) return std::nullopt; @@ -342,7 +346,7 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block sch->PadEinsum(block_rv, info->block_iter_paddings.value()); // Now we need to find out all the padded Block's. - Array inlined_producers, inlined_consumers; + ffi::Array inlined_producers, inlined_consumers; for (const auto& producer : sch->GetProducers(block_rv)) { // PadEinsum will not modify the producer if it does not need padding. if (original_producers.count(sch->GetSRef(producer).get())) { @@ -387,9 +391,9 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block } } // Construct a mapping from tir loops back to LoopRVs - Map loop2rv; + ffi::Map loop2rv; { - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); for (const LoopRV& loop_rv : loop_rvs) { loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); } @@ -417,17 +421,18 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block ICHECK_EQ(total % inner, 0); // Do the split. Leave the outer extent as std::nullopt (unspecified) so that the split factors // can be used for different extents (needed during tuning). - Array split = sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, Integer(inner)}); + ffi::Array split = + sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, Integer(inner)}); ICHECK_EQ(split.size(), 2); inner_loops.insert(sch->GetSRef(split[1]).operator->()); // The inner split will be reordered to the loop domain that is tensorized - int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)).IntValue(); + int desc_loop_index = info->desc_loop_indexer.at(ffi::GetRef(desc_loop)).IntValue(); reorder_suffix[desc_loop_index] = split[1]; } // Reorder the loops std::vector reorder_list; bool meet = false; - Array all_loops = sch->GetLoops(block_rv); + ffi::Array all_loops = sch->GetLoops(block_rv); for (const LoopRV& loop : all_loops) { if (inner_loops.count(sch->GetSRef(loop).operator->())) { meet = true; @@ -441,16 +446,17 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block return reorder_suffix[0]; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.schedule.TileWithTensorIntrin", TileWithTensorIntrin); -}); +} /******** BlockBufferAccessSimplifier ********/ -void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_access_regions) { +void BlockBufferAccessSimplifier::SimplifyAccessRegion( + ffi::Array* old_access_regions) { auto fmutate = [this](const BufferRegion& buffer_region) { - Array new_buffer_region; - Array simplified_min; + ffi::Array new_buffer_region; + ffi::Array simplified_min; for (const auto& range : buffer_region->region) { simplified_min.push_back(range->min); } @@ -466,7 +472,7 @@ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_ (*old_access_regions).MutateByApply(fmutate); } -void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array* indices) { +void BlockBufferAccessSimplifier::SimplifyBufferIndices(ffi::Array* indices) { *indices = this->IterMapSimplifyWithContext(*indices, true); } @@ -492,8 +498,8 @@ PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { /******** PrimFunc-level analysis and transformation ********/ -void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, Array* leaf_blocks) { - Array blocks = sch->GetChildBlocks(cur_block_rv); +void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, ffi::Array* leaf_blocks) { + ffi::Array blocks = sch->GetChildBlocks(cur_block_rv); if (blocks.empty()) { leaf_blocks->push_back(cur_block_rv); } else { @@ -503,14 +509,14 @@ void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, Array* lea } } -Optional NormalizePrimFunc(Schedule sch) { +ffi::Optional NormalizePrimFunc(Schedule sch) { BlockRV root_block = sch->GetBlock("root"); - Array leaf_blocks; + ffi::Array leaf_blocks; GetLeafBlocksHelper(sch, root_block, &leaf_blocks); for (const BlockRV& block : leaf_blocks) { StmtSRef block_sref = sch->GetSRef(block); - Array loops = GetLoops(block_sref); - Array binds = GetBlockRealize(sch->state(), block_sref)->iter_values; + ffi::Array loops = GetLoops(block_sref); + ffi::Array binds = GetBlockRealize(sch->state(), block_sref)->iter_values; if (loops.size() == 0) continue; if (loops.size() != binds.size()) { return std::nullopt; @@ -526,14 +532,14 @@ Optional NormalizePrimFunc(Schedule sch) { } } - Array> block_loops; - Array> block_iters; - Array block_is_reduction; + ffi::Array> block_loops; + ffi::Array> block_iters; + ffi::Array block_is_reduction; for (const BlockRV& block : leaf_blocks) { - Array iters = sch->Get(block)->iter_vars; + ffi::Array iters = sch->Get(block)->iter_vars; bool has_spatial_iter = false; - Array index_map_inputs; - Array index_map_outputs; + ffi::Array index_map_inputs; + ffi::Array index_map_outputs; for (const IterVar& iter : sch->Get(block)->iter_vars) { Var var = iter->var.copy_with_suffix(""); index_map_inputs.push_back(var); @@ -559,13 +565,13 @@ Optional NormalizePrimFunc(Schedule sch) { sch->GetSRef(root_block)); block_is_reduction.push_back(Bool(is_reduction)); } - return Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; + return ffi::Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.schedule.NormalizePrimFunc", NormalizePrimFunc); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 73d6a0d85371..6e26f48320db 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -41,7 +41,8 @@ namespace tir { * \param attr_value The annotation value to be added * \return A new block with the given annotation as its last annotation */ -Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); +Block WithAnnotation(const BlockNode* block, const ffi::String& attr_key, + const ObjectRef& attr_value); /******** Buffer Related ********/ @@ -51,7 +52,7 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec * \param scope The target storage scope. * \return The new buffer with target storage scope. */ -Buffer WithScope(const Buffer& buffer, const String& scope); +Buffer WithScope(const Buffer& buffer, const ffi::String& scope); /*! * \brief Create a new buffer by changint the data type. @@ -68,8 +69,8 @@ Buffer WithDType(const Buffer& buffer, const DataType& dtype); * \param target The buffer to be replaced to * \return The new sequence of regions after replacement */ -Array ReplaceBuffer(Array regions, const Buffer& source, - const Buffer& target); +ffi::Array ReplaceBuffer(ffi::Array regions, const Buffer& source, + const Buffer& target); /*! * \brief Replaces the buffer within the specific sequence of regions @@ -77,8 +78,8 @@ Array ReplaceBuffer(Array regions, const Buffer& sou * \param buffer_map The mapping from old buffers to new buffers * \return The new sequence of regions after replacement */ -Array ReplaceBuffer(Array regions, - const Map& buffer_map); +ffi::Array ReplaceBuffer(ffi::Array regions, + const ffi::Map& buffer_map); /*! * \brief Replaces the buffer within the specific sequence of match_buffers @@ -87,8 +88,8 @@ Array ReplaceBuffer(Array regions, * \param target The buffer to be replaced to * \return The new sequence of match_buffers after replacement */ -Array ReplaceBuffer(Array match_buffers, const Buffer& source, - const Buffer& target); +ffi::Array ReplaceBuffer(ffi::Array match_buffers, + const Buffer& source, const Buffer& target); /*! * \brief Replaces the buffer region within the specific sequence of regions @@ -97,8 +98,9 @@ Array ReplaceBuffer(Array match_buffers, c * \param target The buffer region to be replaced to * \return The new sequence of regions after replacement */ -Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, - const BufferRegion& target); +ffi::Array ReplaceBufferRegion(ffi::Array regions, + const Buffer& source_buffer, + const BufferRegion& target); /*! * \brief Replaces the buffer region within the specific sequence of match_buffers @@ -107,9 +109,9 @@ Array ReplaceBufferRegion(Array regions, const Buffe * \param target The buffer region to be replaced to * \return The new sequence of match_buffers after replacement */ -Array ReplaceBufferRegion(Array match_buffers, - const Buffer& source_buffer, - const BufferRegion& target); +ffi::Array ReplaceBufferRegion(ffi::Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target); /*! * \brief A helper mutator which recursively replaces the old buffer with the new buffer and @@ -129,9 +131,10 @@ class ReplaceBufferMutator : public StmtExprMutator { * sref. */ ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse); + ffi::Map* block_sref_reuse); - ReplaceBufferMutator(const Map& buffer_map, Map* block_sref_reuse); + ReplaceBufferMutator(const ffi::Map& buffer_map, + ffi::Map* block_sref_reuse); protected: using StmtExprMutator::VisitExpr_; @@ -162,7 +165,7 @@ class ReplaceBufferMutator : public StmtExprMutator { */ std::unordered_map buffer_var_map_; /*! \brief The block sref reuse map for the following replacement */ - Map* block_sref_reuse_; + ffi::Map* block_sref_reuse_; }; /******** Block Removal ********/ @@ -214,8 +217,10 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ * \return LoopRV corresponding to the outermost loop of a * block tiled according to the given intrin, std::nullopt if a valid loop mapping is not found */ -Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name, bool allow_padding = false); +ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, + const tir::BlockRV& block_rv, + const ffi::String& intrin_name, + bool allow_padding = false); /******** Block mutation ********/ @@ -242,8 +247,8 @@ class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - void SimplifyAccessRegion(Array* old_access_regions); - void SimplifyBufferIndices(Array* indices); + void SimplifyAccessRegion(ffi::Array* old_access_regions); + void SimplifyBufferIndices(ffi::Array* indices); Stmt VisitStmt_(const BlockNode* op) final; Stmt VisitStmt_(const BufferStoreNode* op) final; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 0c35c5f043a2..cd48cb13d5aa 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -56,12 +56,12 @@ namespace tir { * \param loop_srefs The loop StmtSRefs to be converted * \return The conversion result loops */ -inline Array LoopSRefs2Loops(const Array& loop_srefs) { - Array loops; +inline ffi::Array LoopSRefs2Loops(const ffi::Array& loop_srefs) { + ffi::Array loops; loops.reserve(loop_srefs.size()); for (StmtSRef loop_sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - loops.push_back(GetRef(loop)); + loops.push_back(ffi::GetRef(loop)); } return loops; } @@ -72,8 +72,9 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { * \param block_rvs The random variables to be converted * \return The conversion result srefs */ -inline Array BlockRVs2StmtSRefs(const Schedule& sch, const Array& block_rvs) { - Array block_srefs; +inline ffi::Array BlockRVs2StmtSRefs(const Schedule& sch, + const ffi::Array& block_rvs) { + ffi::Array block_srefs; block_srefs.reserve(block_rvs.size()); for (const BlockRV& block_rv : block_rvs) { block_srefs.push_back(sch->GetSRef(block_rv)); @@ -110,7 +111,7 @@ inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scop */ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { ICHECK_GT(seq->size(), 1); - Array new_stmts; + ffi::Array new_stmts; new_stmts.reserve(seq->size()); for (const Stmt& stmt : seq->seq) { if (to_remove.same_as(stmt)) { @@ -132,7 +133,7 @@ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { * \return If the Stmt is SeqStmt, then returns the sequence; * Otherwise, returns a single-element Array with the Stmt inside. */ -inline Array AsArray(const Stmt& stmt) { +inline ffi::Array AsArray(const Stmt& stmt) { if (const auto* seq_stmt = stmt.as()) { return seq_stmt->seq; } @@ -160,7 +161,7 @@ inline bool IsSingleStmt(const Stmt& stmt) { * \param iter_var_type The type of the new IterVar * \return The newly created IterVar */ -inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_var_type) { +inline IterVar IterVarFromLoop(const For& loop, ffi::String name, IterVarType iter_var_type) { return IterVar(Range::FromMinExtent(loop->min, loop->extent), Var(std::move(name), loop->loop_var.dtype()), iter_var_type); } @@ -221,10 +222,11 @@ inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { * \return The single variable in the expression, or std::nullopt if the expression is neither a * variable or a constant shift from a variable */ -inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { +inline ffi::Optional AnalyzeVarWithShift(const PrimExpr& expr, + ffi::Optional* constant) { if (const auto* var = expr.as()) { *constant = std::nullopt; - return GetRef(var); + return ffi::GetRef(var); } arith::PVar var; arith::PVar shift; @@ -252,8 +254,8 @@ inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* * \return std::nullopt if not found; otherwise the annotation value */ template -inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) { - const Map* annotations = &stmt->annotations; +inline ffi::Optional GetAnn(const TStmtNode* stmt, const ffi::String& ann_key) { + const ffi::Map* annotations = &stmt->annotations; for (const auto& ann : *annotations) { if (ann.first == ann_key) { return Downcast(ann.second); @@ -270,7 +272,7 @@ inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) * \return std::nullopt if not found; otherwise the annotation value */ template -inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) { +inline ffi::Optional GetAnn(const StmtSRef& sref, const ffi::String& ann_key) { if (const auto* loop = sref->StmtAs()) { return GetAnn(loop, ann_key); } else if (const auto* block = sref->StmtAs()) { @@ -288,8 +290,8 @@ inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) * \param ann_val The annotation value to be checked * \return Whether a Block/For has a specific pair of annotation key and values */ -inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& ann_val) { - Optional result = GetAnn(sref, ann_key); +inline bool HasAnn(const StmtSRef& sref, const ffi::String& ann_key, const ffi::String& ann_val) { + ffi::Optional result = GetAnn(sref, ann_key); return result.has_value() && result.value() == ann_val; } @@ -300,8 +302,8 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& an * \param ann_val The boolean annotation value to be checked * \return Whether a Block/For has a specific pair of annotation key and values */ -inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { - Optional result = GetAnn(sref, ann_key); +inline bool HasAnn(const StmtSRef& sref, const ffi::String& ann_key, bool ann_val) { + ffi::Optional result = GetAnn(sref, ann_key); return result.defined() && result.value() == ann_val; } @@ -319,13 +321,13 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, tir::LoopRV* fused_reduce_loop, size_t* num_spatial_loops) { - Array loops = sch->GetLoops(block_rv); - Array loop_srefs; + ffi::Array loops = sch->GetLoops(block_rv); + ffi::Array loop_srefs; for (const tir::LoopRV& loop_rv : loops) { loop_srefs.push_back(sch->GetSRef(loop_rv)); } - Array new_order; + ffi::Array new_order; // Step 1. Add spatial loops. *num_spatial_loops = 0; for (size_t i = 0; i < loops.size(); ++i) { @@ -335,7 +337,7 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl } } // Step 2. Add reduction loops. - Array reduction_loops; + ffi::Array reduction_loops; for (size_t i = 0; i < loops.size(); ++i) { if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { new_order.push_back(loops[i]); @@ -366,7 +368,7 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl * \param buffer_index_type The BufferIndexType value to convert * \return The string representation of BufferIndexType */ -inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { +inline ffi::String BufferIndexType2Str(BufferIndexType buffer_index_type) { if (buffer_index_type == BufferIndexType::kRead) { return "read"; } else { @@ -409,8 +411,8 @@ inline bool HasBlock(const Schedule& sch, const std::string& block_name) { * \param rv_map The substitution map for variables. * \return The transformed objects. */ -Array TranslateInputRVs(const Array& inputs, - const std::unordered_map& rv_map); +ffi::Array TranslateInputRVs(const ffi::Array& inputs, + const std::unordered_map& rv_map); /*! * \brief Update the variable substitution map according to the new outputs. @@ -418,7 +420,7 @@ Array TranslateInputRVs(const Array& inputs, * \param new_outputs The new outputs of the same schedule instruction. * \param rv_map The substitution map for variables. */ -void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, +void TranslateAddOutputRVs(const ffi::Array& old_outputs, const ffi::Array& new_outputs, std::unordered_map* rv_map); /*! @@ -427,7 +429,7 @@ void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_ * \param remove_postproc If postprocessing instructions are removed. * \return Number of instructions. */ -int GetNumValidInstructions(const Array& insts, bool remove_postproc); +int GetNumValidInstructions(const ffi::Array& insts, bool remove_postproc); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index 5cd2d6556572..47b3df5fdaa3 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -40,12 +40,12 @@ class DeviceRegionAnnotater : public StmtMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. - return GetRef(op); + return ffi::GetRef(op); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. - Stmt body = GetRef(op); + Stmt body = ffi::GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { // All other annotations are ignored @@ -75,10 +75,10 @@ Pass AnnotateDeviceRegions() { return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.AnnotateDeviceRegions", AnnotateDeviceRegions); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/annotate_irregular_loop.cc b/src/tir/transforms/annotate_irregular_loop.cc new file mode 100644 index 000000000000..c715922d60b3 --- /dev/null +++ b/src/tir/transforms/annotate_irregular_loop.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class IrregularLoopAnnotator : public StmtMutator { + public: + static Stmt Annotate(const Stmt& body) { return IrregularLoopAnnotator().VisitStmt(body); } + + private: + IrregularLoopAnnotator() = default; + + Stmt VisitStmt_(const ForNode* op) final { + bool cur_has_jump = has_jump_; + has_jump_ = false; + For res = Downcast(StmtMutator::VisitStmt_(op)); + if (has_jump_) { + CHECK(op->kind == ForKind::kSerial) + << "Loop kind " << op->kind << " is invalid for irregular loop " << op->loop_var; + for (const char* key : {attr::pragma_auto_unroll_max_step, attr::pragma_unroll_explicit, + attr::pragma_loop_partition_hint, attr::software_pipeline_stage}) { + CHECK(!res->annotations.count(key)) + << "Annotation `" << key << "` is invalid for irregular loop " << op->loop_var; + } + res.CopyOnWrite()->annotations.Set(attr::irregular_loop_mark, 1); + } + std::swap(cur_has_jump, has_jump_); + return res; + } + + Stmt VisitStmt_(const WhileNode* op) final { + bool cur_has_jump = has_jump_; + has_jump_ = false; + Stmt res = StmtMutator::VisitStmt_(op); + std::swap(cur_has_jump, has_jump_); + return res; + } + + Stmt VisitStmt_(const EvaluateNode* op) final { + if (const CallNode* call = op->value.as()) { + if (call->op.same_as(builtin::continue_loop()) || call->op.same_as(builtin::break_loop())) { + has_jump_ = true; + } + } + return ffi::GetRef(op); + } + + bool has_jump_{false}; +}; + +namespace transform { + +Pass AnnotateIrregularLoop() { + auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc { + func.CopyOnWrite()->body = IrregularLoopAnnotator::Annotate(func->body); + return func; + }; + + return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateIrregularLoop", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.AnnotateIrregularLoop", AnnotateIrregularLoop); +} + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 5b9e005b7ea3..1b85d7d21132 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -76,7 +76,7 @@ void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::stri Bind_(arg, value, arg_name, with_let); } -void ArgBinder::BindArray(const Array& arg, const Array& value, +void ArgBinder::BindArray(const ffi::Array& arg, const ffi::Array& value, const std::string& arg_name) { ICHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; for (size_t i = 0; i < arg.size(); ++i) { @@ -93,8 +93,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " - << " required_alignment=" << arg->data_alignment - << ", provided_alignment=" << value->data_alignment; + << " required alignment=" << arg->data_alignment + << ", provided alignment=" << value->data_alignment; } if (value->elem_offset.defined()) { @@ -218,12 +218,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, init_nest_.emplace_back(LetStmt( buf_strides->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); - PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); + PrimExpr v_strides_is_null = Call(DataType::Bool(), builtin::isnullptr(), {buf_strides->data}); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); PrimExpr expect_stride = make_const(stype, 1); - Array conds; + ffi::Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h index 68cbbb677311..fad5e4d70222 100644 --- a/src/tir/transforms/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -79,7 +79,7 @@ class ArgBinder { * \param value The target expression value * \param arg_name argument name. */ - void BindArray(const Array& arg, const Array& value, + void BindArray(const ffi::Array& arg, const ffi::Array& value, const std::string& arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer @@ -145,7 +145,7 @@ class ArgBinder { */ const std::vector& init_nest() const { return init_nest_; } /*! \return Handle data type of the data */ - const Map& def_handle_dtype() const { return def_handle_dtype_; } + const ffi::Map& def_handle_dtype() const { return def_handle_dtype_; } private: // Internal bind function @@ -158,7 +158,7 @@ class ArgBinder { /*! \brief Initialize nest */ std::vector init_nest_; /*! \brief handle data type in the defintiions */ - Map def_handle_dtype_; + ffi::Map def_handle_dtype_; /*! \brief asserts generated */ std::vector asserts_; /*! \brief internal analyzer. */ diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 06d596adb44d..2b4598a99fa7 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -40,7 +40,7 @@ namespace tir { class ParamsCollector : public StmtExprVisitor { public: - explicit ParamsCollector(const Map& constant_map) + explicit ParamsCollector(const ffi::Map& constant_map) : constant_map_(constant_map) {} std::vector CollectParams(tir::Stmt body) { this->VisitStmt(body); @@ -75,16 +75,16 @@ class ParamsCollector : public StmtExprVisitor { private: std::vector constant_list_; - Map constant_map_; + ffi::Map constant_map_; }; -PrimFunc BindParams(PrimFunc f, const Array& constants) { - Map constant_map; +PrimFunc BindParams(PrimFunc f, const ffi::Array& constants) { + ffi::Map constant_map; // Remove constants from the primfunc signature size_t num_constants = constants.size(); size_t start = f->params.size() - num_constants; - Array params; + ffi::Array params; for (unsigned i = 0; i < start; i++) { params.push_back(f->params[i]); } @@ -101,9 +101,9 @@ PrimFunc BindParams(PrimFunc f, const Array& constants) { // Allocate constants within the primfunc for (auto i : constant_list) { - auto var = GetRef(i); + auto var = ffi::GetRef(i); int ndim = constant_map[var]->ndim; - Array extents; + ffi::Array extents; for (int i = 0; i < ndim; i++) { int shape = constant_map[var]->shape[i]; @@ -126,7 +126,7 @@ PrimFunc BindParams(PrimFunc f, const Array& constants) { namespace transform { -Pass BindParams(const Array& constants) { +Pass BindParams(const ffi::Array& constants) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return BindParams(f, constants); }; diff --git a/src/tir/transforms/bind_target.cc b/src/tir/transforms/bind_target.cc index 46a40228eaa1..9ec0a506a314 100644 --- a/src/tir/transforms/bind_target.cc +++ b/src/tir/transforms/bind_target.cc @@ -71,7 +71,7 @@ class FunctionClassifierVisitor : public StmtExprVisitor { // Only analyze externally exposed functions as potential callers // since they represent the entry points where host/device calls originate for (const auto& [gvar, func] : mod->functions) { - bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); const auto* prim_func = func.as(); if (is_externally_exposed && prim_func != nullptr) { @@ -144,7 +144,7 @@ class CallSubstitutor : public StmtExprMutator { * \brief Constructor with function replacement mapping. * \param replacements Map from original GlobalVar to host-specific GlobalVar */ - explicit CallSubstitutor(const Map& replacements) + explicit CallSubstitutor(const ffi::Map& replacements) : replacements_(replacements) {} /*! @@ -212,7 +212,7 @@ class CallSubstitutor : public StmtExprMutator { /*! \brief Whether the current statement is under a GPU scope */ bool is_under_gpu_scope_ = false; /*! \brief Mapping from original functions to host-specific duplicates */ - Map replacements_; + ffi::Map replacements_; }; /*! @@ -238,7 +238,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { auto target_without_host = target.WithoutHost(); auto mod_copy_on_write = mod.CopyOnWrite(); - auto new_mod = GetRef(mod_copy_on_write); + auto new_mod = ffi::GetRef(mod_copy_on_write); // Step 1: Analyze function call patterns auto [host_called_global_vars, device_called_global_vars] = @@ -257,7 +257,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { // 2.4 If the function is not called by any host or device, skip binding // Track duplicated functions for call replacement - Map host_function_replacements; + ffi::Map host_function_replacements; GlobalVarSupply gvar_supply(new_mod); for (auto [gvar, func] : mod->functions) { @@ -266,9 +266,10 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Skip non-PrimFunc entries continue; } - auto prim_func = GetRef(prim_func_node); + auto prim_func = ffi::GetRef(prim_func_node); - bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_externally_exposed = + prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (auto func_target = func->GetAttr(tvm::attr::kTarget)) { // Rule 1: If the function has a target, and the target has a host, and the function does not @@ -308,7 +309,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Create duplicate with host target for host callers host_func = WithAttr(std::move(host_func), tvm::attr::kTarget, target_host); - String host_func_name = gvar->name_hint + "_host"; + ffi::String host_func_name = gvar->name_hint + "_host"; GlobalVar host_gvar = gvar_supply->FreshGlobal(host_func_name, false); new_mod->Add(host_gvar, host_func); @@ -341,7 +342,8 @@ IRModule BindTarget(IRModule mod, const Target& target) { continue; } - bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_externally_exposed = + prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_externally_exposed) { // Update calls in externally exposed functions to use host duplicates PrimFunc new_func = substitutor.Substitute(Downcast(func)); @@ -371,10 +373,10 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreateModulePass(fpass, 0, "tir.BindTarget", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.BindTarget", BindTarget); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 6d5537e7756e..99d990ece627 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -58,12 +58,13 @@ class BoundCollector : public StmtVisitor { StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map> mem_to_shape; + std::unordered_map> mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: - explicit BoundChecker(const std::unordered_map>& mem_to_shape) + explicit BoundChecker( + const std::unordered_map>& mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -95,13 +96,13 @@ class BoundChecker : public StmtExprMutator { PrimExpr condition = MakeCondition(); if (!condition.as()) { Stmt nop = Evaluate(1); - Stmt then_case = GetRef(op); + Stmt then_case = ffi::GetRef(op); Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop); Stmt body = IfThenElse(condition, then_case, else_case); return body; } } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -116,7 +117,7 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, Array new_shape, const DataType& type) { + void Update(const Var& buffer_var, ffi::Array new_shape, const DataType& type) { // Sanity check at first. if (!ShapeIsValid(new_shape)) { return; @@ -129,7 +130,7 @@ class BoundChecker : public StmtExprMutator { mem_to_shape_[buffer_var.get()] = new_shape; } - bool ShapeIsValid(const Array& shape) const { + bool ShapeIsValid(const ffi::Array& shape) const { if (!shape.defined()) { return false; } @@ -142,7 +143,7 @@ class BoundChecker : public StmtExprMutator { return true; } - bool IndicesAreValid(const Array& indices) const { + bool IndicesAreValid(const ffi::Array& indices) const { if (!indices.defined()) { return false; } @@ -176,12 +177,12 @@ class BoundChecker : public StmtExprMutator { return expr.defined() && expr.dtype().is_scalar(); } - bool CanInstrument(const Array& indices, const Var& buffer_var) const { + bool CanInstrument(const ffi::Array& indices, const Var& buffer_var) const { return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndicesAreValid(indices) && !unsafe_rewritten_; } - void Collect(Array indices, Var buffer_var) { + void Collect(ffi::Array indices, Var buffer_var) { store_scope_bound_collector_.push_back( std::make_pair(indices, mem_to_shape_[buffer_var.get()])); } @@ -189,8 +190,8 @@ class BoundChecker : public StmtExprMutator { PrimExpr MakeCondition() { PrimExpr condition; for (const auto& pair : store_scope_bound_collector_) { - Array indices = pair.first; - Array shape = pair.second; + ffi::Array indices = pair.first; + ffi::Array shape = pair.second; ICHECK_EQ(indices.size(), shape.size()) << "Mismatch between dimension of physical shape and physical indices"; @@ -200,7 +201,7 @@ class BoundChecker : public StmtExprMutator { PrimExpr upper_bound = shape[i]; if (const RampNode* ramp_index = index.as()) { - index = arith::UnwrapVectorExpr(GetRef(ramp_index), ramp_index->lanes); + index = arith::UnwrapVectorExpr(ffi::GetRef(ramp_index), ramp_index->lanes); } // Try to simplify index and bound. @@ -226,11 +227,11 @@ class BoundChecker : public StmtExprMutator { // Whether we face tvm_if_then_else intrinsic. bool unsafe_rewritten_{false}; // Pool which collects the pair of index and shape for specific store/load. - std::vector, Array>> store_scope_bound_collector_; + std::vector, ffi::Array>> store_scope_bound_collector_; // Error message. const char* const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map> mem_to_shape_; + std::unordered_map> mem_to_shape_; // internal analyzer arith::Analyzer analyzer_; }; @@ -256,10 +257,10 @@ Pass InstrumentBoundCheckers() { return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InstrumentBoundCheckers", InstrumentBoundCheckers); -}); +} } // namespace transform diff --git a/src/tir/transforms/canonicalize_loop.cc b/src/tir/transforms/canonicalize_loop.cc new file mode 100644 index 000000000000..93511bf84bb2 --- /dev/null +++ b/src/tir/transforms/canonicalize_loop.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/transforms/canonicalize_loop.cc + * \brief Canonicalize all loops to start from zero and step one. + */ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +class LoopCanonicalizer : public StmtExprMutator { + public: + LoopCanonicalizer() = default; + + private: + Stmt VisitStmt_(const ForNode* op) final { + if (is_zero(op->min) && op->HasTrivialStep()) { + return StmtExprMutator::VisitStmt_(op); + } + arith::Analyzer analyzer; + const auto* loop_var = op->loop_var.get(); + PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1)); + + // report warning for negative step, since it would be a forever loop + if (!analyzer.CanProveGreaterEqual(step, 1)) { + // TODO(tvm): prove dynamic shaped step + LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive: " << step; + } + + new_iter_info_[loop_var] = std::make_pair(step, op->min); + auto n = CopyOnWrite(op); + n->body = VisitStmt(op->body); + n->min = make_zero(loop_var->dtype); + n->extent = analyzer.Simplify(ceildiv(op->extent, step)); + n->step = std::nullopt; + new_iter_info_.erase(loop_var); + return For(n); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = new_iter_info_.find(op); + if (it != new_iter_info_.end()) { + const auto& [stride, offset] = it->second; + return ffi::GetRef(op) * stride + offset; + } + return ffi::GetRef(op); + } + + /*! \brief Map iter variable `x` to `x * stride + offset`. */ + std::unordered_map> new_iter_info_; +}; + +PrimFunc CanonicalizeLoop(PrimFunc func) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = LoopCanonicalizer()(func->body); + return func; +} + +namespace transform { + +Pass CanonicalizeLoop() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return CanonicalizeLoop(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.CanonicalizeLoop", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.CanonicalizeLoop", CanonicalizeLoop); +} + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 2945c8e20f97..bd9d67352659 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -113,10 +113,10 @@ Pass CombineContextCall() { return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.CombineContextCall", CombineContextCall); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index 23c7d88d47c9..9b9619fae937 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -150,8 +150,8 @@ Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) { // Builds the variable name, which is cse_vi where i will go up from 1 std::string prefix = "cse_v"; std::string name = prefix.append(std::to_string(num_last_try_)); - // Builds a String using the std::string - String string_name(name); + // Builds a ffi::String using the std::string + ffi::String string_name(name); // Check that the name that we want to use for the new variable isn't already being used // (names don't really have to be unique as they are just hints, and having the same name @@ -280,11 +280,11 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { [](const std::pair& pair) { return pair.first; }; std::vector vector_vars_known = VectorMap(context_, forget_value); // 2.2 - Transform the std::vector into an Array - Array array_vars_known = Array(vector_vars_known); + ffi::Array array_vars_known = ffi::Array(vector_vars_known); // --- End of chunk needed for reusing the UndefinedVars() analysis --- // We use the UndefinedVars() analysis to get the undefined vars of the computation - Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); + ffi::Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); // Check if we can introduce it : if it contains no undefined variables and if we want // to introduce it according to the predicate @@ -375,7 +375,7 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) { // If the `value` and the `body` of the let-in have been rewritten to the same thing if (value_new.same_as(op->value) && body_new.same_as(op->body)) { // then return a reference to the same node - return GetRef(op); + return ffi::GetRef(op); } else { // Otherwise return a let-in built with the new `value_new` and the new `body_new` that // have just been obtained @@ -460,11 +460,11 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { [](const std::pair& pair) { return pair.first; }; std::vector vector_vars_known = VectorMap(context_, forget_value); // 2.2 - Transform the std::vector into an Array - Array array_vars_known = Array(vector_vars_known); + ffi::Array array_vars_known = ffi::Array(vector_vars_known); // --- End of chunk needed for reusing the UndefinedVars() analysis --- // We use the UndefinedVars() analysis to get the undefined vars of the computation - Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); + ffi::Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); // Check if we can introduce it : if it contains no undefined variables and if we want // to introduce it according to the predicate @@ -556,7 +556,7 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const LetStmtNode* op) { // If the `value` and the `body` of the let-in have been rewritten to the same thing if (value_new.same_as(op->value) && body_new.same_as(op->body)) { // Return a reference to the same node - return GetRef(op); + return ffi::GetRef(op); } else { // Otherwise return a let-in built with the new `value_new` and the new `body_new` that // have just been obtained @@ -597,12 +597,12 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { // If the `min`, `extent` and `body` of the for loop have been rewritten to the same thing if (min_new.same_as(op->min) && extent_new.same_as(op->extent) && body_new.same_as(op->body)) { // Return a reference to the same node - return GetRef(op); + return ffi::GetRef(op); } else { // Otherwise return a for node built with the new `min_new`, `extent_new` and `body_new` // that have just been obtained return For(op->loop_var, min_new, extent_new, op->kind, body_new, op->thread_binding, - op->annotations, op->span); + op->annotations, op->step, op->span); } } @@ -638,10 +638,10 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) { } // The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.CommonSubexprElimTIR", CommonSubexprElimTIR); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index f71d2cf42a02..1c52c6f97f5d 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -447,7 +447,7 @@ void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) { // Copy the `table_of_computations_` into the cache // for the future queries - Stmt ref_to_op = GetRef(op); + Stmt ref_to_op = ffi::GetRef(op); cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; } @@ -482,7 +482,7 @@ void ComputationsDoneBy::VisitStmt_(const ForNode* op) { // Copy the `table_of_computations_` into the cache // for the future queries - Stmt ref_to_op = GetRef(op); + Stmt ref_to_op = ffi::GetRef(op); cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; } @@ -512,7 +512,7 @@ void ComputationsDoneBy::VisitStmt_(const WhileNode* op) { // Copy the `table_of_computations_` into the cache // for the future queries - Stmt ref_to_op = GetRef(op); + Stmt ref_to_op = ffi::GetRef(op); cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; } @@ -646,7 +646,7 @@ void DirectSubexpr::VisitExpr(const PrimExpr& expr) { * \param var_name The variable name to check for * \return A boolean telling if `expr` uses `var_name` */ -bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, String var_name) { +bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, ffi::String var_name) { UsesVarName uses_var_name(var_name); uses_var_name.VisitExpr(expr); @@ -659,7 +659,7 @@ bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, String var_name) { * \param var_name The variable name to check for * \return A boolean telling if `stmt` uses `var_name` */ -bool UsesVarName::StmtUsesVarName(const Stmt& stmt, String var_name) { +bool UsesVarName::StmtUsesVarName(const Stmt& stmt, ffi::String var_name) { UsesVarName uses_var_name(var_name); uses_var_name.VisitStmt(stmt); @@ -668,9 +668,9 @@ bool UsesVarName::StmtUsesVarName(const Stmt& stmt, String var_name) { /*! * \brief Protected constructor of UsesVarName. - * \param var_name The String that we are looking for + * \param var_name The ffi::String that we are looking for */ -UsesVarName::UsesVarName(String var_name) : var_name_(var_name) {} +UsesVarName::UsesVarName(ffi::String var_name) : var_name_(var_name) {} /*! * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions. diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 31a81dabdbf2..ab1e76592a90 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -158,18 +158,18 @@ class DirectSubexpr : public ExprVisitor { class UsesVarName : public StmtExprVisitor { public: // Toplevel (static) methods - static bool ExprUsesVarName(const PrimExpr& expr, String var_name); - static bool StmtUsesVarName(const Stmt& stmt, String var_name); + static bool ExprUsesVarName(const PrimExpr& expr, ffi::String var_name); + static bool StmtUsesVarName(const Stmt& stmt, ffi::String var_name); protected: // Constructor - explicit UsesVarName(String var_name); + explicit UsesVarName(ffi::String var_name); void VisitExpr(const PrimExpr& expr) override; void VisitStmt(const Stmt& stmt) override; private: - String var_name_; + ffi::String var_name_; bool uses_var_name_ = false; }; diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index a1e99313b663..0ba4e75c3004 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -49,9 +49,9 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate, arith::Analyzer* analyzer) { std::unordered_map var_dom; for (const auto& it : dom_map) { - var_dom[GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); + var_dom[ffi::GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); } - Optional> eval_res = + ffi::Optional> eval_res = arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer); if (eval_res.defined()) { @@ -146,7 +146,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef(op)); } + void VisitExpr_(const VarNode* op) final { VisitBufferVar(ffi::GetRef(op)); } void VisitStmt_(const ForNode* op) final { Range loop_range = Range::FromMinExtent(op->min, op->extent); @@ -243,10 +243,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } // Step 2. Record explicit read/write region annotations - auto record_explicit_region = [&](const String& attr_key, BufferIndexType index_type) { + auto record_explicit_region = [&](const ffi::String& attr_key, BufferIndexType index_type) { auto it = op->annotations.find(attr_key); if (it != op->annotations.end()) { - Array buffer_indices = Downcast>((*it).second); + ffi::Array buffer_indices = Downcast>((*it).second); for (const auto& index : buffer_indices) { int buffer_index = index->value; if (buffer_index >= 0 && buffer_index < static_cast(op->reads.size())) { @@ -430,9 +430,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { ICHECK(it != relaxed_accesses_.end()) << buffer << " is allocated but not accessed within block scope"; - const Array& original_shape = buffer->shape; + const ffi::Array& original_shape = buffer->shape; const NDIntSet& nd_int_set = it->second; - Array& result_region = buffer_access_region_[buffer]; + ffi::Array& result_region = buffer_access_region_[buffer]; result_region.resize(nd_int_set.size()); for (size_t i = 0; i < nd_int_set.size(); ++i) { @@ -566,7 +566,7 @@ class BufferCompactor : public StmtExprMutator { // Step 0. Check there is no Init part. ICHECK(!op->init.defined()); // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo. - Array alloc_buffers = + ffi::Array alloc_buffers = op->alloc_buffers.Map([this](const Buffer& buf) { return RewriteAllocBuffer(buf); }); // Step 2. Recursively rewrite BufferLoad/BufferStore. Block block = Downcast(StmtExprMutator::VisitStmt_(op)); @@ -600,7 +600,7 @@ class BufferCompactor : public StmtExprMutator { if (op->dtype != new_buffer->dtype) { return allocate; } - Array new_shape = GetBufferAllocationShape(new_buffer); + ffi::Array new_shape = GetBufferAllocationShape(new_buffer); auto n = allocate.CopyOnWrite(); ICHECK(n->buffer_var.same_as(new_buffer->data)); n->extents = new_shape; @@ -615,7 +615,7 @@ class BufferCompactor : public StmtExprMutator { return buffer; } - void RewriteBufferAccess(Buffer* buffer, Array* indices) const { + void RewriteBufferAccess(Buffer* buffer, ffi::Array* indices) const { auto it = buffer_info_.find((*buffer)->data); if (it == buffer_info_.end()) { return; @@ -623,7 +623,7 @@ class BufferCompactor : public StmtExprMutator { const BufferAllocInfo& info = it->second; ICHECK_EQ(indices->size(), info.region.size()); int ndim = info.region.size(); - Array new_indices; + ffi::Array new_indices; new_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { new_indices.push_back((*indices)[i] - info.region[i]->min); @@ -650,8 +650,8 @@ class BufferCompactor : public StmtExprMutator { *region = std::move(new_region); } - void RewriteBufferRegions(Array* regions) const { - Array new_regions; + void RewriteBufferRegions(ffi::Array* regions) const { + ffi::Array new_regions; new_regions.reserve(regions->size()); for (const auto& region : *regions) { BufferRegion buffer_region = region; @@ -662,12 +662,12 @@ class BufferCompactor : public StmtExprMutator { *regions = std::move(new_regions); } - void RewriteMatchBuffers(Array* match_buffers) const { - Array result; + void RewriteMatchBuffers(ffi::Array* match_buffers) const { + ffi::Array result; result.reserve(match_buffers->size()); for (const auto& match_buffer : *match_buffers) { const BufferRegion& buffer_region = match_buffer->source; - auto p = make_object(*buffer_region.get()); + auto p = ffi::make_object(*buffer_region.get()); RewriteBufferRegion(&p->buffer, &p->region); result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p))); } @@ -678,7 +678,8 @@ class BufferCompactor : public StmtExprMutator { std::unordered_map buffer_info_; }; -Array CalcStrides(const BufferAllocInfo& alloc_info, const Array& shape) { +ffi::Array CalcStrides(const BufferAllocInfo& alloc_info, + const ffi::Array& shape) { std::vector strides; if (alloc_info.dim_aligns.size()) { ICHECK(alloc_info.dim_aligns.size() == shape.size()); @@ -725,9 +726,9 @@ Stmt BufferCompactorCompact( } // prepare new buffer - Array shape = region.Map([](const Range& range) { return range->extent; }); - Array strides = CalcStrides(alloc_info, shape); - ObjectPtr n = make_object(*buffer.get()); + ffi::Array shape = region.Map([](const Range& range) { return range->extent; }); + ffi::Array strides = CalcStrides(alloc_info, shape); + ObjectPtr n = ffi::make_object(*buffer.get()); n->shape = std::move(shape); n->strides = std::move(strides); alloc_info.new_buffer = Buffer(std::move(n)); @@ -756,10 +757,10 @@ Pass CompactBufferAllocation(bool is_strict) { return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.CompactBufferAllocation", CompactBufferAllocation); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index bd340df97e61..f187252b2e31 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -54,7 +54,7 @@ class OpaqueBlockConverter : public StmtExprMutator { if (it != var_substitutes_.end()) { return it->second; } - return GetRef(var); + return ffi::GetRef(var); } Stmt VisitStmt_(const BlockNode* block) final { @@ -74,7 +74,7 @@ class OpaqueBlockConverter : public StmtExprMutator { // Step 1. Visit the predicate and iter_values, without any variable bindings for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.insert(iter->var.get()); PrimExpr predicate = VisitExpr(realize->predicate); - Array iter_values = realize->iter_values; + ffi::Array iter_values = realize->iter_values; iter_values.MutateByApply([this](PrimExpr expr) { return VisitExpr(std::move(expr)); }); for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.erase(iter->var.get()); @@ -96,7 +96,7 @@ class OpaqueBlockConverter : public StmtExprMutator { // Step 5. Return if (predicate.same_as(realize->predicate) && iter_values.same_as(realize->iter_values) && new_block.same_as(realize->block) && realize->iter_values.size() == 0) { - return GetRef(realize); + return ffi::GetRef(realize); } else { return BlockRealize({}, predicate, new_block); } @@ -123,10 +123,10 @@ Pass ConvertBlocksToOpaque() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ConvertBlocksToOpaque", ConvertBlocksToOpaque); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc index 9b2554779360..691d8b885c59 100644 --- a/src/tir/transforms/convert_for_loops_serial.cc +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -43,7 +43,7 @@ class ForLoopSerialConverter : public StmtExprMutator { Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { if (op->kind == ForKind::kParallel) { return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, - op->annotations, op->span); + op->annotations, op->step, op->span); } return StmtExprMutator::VisitStmt_(op); } @@ -67,10 +67,10 @@ Pass ConvertForLoopsToSerial() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ConvertForLoopsToSerial", ConvertForLoopsToSerial); -}); +} } // namespace transform diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc index a8c6b07c7602..ab0078a50ae0 100644 --- a/src/tir/transforms/decorate_device_scope.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -45,10 +45,10 @@ Pass DecorateDeviceScope() { return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.DecorateDeviceScope", DecorateDeviceScope); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 5e1e5efa0e4c..74c299456a4b 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -34,20 +34,20 @@ namespace transform { void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread_per_block, int64_t max_threadblocks = 256) { // fetch the loops - Array loops = sch->GetLoops(block); + ffi::Array loops = sch->GetLoops(block); for (const tir::LoopRV& loop : loops) { // skip block if already scheduled if (sch->Get(loop)->thread_binding.defined()) { return; } } - Array iters = sch->Get(block)->iter_vars; + ffi::Array iters = sch->Get(block)->iter_vars; // when there is no loops, tir will add a dummy iter var for the block // so loops.size() == 0 && iters.size() == 1 ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1)); - Array data_parallel_loops; + ffi::Array data_parallel_loops; // only fuse data parallel loops for (size_t i = 0; i < loops.size(); ++i) { if (iters[i]->iter_type == tir::IterVarType::kDataPar) { @@ -68,14 +68,14 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread } // schedule the fused loop if (product > max_thread_per_block * max_threadblocks) { - Array splits = sch->Split( + ffi::Array splits = sch->Split( fused, /*factors=*/{std::nullopt, Integer(max_threadblocks), Integer(max_thread_per_block)}); sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); } else { - Array splits = sch->Split( + ffi::Array splits = sch->Split( fused, /*factors=*/{std::nullopt, Integer(std::min(product, max_thread_per_block))}); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); @@ -83,11 +83,11 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread } IRModule MarkScheduled(const IRModule& mod) { - Map result; + ffi::Map result; for (const auto& [gv, base_func] : mod->functions) { if (const auto* prim_func_node = base_func.as()) { - tir::PrimFunc prim_func = GetRef(prim_func_node); + tir::PrimFunc prim_func = ffi::GetRef(prim_func_node); tir::PrimFunc new_prim_func = WithAttr(std::move(prim_func), tir::attr::kIsScheduled, true); result.Set(gv, new_prim_func); } else { @@ -105,7 +105,7 @@ bool IsScheduledOnGPU(const BaseFunc& func) { // the target from context. tvm::Target target = tvm::Target::Current(); // the Target in kTarget attribute of PrimFunc - Optional func_target = func->attrs.GetAttr(tvm::attr::kTarget); + ffi::Optional func_target = func->attrs.GetAttr(tvm::attr::kTarget); if (func_target.defined()) { target = func_target.value(); } @@ -131,7 +131,7 @@ Pass DefaultGPUSchedule() { // get the target from context. tvm::Target target = tvm::Target::Current(); // get the target from kTarget attribute - Optional func_target = + ffi::Optional func_target = func->attrs.GetAttr(tvm::attr::kTarget); if (func_target.defined()) { target = func_target.value(); @@ -139,14 +139,14 @@ Pass DefaultGPUSchedule() { ICHECK(target.defined()) << "The target is missing either in the current context or in " "the prim_func's attribute."; // get the max thread per block from target. - Optional opt_max_thread_per_block = + ffi::Optional opt_max_thread_per_block = target->GetAttr("max_num_threads"); ICHECK(opt_max_thread_per_block.defined()) << "max_num_threads is not set for target " << target; int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); sch->WorkOn(gv->name_hint); - Array blocks = meta_schedule::BlockCollector::Collect(sch); + ffi::Array blocks = meta_schedule::BlockCollector::Collect(sch); for (const tir::BlockRV& block : blocks) { auto childs = sch->GetChildBlocks(block); if (!childs.empty()) { @@ -164,10 +164,10 @@ Pass DefaultGPUSchedule() { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.DefaultGPUSchedule", DefaultGPUSchedule); -}); +} } // namespace transform diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 301c6c13b9f0..be5da45d9f6f 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -36,14 +36,14 @@ namespace tvm { namespace tir { -using ConstArrayType = Array; +using ConstArrayType = ffi::Array; class Applicator : public tir::StmtMutator { protected: // returns index of the a in constant_array_, if not found - appends - size_t DeDup(const runtime::NDArray& a) { + size_t DeDup(const runtime::Tensor& a) { tvm::StructuralEqual eql; auto it = std::find_if(constant_array_.begin(), constant_array_.end(), - [&eql, a](const runtime::NDArray& v) { return eql(a, v); }); + [&eql, a](const runtime::Tensor& v) { return eql(a, v); }); if (it != constant_array_.end()) { return it - constant_array_.begin(); } @@ -62,7 +62,7 @@ class Applicator : public tir::StmtMutator { // and add array index. ICHECK(acn->data) << "data field should be defined"; auto node = CopyOnWrite(acn); - node->irmod_storage_idx = Optional(Integer(DeDup(node->data.value()))); + node->irmod_storage_idx = ffi::Optional(Integer(DeDup(node->data.value()))); return Stmt(node); } @@ -75,7 +75,7 @@ tvm::transform::Pass ExtractPrimFuncConstants() { auto prim_func_pass = [=](PrimFunc foo, IRModule m, tvm::transform::PassContext ctx) { auto* func = foo.CopyOnWrite(); if (!m->attrs.defined()) { - m->attrs = DictAttrs(Map()); + m->attrs = DictAttrs(ffi::Map()); } auto* attrs = m->attrs.CopyOnWrite(); ConstArrayType constant_array_ = @@ -88,11 +88,11 @@ tvm::transform::Pass ExtractPrimFuncConstants() { if (constant_list.size()) { attrs->dict.Set(tvm::attr::kConstants, constant_list); } - return GetRef(func); + return ffi::GetRef(func); }; auto pass_func = [=](IRModule module, tvm::transform::PassContext pc) { - auto m = GetRef(module.CopyOnWrite()); + auto m = ffi::GetRef(module.CopyOnWrite()); for (const auto& kv : m->functions) { if (auto func = kv.second.as()) { m->Update(kv.first, prim_func_pass(func.value(), m, pc)); @@ -104,10 +104,10 @@ tvm::transform::Pass ExtractPrimFuncConstants() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ExtractPrimFuncConstants", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ExtractPrimFuncConstants", ExtractPrimFuncConstants); -}); +} } // namespace transform diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 1515bfadb59a..1a9ba390703f 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -65,21 +65,21 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; - Block block = GetRef(op); + Block block = ffi::GetRef(op); - Array alloc_buffers = op->alloc_buffers; + ffi::Array alloc_buffers = op->alloc_buffers; alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); }); if (!alloc_buffers.same_as(op->alloc_buffers)) { block.CopyOnWrite()->alloc_buffers = alloc_buffers; } - Array reads = op->reads; + ffi::Array reads = op->reads; reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); if (!reads.same_as(op->reads)) { block.CopyOnWrite()->reads = reads; } - Array writes = op->writes; + ffi::Array writes = op->writes; writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); if (!writes.same_as(op->writes)) { block.CopyOnWrite()->writes = writes; @@ -91,7 +91,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const AllocateNode* op) final { // Determine the flattened extents first, before stripping of // DeclBuffer. - auto new_extents = [&]() -> Array { + auto new_extents = [&]() -> ffi::Array { if (op->extents.size() == 1) { // No flattening required for buffers that are already flat return op->extents; @@ -219,7 +219,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { } } - Array GetSimplifiedElemOffset(const Buffer& buffer, const Array& indices) { + ffi::Array GetSimplifiedElemOffset(const Buffer& buffer, + const ffi::Array& indices) { auto flattened_indices = buffer->ElemOffset(indices); return this->IterMapSimplifyWithContext(flattened_indices, false); } @@ -243,17 +244,17 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { return region; } - Array min_values; - Array max_values; + ffi::Array min_values; + ffi::Array max_values; for (const auto& range : region->region) { min_values.push_back(range->min); max_values.push_back(range->min + range->extent - 1); } - Array flattened_min = GetSimplifiedElemOffset(orig_buf, min_values); - Array flattened_max = GetSimplifiedElemOffset(orig_buf, max_values); + ffi::Array flattened_min = GetSimplifiedElemOffset(orig_buf, min_values); + ffi::Array flattened_max = GetSimplifiedElemOffset(orig_buf, max_values); - Array flattened_ranges; + ffi::Array flattened_ranges; ICHECK_EQ(flattened_min.size(), flattened_max.size()); for (size_t i = 0; i < flattened_min.size(); i++) { flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); @@ -266,7 +267,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { std::unordered_map buffer_remap_; /*! \brief The updated external buffer map. */ - Map updated_extern_buffer_map_; + ffi::Map updated_extern_buffer_map_; }; PrimFunc FlattenBuffer(PrimFunc f) { return BufferFlattener::Flatten(f); } @@ -280,10 +281,10 @@ Pass FlattenBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.FlattenBuffer", FlattenBuffer); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc index d291e40f3c31..711c2a739f59 100644 --- a/src/tir/transforms/force_narrow_index_to_i32.cc +++ b/src/tir/transforms/force_narrow_index_to_i32.cc @@ -56,7 +56,7 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer { ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); return IntImm(DataType::Int(32), op->value); } - return GetRef(op); + return ffi::GetRef(op); } Stmt VisitStmt_(const BlockNode* block) final { @@ -87,10 +87,10 @@ Pass ForceNarrowIndexToInt32() { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ForceNarrowIndexToInt32", ForceNarrowIndexToInt32); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index 1548ea1da625..ebd90583c93d 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -81,24 +81,23 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(flag) & hoisted_let_bindings; } - - static constexpr const char* _type_key = "tir.transform.HoistExpressionConfig"; - TVM_DECLARE_FINAL_OBJECT_INFO(HoistExpressionConfigNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.HoistExpressionConfig", + HoistExpressionConfigNode, Object); }; class HoistExpressionConfig : public Attrs { public: HoistExpressionConfig(int hoisted_conditionals, int hoisted_let_bindings) { - auto node = make_object(); + auto node = ffi::make_object(); node->hoisted_conditionals = hoisted_conditionals; node->hoisted_let_bindings = hoisted_let_bindings; data_ = std::move(node); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistExpressionConfig, Attrs, - HoistExpressionConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistExpressionConfig, Attrs, + HoistExpressionConfigNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ HoistExpressionConfigNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { HoistExpressionConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistExpression", HoistExpressionConfig); @@ -111,18 +110,17 @@ struct HoistIfThenElseConfigNode : public AttrsNodeReflAdapterloop_var, GetRef(op)}); + active_loops.push_back({op->loop_var, ffi::GetRef(op)}); active_loop_vars.insert(op->loop_var.get()); Parent::VisitStmt_(op); @@ -272,7 +270,7 @@ class HoistInfoCollector : public StmtExprVisitor { active_block_vars.insert(var.get()); active_loop_vars.insert(var.get()); - active_loops.push_back({var, GetRef(op)}); + active_loops.push_back({var, ffi::GetRef(op)}); Parent::VisitStmt_(op); @@ -562,10 +560,10 @@ Pass HoistExpression() { "tir.HoistExpression"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.HoistExpression", HoistExpression); -}); +} Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -600,10 +598,10 @@ Pass HoistIfThenElse() { "tir.HoistIfThenElse"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.HoistIfThenElse", HoistIfThenElse); -}); +} Pass HoistIfThenElseBasic() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -623,10 +621,10 @@ Pass HoistIfThenElseBasic() { "tir.HoistIfThenElseBasic"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.HoistIfThenElseBasic", HoistIfThenElseBasic); -}); +} } // namespace transform diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 8ced9c82253d..e874dc0564cf 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -41,18 +41,17 @@ struct InjectDoubleBufferConfigNode : public AttrsNodeReflAdapter(); - Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; + ffi::Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; ICHECK(entry.loop != nullptr); auto& alloc_nest = loop_allocs_[entry.loop]; alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents, op->condition, @@ -249,7 +248,7 @@ class DoubleBufferInjector : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* op) final { ICHECK(!dbuffer_info_.count(op)); - return GetRef(op); + return ffi::GetRef(op); } private: @@ -327,10 +326,10 @@ Pass InjectDoubleBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InjectDoubleBuffer", InjectDoubleBuffer); -}); +} } // namespace transform diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index f90752e26418..cdbe17508339 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -59,7 +59,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) { + ffi::Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) { ICHECK(permute_); // Index after vectorizing by 8 PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR), @@ -104,7 +104,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } static bool CheckAnnotation(const Any& annotation) { - if (auto opt_str = annotation.as()) { + if (auto opt_str = annotation.as()) { // Support string annotation for backward compatibility return *opt_str != ""; } else if (auto* node = annotation.as()) { @@ -165,7 +165,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return buffer_row_size; } - Array HandleBufferIndices(Buffer buffer, Array indices) { + ffi::Array HandleBufferIndices(Buffer buffer, ffi::Array indices) { auto buffer_row_size = CheckAndGetBufferRowSize(buffer); // Mutate the last two indices @@ -216,7 +216,8 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return load; } - PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional offset = std::nullopt) { + PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, + ffi::Optional offset = std::nullopt) { // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to // smem_offset CHECK(access_ptr->IsInstance()) @@ -296,10 +297,10 @@ Pass InjectPermutedLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InjectPermutedLayout", InjectPermutedLayout); -}); +} } // namespace transform diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index f0a88ba98192..0e9820aa659e 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -81,8 +81,8 @@ class PTXAsyncCopyInjector : public StmtMutator { if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; - Array args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)}; + ffi::Array args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated cp.async if (predicated) { args.push_back(predicate_value); @@ -200,10 +200,10 @@ Pass InjectPTXAsyncCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); -}); +} } // namespace transform diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc index 848e8491945f..8cdef1be44a5 100644 --- a/src/tir/transforms/inject_ptx_ldg32.cc +++ b/src/tir/transforms/inject_ptx_ldg32.cc @@ -41,7 +41,7 @@ class PTXRewriter : public StmtMutator { // addr[0] -> global_addr / addr[1] -> local_addr addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local"); predicate_buffer = - decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(1), "predicate", "local"); + decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), "predicate", "local"); } Stmt result = StmtMutator::VisitStmt_(allocate); if (!has_buffer_2) { @@ -95,8 +95,8 @@ class PTXRewriter : public StmtMutator { BufferStore value_store(store->buffer, imm_value, {new_indice}); Evaluate ptx_load(Call(store->buffer->dtype, tvm::tir::builtin::ptx_ldg32(), {store->buffer->data, new_predicate, new_lhs, new_indice})); - Array tmp_seq = {addr_store, local_addr_store, predicate_store, value_store, - ptx_load}; + ffi::Array tmp_seq = {addr_store, local_addr_store, predicate_store, value_store, + ptx_load}; SeqStmt seq_stmt = SeqStmt(tmp_seq); return seq_stmt; } @@ -124,10 +124,10 @@ Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) { // The pass can now be invoked via the pass infrastructure, but we also add a // Python binding for it -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InjectPTXLDG32", InjectPTXLDG32); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index a68308261a19..c3b41e05899b 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -50,7 +50,7 @@ struct RollingBufferInfo { int rolling_axis; int rolling_extent; std::vector axis_overlaps; - std::vector> axis_iter_vars; + std::vector> axis_iter_vars; }; class RollingBufferInjector : public StmtExprMutator { @@ -70,7 +70,7 @@ class RollingBufferInjector : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { // Manage the stack of iter_vars - for_loops.push_back(GetRef(op)); + for_loops.push_back(ffi::GetRef(op)); auto stmt{StmtExprMutator::VisitStmt_(op)}; op = stmt.as(); @@ -82,7 +82,7 @@ class RollingBufferInjector : public StmtExprMutator { if (it != hoist_buffer_to_for.end()) { // If the loop corresponds to an iter_var that needs a BufferRealize // hoisting to its scope, perform the hoisting - Stmt body{GetRef(op)}; + Stmt body{ffi::GetRef(op)}; for (auto realise : it->second) { auto attrs{buffer_to_attrs[realise->buffer]}; Stmt new_realize{BufferRealize(realise->buffer, realise->bounds, realise->condition, body, @@ -108,7 +108,7 @@ class RollingBufferInjector : public StmtExprMutator { // Keep a dictionary associating attribute statements with the buffers // they reference. We'll need this if the buffer gets hoisted and we // need to hoist all of its attributes at the same time. - buffer_to_attrs[buffer].push_back(GetRef(op)); + buffer_to_attrs[buffer].push_back(ffi::GetRef(op)); if (op->attr_key == attr::rolling_buffer_scope && Downcast(op->value)->value) { // If the attribute is indicating that a buffer should be a rolling @@ -122,13 +122,13 @@ class RollingBufferInjector : public StmtExprMutator { // If a BufferRealize has been identified as needing to be made into // a rolling buffer, begin the analysis. - std::vector> bound_iter_vars{}; + std::vector> bound_iter_vars{}; std::vector bound_overlaps{}; // We use the bound information of the BufferRealize to calculate // how we can legally roll auto stride{0}; auto divisor{1}; - Optional iter_var{}; + ffi::Optional iter_var{}; for (auto bound : buffer_realize->bounds) { divisor = 1; if (auto floor_div = bound->min.as()) { @@ -143,7 +143,7 @@ class RollingBufferInjector : public StmtExprMutator { iter_var = nullptr; } else if (auto var = bound->min.as()) { // If the bound is just a Var, that implies the stride is 1 - iter_var = GetRef(var); + iter_var = ffi::GetRef(var); stride = 1; } else { // Otherwise, it's the iter var multiplied by the stride @@ -154,7 +154,7 @@ class RollingBufferInjector : public StmtExprMutator { ICHECK(a) << "Rolling buffer injection failed: the buffer striding is unsupported"; auto b = mul->b.as(); ICHECK(b) << "Rolling buffer injection failed: the buffer striding is unsupported"; - iter_var = GetRef(a); + iter_var = ffi::GetRef(a); stride = b->value; } stride = std::ceil(static_cast(stride) / divisor); @@ -167,7 +167,7 @@ class RollingBufferInjector : public StmtExprMutator { } // Pick the outermost iter_var that's mentioned in the bounds // to be the rolling axis - Optional roll_iter_var{}; + ffi::Optional roll_iter_var{}; int roll_axis{1}; for (auto loop : for_loops) { auto loop_var{loop->loop_var}; @@ -175,7 +175,7 @@ class RollingBufferInjector : public StmtExprMutator { auto it{std::find_if( bound_iter_vars.begin(), bound_iter_vars.end(), - [&](Optional var) { return var && (var.get() == loop_var.get()); })}; + [&](ffi::Optional var) { return var && (var.get() == loop_var.get()); })}; if (it != bound_iter_vars.end()) { auto i{std::distance(bound_iter_vars.begin(), it)}; @@ -195,7 +195,7 @@ class RollingBufferInjector : public StmtExprMutator { bound_iter_vars, }; rolling_buffer_to_info[buffer] = rolling_buffer_info; - Array new_bounds{}; + ffi::Array new_bounds{}; auto shape{buffer->shape}; for (size_t i{0}; i < shape.size(); ++i) { auto extent{shape[i]}; @@ -225,7 +225,7 @@ class RollingBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const BufferRealizeNode* op) final { - buffer_to_buffer_realize.insert({op->buffer, GetRef(op)}); + buffer_to_buffer_realize.insert({op->buffer, ffi::GetRef(op)}); auto stmt{StmtExprMutator::VisitStmt_(op)}; op = stmt.as(); @@ -266,7 +266,7 @@ class RollingBufferInjector : public StmtExprMutator { auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; if (iter_var && rolling_buffer_info.axis_overlaps[i] > 0) { Var var{iter_var.value()}; - const Map dmap{std::make_pair(var, IntSet::Interval(0, 0))}; + const ffi::Map dmap{std::make_pair(var, IntSet::Interval(0, 0))}; auto term_2{arith::Analyzer{}.int_set(op->indices[i], dmap).min()}; auto condition = Or(LT(var, 1), GE(term_2, rolling_buffer_info.axis_overlaps[i])); buffer_store = IfThenElse(likely(condition), buffer_store); @@ -316,10 +316,10 @@ Pass InjectRollingBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InjectRollingBuffer", InjectRollingBuffer); -}); +} } // namespace transform diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index b89d3b89fa82..950e3fb8c850 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -48,7 +48,7 @@ namespace software_pipeline { * \param buffer_data_to_buffer The map from buffer data to buffer. * \return The result block. */ -Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) { +Block MakeBlock(const Stmt& body, const ffi::Map& buffer_data_to_buffer) { if (const BlockRealizeNode* block_realize = body.as()) { if (is_one(block_realize->predicate)) { // no need to create a new block @@ -56,7 +56,8 @@ Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) } } Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); - Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + ffi::Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer); BlockNode* n = block.CopyOnWrite(); n->reads = access[0]; n->writes = access[1]; @@ -88,8 +89,8 @@ class PipelineOpaqueAccessRewriter { * \param fragment_info Information about tensor core fragment */ PipelineOpaqueAccessRewriter( - const Map& buffer_data_to_buffer, const Map& buffer_remap, - const For& pipeline_loop, + const ffi::Map& buffer_data_to_buffer, + const ffi::Map& buffer_remap, const For& pipeline_loop, const std::unordered_map& fragment_info) : buffer_data_to_buffer_(buffer_data_to_buffer), buffer_remap_(buffer_remap), @@ -109,13 +110,13 @@ class PipelineOpaqueAccessRewriter { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); auto it = buffer_remap_.find(buffer); if (it != buffer_remap_.end()) { - Array new_args = call->args; + ffi::Array new_args = call->args; const Buffer& new_buffer = (*it).second; new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } } else if (call->op.same_as(mma_sync)) { - Array new_args = call->args; + ffi::Array new_args = call->args; for (int i = 0; i < 4; i++) { const Var& buffer_var = Downcast(call->args[i * 2]); const PrimExpr& index = call->args[i * 2 + 1]; @@ -126,7 +127,7 @@ class PipelineOpaqueAccessRewriter { new_args.Set(i * 2 + 1, new_index); } } - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } else if (call->op.same_as(access_ptr)) { return RewriteBufferAccess(call, {1}); } else if (call->op.same_as(ptx_mma)) { @@ -160,11 +161,11 @@ class PipelineOpaqueAccessRewriter { } PrimExpr RewriteBufferAccess(const Call& call, const std::vector arg_indices) { - auto product = [](const Array& input) { + auto product = [](const ffi::Array& input) { return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), input); }; - Array new_args = call->args; + ffi::Array new_args = call->args; for (int i : arg_indices) { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[i])); auto it = buffer_remap_.find(buffer); @@ -189,11 +190,11 @@ class PipelineOpaqueAccessRewriter { new_args.Set(i + 1, new_index); } } - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } - const Map& buffer_data_to_buffer_; - const Map& buffer_remap_; + const ffi::Map& buffer_data_to_buffer_; + const ffi::Map& buffer_remap_; const For& pipeline_loop_; const std::unordered_map& fragment_info_; }; @@ -215,8 +216,8 @@ class PipelineBodyRewriter : public StmtExprMutator { * of a two-stage software pipeline, only one version of these buffers are accessed. * \param fragment_info Information about tensor core fragment */ - PipelineBodyRewriter(const Map& buffer_data_to_buffer, - const Map& buffer_remap, For pipeline_loop, + PipelineBodyRewriter(const ffi::Map& buffer_data_to_buffer, + const ffi::Map& buffer_remap, For pipeline_loop, bool access_all_versions, const std::unordered_map& fragment_info) : buffer_data_to_buffer_(buffer_data_to_buffer), @@ -299,8 +300,8 @@ class PipelineBodyRewriter : public StmtExprMutator { return opaque_access_rewriter_.Rewrite(call); } - Map buffer_data_to_buffer_; - Map buffer_remap_; + ffi::Map buffer_data_to_buffer_; + ffi::Map buffer_remap_; For pipeline_loop_; bool access_all_versions_; PipelineOpaqueAccessRewriter opaque_access_rewriter_; @@ -312,24 +313,24 @@ class PipelineBodyRewriter : public StmtExprMutator { class PipelineRewriter : public StmtExprMutator { public: static Stmt Rewrite( - Map buffer_data_to_buffer, + ffi::Map buffer_data_to_buffer, const std::unordered_set& double_buffers, - const Array pipeline_allocs, const For& pipeline_loop, + const ffi::Array pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) { + const ffi::Map preserved_annotations) { PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, pipeline_info, fragment_info, preserved_annotations); return rewriter.BuildPipeline(); } private: - PipelineRewriter(Map buffer_data_to_buffer, + PipelineRewriter(ffi::Map buffer_data_to_buffer, const std::unordered_set& double_buffers, - const Array& pipeline_allocs, const For& pipeline_loop, + const ffi::Array& pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) + const ffi::Map preserved_annotations) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), double_buffers_(double_buffers), @@ -365,7 +366,7 @@ class PipelineRewriter : public StmtExprMutator { // introduce extra lowerbound when the loop length is smaller than num stages // to ensure the epilogue interval do not overlap the prologue interval. PrimExpr epigogue_start = pipeline_loop_->min + pipeline_loop_->extent; - Optional extra_epilogue_lower_bound = std::nullopt; + ffi::Optional extra_epilogue_lower_bound = std::nullopt; if (max_stage_ > 1 && !analyzer_.CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) { if (is_const_int(epigogue_start)) { epigogue_start = max(epigogue_start, pipeline_loop_->min + max_stage_); @@ -382,7 +383,7 @@ class PipelineRewriter : public StmtExprMutator { SeqStmt stmt = SeqStmt({prologue, body, epilogue}); // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. - Array alloc_buffers; + ffi::Array alloc_buffers; for (const auto& alloc : pipeline_allocs_) { alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); buffer_data_to_buffer_.erase(alloc->data); @@ -527,7 +528,7 @@ class PipelineRewriter : public StmtExprMutator { * \return The resized buffer. */ Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { - ObjectPtr new_buffer = make_object(*(buffer.get())); + ObjectPtr new_buffer = ffi::make_object(*(buffer.get())); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); if (new_buffer->strides.size()) { ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); @@ -546,7 +547,7 @@ class PipelineRewriter : public StmtExprMutator { // async invocations exactly. When it is valid, it is the "sum of extents of loops that have // been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This // is only needed to compute wait count for epilogue without async producers. - Optional producer_head{PrimExpr(-1)}; + ffi::Optional producer_head{PrimExpr(-1)}; bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } }; @@ -578,9 +579,9 @@ class PipelineRewriter : public StmtExprMutator { // A symbolic expression representing the index the latest async operation associated with this // stage has written into, at the "current" iteration. - Optional producer_head; + ffi::Optional producer_head; // The predicate of BlockRealize containing the async operation of this stage. - Optional predicate; + ffi::Optional predicate; // Indices into a list of blocks, where async_commit_queue scope should be attached. // If multiple async producers are interleaved with their consumer in between, we need separate // async_commit_queue for each producer. Thus, we need multiple sets of indices. @@ -670,7 +671,7 @@ class PipelineRewriter : public StmtExprMutator { auto& dep_local_state = (*async_states_local)[producer_stage_idx]; const auto num_commit_group = dep_local_state.commit_groups.size(); - std::vector> producer_head_per_commit; + std::vector> producer_head_per_commit; if (num_commit_group == 0) { // Epilogue, no async producer. Since "local" producer_head is not available, use @@ -728,7 +729,7 @@ class PipelineRewriter : public StmtExprMutator { // Given pipelined blocks and async-related information, generate final loop statements with async // scopes (if any). - Array CompletePipelineLoopStatements( + ffi::Array CompletePipelineLoopStatements( const std::vector& blocks, const std::map& async_states_local, arith::Analyzer* ana_normalized) const { @@ -768,7 +769,7 @@ class PipelineRewriter : public StmtExprMutator { } } - Array stmts; + ffi::Array stmts; for (size_t i = 0; i < new_blocks.size();) { if (commit_group_indices[i] == -1) { @@ -776,7 +777,7 @@ class PipelineRewriter : public StmtExprMutator { stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); ++i; } else { - Array group_bodies; + ffi::Array group_bodies; auto stage_id = commit_group_indices[i]; auto predicate = new_blocks[i].predicate; for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { @@ -812,7 +813,7 @@ class PipelineRewriter : public StmtExprMutator { * \return The result loop. */ Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, - Optional extra_loop_lower_bound = std::nullopt) { + ffi::Optional extra_loop_lower_bound = std::nullopt) { PrimExpr new_loop_var; PrimExpr extent = end - start; @@ -942,7 +943,7 @@ class PipelineRewriter : public StmtExprMutator { if (!is_unit_loop) { new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop), - std::nullopt, preserved_annotations_); + std::nullopt, preserved_annotations_, std::nullopt); } // Update producer heads in the global async states. @@ -966,17 +967,17 @@ class PipelineRewriter : public StmtExprMutator { } arith::Analyzer analyzer_; - Map buffer_data_to_buffer_; + ffi::Map buffer_data_to_buffer_; const std::unordered_set& double_buffers_; - Array pipeline_allocs_; + ffi::Array pipeline_allocs_; For pipeline_loop_; PipelineInfo pipeline_info_; const std::unordered_map& fragment_info_; int max_stage_ = -1; - Map buffer_remap_; - Array ordered_stmts_; + ffi::Map buffer_remap_; + ffi::Array ordered_stmts_; std::map async_states; - Map preserved_annotations_; + ffi::Map preserved_annotations_; }; /*! @@ -988,10 +989,10 @@ class PipelineRewriter : public StmtExprMutator { * destination to the source. */ void BuildDependencyGraph( - const Array& blocks, - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { - std::unordered_map> buffer_writers; + const ffi::Array& blocks, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { + std::unordered_map> buffer_writers; for (const Block& block : blocks) { for (const BufferRegion& read : block->reads) { @@ -1016,7 +1017,7 @@ void BuildDependencyGraph( class PipelineInjector : private StmtExprMutator { public: static Stmt Inject(const PrimFunc& func) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); PipelineInjector injector(global_symbol); for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -1027,7 +1028,8 @@ class PipelineInjector : private StmtExprMutator { } private: - explicit PipelineInjector(Optional global_symbol) : global_symbol_(global_symbol) {} + explicit PipelineInjector(ffi::Optional global_symbol) + : global_symbol_(global_symbol) {} /*! * \brief Check the pipeline satisfies the following conditions: @@ -1037,7 +1039,8 @@ class PipelineInjector : private StmtExprMutator { * case 1: stage(A) < stage(B) * case 2: stage(A) == stage(B) and order(A) < order(B) */ - void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array& original_order) { + void ValidatePipelineBody(const PipelineInfo& pipeline_info, + const ffi::Array& original_order) { std::unordered_set used_orders; std::unordered_map stage_max_order; std::unordered_map order_to_block; @@ -1050,13 +1053,13 @@ class PipelineInjector : private StmtExprMutator { used_orders.insert(order); } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; BuildDependencyGraph(original_order, &dep_src2dst, nullptr); for (const auto& pair : dep_src2dst) { const Block& src = pair.first; const auto& src_info = pipeline_info.at(src); - const Array& dsts = pair.second; + const ffi::Array& dsts = pair.second; for (const Block& dst : dsts) { const auto& dst_info = pipeline_info.at(dst); CHECK_LE(src_info.stage, dst_info.stage) @@ -1081,7 +1084,7 @@ class PipelineInjector : private StmtExprMutator { // the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the // child of the block. Stmt pipeline_body{nullptr}; - Array pipeline_allocs; + ffi::Array pipeline_allocs; if (const auto* realize = for_node->body.as()) { const auto& block = realize->block; for (const auto& buffer : block->alloc_buffers) { @@ -1102,7 +1105,7 @@ class PipelineInjector : private StmtExprMutator { // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be // converted into a block. PipelineInfo pipeline_info; - Array original_order; // pipeline body blocks in the original order + ffi::Array original_order; // pipeline body blocks in the original order auto f_add_child = [&](const Stmt& child) { original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); @@ -1128,9 +1131,9 @@ class PipelineInjector : private StmtExprMutator { } auto pipeline_stages = - Downcast>(op->annotations.at(attr::software_pipeline_stage)); + Downcast>(op->annotations.at(attr::software_pipeline_stage)); auto pipeline_orders = - Downcast>(op->annotations.at(attr::software_pipeline_order)); + Downcast>(op->annotations.at(attr::software_pipeline_order)); CHECK_EQ(pipeline_stages.size(), original_order.size()) << "PrimFunc " << global_symbol_ << " has original order " << original_order.Map([](const auto& block) { return block->name_hint; }) @@ -1142,14 +1145,14 @@ class PipelineInjector : private StmtExprMutator { std::unordered_set pipeline_async_stages; if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) { - for (auto s : Downcast>(annot.value())) { + for (auto s : Downcast>(annot.value())) { pipeline_async_stages.insert(s->value); } } - Map preserved_annotations; + ffi::Map preserved_annotations; for (const auto& kv : op->annotations) { - const String& key = kv.first; + const ffi::String& key = kv.first; if (kv.first != attr::software_pipeline_stage && kv.first != attr::software_pipeline_order && kv.first != attr::software_pipeline_async_stages) { preserved_annotations.Set(key, kv.second); @@ -1169,7 +1172,7 @@ class PipelineInjector : private StmtExprMutator { // Step 4: Rewrite the pipeline body. Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, - pipeline_allocs, GetRef(op), pipeline_info, + pipeline_allocs, ffi::GetRef(op), pipeline_info, fragment_info_, preserved_annotations); if (const auto* realize = op->body.as()) { @@ -1186,7 +1189,7 @@ class PipelineInjector : private StmtExprMutator { * \param n The block pointer to which the buffer allocations are added. * \param alloc_buffers The buffer allocations to be added. */ - void AddAllocBuffers(BlockNode* n, const Array alloc_buffers) { + void AddAllocBuffers(BlockNode* n, const ffi::Array alloc_buffers) { for (const Buffer& alloc_buffer : alloc_buffers) { n->alloc_buffers.push_back(alloc_buffer); Region region; @@ -1236,10 +1239,10 @@ class PipelineInjector : private StmtExprMutator { return false; } - Map buffer_data_to_buffer_; + ffi::Map buffer_data_to_buffer_; std::unordered_map fragment_info_; std::unordered_set double_buffers; - Optional global_symbol_; + ffi::Optional global_symbol_; }; } // namespace software_pipeline @@ -1260,10 +1263,10 @@ Pass InjectSoftwarePipeline() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InjectSoftwarePipeline", InjectSoftwarePipeline); -}); +} } // namespace transform diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index d0f84842a4fe..ce30e5840cc7 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -208,7 +208,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { if (touched_var_.count(op)) { visit_touched_var_ = true; } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const { return analyzer_->Simplify(index + var_ * alloc_extent); @@ -227,9 +227,9 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = RewriteIndex(offset, stride); - return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}); + return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}, op->annotations); } else if (op->op.same_as(builtin::tvm_context_id())) { - return allow_share_ ? GetRef(op) : var_; + return allow_share_ ? ffi::GetRef(op) : var_; } else { return StmtExprMutator::VisitExpr_(op); } @@ -287,14 +287,14 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const AttrStmtNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } else if (!allow_share_ && !vt_loop_injected_ && (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } else { Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return AttrStmt(op->node, op->attr_key, value, body); } @@ -304,12 +304,12 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const LetStmtNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(op->var, value, body); } @@ -319,7 +319,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { ICHECK(is_zero(op->min)); PrimExpr extent = this->VisitExpr(op->extent); if (visit_touched_var_ && !vt_loop_injected_) { - Stmt stmt = InjectVTLoop(GetRef(op), true); + Stmt stmt = InjectVTLoop(ffi::GetRef(op), true); ++max_loop_depth_; return stmt; } @@ -327,7 +327,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt body = this->VisitStmt(op->body); ++max_loop_depth_; if (extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->extent = std::move(extent); @@ -339,12 +339,12 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const IfThenElseNode* op) final { PrimExpr condition = this->VisitExpr(op->condition); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; ICHECK_EQ(max_loop_depth_, 0); Stmt then_case = this->VisitStmt(op->then_case); - Optional else_case = std::nullopt; + ffi::Optional else_case = std::nullopt; if (op->else_case) { int temp = max_loop_depth_; max_loop_depth_ = 0; @@ -353,7 +353,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { return IfThenElse(condition, then_case, else_case); } @@ -379,15 +379,15 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } // Allocate Stmt VisitStmt_(const AllocateNode* op) final { - Allocate node = GetRef(op); + Allocate node = ffi::GetRef(op); PrimExpr condition = this->VisitExpr(op->condition); - Array extents = + ffi::Array extents = op->extents.Map([this](const PrimExpr& extent) { return this->VisitExpr(extent); }); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; @@ -417,7 +417,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Allocate(op->buffer_var, op->dtype, extents, condition, body); } @@ -439,7 +439,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // only unroll if number of vthreads are small if (max_loop_depth_ == 0 && num_threads_ < 16) { // do unrolling if it is inside innermost content. - Array seq; + ffi::Array seq; for (int i = 0; i < num_threads_; ++i) { seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); } @@ -524,10 +524,10 @@ Pass InjectVirtualThread() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InjectVirtualThread", InjectVirtualThread); -}); +} } // namespace transform diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index 8521607f893e..ce69053311d1 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -103,7 +103,7 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, // Only inline private functions. Externally-exposed functions // must be preserved so to avoid breaking callsites outside of // the IRModule. - bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return false; // We do not currently implement any analysis for termination of @@ -128,10 +128,10 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, return true; } -Map CollectInlinablePrimFuncs(const IRModule& mod) { +ffi::Map CollectInlinablePrimFuncs(const IRModule& mod) { auto recursive_functions = CollectRecursiveFunctions(mod); - Map output; + ffi::Map output; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto prim_func = opt.value(); @@ -146,7 +146,7 @@ Map CollectInlinablePrimFuncs(const IRModule& mod) { class PrimFuncInliner : StmtExprMutator { public: - explicit PrimFuncInliner(Map inlinable_funcs) + explicit PrimFuncInliner(ffi::Map inlinable_funcs) : inlinable_funcs_(inlinable_funcs) { for (const auto& [gvar, callee] : inlinable_funcs_) { removable_funcs_.insert(gvar); @@ -176,7 +176,7 @@ class PrimFuncInliner : StmtExprMutator { } } - Optional GetInlinedFunction(const EvaluateNode* eval) { + ffi::Optional GetInlinedFunction(const EvaluateNode* eval) { auto call = eval->value.as(); if (!call) return std::nullopt; @@ -222,7 +222,8 @@ class PrimFuncInliner : StmtExprMutator { return StmtExprMutator::VisitExpr_(call); } - Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const Array& args) const { + Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, + const ffi::Array& args) const { CHECK_EQ(callee->params.size(), args.size()) << "Callee " << gvar << " accepts " << callee->params.size() << " parameters (" << callee->params << "), but is called with " << args.size() << " arguments (" << args @@ -232,7 +233,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map> param_map; + ffi::Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); } @@ -243,7 +244,7 @@ class PrimFuncInliner : StmtExprMutator { } // Map from GlobalVar to PrimFuncs which may be inlined. - Map inlinable_funcs_; + ffi::Map inlinable_funcs_; /* \brief Set of callees that may be removed * @@ -253,7 +254,7 @@ class PrimFuncInliner : StmtExprMutator { */ PSet removable_funcs_; - Optional current_target_ = std::nullopt; + ffi::Optional current_target_ = std::nullopt; }; } // namespace @@ -293,10 +294,10 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InlinePrivateFunctions", InlinePrivateFunctions); -}); +} } // namespace transform diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 3f94fb0cfc6e..0e83b9113b98 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -41,48 +41,48 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { Stmt s = *ri; if (const auto* for_ = s.as()) { - auto n = make_object(*for_); + auto n = ffi::make_object(*for_); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* let = s.as()) { - auto n = make_object(*let); + auto n = ffi::make_object(*let); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* attr = s.as()) { - auto n = make_object(*attr); + auto n = ffi::make_object(*attr); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* ite = s.as()) { - auto n = make_object(*ite); + auto n = ffi::make_object(*ite); ICHECK(is_no_op(n->then_case)); ICHECK(!n->else_case); n->then_case = body; body = Stmt(n); } else if (const auto* seq = s.as()) { - auto n = make_object(*seq); + auto n = ffi::make_object(*seq); ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); n->seq.Set(n->size() - 1, body); body = Stmt(n); } else if (const auto* assert_ = s.as()) { - auto n = make_object(*assert_); + auto n = ffi::make_object(*assert_); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* alloc = s.as()) { - auto n = make_object(*alloc); + auto n = ffi::make_object(*alloc); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* alloc = s.as()) { - auto n = make_object(*alloc); + auto n = ffi::make_object(*alloc); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* decl_buffer = s.as()) { - auto n = make_object(*decl_buffer); + auto n = ffi::make_object(*decl_buffer); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); @@ -130,7 +130,7 @@ class IRConvertSSA final : public StmtExprMutator { if (defined_params.count(var_ptr)) return; if (defined_.count(var_ptr)) { - auto var = GetRef(var_ptr); + auto var = ffi::GetRef(var_ptr); redefines.emplace_back(this, var); } else { defined_.insert(var_ptr); @@ -148,7 +148,7 @@ class IRConvertSSA final : public StmtExprMutator { // Update the buffer map, based on the redefined parameters auto buffer_map = [&]() { - Map buffer_map; + ffi::Map buffer_map; bool made_change = false; for (const auto& [var, buffer] : func->buffer_map) { auto new_var = GetRemappedVar(var); @@ -174,17 +174,39 @@ class IRConvertSSA final : public StmtExprMutator { return DictAttrs(); } - Map dict; + ffi::Map dict; bool made_change = false; for (const auto& [key, old_value] : func->attrs->dict) { auto value = old_value; if (auto* expr = value.as()) { - value = VisitExpr(GetRef(expr)); + value = VisitExpr(ffi::GetRef(expr)); } else if (auto* stmt = value.as()) { - value = VisitStmt(GetRef(stmt)); + value = VisitStmt(ffi::GetRef(stmt)); + } else if (auto opt_arr = value.try_cast>()) { + // Handle container types like Array[...] that may contain Vars/Buffers/Exprs/Stmts + auto arr = opt_arr.value(); + bool arr_changed = false; + std::vector rewritten; + rewritten.reserve(arr.size()); + for (const ObjectRef& elem : arr) { + ObjectRef new_elem = elem; + if (auto* e = elem.as()) { + new_elem = VisitExpr(ffi::GetRef(e)); + } else if (auto* s = elem.as()) { + new_elem = VisitStmt(ffi::GetRef(s)); + } else if (auto* v = elem.as()) { + new_elem = GetRemappedVar(ffi::GetRef(v)); + } else if (auto* b = elem.as()) { + new_elem = GetRemappedBuffer(ffi::GetRef(b)); + } + arr_changed = arr_changed || !new_elem.same_as(elem); + rewritten.push_back(new_elem); + } + if (arr_changed) { + value = ffi::Array(rewritten); + } } - made_change = made_change || !value.same_as(old_value); dict.Set(key, value); } @@ -195,9 +217,7 @@ class IRConvertSSA final : public StmtExprMutator { return func->attrs; } }(); - auto body = VisitStmt(func->body); - // If anything changed, update the returned function if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map) || !attrs.same_as(func->attrs) || !body.same_as(func->body)) { @@ -212,7 +232,8 @@ class IRConvertSSA final : public StmtExprMutator { return func; } - PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(GetRef(op)); } + PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(ffi::GetRef(op)); } + PrimExpr VisitExpr_(const LetNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { @@ -248,13 +269,13 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* op) final { - Block block = GetRef(op); + Block block = ffi::GetRef(op); // The BlockNode is the point of definition for the IterVar // instances. These re-defines must be present before visiting // the body of the BlockNode. std::vector redefines; - Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { + ffi::Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { if (defined_.count(iter_var->var.get())) { redefines.emplace_back(this, iter_var->var); iter_var.CopyOnWrite()->var = redefines.back().new_var; @@ -263,9 +284,9 @@ class IRConvertSSA final : public StmtExprMutator { } return iter_var; }); - Array reads = + ffi::Array reads = block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); }); - Array writes = + ffi::Array writes = block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); }); if (!reads.same_as(block->reads) || !writes.same_as(block->writes) || @@ -312,8 +333,8 @@ class IRConvertSSA final : public StmtExprMutator { Var new_buffer_var = GetRemappedVar(buf->data); PrimExpr elem_offset = VisitExpr(buf->elem_offset); auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); }; - Array shape = buf->shape.Map(visit_expr); - Array strides = buf->strides.Map(visit_expr); + ffi::Array shape = buf->shape.Map(visit_expr); + ffi::Array strides = buf->strides.Map(visit_expr); // If no mapping is required, return the original buffer. if (new_buffer_var.same_as(buf->data) && elem_offset.same_as(buf->elem_offset) && @@ -362,9 +383,9 @@ class IRConvertSSA final : public StmtExprMutator { if (defined_.count(v.get())) { ScopedRedefine redefine(this, v); Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, - op->annotations); + auto n = ffi::make_object(*stmt.as()); + n->loop_var = redefine.new_var; + return For(n); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -432,7 +453,7 @@ class IRConvertSSA final : public StmtExprMutator { IterVar new_iter_var; if (dom.same_as(iter_var->dom) && var.same_as(iter_var->var)) { - new_iter_var = GetRef(iter_var); + new_iter_var = ffi::GetRef(iter_var); } else { new_iter_var = IterVar(dom, var, iter_var->iter_type, iter_var->thread_tag, iter_var->span); } @@ -442,7 +463,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt output; if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) { - output = GetRef(op); + output = ffi::GetRef(op); } else { output = AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span); } @@ -530,14 +551,14 @@ class IRConvertSSA final : public StmtExprMutator { Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } -String GetPtrStorageScope(Var buffer_var) { +ffi::String GetPtrStorageScope(Var buffer_var) { const auto* ptr_type = buffer_var->type_annotation.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; return ptr_type->storage_scope; } -Array GetBufferAllocationShape(const Buffer& buffer) { - Array alloc_shape = buffer->shape; +ffi::Array GetBufferAllocationShape(const Buffer& buffer) { + ffi::Array alloc_shape = buffer->shape; if (buffer->strides.size()) { ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); for (size_t i = buffer->strides.size() - 1; i > 0; --i) { @@ -549,14 +570,14 @@ Array GetBufferAllocationShape(const Buffer& buffer) { return alloc_shape; } -Array ConvertIndices(const MatchBufferRegion& match_buffer, - const Array& indices) { +ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, + const ffi::Array& indices) { const Buffer& target = match_buffer->buffer; const BufferRegion& source = match_buffer->source; ICHECK_EQ(indices.size(), target->shape.size()); arith::Analyzer analyzer; - Array result; + ffi::Array result; result.reserve(source->region.size()); size_t offset = source->region.size() - indices.size(); for (size_t i = 0; i < offset; ++i) { @@ -595,7 +616,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region return result; } -Optional ConditionalBoundsContext::TrySolveCondition() { +ffi::Optional ConditionalBoundsContext::TrySolveCondition() { // extract equations and related vars from condition expression. // currently only extract simple integral equations which could be solvable. arith::Analyzer analyzer; @@ -603,8 +624,8 @@ Optional ConditionalBoundsContext::TrySolveCondition() { if (is_const_int(condition)) { return std::nullopt; } - Array equations; - Array vars; + ffi::Array equations; + ffi::Array vars; std::function fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) { if (e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance()) { @@ -615,7 +636,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { return; } else if (const VarNode* var = obj.as()) { if (var->dtype.is_int() || var->dtype.is_uint()) { - cand_vars.push_back(GetRef(var)); + cand_vars.push_back(ffi::GetRef(var)); } } else { is_simple &= obj->IsInstance() || obj->IsInstance() || @@ -648,7 +669,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { return std::nullopt; } // build dom ranges for related vars - Map ranges; + ffi::Map ranges; for (const Var& v : vars) { arith::IntSet dom; auto relax_it = relax_map_->find(v.get()); @@ -684,7 +705,7 @@ ConditionalBoundsContext::ConditionalBoundsContext( origin_pending_conditions_num_(pending_conditions->size()) {} void ConditionalBoundsContext::EnterWithScope() { - Optional constraints = TrySolveCondition(); + ffi::Optional constraints = TrySolveCondition(); if (!constraints.defined()) { // fail to process the condition, add to unresolved pending_conditions_->push_back(condition_); @@ -831,11 +852,11 @@ namespace transform { Pass ConvertSSA() { auto pass_func = [](IRModule mod, PassContext ctx) { tir::IRConvertSSA converter; - Map functions; + ffi::Map functions; bool made_change = false; for (auto [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { - auto updated = converter.VisitPrimFunc(GetRef(ptr)); + auto updated = converter.VisitPrimFunc(ffi::GetRef(ptr)); if (!updated.same_as(base_func)) { made_change = true; base_func = updated; @@ -851,10 +872,10 @@ Pass ConvertSSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ConvertSSA", ConvertSSA); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index cc58f96b83fb..fdf4def699ec 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -69,7 +69,7 @@ Stmt MergeNest(const std::vector>& nest, Stmt body); * original array */ template -inline Array UpdateArray(Array arr, F fupdate) { +inline ffi::Array UpdateArray(ffi::Array arr, F fupdate) { std::vector new_arr(arr.size()); bool changed = false; for (size_t i = 0; i < arr.size(); ++i) { @@ -81,7 +81,7 @@ inline Array UpdateArray(Array arr, F fupdate) { if (!changed) { return arr; } else { - return Array(new_arr); + return ffi::Array(new_arr); } } @@ -95,8 +95,8 @@ inline Array UpdateArray(Array arr, F fupdate) { */ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, builtin::TVMStructFieldKind kind) { - Array args = {handle, make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind))}; + ffi::Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind))}; return Call(dtype, builtin::tvm_struct_get(), args); } @@ -142,8 +142,8 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \return the set stmt. */ inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) { - Array args = {handle, make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind)), value}; + ffi::Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind)), value}; return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args)); } @@ -195,7 +195,7 @@ inline PrimExpr ConstInt32(size_t index) { * \return PrimExpr representing the TVMValue */ inline PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {StringImm(type), ConstInt32(num)}; + ffi::Array args = {StringImm(type), ConstInt32(num)}; return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); } @@ -211,15 +211,15 @@ Stmt ConvertSSA(Stmt stmt); * \param buffer_var The input buffer variable. * \return A string representing the storage scope of this buffer variable. */ -String GetPtrStorageScope(Var buffer_var); +ffi::String GetPtrStorageScope(Var buffer_var); /*! * \brief Convert match buffer target buffer access indices to original one. * \param indices The indices of the target buffer * \return The indices of source buffer. */ -Array ConvertIndices(const MatchBufferRegion& match_buffer, - const Array& indices); +ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, + const ffi::Array& indices); /*! * \brief Convert match buffer target buffer region to original one. @@ -233,7 +233,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region * \param buffer The buffer object. * \return shape The shape considering buffer strides. */ -Array GetBufferAllocationShape(const Buffer& buffer); +ffi::Array GetBufferAllocationShape(const Buffer& buffer); /*! * \brief Context helper to update domain map within conditional scope. @@ -261,7 +261,7 @@ class ConditionalBoundsContext { void ExitWithScope(); /*! \brief Helper to solve related variable's bound within conditional scope.*/ - Optional TrySolveCondition(); + ffi::Optional TrySolveCondition(); /*! \brief the condition holds on true branch. */ const PrimExpr& condition_; @@ -322,12 +322,12 @@ std::pair GetAsyncWaitAttributes(const AttrStmtNode* op); * function body. * \return The updated function. */ -PrimFunc BindParams(PrimFunc f, const Array& constants); +PrimFunc BindParams(PrimFunc f, const ffi::Array& constants); /*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ using StorageAlignTuple = ffi::Tuple; /*! \brief A list of StorageAlignTuple, used by StorageAlign */ -using StorageAlignAnnotation = Array; +using StorageAlignAnnotation = ffi::Array; /*! * \brief Collect storage alignment annotations for all buffer vars within body. * \param body The stmt to collect. diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index 8995beb2ce9e..45bbf4af52de 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -32,14 +32,14 @@ namespace tvm { namespace tir { -std::pair>>, +std::pair>>, ObjectPtrHash, ObjectPtrEqual>, - Map> + ffi::Map> FindLoopLCA(const Stmt& root) { class LCAFinder : public StmtVisitor { public: void VisitStmt_(const ForNode* op) final { - stack.push_back(GetRef(op)); + stack.push_back(ffi::GetRef(op)); StmtVisitor::VisitStmt_(op); if (op->kind == ForKind::kThreadBinding) { UpdateLCA(op); @@ -50,7 +50,7 @@ FindLoopLCA(const Stmt& root) { void UpdateLCA(const ForNode* loop) { std::string thread_tag = loop->thread_binding.value()->thread_tag; { - Map* tgt = &annotations[thread_tag]; + ffi::Map* tgt = &annotations[thread_tag]; for (const auto& kv : loop->annotations) { tgt->Set(kv.first, kv.second); } @@ -78,14 +78,14 @@ FindLoopLCA(const Stmt& root) { std::unordered_map> lca; std::unordered_map iters; - std::unordered_map> annotations; - Map var_subst; + std::unordered_map> annotations; + ffi::Map var_subst; std::vector stack; }; LCAFinder finder; finder(root); - std::unordered_map>>, ObjectPtrHash, - ObjectPtrEqual> + std::unordered_map>>, + ObjectPtrHash, ObjectPtrEqual> result; std::vector sorted_thread_tags; for (const auto& kv : finder.lca) { @@ -104,7 +104,7 @@ FindLoopLCA(const Stmt& root) { for (const auto& thread_tag : sorted_thread_tags) { Stmt lca = finder.lca[thread_tag].back(); const IterVar& iter = finder.iters[thread_tag]; - const Map& annotations = finder.annotations[thread_tag]; + const ffi::Map& annotations = finder.annotations[thread_tag]; result[lca].emplace_back(iter, annotations); } return {result, finder.var_subst}; @@ -117,7 +117,7 @@ FindLoopLCA(const Stmt& root) { class ThreadBindingLifter : public StmtExprMutator { public: Stmt VisitStmt_(const ForNode* _op) final { - For op = GetRef(_op); + For op = ffi::GetRef(_op); bool is_kernel_root = false; if (op->kind == ForKind::kThreadBinding) { if (iter_lca.empty()) { @@ -133,7 +133,7 @@ class ThreadBindingLifter : public StmtExprMutator { ForKind::kThreadBinding, std::move(body), IterVar(Range(nullptr), Var(iter_var->thread_tag, iter_var->var->dtype), kThreadIndex, iter_var->thread_tag), - annotation); + annotation, std::nullopt); } } if (is_kernel_root) { @@ -149,24 +149,24 @@ class ThreadBindingLifter : public StmtExprMutator { } void SetKernelRoot(const ForNode* op) { - auto result = FindLoopLCA(GetRef(op)); + auto result = FindLoopLCA(ffi::GetRef(op)); this->iter_lca = std::move(result.first); this->var_subst = std::move(result.second); } PrimExpr VisitExpr_(const VarNode* op) final { - auto it = var_subst.find(GetRef(op)); + auto it = var_subst.find(ffi::GetRef(op)); if (it != var_subst.end()) { return (*it).second; } else { - return GetRef(op); + return ffi::GetRef(op); } } - std::unordered_map>>, ObjectPtrHash, - ObjectPtrEqual> + std::unordered_map>>, + ObjectPtrHash, ObjectPtrEqual> iter_lca; - Map var_subst; + ffi::Map var_subst; }; PrimFunc LiftThreadBinding(PrimFunc f) { @@ -184,10 +184,10 @@ Pass LiftThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.LiftThreadBinding", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LiftThreadBinding", LiftThreadBinding); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index f083a9d6d4df..fd9bd2d6531c 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -59,16 +59,16 @@ struct LoopPartitionConfigNode : public AttrsNodeReflAdapterloop_var.get(); if (partition_hint_vars.count(var)) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); return; } @@ -122,7 +122,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); } record_.erase(var); } else { @@ -137,7 +137,7 @@ class CandidateSelector final : public StmtExprVisitor { Var var = iv->var; // always treat var with hint to be partitioned if (partition_hint_vars.count(var.get())) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); return; } @@ -146,7 +146,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); } record_.erase(var.get()); return; @@ -213,7 +213,7 @@ class CandidateSelector final : public StmtExprVisitor { #define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \ void VisitExpr_(const OpNodeT* op) final { \ if (has_partition_hint_) { \ - DeduceCondition(GetRef(op)); \ + DeduceCondition(ffi::GetRef(op)); \ return; \ } \ StmtExprVisitor::VisitExpr_(op); \ @@ -421,7 +421,7 @@ class LoopPartitioner : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true); - auto fs = GetRef(op); + auto fs = ffi::GetRef(op); if (selector.candidates.count(fs)) { Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; @@ -443,7 +443,7 @@ class LoopPartitioner : public StmtMutator { const IterVarNode* iv = op->node.as(); ICHECK(iv); Var var = iv->var; - auto as = GetRef(op); + auto as = ffi::GetRef(op); if (selector.candidates.count(as)) { Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true); if (s.defined()) return s; @@ -489,7 +489,7 @@ class LoopPartitioner : public StmtMutator { std::pair LoopPartitioner::GetIntervalAndCondset( const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value, bool has_partition_hint) { - Array sets; + ffi::Array sets; ExpressionSet cond_set; for (const auto& kv : partitions) { @@ -760,14 +760,18 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) { const ForNode* for_node = static_cast(node); ICHECK(for_node); + if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) && !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { ICHECK(for_node->kind != ForKind::kThreadBinding); - return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body, - for_node->thread_binding, for_node->annotations); + auto new_loop = ffi::make_object(*for_node); + new_loop->min = IntImm(for_node->min.dtype(), 0); + new_loop->extent = extent; + new_loop->body = body; + return For(new_loop); } } @@ -819,10 +823,10 @@ Pass LoopPartition() { return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LoopPartition", LoopPartition); -}); +} } // namespace transform diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 71c6c945e8f3..1b7bf14c38ae 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -52,7 +52,8 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { } // if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior - std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); + std::optional mem_copy = + IdentifyMemCpy(ffi::GetRef(loop), analyzer_); if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || mem_copy->source->region.size() != 1) { return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); @@ -159,7 +160,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { std::set queue_ids_; std::optional async_queue_id_ = std::nullopt; bool dma_bypass_cache_; - Map input_iters = Map(); + ffi::Map input_iters = ffi::Map(); }; namespace transform { @@ -176,10 +177,10 @@ Pass LowerAsyncDMA() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerAsyncDMA", LowerAsyncDMA); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 00cc2f226a60..2f7ac3ddb1c0 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -105,7 +105,7 @@ bool IsDominantBlock(const Block& scope_block, const Block& block) { * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the * check again. */ -bool IsReductionBlock(const BlockRealize& realize, const Map& loop_range_map, +bool IsReductionBlock(const BlockRealize& realize, const ffi::Map& loop_range_map, const Block& scope_block, arith::Analyzer* analyzer) { const auto* block = realize->block.as(); // Cond 1. The block has the `init` statement. @@ -123,11 +123,11 @@ bool IsReductionBlock(const BlockRealize& realize, const Map& loop_r } // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. - if (!IsDominantBlock(scope_block, GetRef(block))) { + if (!IsDominantBlock(scope_block, ffi::GetRef(block))) { return false; } // Cond 5. The reduction block vars are not used to index the output buffers. - return ReductionIterNotIndexOutputBuffer(GetRef(block)); + return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)); } /*! @@ -137,11 +137,12 @@ bool IsReductionBlock(const BlockRealize& realize, const Map& loop_r * computation results or not, which is used for determine the buffer name prefix * \return The created buffers */ -Array MakeScratchpads(const Array& reduction_buffers, bool is_cross_thread_buffer) { - Array new_buffers; +ffi::Array MakeScratchpads(const ffi::Array& reduction_buffers, + bool is_cross_thread_buffer) { + ffi::Array new_buffers; new_buffers.reserve(reduction_buffers.size()); for (const Buffer& buffer : reduction_buffers) { - String name = is_cross_thread_buffer ? "cross" : "in"; + ffi::String name = is_cross_thread_buffer ? "cross" : "in"; name = name + "_thread_" + buffer->name; new_buffers.push_back(Buffer(/*ptr=*/Var(name, PointerType(PrimType(buffer->dtype), "local")), /*dtype=*/buffer->dtype, @@ -162,8 +163,8 @@ Array MakeScratchpads(const Array& reduction_buffers, bool is_cr */ class BufferReplacer : private StmtExprMutator { public: - static Stmt Run(Array src_buffers, Array tgt_buffers, Stmt stmt) { - Map buffer_map; + static Stmt Run(ffi::Array src_buffers, ffi::Array tgt_buffers, Stmt stmt) { + ffi::Map buffer_map; ICHECK_EQ(src_buffers.size(), tgt_buffers.size()); int n_buffers = src_buffers.size(); for (int i = 0; i < n_buffers; ++i) { @@ -173,11 +174,12 @@ class BufferReplacer : private StmtExprMutator { } private: - explicit BufferReplacer(Map buffer_map) : buffer_map_(std::move(buffer_map)) {} + explicit BufferReplacer(ffi::Map buffer_map) + : buffer_map_(std::move(buffer_map)) {} PrimExpr VisitExpr_(const BufferLoadNode* load) final { auto it = buffer_map_.find(load->buffer); - return it != buffer_map_.end() ? BufferLoad((*it).second, {0}) : GetRef(load); + return it != buffer_map_.end() ? BufferLoad((*it).second, {0}) : ffi::GetRef(load); } Stmt VisitStmt_(const BufferStoreNode* store) final { @@ -190,7 +192,7 @@ class BufferReplacer : private StmtExprMutator { } } - Map buffer_map_; + ffi::Map buffer_map_; }; /*! @@ -217,7 +219,7 @@ class InThreadReducerMaker : private StmtMutator { private: void VisitStmt_(const BlockNode* block) final { - Array iter_vars = block->iter_vars; + ffi::Array iter_vars = block->iter_vars; for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type == kCommReduce) { reduction_block_vars_.push_back(iter_var); @@ -227,17 +229,17 @@ class InThreadReducerMaker : private StmtMutator { } /*! \brief the map from thread tag to its extent */ - Array reduction_block_vars_; + ffi::Array reduction_block_vars_; }; - static Optional Make(const BlockRealizeNode* src_realize, - Optional tgt_realize, Stmt stmt) { + static ffi::Optional Make(const BlockRealizeNode* src_realize, + ffi::Optional tgt_realize, Stmt stmt) { return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt)); } private: explicit InThreadReducerMaker(const BlockRealizeNode* src_realize, - Optional tgt_realize) + ffi::Optional tgt_realize) : src_realize_(src_realize), tgt_realize_(tgt_realize) {} Stmt VisitStmt_(const BlockRealizeNode* realize) final { if (realize == src_realize_) { @@ -245,11 +247,11 @@ class InThreadReducerMaker : private StmtMutator { ? tgt_realize_.value() : Stmt{nullptr}; } - return GetRef(realize); + return ffi::GetRef(realize); } Stmt VisitStmt_(const ForNode* loop) final { - if (Optional opt_res = Downcast>(StmtMutator::VisitStmt_(loop))) { + if (ffi::Optional opt_res = Downcast>(StmtMutator::VisitStmt_(loop))) { For res = opt_res.value(); if (res->thread_binding.defined()) { UnderLoopReductionBlockVarCollector collector; @@ -267,10 +269,10 @@ class InThreadReducerMaker : private StmtMutator { } Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array stmts; + ffi::Array stmts; stmts.reserve(seq->size()); for (const Stmt& stmt : seq->seq) { - if (Optional opt_res = VisitStmt(stmt)) { + if (ffi::Optional opt_res = VisitStmt(stmt)) { stmts.push_back(opt_res.value()); } } @@ -278,7 +280,7 @@ class InThreadReducerMaker : private StmtMutator { } const BlockRealizeNode* src_realize_; - Optional tgt_realize_; + ffi::Optional tgt_realize_; }; /*! @@ -293,19 +295,19 @@ class InThreadReducerMaker : private StmtMutator { * \param combiner_rhs The RHS values of the combiner * \param reduction_loops The reduction loops */ -Stmt TransformReductionBlock(const BlockRealizeNode* realize, // - const Optional>& it_buffers, // - const Array& ct_buffers, // - const Array& wb_buffers, // - const Array& old_wb_indices, // - const CommReducer& reducer, // - const Array& combiner_rhs, // +Stmt TransformReductionBlock(const BlockRealizeNode* realize, // + const ffi::Optional>& it_buffers, // + const ffi::Array& ct_buffers, // + const ffi::Array& wb_buffers, // + const ffi::Array& old_wb_indices, // + const CommReducer& reducer, // + const ffi::Array& combiner_rhs, // const std::vector& reduction_loops) { int n_buffers = wb_buffers.size(); const BlockNode* block = realize->block.get(); - auto f_create_buffer_regions = [](Array buffers) { - Array regions; + auto f_create_buffer_regions = [](ffi::Array buffers) { + ffi::Array regions; regions.reserve(buffers.size()); for (const Buffer& buffer : buffers) { regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)})); @@ -313,8 +315,8 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // return regions; }; - Array ct_buffer_regions = f_create_buffer_regions(ct_buffers); - Optional> it_buffer_regions = std::nullopt; + ffi::Array ct_buffer_regions = f_create_buffer_regions(ct_buffers); + ffi::Optional> it_buffer_regions = std::nullopt; if (it_buffers.defined()) { it_buffer_regions = f_create_buffer_regions(it_buffers.value()); } @@ -323,11 +325,11 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // // - Stmt 2: do in-thread reduction // - Stmt 3: do cross-thread reduction // - Stmt 4: write cross-thread reduction result to the original buffer - Array stmts; + ffi::Array stmts; stmts.reserve(4); // Stmt 1: initialize the buffer for in-thread reduction if (it_buffers.defined()) { - Array inits; + ffi::Array inits; inits.reserve(n_buffers); for (int i = 0; i < n_buffers; ++i) { inits.push_back( @@ -344,31 +346,32 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // } // Stmt 2: do in-thread reduction { - Optional new_realize = std::nullopt; + ffi::Optional new_realize = std::nullopt; // If need to generate in-thread reduction, // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize // otherwise, directly remove given BlockRealize if (it_buffers.defined()) { - ObjectPtr new_block = make_object(*block); + ObjectPtr new_block = ffi::make_object(*block); new_block->reads = std::move(new_block->reads); new_block->writes = it_buffer_regions.value(); new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body)); new_block->init = std::nullopt; - ObjectPtr n = make_object(*realize); + ObjectPtr n = ffi::make_object(*realize); n->block = Block(new_block); new_realize = BlockRealize(n); } - For loop = GetRef(reduction_loops[0]); - if (Optional stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { + For loop = ffi::GetRef(reduction_loops[0]); + if (ffi::Optional stmt = + InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { stmts.push_back(stmt.value()); } } // Stmt 3: do cross-thread reduction { // Step 3.1. Create the parameters to the intrinsic - Array parameters; + ffi::Array parameters; parameters.reserve(reduction_loops.size() + 4); // 1-st argument: number of buffers parameters.push_back(make_const(DataType::UInt(32), n_buffers)); @@ -393,12 +396,12 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // } } // Step 3.2. Create the block and the block-realize. - Array iter_vars{nullptr}; - Array bindings{nullptr}; - Array reads{nullptr}; + ffi::Array iter_vars{nullptr}; + ffi::Array bindings{nullptr}; + ffi::Array reads{nullptr}; if (it_buffers.defined()) { - iter_vars = Array{}; - bindings = Array{}; + iter_vars = ffi::Array{}; + bindings = ffi::Array{}; reads = it_buffer_regions.value(); } else { iter_vars = block->iter_vars; @@ -426,9 +429,9 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // { ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); int n_iter = static_cast(block->iter_vars.size()); - Array iter_vars; - Array bindings; - Map var_map; + ffi::Array iter_vars; + ffi::Array bindings; + ffi::Map var_map; iter_vars.reserve(n_iter); bindings.reserve(n_iter); for (int i = 0; i < n_iter; ++i) { @@ -437,8 +440,8 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // if (iter_var->iter_type != kCommReduce) { IterVar new_iter_var{nullptr}; { - ObjectPtr n = make_object(*iter_var.get()); - ObjectPtr v = make_object(*iter_var->var.get()); + ObjectPtr n = ffi::make_object(*iter_var.get()); + ObjectPtr v = ffi::make_object(*iter_var->var.get()); n->var = Var(v); new_iter_var = IterVar(n); } @@ -447,13 +450,13 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // var_map.Set(iter_var->var, new_iter_var->var); } } - Array wb_updates; - Array wb_regions; + ffi::Array wb_updates; + ffi::Array wb_regions; wb_updates.reserve(n_buffers); wb_regions.reserve(n_buffers); int n_dim = static_cast(old_wb_indices.size()); - Array region = Substitute(block->writes[0]->region, var_map); - Array wb_indices; + ffi::Array region = Substitute(block->writes[0]->region, var_map); + ffi::Array wb_indices; wb_indices.reserve(n_dim); for (int d = 0; d < n_dim; ++d) { wb_indices.push_back(Substitute(old_wb_indices[d], var_map)); @@ -475,13 +478,13 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // } PostOrderVisit(realize->predicate, [&wb_predicate, &reduction_loop_vars](const ObjectRef& obj) { if (const auto* and_node = obj.as()) { - Array sub_exprs = {and_node->a, and_node->b}; + ffi::Array sub_exprs = {and_node->a, and_node->b}; for (PrimExpr sub_expr : sub_exprs) { if (sub_expr->IsInstance()) { continue; } bool is_reduction = [sub_expr, &reduction_loop_vars]() { - Array vars = UndefinedVars(sub_expr); + ffi::Array vars = UndefinedVars(sub_expr); for (Var var : vars) { if (reduction_loop_vars.find(var.get()) != reduction_loop_vars.end()) { return true; @@ -520,7 +523,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) { const ForNode* loop = *rit; if (loop->thread_binding.defined()) { - ObjectPtr n = make_object(*loop); + ObjectPtr n = ffi::make_object(*loop); n->body = std::move(new_stmt); new_stmt = For(n); } @@ -541,14 +544,14 @@ class CrossThreadReductionTransformer : public StmtMutator { } // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. - if (!IsReductionBlock(GetRef(realize), loop_range_map_, - GetRef(block_stack_.back()), &analyzer_)) { + if (!IsReductionBlock(ffi::GetRef(realize), loop_range_map_, + ffi::GetRef(block_stack_.back()), &analyzer_)) { return {}; } // Step 2. Collect all the vars that appear in the bindings of reduction block iters. std::unordered_set reduction_vars; - GetVarsTouchedByBlockIters(GetRef(realize), nullptr, &reduction_vars); + GetVarsTouchedByBlockIters(ffi::GetRef(realize), nullptr, &reduction_vars); // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters. // We call these loops "reduction-related". @@ -628,7 +631,7 @@ class CrossThreadReductionTransformer : public StmtMutator { * - the RHS values of the reduction updates, * - the indices which is used to access the reduction buffers when storing the reduction results */ - std::tuple, Array, Array> + std::tuple, ffi::Array, ffi::Array> CheckCanApplyCrossThreadReduction(const BlockNode* block, const std::vector& reduction_loops) const { // Condition 1. All the reduction-related loops should be the deepest among all statements @@ -669,19 +672,19 @@ class CrossThreadReductionTransformer : public StmtMutator { // Condition 3. Get the identity values of the block init and the BufferStore block combiner // updates of the reduction. Extract the commutative reducer, combiner lhs and combiner rhs from // the reduction identities and the reduction combiner. - Array init_values{nullptr}; - Array updates{nullptr}; + ffi::Array init_values{nullptr}; + ffi::Array updates{nullptr}; CommReducer reducer{nullptr}; - Array combiner_lhs{nullptr}; - Array combiner_rhs{nullptr}; + ffi::Array combiner_lhs{nullptr}; + ffi::Array combiner_rhs{nullptr}; std::tie(init_values, updates) = - GetInitValuesAndUpdatesFromReductionBlock(std::nullopt, GetRef(block)); + GetInitValuesAndUpdatesFromReductionBlock(std::nullopt, ffi::GetRef(block)); std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(std::nullopt, init_values, updates); // Condition 4. All reduction buffers should be all local or all non-local. int is_local_buf = -1; - Array reduction_buffers; + ffi::Array reduction_buffers; reduction_buffers.reserve(updates.size()); for (const BufferStore& buf_store : updates) { reduction_buffers.push_back(buf_store->buffer); @@ -702,7 +705,7 @@ class CrossThreadReductionTransformer : public StmtMutator { // Condition 5. The block should be the last block under the first reduction-related loop. bool visit = false; - PreOrderVisit(GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { + PreOrderVisit(ffi::GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { if (const auto* realize = obj.as()) { CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " "block isn't the last block under its first reduction-related loop"; @@ -772,7 +775,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Map old_loop_range_map; + ffi::Map old_loop_range_map; block_stack_.push_back(block); std::swap(old_loop_range_map, loop_range_map_); @@ -801,9 +804,9 @@ class CrossThreadReductionTransformer : public StmtMutator { // which condition the block violates. int n_bound_reduction_loops = 0; CommReducer reducer{nullptr}; - Array reduction_buffers{nullptr}; - Array combiner_rhs{nullptr}; - Array wb_indices{nullptr}; + ffi::Array reduction_buffers{nullptr}; + ffi::Array combiner_rhs{nullptr}; + ffi::Array wb_indices{nullptr}; std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) = CheckCanApplyCrossThreadReduction(block, reduction_loops); // Step 2. Before doing the cross-thread reduction, in-thread reduction is needed when @@ -814,10 +817,11 @@ class CrossThreadReductionTransformer : public StmtMutator { !is_one(realize->predicate); // Step 3. Create intermediate buffers, storing them in `ct_buffers` and // `it_buffers`. Let the scope block allocate these new buffers. - Array& new_buffers = block2new_buffers_[block_stack_.back()]; - Array ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); + ffi::Array& new_buffers = block2new_buffers_[block_stack_.back()]; + ffi::Array ct_buffers = + MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); new_buffers.insert(new_buffers.end(), ct_buffers.begin(), ct_buffers.end()); - Optional> it_buffers = std::nullopt; + ffi::Optional> it_buffers = std::nullopt; if (need_in_thread_reduction) { it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false); new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end()); @@ -849,7 +853,7 @@ class CrossThreadReductionTransformer : public StmtMutator { // Step 1. Generate loop var for each unbound thread. // Update the block predicate with clauses of `thread_var == min`. PrimExpr predicate = realize->predicate; - Array loop_vars; + ffi::Array loop_vars; loop_vars.reserve(unbound_thread2range.size()); for (auto [scope, range] : unbound_thread2range) { std::string dim_index(1, static_cast(scope.dim_index + 'x')); @@ -859,7 +863,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } // Step 2. Update the BlockRealize with the new predicate. - ObjectPtr p_realize = make_object(*realize); + ObjectPtr p_realize = ffi::make_object(*realize); p_realize->predicate = std::move(predicate); // Step 3. Wrap the updated BlockRealize with the new loops. @@ -874,7 +878,9 @@ class CrossThreadReductionTransformer : public StmtMutator { /*body=*/body, // /*thread_binding=*/ IterVar(NullValue(), Var("", loop_vars[i]->dtype), IterVarType::kThreadIndex, - "threadIdx." + dim_index)); + "threadIdx." + dim_index), + /*annotations=*/{}, + /*step=*/std::nullopt); } return body; } @@ -910,9 +916,9 @@ class CrossThreadReductionTransformer : public StmtMutator { std::vector statement_stack_; std::vector loop_stack_; std::vector block_stack_; - std::unordered_map> block2new_buffers_; + std::unordered_map> block2new_buffers_; std::unordered_map loop2new_stmt_; - Map loop_range_map_; + ffi::Map loop_range_map_; arith::Analyzer analyzer_; int block_idx_depth = 0; @@ -936,10 +942,10 @@ Pass LowerCrossThreadReduction() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerCrossThreadReduction", LowerCrossThreadReduction); -}); +} } // namespace transform diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index f77276e1553c..90725fe5befc 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -64,7 +64,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { PrimExpr VisitExpr_(const FloatImmNode* imm) final { auto type_code = imm->dtype.code(); - auto e = GetRef(imm); + auto e = ffi::GetRef(imm); if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); ICHECK(lower) << "FloatImm lowering function for target " << target_ << " type " @@ -75,7 +75,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto itr = var_remap_.find(var); if (itr != var_remap_.end()) { @@ -251,10 +251,10 @@ Pass LowerCustomDatatypes() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerCustomDatatypes", LowerCustomDatatypes); -}); +} } // namespace transform diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 529956d372f3..d7a89f87a811 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -43,21 +43,26 @@ struct KernelInfo { // The externally visible symbol which may refer to the PrimFunc // when launching a device kernel. - String global_symbol; + ffi::String global_symbol; // The parameters accepted by the PrimFunc. Used to rewrite // `launch_args` to be in terms of the calling scope. - Array params; + ffi::Array params; // The launch parameters that should annotate the PrimFunc, if the // kernel is ever called from the host. - Array launch_params; + ffi::Array launch_params; // Additional arguments which must be provided to the host-side // ffi::Function. These may be in terms of the function's parameters // (e.g. a function that computes the average of `N` elements, and // which must be launched with `N` CUDA threads). - Array launch_args; + ffi::Array launch_args; + + // The extent of each thread + ffi::Map thread_extent; + // The amount of dynamic shared memory used + ffi::Optional dyn_shmem_size{std::nullopt}; }; /*! @@ -80,16 +85,18 @@ class DeviceInfoCollector : public StmtVisitor { } collector.info_.global_symbol = - func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); + func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); collector.info_.launch_args = collector.info_.launch_params.Map( [&](const auto& param) { return collector.GetArgument(param); }); + collector.info_.dyn_shmem_size = collector.dyn_shmem_size; + collector.info_.thread_extent = collector.thread_extent; return collector.info_; } private: - PrimExpr GetArgument(const String& launch_param) const { + PrimExpr GetArgument(const ffi::String& launch_param) const { if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { CHECK(dyn_shmem_size.defined()) << "Compute kernel requires launch parameter \"" << launch_param @@ -142,9 +149,9 @@ class DeviceInfoCollector : public StmtVisitor { // recording what thread axis have been visited. std::unordered_set defined_thread; // The extent of each thread - Map thread_extent; + ffi::Map thread_extent; // The amount of dynamic shared memory used - Optional dyn_shmem_size{std::nullopt}; + ffi::Optional dyn_shmem_size{std::nullopt}; }; class ReturnRemover : public StmtExprMutator { @@ -229,10 +236,16 @@ class DeviceKernelMutator : public StmtExprMutator { {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, {tvm::attr::kGlobalSymbol, info.global_symbol}}); - } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { + } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); } + const auto& info = device_info_map_.at(gvar.get()); + const auto& thread_extent = info.thread_extent; + func = WithAttr(std::move(func), "thread_extent", thread_extent); + if (info.dyn_shmem_size.defined()) { + func = WithAttr(std::move(func), "dyn_shared_memory_buf", info.dyn_shmem_size.value()); + } return func; } @@ -266,12 +279,12 @@ class DeviceKernelMutator : public StmtExprMutator { // calling a custom TIRToRuntime target) do not require a kernel // launch, but need to be replaced with call_extern. extern_function_call_.insert(gvar); - Array args; + ffi::Array args; args.push_back(StringImm(gvar->name_hint)); for (const auto& arg : node->args) { args.push_back(arg); } - return Call(node->dtype, builtin::call_extern(), args); + return Call(node->dtype, builtin::call_extern(), args, node->annotations); } ICHECK(dev_info.launch_params.defined()) @@ -285,8 +298,8 @@ class DeviceKernelMutator : public StmtExprMutator { // caller's parameters. The param_map allows substitution of // parameter values into the thread extents, to generate // expressions that are valid within the caller. - Map param_map = [&]() { - Map param_map; + ffi::Map param_map = [&]() { + ffi::Map param_map; CHECK_EQ(node->args.size(), dev_info.params.size()) << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() << " arguments as input, but is called using " << node->args.size() << " arguments"; @@ -298,7 +311,7 @@ class DeviceKernelMutator : public StmtExprMutator { device_kernel_launch_.insert(gvar); - Array call_args; + ffi::Array call_args; call_args.push_back(StringImm(dev_info.global_symbol)); for (PrimExpr arg : node->args) { call_args.push_back(arg); @@ -309,10 +322,10 @@ class DeviceKernelMutator : public StmtExprMutator { auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; - return Call(dtype, builtin::tvm_call_packed(), call_args); + return Call(dtype, builtin::tvm_call_packed(), call_args, node->annotations); } - Optional current_target_; + ffi::Optional current_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; std::unordered_set extern_function_call_; @@ -336,7 +349,7 @@ Pass LowerDeviceKernelLaunch() { IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { - auto prim_func = mutator.RewriteKernelLaunchSite(gvar, GetRef(ptr)); + auto prim_func = mutator.RewriteKernelLaunchSite(gvar, ffi::GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } @@ -352,7 +365,7 @@ Pass LowerDeviceKernelLaunch() { IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { - auto prim_func = mutator.UpdateKernelAttributes(gvar, GetRef(ptr)); + auto prim_func = mutator.UpdateKernelAttributes(gvar, ffi::GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } @@ -370,10 +383,10 @@ Pass LowerDeviceKernelLaunch() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.LowerDeviceKernelLaunch", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerDeviceKernelLaunch", LowerDeviceKernelLaunch); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 65a55dff36ed..d3a2001c6ec5 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -131,10 +131,10 @@ Pass LowerDeviceStorageAccessInfo() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerDeviceStorageAccessInfo", LowerDeviceStorageAccessInfo); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index d3994b066dbc..5ae654077316 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -45,7 +45,7 @@ class InitBlockLower : public StmtMutator { return Block(n); } - static Stmt DoLowering(const Stmt& init, const Array& iter_vars) { + static Stmt DoLowering(const Stmt& init, const ffi::Array& iter_vars) { std::vector conditions; for (const IterVar& var : iter_vars) { if (var->iter_type == IterVarType::kCommReduce) { @@ -80,10 +80,10 @@ Pass LowerInitBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerInitBlock", LowerInitBlock); -}); +} } // namespace transform diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 2915a741e80e..4c35fdb2902f 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -68,9 +68,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode* op) final { if (auto* ptr_op = op->op.as()) { for (const auto& f_attr_map : attr_maps_) { - FLowerGeneral f = f_attr_map.get(GetRef(ptr_op), nullptr); + FLowerGeneral f = f_attr_map.get(ffi::GetRef(ptr_op), nullptr); if (f != nullptr) { - PrimExpr e = GetRef(op); + PrimExpr e = ffi::GetRef(op); PrimExpr r = f(e); ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; if (!r.same_as(e)) { @@ -97,7 +97,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // We use floordiv for integer analysis, // but will need to lower them to native truncdiv instructions PrimExpr VisitExpr_(const FloorDivNode* op) final { - auto e = GetRef(op); + auto e = ffi::GetRef(op); PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; @@ -290,7 +290,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using namespace arith; PVar x, y; PVar c; - auto e = GetRef(op); + auto e = ffi::GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); @@ -301,7 +301,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const EQNode* op) final { using namespace arith; PVar x, y; - auto e = GetRef(op); + auto e = ffi::GetRef(op); if ((floormod(x, y) == 0).Match(e)) { return VisitExpr((truncmod(x, y) == 0).Eval()); } @@ -311,7 +311,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const NENode* op) final { using namespace arith; PVar x, y; - auto e = GetRef(op); + auto e = ffi::GetRef(op); if ((floormod(x, y) != 0).Match(e)) { return VisitExpr((truncmod(x, y) != 0).Eval()); } @@ -387,7 +387,7 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - auto mtriple = target.value()->GetAttr("mtriple", ""); + auto mtriple = target.value()->GetAttr("mtriple", ""); n->body = IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); return f; @@ -395,10 +395,10 @@ Pass LowerIntrin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerIntrin", LowerIntrin); -}); +} } // namespace transform diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index d301e910f922..dc3cc0dbab39 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -52,9 +52,9 @@ class MatchBufferLower : public StmtExprMutator { Stmt stmt = StmtExprMutator ::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - Array reads = + ffi::Array reads = op->reads.Map(std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); - Array writes = op->writes.Map( + ffi::Array writes = op->writes.Map( std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); if (reads.same_as(op->reads) && writes.same_as(op->writes) && op->match_buffers.empty()) { @@ -74,7 +74,7 @@ class MatchBufferLower : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); + Var v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return (*it).second; @@ -115,7 +115,7 @@ class MatchBufferLower : public StmtExprMutator { } else { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; - Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + ffi::Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in lower match buffer pass."; return BufferLoad(source->buffer, indices); @@ -152,8 +152,8 @@ class MatchBufferLower : public StmtExprMutator { // Step.1.2. Check data alignment if (source_buffer->data_alignment % buffer->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " - << " required_alignment=" << buffer->data_alignment - << ", provided_alignment=" << source_buffer->data_alignment; + << " required alignment=" << buffer->data_alignment + << ", provided alignment=" << source_buffer->data_alignment; } if (is_zero(buffer->elem_offset)) { ICHECK(is_zero(source_buffer->elem_offset)) @@ -170,13 +170,13 @@ class MatchBufferLower : public StmtExprMutator { // Step.2.2. Update element offset // We use the ElemOffset method to avoid duplicating the index calculation. { - Array indices; + ffi::Array indices; indices.reserve(source->region.size()); for (const Range& range : source->region) { indices.push_back(range->min); } - Array buffer_start_indices = source_buffer->ElemOffset(indices); + ffi::Array buffer_start_indices = source_buffer->ElemOffset(indices); if (buffer_start_indices.size() == 1) { Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) @@ -184,7 +184,7 @@ class MatchBufferLower : public StmtExprMutator { << " does not satisfy the offset_factor " << buffer->offset_factor << "."; } else { // Non-zero elem_offset is ill-defined for non-flat memory. - // If needed in the future, will require `Array + // If needed in the future, will require `ffi::Array // elem_offsets`, with one offset for each flattened index. Bind(buffer->elem_offset, make_const(buffer->elem_offset.dtype(), 0)); } @@ -220,8 +220,15 @@ class MatchBufferLower : public StmtExprMutator { } void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") { - CHECK_EQ(arg.dtype(), value.dtype()) - << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; + if (arg.dtype() != value.dtype()) { + if (arg.dtype().is_int() && value.dtype().is_int() && + arg.dtype().lanes() == value.dtype().lanes()) { + value = cast(arg.dtype(), value); + } else { + CHECK_EQ(arg.dtype(), value.dtype()) + << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; + } + } // Handle recursive case value = Substitute(std::move(value), var_map_); if (arg->IsInstance()) { @@ -246,9 +253,9 @@ class MatchBufferLower : public StmtExprMutator { private: /*! \brief Buffer region mapping. */ - Map match_buffers_; + ffi::Map match_buffers_; /*! \brief Var mapping for buffer signature (data, strides, element_offset, etc.) */ - Map var_map_; + ffi::Map var_map_; /*! \brief The analyzer */ arith::Analyzer analyzer_; }; @@ -268,10 +275,10 @@ Pass LowerMatchBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerMatchBuffer", LowerMatchBuffer); -}); +} } // namespace transform diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 75bfece625d8..c0363dd8982f 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -57,9 +57,9 @@ class OpaqueBlockLower : public StmtExprMutator { // Step 3. Handle allocations in reverse order for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { const Buffer& buffer = new_block->alloc_buffers[i - 1]; - Array allocation_shape = GetBufferAllocationShape(buffer); + ffi::Array allocation_shape = GetBufferAllocationShape(buffer); body = DeclBuffer(buffer, std::move(body)); - Map allocate_annotations; + ffi::Map allocate_annotations; auto it = storage_align_.find(buffer->data); if (it != storage_align_.end()) { StorageAlignAnnotation allocate_aligns; @@ -90,25 +90,28 @@ class OpaqueBlockLower : public StmtExprMutator { // handling unit loop unit_loop_vars_[op->loop_var] = min; } + // Step 2. Visit recursively Stmt body = this->VisitStmt(op->body); + // Step 3. Handle annotations std::vector> pragma_attrs; - Map new_annotations = + ffi::Map new_annotations = HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false); // Step 4. Create new For loop accordingly if (op->kind == ForKind::kThreadBinding) { // Case 1. Thread binding ICHECK(op->thread_binding.defined()); - String thread_tag = op->thread_binding.value()->thread_tag; + ffi::String thread_tag = op->thread_binding.value()->thread_tag; body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); - } else if (is_one(extent) && op->annotations.empty()) { + } else if (is_one(extent) && op->annotations.empty() && + !op->annotations.count(attr::irregular_loop_mark)) { // Case 2. Unit loop return body; } else { // Case 3. An ordinary loop body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body), - std::nullopt, new_annotations); + std::nullopt, new_annotations, op->step); } // Step 5. Insert nested attrs for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { @@ -118,7 +121,7 @@ class OpaqueBlockLower : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = unit_loop_vars_.find(var); if (it == unit_loop_vars_.end()) { return var; @@ -132,16 +135,16 @@ class OpaqueBlockLower : public StmtExprMutator { } } - static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, + static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, ffi::String thread_tag, Stmt body) { IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), /*var=*/std::move(var), /*iter_type=*/IterVarType::kThreadIndex, /*thread_tag=*/thread_tag); - String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || - thread_tag == "vthread.y" || thread_tag == "vthread.z") - ? attr::virtual_thread - : attr::thread_extent; + ffi::String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || + thread_tag == "vthread.y" || thread_tag == "vthread.z") + ? attr::virtual_thread + : attr::thread_extent; return AttrStmt(/*node=*/std::move(iter_var), /*attr_key=*/std::move(attr_key), /*value=*/std::move(extent), @@ -149,12 +152,12 @@ class OpaqueBlockLower : public StmtExprMutator { } /*! \brief Convert attr value from annotation map into PrimExpr. */ - PrimExpr ConvertAttrValue(const String& key, const Any& obj) { + PrimExpr ConvertAttrValue(const ffi::String& key, const Any& obj) { if (obj == nullptr) { return PrimExpr(); } else if (auto expr = obj.try_cast()) { return expr.value(); - } else if (auto str = obj.try_cast()) { + } else if (auto str = obj.try_cast()) { return std::move(StringImm(str.value())); } else { LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj.GetTypeKey() @@ -171,13 +174,13 @@ class OpaqueBlockLower : public StmtExprMutator { * (3) the non-pragma block annotations are dropped * \return New annotation dict with preserved keys. Also update pragma attr pairs ordered by key. */ - Map HandleAnnotations( - const Map& annotations, + ffi::Map HandleAnnotations( + const ffi::Map& annotations, std::vector>* pragma_attrs, bool is_block) { - Map preserved_annotations; + ffi::Map preserved_annotations; pragma_attrs->clear(); for (const auto& kv : annotations) { - const String& key = kv.first; + const ffi::String& key = kv.first; if (attr::IsPragmaKey(key)) { pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); } else if (!is_block) { @@ -215,10 +218,10 @@ Pass LowerOpaqueBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerOpaqueBlock", LowerOpaqueBlock); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 3b972482b728..c8873e8fd5e1 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -92,7 +92,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return node; } - Optional GetRemappedBuffer(const Buffer& buf) { + ffi::Optional GetRemappedBuffer(const Buffer& buf) { if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) { return it->second; } @@ -162,7 +162,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const IntImmNode* size_of_args = call->args[0].as(); ICHECK(size_of_args) << call->args[0]->GetTypeKey(); ICHECK_EQ(size, size_of_args->value); - Array inits = combiner->identity_element; + ffi::Array inits = combiner->identity_element; std::vector values(size); std::vector types(size); PrimExpr cond = call->args[size + 1]; @@ -292,7 +292,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) { std::vector reduce_results; DataType mask_dtype = DataType::UInt(32); - PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); + PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}, call->annotations); if (reduce_extent <= warp_size_) { std::tie(reduce_results, new_alloc_bufs) = @@ -433,12 +433,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } std::pair, std::vector> MakeWarpAllreduce( - std::vector src_values, // - std::vector dtypes, // - const CommReducerNode* combiner, // - PrimExpr reduce_index, int reduce_extent, // - PrimExpr group_index, // - PrimExpr mask, Optional predicate, // + std::vector src_values, // + std::vector dtypes, // + const CommReducerNode* combiner, // + PrimExpr reduce_index, int reduce_extent, // + PrimExpr group_index, // + PrimExpr mask, ffi::Optional predicate, // std::vector* seq) { int n_buffers = src_values.size(); @@ -449,8 +449,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // This is the index to the reduction variable, one reduction // variable per warp. Local scope seems easier to reason without // relying on a pattern match pass to fix it later. - Array zero_indices = {0}; - Array shape = {1}; + ffi::Array zero_indices = {0}; + ffi::Array shape = {1}; std::vector load_values; load_values.reserve(n_buffers); @@ -473,7 +473,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The mask for this reducer, as this reducer may sit inside // a divergent control flow. Here it uses a variable to cache the current // active channels. - Optional mask_buffer; + ffi::Optional mask_buffer; if (need_warp_shuffle_mask_) { mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local"); seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices)); @@ -489,7 +489,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } for (int offset = start_offset; offset > 0; offset /= 2) { // Load reduction values, no synchronization needed. - Array a, b; + ffi::Array a, b; for (int i = 0; i < n_buffers; ++i) { Buffer shared_buf = shared_bufs[i]; BufferLoad val(shared_buf, zero_indices); @@ -519,7 +519,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Do reductions. - Array ret = (*combiner)(a, b); + ffi::Array ret = (*combiner)(a, b); // Store the reduction result to itself. std::vector stores; @@ -554,7 +554,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, - const Array& shared_bufs, PrimExpr reduce_index, + const ffi::Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, int reduce_extent, int group_extent, int contiguous_reduce_extent) { // Get next power of two @@ -569,7 +569,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent); // make reduction auto fload = [&](int offset) { - Array a, b; + ffi::Array a, b; for (size_t i = 0; i < size; ++i) { BufferLoad b_load(shared_bufs[i], {BufIndex(reduce_index + offset, group_index, reduce_extent)}); @@ -580,10 +580,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { ICHECK_EQ(a_load->dtype, types[i]); a.push_back(a_load); } - Array ret = (*combiner)(a, b); + ffi::Array ret = (*combiner)(a, b); return ret; }; - auto fstore = [&](const Array& ret) { + auto fstore = [&](const ffi::Array& ret) { std::vector stores(size); for (size_t i = 0; i < size; ++i) { stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index}); @@ -633,7 +633,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // here to reduce thread divergence. auto loads = fload(reduce_align); - Array in_warp_local_vars; + ffi::Array in_warp_local_vars; for (auto expr : loads) { Var var( "w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()), @@ -696,9 +696,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Optional mask_buffer, PrimExpr val, + PrimExpr WarpShuffle(const Op& op, ffi::Optional mask_buffer, PrimExpr val, PrimExpr delta_or_lane) { - Array indices = {0}; + ffi::Array indices = {0}; PrimExpr mask; if (mask_buffer.defined()) { mask = BufferLoad(mask_buffer.value(), indices); @@ -706,7 +706,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { mask = IntImm(DataType::Int(32), 0); } PrimExpr width = IntImm(DataType::Int(32), warp_size_); - Array args{mask, val, delta_or_lane, width, width}; + ffi::Array args{mask, val, delta_or_lane, width, width}; return Call(val.dtype(), op, args); } @@ -810,10 +810,10 @@ Pass LowerThreadAllreduce() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerThreadAllreduce", LowerThreadAllreduce); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e74f5c7c9046..66e13791f3b2 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -40,7 +40,7 @@ namespace tir { class BuiltinLower : public StmtExprMutator { public: static PrimFunc Build(PrimFunc func) { - Optional device_type = std::nullopt; + ffi::Optional device_type = std::nullopt; if (auto target = func->GetAttr(tvm::attr::kTarget)) { device_type = Integer(target.value()->kind->default_device_type); } @@ -50,7 +50,7 @@ class BuiltinLower : public StmtExprMutator { return func; } - explicit BuiltinLower(Optional device_type = std::nullopt) + explicit BuiltinLower(ffi::Optional device_type = std::nullopt) : device_type_(device_type) {} // NOTE: Right now, we make the following scoping requirement @@ -256,7 +256,7 @@ class BuiltinLower : public StmtExprMutator { Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); Stmt alloc_nullptr_check = IfThenElse( - Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error); + Call(DataType::Bool(), builtin::isnullptr(), {op->buffer_var}), throw_last_error); PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_.value()), cast(DataType::Int(32), device_id_.value()), op->buffer_var}); @@ -317,7 +317,7 @@ class BuiltinLower : public StmtExprMutator { } if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->min = std::move(min); @@ -370,7 +370,7 @@ class BuiltinLower : public StmtExprMutator { << "but was instead the expression " << device_type_ << " with type " << device_type_.value()->GetTypeKey(); - String device_name = runtime::DLDeviceType2Str(as_int->value); + ffi::String device_name = runtime::DLDeviceType2Str(as_int->value); return StringImm("device_api." + device_name + "." + method_name); } @@ -594,9 +594,9 @@ class BuiltinLower : public StmtExprMutator { scope.run_sizes.shape_stack = restore_shape_stack; scope.run_sizes.array_stack = restore_array_stack; scope.run_sizes.arg_stack = arg_stack_begin; - Array packed_args = {op->args[name_offset], scope.stack_ffi_any, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + num_args)}; + ffi::Array packed_args = {op->args[name_offset], scope.stack_ffi_any, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + num_args)}; if (pass_last_arg_as_traced_value) { // pass in last element as traced value // used by call_packed_traced @@ -617,7 +617,7 @@ class BuiltinLower : public StmtExprMutator { Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); Stmt body = SeqStmt( - {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error), + {IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), throw_last_error), let->body, free_stmt}); DataType dtype = @@ -626,7 +626,7 @@ class BuiltinLower : public StmtExprMutator { std::string fdevapi_prefix = "device_api."; fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as()->value); - Array args = { + ffi::Array args = { GetDeviceMethodName("alloc_nd"), device_type_.value(), device_id_.value(), @@ -657,8 +657,8 @@ class BuiltinLower : public StmtExprMutator { // The prepration sequence to be emitted before the current statement. std::vector> prep_seq_stack_; - Optional device_type_{std::nullopt}; - Optional device_id_{std::nullopt}; + ffi::Optional device_type_{std::nullopt}; + ffi::Optional device_id_{std::nullopt}; bool is_precheck_{false}; @@ -679,10 +679,10 @@ Pass LowerTVMBuiltin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerTVMBuiltin", LowerTVMBuiltin); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_vtcm_alloc.cc b/src/tir/transforms/lower_vtcm_alloc.cc index 7cddfb678514..c3b03f8623cd 100644 --- a/src/tir/transforms/lower_vtcm_alloc.cc +++ b/src/tir/transforms/lower_vtcm_alloc.cc @@ -40,7 +40,7 @@ class VtcmAllocator : public StmtExprMutator { std::string storage_scope = GetStorageScope(op->buffer_var); if (IsVtcmStorage(storage_scope)) { Stmt body = this->VisitStmt(op->body); - Array args; + ffi::Array args; args.push_back(StringImm(storage_scope)); args.push_back(IntImm(DataType::Int(64), op->extents.size())); args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->extents)); @@ -73,10 +73,10 @@ Pass LowerVtcmAlloc() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerVtcmAlloc", LowerVtcmAlloc); -}); +} } // namespace transform diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 5708ab0746f2..ceb3ed826529 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -150,7 +150,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { } void UpdatePattern(const PrimExpr& index) { - Array m = arith::DetectLinearEquation(index, {warp_index_}); + ffi::Array m = arith::DetectLinearEquation(index, {warp_index_}); ICHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed. Could not simplify the store index `" << index << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " @@ -254,14 +254,14 @@ class WarpAccessRewriter : protected StmtExprMutator { protected: PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector& indices) { - Array new_args = op->args; + ffi::Array new_args = op->args; for (int i : indices) { if (op->args[i].get() == buffer_) { PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first; new_args.Set(i + 1, local_index); } } - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } PrimExpr VisitExpr_(const CallNode* op) override { @@ -426,7 +426,7 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } - std::unordered_map new_storage_scopes_; + std::unordered_map new_storage_scopes_; private: Stmt VisitStmt_(const AllocateNode* op) { @@ -462,10 +462,10 @@ Pass LowerWarpMemory() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerWarpMemory", LowerWarpMemory); -}); +} } // namespace transform diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7477fe86363d..e8a1b564a43b 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,6 +20,7 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ +#include #include #include #include @@ -124,7 +125,8 @@ class ReturnRewriter : public StmtMutator { class SubroutineCallRewriter : public StmtExprMutator { public: - static Optional Apply(const Map& packed_func_methods, Stmt stmt) { + static ffi::Optional Apply(const ffi::Map& packed_func_methods, + Stmt stmt) { SubroutineCallRewriter rewriter(packed_func_methods); stmt = rewriter.VisitStmt(std::move(stmt)); if (rewriter.made_change_) { @@ -135,16 +137,16 @@ class SubroutineCallRewriter : public StmtExprMutator { } private: - explicit SubroutineCallRewriter(const Map& packed_func_methods) + explicit SubroutineCallRewriter(const ffi::Map& packed_func_methods) : packed_func_methods(packed_func_methods) {} PrimExpr VisitExpr_(const CallNode* op) override { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); if (auto* gvar_ptr = node->op.as()) { - auto gvar = GetRef(gvar_ptr); + auto gvar = ffi::GetRef(gvar_ptr); if (auto symbol = packed_func_methods.Get(gvar)) { - Array cpacked_args; + ffi::Array cpacked_args; cpacked_args.push_back(tir::StringImm(symbol.value())); for (auto arg : node->args) { cpacked_args.push_back(arg); @@ -159,7 +161,7 @@ class SubroutineCallRewriter : public StmtExprMutator { return node; } - const Map& packed_func_methods; + const ffi::Map& packed_func_methods; bool made_change_{false}; }; @@ -181,7 +183,7 @@ inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { * \returns The global_symbol to be used for the function at call * sites, or std::nullopt if the function is to remain unchanged. */ -Optional RequiresPackedAPI(const PrimFunc& func) { +ffi::Optional RequiresPackedAPI(const PrimFunc& func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { @@ -191,12 +193,12 @@ Optional RequiresPackedAPI(const PrimFunc& func) { } // Internal function calls do not need the ffi::Function API - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol.has_value()) { return std::nullopt; } - return global_symbol; + return global_symbol.value(); } PrimFunc MakePackedAPI(PrimFunc func) { @@ -223,6 +225,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } auto* func_ptr = func.CopyOnWrite(); + // set the global symbol to the packed function name const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -246,8 +249,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { // local function definitions // load i-th argument as type t auto f_load_arg_value = [&](DataType arg_type, int i) { - Array call_args{v_packed_args, IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)}; + ffi::Array call_args{v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)}; // load 64 bit version DataType api_type = APIType(arg_type); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); @@ -297,14 +300,16 @@ PrimFunc MakePackedAPI(PrimFunc func) { type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr || type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin, tvm::tir::StringImm(msg.str()), nop)); - // if type_index is NDArray, we need to add the offset of the DLTensor header + // if type_index is Tensor, we need to add the offset of the DLTensor header // which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor* + const int64_t object_cell_offset = sizeof(TVMFFIObject); + static_assert(object_cell_offset == 24); arg_value = f_load_arg_value(param.dtype(), i); - PrimExpr handle_from_ndarray = + PrimExpr handle_from_tensor = Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), - {arg_value, IntImm(DataType::Int(32), 16)}); + {arg_value, IntImm(DataType::Int(32), object_cell_offset)}); arg_value = - Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value); + Select(type_index == ffi::TypeIndex::kTVMFFITensor, handle_from_tensor, arg_value); } else if (dtype.is_bool()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be boolean"; @@ -337,13 +342,13 @@ PrimFunc MakePackedAPI(PrimFunc func) { var_def.emplace_back(arg_value, param); if (func_ptr->buffer_map.count(param)) { // buffer binding now depends on type index - // if the index is NDArray handle, we need to offset to get the DLTensor* + // if the index is Tensor handle, we need to offset to get the DLTensor* buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } } // signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result) - Array args{v_self_handle, v_packed_args, v_num_packed_args, v_result}; + ffi::Array args{v_self_handle, v_packed_args, v_num_packed_args, v_result}; // Arg definitions are defined before buffer binding to avoid the use before // def errors. @@ -360,10 +365,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } - - func = WithAttrs(std::move(func), - {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, - {tvm::attr::kTarget, target_host}}); + // reset global symbol to attach prefix + func = WithAttrs( + std::move(func), + {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}, + {tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); Stmt body = ReturnRewriter(v_result)(func_ptr->body); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, @@ -390,11 +397,11 @@ PrimFunc MakePackedAPI(PrimFunc func) { func_ptr->body = body; func_ptr->params = args; - Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); + ffi::Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined << " are used, but are not passed in as API arguments"; - func_ptr->buffer_map = Map(); + func_ptr->buffer_map = ffi::Map(); func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. @@ -405,7 +412,7 @@ namespace transform { Pass MakePackedAPI() { auto pass_func = [](IRModule mod, PassContext ctx) { - Map packed_func_methods; + ffi::Map packed_func_methods; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto prim_func = opt.value(); @@ -444,10 +451,10 @@ Pass MakePackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.MakePackedAPI", []() { return MakePackedAPI(); }); -}); +} } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 8276d26fcfa8..0e26560ac622 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -45,8 +45,8 @@ namespace { class SubroutineCallRewriter : public StmtExprMutator { public: - static Optional Apply(const std::unordered_set& external_methods, - Stmt stmt) { + static ffi::Optional Apply(const std::unordered_set& external_methods, + Stmt stmt) { SubroutineCallRewriter rewriter(external_methods); stmt = rewriter.VisitStmt(std::move(stmt)); if (rewriter.made_change_) { @@ -65,7 +65,7 @@ class SubroutineCallRewriter : public StmtExprMutator { if (auto gvar = node->op.as()) { if (external_methods_.count(gvar)) { - Array args = node->args.Map([](const PrimExpr& arg) -> PrimExpr { + ffi::Array args = node->args.Map([](const PrimExpr& arg) -> PrimExpr { if (auto* as_call = arg.as()) { if (as_call->op.same_as(builtin::tvm_stack_make_array())) { PrimExpr data_ptr = as_call->args[0]; @@ -102,7 +102,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { } // Internal function calls do not need API updates - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol.has_value()) { return func; } @@ -133,7 +133,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { std::vector device_init; // Collect variables and buffers to map between - Array args; + ffi::Array args; for (const Var& param : func->params) { // Ideally all func params should have Buffers defined in the buffer_map @@ -156,7 +156,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { func_ptr->body = body; func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); - func_ptr->buffer_map = Map(); + func_ptr->buffer_map = ffi::Map(); // return the function. return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}}); @@ -169,7 +169,7 @@ Pass MakeUnpackedAPI() { std::unordered_set external_methods; for (const auto& [gvar, base_func] : mod->functions) { if (auto* prim_func = base_func.as()) { - if (prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { external_methods.insert(gvar.get()); } } @@ -201,10 +201,10 @@ Pass MakeUnpackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.MakeUnpackedAPI", MakeUnpackedAPI); -}); +} } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 73f5d7746da9..8d0b71e75e5d 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -73,7 +73,7 @@ class IntermediateStageRewriter { BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; - Block new_block = GetRef(block); + Block new_block = ffi::GetRef(block); new_block.CopyOnWrite()->body = std::move(new_buffer_store); return {target_buffer, new_buffer, new_block, local_stage}; @@ -119,7 +119,7 @@ class IntermediateStageRewriter { /*! \brief Create the intermediate stage. */ Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer, - Array local_stage_indices, + ffi::Array local_stage_indices, std::vector relaxed_loops, const BufferStoreNode* store) { // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer. Stmt local_stage = BufferStore(new_buffer, store->value, local_stage_indices); @@ -135,9 +135,9 @@ class IntermediateStageRewriter { Downcast(local_stage)); // Step 2: Add outer loops - Map subst_map; + ffi::Map subst_map; for (const ForNode* relaxed_loop : relaxed_loops) { - ObjectPtr for_node = make_object(*relaxed_loop); + ObjectPtr for_node = ffi::make_object(*relaxed_loop); for_node->loop_var = for_node->loop_var.copy_with_suffix(""); for_node->body = std::move(local_stage); local_stage = For(for_node); @@ -148,10 +148,10 @@ class IntermediateStageRewriter { } /*! \brief Create the intermediate buffer with the extents of the relaxed outer loops. */ - std::pair> CreateIntermediateBuffer( + std::pair> CreateIntermediateBuffer( const std::vector relaxed_loops, const Buffer& buffer) const { - Array buffer_indices; - Array new_buffer_shape; + ffi::Array buffer_indices; + ffi::Array new_buffer_shape; // Create the intermediate buffer for the local stage. The shape of the new buffer is the // extents of the relaxed outer loops. @@ -172,14 +172,14 @@ class IntermediateStageRewriter { class SharedMemoryLocalStageInserter : public StmtMutator { public: Stmt VisitStmt_(const ForNode* op) final { - ancestor_loop_or_blocks_.push_back(GetRef(op)); + ancestor_loop_or_blocks_.push_back(ffi::GetRef(op)); Stmt new_stmt = StmtMutator::VisitStmt_(op); ancestor_loop_or_blocks_.pop_back(); return new_stmt; } Stmt VisitStmt_(const BlockRealizeNode* op) final { - ancestor_loop_or_blocks_.push_back(GetRef(op)); + ancestor_loop_or_blocks_.push_back(ffi::GetRef(op)); Stmt new_stmt = StmtMutator::VisitStmt_(op); ancestor_loop_or_blocks_.pop_back(); return new_stmt; @@ -206,8 +206,8 @@ class SharedMemoryLocalStageInserter : public StmtMutator { op->alloc_buffers.begin(), op->alloc_buffers.end()); // Visit children and insert local stages (if any) to the proper location. - Array new_alloc_buffers; - Array new_seq; + ffi::Array new_alloc_buffers; + ffi::Array new_seq; // Helper function to check if the subtree (body of the block) contains any target buffers. // If so, the allocated intermediate buffer and the local stage should be lifted to the current @@ -236,7 +236,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator { } } if (!changed) { - return GetRef(op); + return ffi::GetRef(op); } } else { int subtree_start = target_buffers_.size(); @@ -244,12 +244,12 @@ class SharedMemoryLocalStageInserter : public StmtMutator { int subtree_end = target_buffers_.size(); f_check_subtree(subtree_start, subtree_end); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } new_seq.push_back(body); } - Block new_block = GetRef(op); + Block new_block = ffi::GetRef(op); BlockNode* new_block_node = new_block.CopyOnWrite(); // Add new buffer allocations if any. if (new_alloc_buffers.size() > 0) { @@ -260,9 +260,10 @@ class SharedMemoryLocalStageInserter : public StmtMutator { } std::vector ancestor_loop_or_blocks_; // ancestor loops or block realize - Map buffer_remap_; // mapping from the target buffer to the intermediate buffer - Map buffer_local_stage_; // mapping from the target buffer to the local stage - Array target_buffers_; // the target buffers for rewriting + ffi::Map + buffer_remap_; // mapping from the target buffer to the intermediate buffer + ffi::Map buffer_local_stage_; // mapping from the target buffer to the local stage + ffi::Array target_buffers_; // the target buffers for rewriting }; namespace transform { @@ -276,11 +277,11 @@ Pass ManifestSharedMemoryLocalStage() { return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ManifestSharedMemoryLocalStage", ManifestSharedMemoryLocalStage); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc index 43a976fa892f..0d5b27044232 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -40,13 +40,13 @@ Stmt FuseNestLoops(Stmt body) { } suffix += "_fused"; Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); - Map subst_map; + ffi::Map subst_map; PrimExpr tot = fused_var; for (int i = n - 1; i >= 0; i--) { subst_map.Set(loops[i]->loop_var, floormod(tot, loops[i]->extent)); tot = floordiv(tot, loops[i]->extent); } - auto f_substitute = [&](const Var& v) -> Optional { + auto f_substitute = [&](const Var& v) -> ffi::Optional { return subst_map.Get(v).value_or(v); }; PrimExpr fused_extent = 1; @@ -74,19 +74,19 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { // generate thread binding loops std::vector factors{-1}; std::vector thread_axis; - if (Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { + if (ffi::Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { int t = o_t.value()->value; tot_threads *= t; factors.push_back(t); thread_axis.push_back("threadIdx.z"); } - if (Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { + if (ffi::Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { int t = o_t.value()->value; tot_threads *= t; factors.push_back(t); thread_axis.push_back("threadIdx.y"); } - if (Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { + if (ffi::Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { int t = o_t.value()->value; tot_threads *= t; factors.push_back(t); @@ -114,7 +114,7 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { substitute_value += new_loop_vars[i]; } // Construct the new loop nest - Stmt body = Substitute(loop->body, [&](const Var& v) -> Optional { + Stmt body = Substitute(loop->body, [&](const Var& v) -> ffi::Optional { if (v.same_as(loop->loop_var)) { return substitute_value; } else { @@ -128,7 +128,8 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); for (int i = n - 2; i >= 1; i--) { body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body), - IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1]), + {}, std::nullopt); } return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body)); } @@ -152,17 +153,17 @@ Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints * the index mapping * \return The mapping in the form of j0, ..., jm, where j0, ... jm = f(i0, ..., in) */ -Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { +ffi::Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { Stmt body = stmt; while (const ForNode* loop = body.as()) { body = loop->body; } const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); BufferRegion write_region = constraints.write_region; - const Array& write_index = buf_store->indices; + const ffi::Array& write_index = buf_store->indices; ICHECK(write_region->region.size() == write_index.size() && write_region->buffer.same_as(buf_store->buffer)); - Array result; + ffi::Array result; arith::Analyzer analyzer; for (int i = 0; i < static_cast(write_region->region.size()); i++) { PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); @@ -176,10 +177,10 @@ Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body = stmt; - Map var_range; - Array loop_vars; + ffi::Map var_range; + ffi::Array loop_vars; // Step 1. Get index mapping - Array mapping_pattern = GetMapping(stmt, constraints); + ffi::Array mapping_pattern = GetMapping(stmt, constraints); while (const ForNode* loop = body.as()) { var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); loop_vars.push_back(loop->loop_var); @@ -191,14 +192,15 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, auto iter_map = arith::DetectIterMap(mapping_pattern, var_range, Bool(true), arith::Bijective, &analyzer); CHECK_EQ(iter_map->indices.size(), loop_vars.size()); - Map inverse_mapping = arith::InverseAffineIterMap(iter_map->indices, loop_vars); + ffi::Map inverse_mapping = + arith::InverseAffineIterMap(iter_map->indices, loop_vars); // Step 3. Generate new body BufferRegion read_region = constraints.read_region; BufferRegion write_region = constraints.write_region; - Array write_index; - Array read_index; - Array new_loop_vars; - Map substitute_map; + ffi::Array write_index; + ffi::Array read_index; + ffi::Array new_loop_vars; + ffi::Map substitute_map; // Step 3.1 construct target buffer indices for (int i = 0, j = 0; i < static_cast(write_region->region.size()); i++) { if (is_one(write_region->region[i]->extent)) { diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc index 2ecb740ba327..d4826e609319 100644 --- a/src/tir/transforms/memhammer_intermediate_stage.cc +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -25,7 +25,7 @@ Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_bo Stmt* ith_loop = nullptr) { Stmt ret = inner_body; for (int i = static_cast(loops.size() - 1); i >= 0; i--) { - ObjectPtr new_loop = make_object(*loops[i]); + ObjectPtr new_loop = ffi::make_object(*loops[i]); new_loop->body = ret; ret = For(new_loop); if (ith == i) { @@ -71,7 +71,7 @@ std::pair LiftThreadBindingLoops(Stmt stmt) { */ class IndexPatternFinder : public ExprVisitor { public: - IndexPatternFinder(const Map& var_range, Array* resulting_index) + IndexPatternFinder(const ffi::Map& var_range, ffi::Array* resulting_index) : var_range_(var_range), resulting_index_(resulting_index) {} struct Operator { enum class OpKind { Mul, FloorDiv, FloorMod }; @@ -87,19 +87,19 @@ class IndexPatternFinder : public ExprVisitor { * \param rewrite_indices The access indices after rank promotion * \return The new buffer shape after rank promotion. */ - static Array getRankPromotedShape(Array indices, - const Map& var_range, - Array* rewrite_indices) { - Map var_dom = arith::AsIntSet(var_range); - Array new_shape; + static ffi::Array getRankPromotedShape(ffi::Array indices, + const ffi::Map& var_range, + ffi::Array* rewrite_indices) { + ffi::Map var_dom = arith::AsIntSet(var_range); + ffi::Array new_shape; for (const PrimExpr& expr : indices) { - Array indices_dim; + ffi::Array indices_dim; IndexPatternFinder extractor(var_range, &indices_dim); extractor(expr); if (!extractor.success_) { return {}; } - Array access_shape = extractor.access_shape_; + ffi::Array access_shape = extractor.access_shape_; PrimExpr product_shape = 1; for (PrimExpr e : access_shape) { product_shape *= e; @@ -119,8 +119,8 @@ class IndexPatternFinder : public ExprVisitor { if (!success_) { return; } - if (Optional range = var_range_.Get(GetRef(op))) { - PrimExpr index = GetRef(op); + if (ffi::Optional range = var_range_.Get(ffi::GetRef(op))) { + PrimExpr index = ffi::GetRef(op); int64_t max = range.value()->extent.as()->value; int64_t extent = max; for (int i = static_cast(operator_stack.size()) - 1; i >= 0; i--) { @@ -190,9 +190,9 @@ class IndexPatternFinder : public ExprVisitor { operator_stack.pop_back(); } - Map var_range_; - Array access_shape_; - Array* resulting_index_; + ffi::Map var_range_; + ffi::Array access_shape_; + ffi::Array* resulting_index_; std::vector operator_stack; bool success_ = true; }; @@ -225,15 +225,16 @@ class BufferLoadReplacer : public StmtExprMutator { * \return a pair. The first is the stmt after transformation. * The second is the SeqStmt that contains 2 stages (one original and another inserted). */ -std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, - Optional compute_location, - const Array& outer_loops, Buffer* alloc_buffer) { +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::String storage_scope, + ffi::Optional compute_location, + const ffi::Array& outer_loops, + Buffer* alloc_buffer) { Stmt body = stmt; std::vector loops; std::vector loops_under_compute_location; std::vector relaxed_thread_loops; bool need_relax = !compute_location.defined(); - Map var_range; + ffi::Map var_range; PrimExpr vector_bytes = -1; // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into // several contiguous-changing dimensions @@ -253,7 +254,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } body = loop->body; } - Optional predicate; + ffi::Optional predicate; if (const auto* op = body.as()) { // the predicate is generated by coalescing predicate = op->condition; @@ -261,9 +262,13 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } for (const For& loop : outer_loops) { if (loop->kind == ForKind::kThreadBinding) { - const String& thread_tag = loop->thread_binding.value()->thread_tag; - if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), - runtime::ThreadScope::Create(thread_tag))) { + const ffi::String& thread_tag = loop->thread_binding.value()->thread_tag; + auto thread_scope = runtime::ThreadScope::Create(thread_tag); + if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), thread_scope)) { + if (is_write_cache && thread_scope.dim_index == 0) { + // writing C_reindex_m16n8k8_matrixC_shared_dyn is warp execution + continue; + } var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); relaxed_thread_loops.push_back(loop.get()); } @@ -296,11 +301,11 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); - Array cache_indices; - Array new_shape; + ffi::Array cache_indices; + ffi::Array new_shape; bool use_rank_promotion = false; if (!is_write_cache && buf_store->value.as()) { - Array indices = + ffi::Array indices = is_write_cache ? buf_store->indices : buf_store->value.as()->indices; new_shape = IndexPatternFinder::getRankPromotedShape(indices, var_range, &cache_indices); // write cache disabled for now @@ -309,8 +314,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String use_rank_promotion = true; } } - Array new_loop_vars; - Map subst_map; + ffi::Array new_loop_vars; + ffi::Map subst_map; if (!use_rank_promotion) { cache_indices.clear(); for (const ForNode* loop : relaxed_thread_loops) { @@ -339,8 +344,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String cache_indices.push_back(loop->loop_var); } } - Array subst_indices; - Array subst_cache_indices; + ffi::Array subst_indices; + ffi::Array subst_cache_indices; if (is_write_cache) { for (PrimExpr e : buf_store->indices) { subst_indices.push_back(Substitute(e, subst_map)); @@ -366,8 +371,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String if (is_write_cache) { // copy from wmma to new cache buffer BufferLoad new_buffer_load{new_buffer, cache_indices}; - generate_body = - BufferLoadReplacer(target_buffer_load->buffer, new_buffer_load)(GetRef(buf_store)); + generate_body = BufferLoadReplacer(target_buffer_load->buffer, + new_buffer_load)(ffi::GetRef(buf_store)); generate_body = Substitute(generate_body, subst_map); } else { generate_body = @@ -384,14 +389,14 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String for (int i = static_cast(loops_under_compute_location.size()) - 1; i >= 0; i--) { const ForNode* orig_loop = loops_under_compute_location[i]; - ObjectPtr new_loop = make_object(*orig_loop); + ObjectPtr new_loop = ffi::make_object(*orig_loop); new_loop->loop_var = new_loop_vars[i + relaxed_thread_loops.size()]; new_loop->body = generate_body; generate_body = For(new_loop); } for (int i = static_cast(relaxed_thread_loops.size()) - 1; i >= 0; i--) { const ForNode* orig_loop = relaxed_thread_loops[i]; - ObjectPtr new_loop = make_object(*orig_loop); + ObjectPtr new_loop = ffi::make_object(*orig_loop); new_loop->loop_var = new_loop_vars[i]; new_loop->body = generate_body; new_loop->kind = ForKind::kSerial; @@ -402,7 +407,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String Stmt rewrite_body; if (is_write_cache) { BufferLoad new_buffer_load{new_buffer, cache_indices}; - rewrite_body = BufferStore(new_buffer, GetRef(target_buffer_load), cache_indices); + rewrite_body = + BufferStore(new_buffer, ffi::GetRef(target_buffer_load), cache_indices); } else { rewrite_body = BufferStore(buf_store->buffer, BufferLoad(new_buffer, cache_indices), buf_store->indices); @@ -412,7 +418,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } for (int i = static_cast(loops_under_compute_location.size()) - 1; i >= 0; i--) { const ForNode* orig_loop = loops_under_compute_location[i]; - ObjectPtr new_loop = make_object(*orig_loop); + ObjectPtr new_loop = ffi::make_object(*orig_loop); new_loop->body = rewrite_body; rewrite_body = For(new_loop); } diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc index 15dd58d4ca75..498de4796cd4 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -90,8 +90,8 @@ class AutoPadder { * \param buffers the given buffers * \return the list of new padded buffers */ - Array PadSharedMemory(const Array& buffers) { - Array result; + ffi::Array PadSharedMemory(const ffi::Array& buffers) { + ffi::Array result; for (const Buffer& buffer : buffers) { runtime::StorageScope scope = runtime::StorageScope::Create(buffer.scope()); @@ -113,7 +113,7 @@ class AutoPadder { low_dim_iter_space[i] = last_dim_iter_space; } PrimExpr stride = 1; - Array reverse_strides; + ffi::Array reverse_strides; int pad_min = padding_min_.Get(buffer).value_or(Integer(1)).IntValue(); // Step 2. For each dimension, select a padding that has minimal bank conflict for (int k = n - 2; k >= 0; k--) { // dims @@ -165,8 +165,8 @@ class AutoPadder { reverse_strides.push_back(stride); } // Step 3. create the new padded buffer - ObjectPtr b = make_object(*buffer.get()); - Array strides; + ObjectPtr b = ffi::make_object(*buffer.get()); + ffi::Array strides; for (int i = static_cast(reverse_strides.size()) - 1; i >= 0; i--) { strides.push_back(reverse_strides[i]); } @@ -190,7 +190,7 @@ class AutoPadder { Stmt RewriteBufferAccess(const Stmt& stmt) { class Rewriter : public StmtExprMutator { public: - explicit Rewriter(const Map& buffer_map) : buffer_map_(buffer_map) {} + explicit Rewriter(const ffi::Map& buffer_map) : buffer_map_(buffer_map) {} private: PrimExpr VisitExpr_(const BufferLoadNode* _op) final { @@ -217,7 +217,7 @@ class AutoPadder { // after mutation. Otherwise we just return the original block. bool changed = false; // Step 1. Mutate the read region. - Array reads; + ffi::Array reads; for (const BufferRegion& read : op->reads) { if (buffer_map_.count(read->buffer)) { changed = true; @@ -227,7 +227,7 @@ class AutoPadder { } } // Step 2. Mutate the write region. - Array writes; + ffi::Array writes; for (const BufferRegion& write : op->writes) { if (buffer_map_.count(write->buffer)) { changed = true; @@ -238,7 +238,7 @@ class AutoPadder { } // Step 4. Mutate `match_buffers`. If an old buffer appears as a source of // MatchBufferRegion, the storage scope of the target buffer also needs to be set. - Array match_buffers; + ffi::Array match_buffers; for (const MatchBufferRegion& match_buffer : op->match_buffers) { if (buffer_map_.count(match_buffer->source->buffer)) { changed = true; @@ -262,10 +262,10 @@ class AutoPadder { block->match_buffers = std::move(match_buffers); return Stmt(block); } else { - return GetRef(op); + return ffi::GetRef(op); } } - const Map& buffer_map_; + const ffi::Map& buffer_map_; }; Rewriter rewriter(padded_buffer_map_); return rewriter(stmt); @@ -287,7 +287,7 @@ class AutoPadder { if (!success_) { return; } - int extent = var_range_[GetRef(op)]->extent.as()->value; + int extent = var_range_[ffi::GetRef(op)]->extent.as()->value; if (extent > 1) { stack_.push({{extent, 1}}); } else { @@ -396,7 +396,7 @@ class AutoPadder { } public: - explicit PatternCollector(const Map& var_range) : var_range_(var_range) {} + explicit PatternCollector(const ffi::Map& var_range) : var_range_(var_range) {} /*! * \brief Collect the iteration space for given indices. The iteration space is the possible @@ -409,9 +409,8 @@ class AutoPadder { * \return The iteration space. The first array represents dimensions, and the second array * represents the iteration space of one dimension */ - static std::vector> CollectIterationSpace(const Array& indices, - const Map& var_range, - int data_bits) { + static std::vector> CollectIterationSpace( + const ffi::Array& indices, const ffi::Map& var_range, int data_bits) { PatternCollector collector(var_range); std::vector> ret; for (int i = 0; i < static_cast(indices.size()); i++) { @@ -444,30 +443,30 @@ class AutoPadder { } std::stack> stack_; - const Map& var_range_; + const ffi::Map& var_range_; bool success_ = true; }; /*! A utility class for calling CollectIterationSpace to each buffer access*/ class IterSpaceAnalyzer : public StmtExprVisitor { public: - IterSpaceAnalyzer(const Map& substitute_map, AutoPadder* self, int data_bits, - const Map warp_thread_extent) + IterSpaceAnalyzer(const ffi::Map& substitute_map, AutoPadder* self, + int data_bits, const ffi::Map warp_thread_extent) : substitute_map_(substitute_map), self(self), data_bits_(data_bits), warp_thread_extent_(warp_thread_extent) {} private: - bool CheckVarContiguous(PrimExpr e, Var var, const Map& subst_map) { - PrimExpr e1 = Substitute(e, [var](const Var& v) -> Optional { + bool CheckVarContiguous(PrimExpr e, Var var, const ffi::Map& subst_map) { + PrimExpr e1 = Substitute(e, [var](const Var& v) -> ffi::Optional { if (v.same_as(var)) { return Integer(0); } else { return v; } }); - PrimExpr e2 = Substitute(e, [var](const Var& v) -> Optional { + PrimExpr e2 = Substitute(e, [var](const Var& v) -> ffi::Optional { if (v.same_as(var)) { return Integer(1); } else { @@ -508,7 +507,7 @@ class AutoPadder { void VisitStmt_(const BufferStoreNode* op) final { runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { - Array substitued_indices; + ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : op->indices) { substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); @@ -536,7 +535,7 @@ class AutoPadder { void VisitExpr_(const BufferLoadNode* op) final { runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { - Array substitued_indices; + ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : op->indices) { substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); @@ -572,13 +571,13 @@ class AutoPadder { runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { Region region = r->source->region; - Array indices; + ffi::Array indices; for (int i = 0; i < static_cast(region.size()); i++) { Var var("region" + std::to_string(i)); indices.push_back(region[i]->min + var); var_range_.Set(var, Range::FromMinExtent(0, region[i]->extent)); } - Array substitued_indices; + ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : indices) { substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); @@ -595,11 +594,11 @@ class AutoPadder { } } - Map substitute_map_; + ffi::Map substitute_map_; AutoPadder* self; int data_bits_; - Map warp_thread_extent_; - Map var_range_; + ffi::Map warp_thread_extent_; + ffi::Map var_range_; int vector_length_ = -1; Var vector_var; }; @@ -611,11 +610,12 @@ class AutoPadder { * \param data_bits The length of dtype in bits * \param thread_extent The extents of all thread binding loops */ - void AnalyzeSharedMemoryAccess(const Stmt& stmt, const Array& outer_loops, int data_bits, - const Map& thread_extent) { - Map warp_thread_extent; + void AnalyzeSharedMemoryAccess(const Stmt& stmt, const ffi::Array& outer_loops, + int data_bits, + const ffi::Map& thread_extent) { + ffi::Map warp_thread_extent; Integer prod = 1; - Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; + ffi::Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; arith::Analyzer analyzer; for (int i = 0; i < 3; i++) { Integer extent = thread_extent.Get(thread_tags[i]).value_or(1); @@ -628,7 +628,7 @@ class AutoPadder { prod *= extent; } } - Map substitute_map; + ffi::Map substitute_map; for (const For& loop : outer_loops) { substitute_map.Set(loop->loop_var, loop->min); } @@ -638,11 +638,11 @@ class AutoPadder { private: /*! \brief A map from the old buffers to the new padded buffers */ - Map padded_buffer_map_; + ffi::Map padded_buffer_map_; /*! \brief A map from each buffer to the iteration spaces of the accesses*/ std::unordered_map>>> iter_spaces_; /*! \brief A map from each buffer to their minimal padding size */ - Map padding_min_; + ffi::Map padding_min_; /*! \brief max padding size in relative to the original shape*/ const double max_pad_factor_ = 0.25; @@ -651,7 +651,8 @@ class AutoPadder { class AutoCopyMutator : public StmtExprMutator { public: - explicit AutoCopyMutator(Map thread_extent) : thread_extent_(thread_extent) {} + explicit AutoCopyMutator(ffi::Map thread_extent) + : thread_extent_(thread_extent) {} /** * \brief Replace old buffers with padded buffers in the stmt * \param stmt The stmt to rewrite @@ -708,16 +709,16 @@ class AutoCopyMutator : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) final { - outer_loops_.push_back(GetRef(op)); + outer_loops_.push_back(ffi::GetRef(op)); Stmt stmt = StmtMutator::VisitStmt_(op); outer_loops_.pop_back(); return stmt; } /*! \brief Thread extents collected. */ - Map thread_extent_; + ffi::Map thread_extent_; /*! \brief The outer loops during recursive visit */ - Array outer_loops_; + ffi::Array outer_loops_; /*! \brief Calculating optimal padding size */ AutoPadder padder; @@ -736,7 +737,7 @@ class AutoCopyMutator : public StmtExprMutator { */ class ThreadExtentCollector : public StmtVisitor { public: - static Map CollectThreadExtent(const Stmt& stmt) { + static ffi::Map CollectThreadExtent(const Stmt& stmt) { ThreadExtentCollector collector; collector(stmt); return collector.thread_extent_; @@ -744,7 +745,7 @@ class ThreadExtentCollector : public StmtVisitor { private: void VisitStmt_(const BlockNode* op) final { - if (Optional warp_execution = GetAnn(op, "warp_execution")) { + if (ffi::Optional warp_execution = GetAnn(op, "warp_execution")) { if (warp_execution.value()->value != 0) { thread_extent_.Set("threadIdx.x", Integer(32)); } @@ -754,14 +755,14 @@ class ThreadExtentCollector : public StmtVisitor { void VisitStmt_(const ForNode* op) final { if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) { if (const auto* extent = op->extent.as()) { - thread_extent_.Set(op->thread_binding.value()->thread_tag, GetRef(extent)); + thread_extent_.Set(op->thread_binding.value()->thread_tag, ffi::GetRef(extent)); } } StmtVisitor::VisitStmt_(op); } /*! \brief the map from thread tag to its extent */ - Map thread_extent_; + ffi::Map thread_extent_; }; namespace transform { @@ -777,10 +778,10 @@ Pass LowerAutoCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.LowerAutoCopy", LowerAutoCopy); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h index 46c9a97c527d..5751aa119e36 100644 --- a/src/tir/transforms/memhammer_rewrite_rule.h +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -37,9 +37,9 @@ namespace tir { /*! \brief The set containing all possible constraints of a data copy */ struct ConstraintSet { /*! \brief The extents of the thread binding loops */ - Map thread_extent; + ffi::Map thread_extent; /*! \brief The outer loops surrounding the data copy */ - Array outer_loops; + ffi::Array outer_loops; /*! \brief The read region of the data copy */ BufferRegion read_region; /*! \brief The write region of the data copy */ @@ -51,12 +51,12 @@ struct ConstraintSet { /*! \brief The vectorization length in bytes */ int vector_bytes = 1; - explicit ConstraintSet(Map thread_extent, // - Array outer_loops, // - BufferRegion read_region, // - BufferRegion write_region, // - int data_bits, // - const Map& ann) + explicit ConstraintSet(ffi::Map thread_extent, // + ffi::Array outer_loops, // + BufferRegion read_region, // + BufferRegion write_region, // + int data_bits, // + const ffi::Map& ann) : thread_extent(thread_extent), outer_loops(outer_loops), read_region(read_region), @@ -74,9 +74,9 @@ struct ConstraintSet { /*! \brief The set containing all possible outputs of a rewrite rule */ struct OutputSet { /*! \brief New buffers allocated after rewrite */ - Array alloc_buffer; + ffi::Array alloc_buffer; /*! \brief The minimal padding size of a buffer in base 2 logarithm */ - Map padding_min; + ffi::Map padding_min; }; /*! @@ -248,9 +248,9 @@ class WmmaToShared : public RewriteRule { * \return a pair. The first is the stmt after transformation. * The second is the SeqStmt that contains 2 stages (one original and another inserted). */ -std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, - Optional compute_location, - const Array& outer_loops, Buffer* alloc_buffer); +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::String storage_scope, + ffi::Optional compute_location, + const ffi::Array& outer_loops, Buffer* alloc_buffer); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index 5a0d0fa2105c..e69ac30366b1 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -28,7 +28,7 @@ namespace tir { * \return A pair. The first is the stmt after transformation. * The second is the compute location where we may add write cache. */ -std::pair> TileWmmaBlock(Stmt stmt) { +std::pair> TileWmmaBlock(Stmt stmt) { Stmt body = stmt; std::vector loops; while (const ForNode* loop = body.as()) { @@ -52,7 +52,7 @@ std::pair> TileWmmaBlock(Stmt stmt) { /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), }; body = Substitute(std::move(body), - Map{ + ffi::Map{ {loops[n - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]}, {loops[n - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]}, }); @@ -70,21 +70,23 @@ std::pair> TileWmmaBlock(Stmt stmt) { } For compute_location = Downcast(body); for (int i = n - 3; i >= 0; i--) { - body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), - loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(body); + body = new_loop; } return {body, compute_location}; } -Array RelaxIndices(const Array& indices, const Array& shape, - const Map& var_dom) { - Array int_set; +ffi::Array RelaxIndices(const ffi::Array& indices, + const ffi::Array& shape, + const ffi::Map& var_dom) { + ffi::Array int_set; int_set.reserve(indices.size()); for (auto& indice : indices) { int_set.push_back(arith::EvalSet(indice, var_dom)); } int ndim = int_set.size(); - Array region; + ffi::Array region; region.reserve(ndim); for (int i = 0; i < ndim; ++i) { region.push_back(int_set[i].CoverRange(Range::FromMinExtent(0, shape[i]))); @@ -110,7 +112,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { } int n = loops.size(); - Map var_dom{ + ffi::Map var_dom{ {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, }; @@ -141,8 +143,8 @@ Stmt RewriteWmmaLoad(Stmt stmt) { /*data_alignment=*/64, /*offset_factor=*/16, /*buffer_type=*/kDefault); - Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); - Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); Stmt wmma_body = BlockRealize( /*iter_values=*/{}, /*predicate=*/Bool(true), @@ -186,8 +188,9 @@ Stmt RewriteWmmaLoad(Stmt stmt) { }, /*annotations=*/{})); for (int i = n - 3; i >= 0; i--) { - wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, - std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(wmma_body); + wmma_body = new_loop; } return wmma_body; } @@ -209,7 +212,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { } int n = loops.size(); - Map var_dom{ + ffi::Map var_dom{ {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, }; @@ -249,8 +252,8 @@ Stmt RewriteWmmaStore(Stmt stmt) { /*offset_factor=*/16, /*buffer_type=*/kDefault); - Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); - Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); Stmt wmma_body = BlockRealize( /*iter_values=*/{}, // /*predicate=*/Bool(true), @@ -289,8 +292,9 @@ Stmt RewriteWmmaStore(Stmt stmt) { }, /*annotations=*/{})); for (int i = n - 3; i >= 0; i--) { - wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, - std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(wmma_body); + wmma_body = new_loop; } return wmma_body; } @@ -333,7 +337,7 @@ class WmmaToGlobalRewriter : public StmtExprMutator { Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - Optional compute_location{nullptr}; + ffi::Optional compute_location; std::tie(body, compute_location) = TileWmmaBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer; @@ -347,7 +351,7 @@ Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, return rewriter(body); } -std::pair> TileMmaToGlobalBlock(Stmt stmt) { +std::pair> TileMmaToGlobalBlock(Stmt stmt) { // i, j = sch.get_loops(block)[2:] // i_0, i_1 = sch.split(i, factors=[None, 8]) // j_0, j_1 = sch.split(j, factors=[None, 8]) @@ -376,7 +380,7 @@ std::pair> TileMmaToGlobalBlock(Stmt stmt) { /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), }; body = Substitute(std::move(body), - Map{ + ffi::Map{ {loops[n - 2]->loop_var, new_loop_vars[0] * 8 + new_loop_vars[2]}, {loops[n - 1]->loop_var, new_loop_vars[1] * 8 + new_loop_vars[3]}, }); @@ -394,8 +398,9 @@ std::pair> TileMmaToGlobalBlock(Stmt stmt) { } For compute_location = Downcast(body); for (int i = n - 3; i >= 0; i--) { - body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), - loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(body); + body = new_loop; } return {body, compute_location}; } @@ -418,7 +423,7 @@ Stmt RewriteMmaStore(Stmt stmt) { } int n = loops.size(); - Map var_dom{ + ffi::Map var_dom{ {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, }; @@ -468,8 +473,8 @@ Stmt RewriteMmaStore(Stmt stmt) { /*buffer_type=*/kDefault); // Step 3.2. Generate new r/w region - Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); - Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); // Step 3.3. Generate new inner loop body // for v in T.vectorized(2): @@ -483,21 +488,21 @@ Stmt RewriteMmaStore(Stmt stmt) { /*reads=*/{BufferRegion(src_buffer, read_region)}, /*writes=*/{BufferRegion(tgt_buffer, write_region)}, /*name_hint=*/"mma_store", - AttrStmt(/*node=*/IterVar( - /*dom=*/Range::FromMinExtent(0, 32), - /*var=*/tx, - /*iter_type=*/IterVarType::kThreadIndex, - /*thread_tag=*/"threadIdx.x"), - /*attr_key=*/"thread_extent", - /*value=*/Integer(32), - /*body=*/ - For(vec, 0, 2, ForKind::kVectorized, - /*body=*/ - BufferStore(new_tgt_buffer, - BufferLoad(new_src_buffer, - {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), - {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), - /*annotations=*/{})), + AttrStmt( + /*node=*/IterVar( + /*dom=*/Range::FromMinExtent(0, 32), + /*var=*/tx, + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/"threadIdx.x"), + /*attr_key=*/"thread_extent", + /*value=*/Integer(32), + /*body=*/ + For(vec, 0, 2, ForKind::kVectorized, + /*body=*/ + BufferStore( + new_tgt_buffer, + BufferLoad(new_src_buffer, {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), + {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}))), /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/ @@ -509,8 +514,9 @@ Stmt RewriteMmaStore(Stmt stmt) { // Step 3.4. wrap outer loops for (int i = n - 3; i >= 0; i--) { - mma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, - std::move(mma_body), loops[i]->thread_binding, loops[i]->annotations); + auto new_loop = ffi::GetRef(loops[i]); + new_loop.CopyOnWrite()->body = std::move(mma_body); + mma_body = new_loop; } return mma_body; } @@ -542,7 +548,7 @@ class MmaToGlobalRewriter : public StmtExprMutator { Stmt MmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - Optional compute_location{nullptr}; + ffi::Optional compute_location; std::tie(body, compute_location) = TileMmaToGlobalBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer; diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 63342bd2ec8d..4b6e768e8d6f 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -125,7 +125,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -156,7 +156,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -168,9 +168,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { for (const auto& index : load->indices) { this->VisitExpr(index); } - } else { - StmtExprVisitor::VisitExpr_(op); } + StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* buf) final { @@ -178,7 +177,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -215,6 +214,10 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); + } else if (op->attr_key == "kWarpSpecializationScope") { + IfThenElse body = Downcast(op->body); + this->VisitStmt(body->then_case); + this->VisitStmt(body->else_case.value()); } else { StmtExprVisitor::VisitStmt_(op); } @@ -352,8 +355,8 @@ class SharedMemoryRewriter : public StmtExprMutator { << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " << "FlattenBuffer"; - Array indices = {node->indices[0] + - this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; + ffi::Array indices = { + node->indices[0] + this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; auto writer = node.CopyOnWrite(); writer->buffer = GetUpdatedBuffer(node->buffer); @@ -397,7 +400,7 @@ class SharedMemoryRewriter : public StmtExprMutator { PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); return Call(op->dtype, op->op, - {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); + {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}, op->annotations); } else if (op->op.same_as(builtin::ptx_cp_async())) { ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); DataType dtype = op->dtype; @@ -414,11 +417,11 @@ class SharedMemoryRewriter : public StmtExprMutator { if (op->args.size() == 5) return Call(dtype, op->op, {merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)), - op->args[2], op->args[3], op->args[4]}); + op->args[2], op->args[3], op->args[4]}, op->annotations); else return Call(dtype, op->op, {merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)), - op->args[2], op->args[3], op->args[4], op->args[5]}); + op->args[2], op->args[3], op->args[4], op->args[5]}, op->annotations); } else { return StmtExprMutator::VisitExpr_(op); } @@ -696,10 +699,10 @@ Pass MergeSharedMemoryAllocations() { return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.MergeSharedMemoryAllocations", MergeSharedMemoryAllocations); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index b09a4dc17b26..3ad05337b591 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -77,7 +77,7 @@ class DataTypeVisitor final : public StmtExprVisitor { explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {} void VisitExpr(const PrimExpr& e) { - if (e.dtype().is_int()) { + if (e.dtype().is_int() || e.dtype().is_uint()) { int bits = max_bits_; if (bound_.find(e) == bound_.end()) { analyzer_.const_int_bound(e, &bound_); @@ -212,7 +212,7 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { Stmt operator()(Stmt s) { visitor_(s); for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) { - PrimExpr e = GetRef(i->first); + PrimExpr e = ffi::GetRef(i->first); if (e.dtype() == i->second) { i = visitor_.vmap.erase(i); } else { @@ -268,7 +268,7 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ if (a.dtype() != b.dtype()) { \ bool is_enabled = is_enabled_; \ @@ -321,10 +321,10 @@ Pass NarrowDataType(int target_bits) { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.NarrowDataType", NarrowDataType); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index bcd5f53dd4f4..779076a89f6f 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -53,7 +53,7 @@ class CollectManagedAllocations : public StmtExprVisitor { /*! \brief Collect the allocate buffer order. */ class BufferAllocateOrderCollector : public StmtExprVisitor { public: - static Array Collect(const PrimFunc& func) { + static ffi::Array Collect(const PrimFunc& func) { BufferAllocateOrderCollector collector; for (const auto& kv : func->buffer_map) { collector.buffer_alloc_recorder_.push_back(kv.second); @@ -98,16 +98,16 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { } /*! \brief The buffer allocated order recorder. */ - Array buffer_alloc_recorder_; + ffi::Array buffer_alloc_recorder_; }; class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc& func) { - Map> buffer_lca = DetectBufferAccessLCA(func); + ffi::Map> buffer_lca = DetectBufferAccessLCA(func); // The buffer_alloc_recorder Array is used to keep the buffer allocation order // since the buffer_lca Map is unordered. - Array buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func); + ffi::Array buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func); std::unordered_set arg_buffer_vars; CollectManagedAllocations collector; collector(func->body); @@ -145,7 +145,7 @@ class BufferAllocationLocator : public StmtExprMutator { } auto node = Downcast(StmtMutator::VisitStmt_(op)); - Array new_block_alloc_bufs; + ffi::Array new_block_alloc_bufs; for (const Buffer& buf : it->second) { if (managed_allocations_.count(buf->data.get())) { buffer_data_to_buffer_.erase(buf->data); @@ -162,7 +162,7 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { ICHECK(!op->init.defined()); - Array alloc_buffers; + ffi::Array alloc_buffers; auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { alloc_buffers = it->second; @@ -206,7 +206,7 @@ class BufferAllocationLocator : public StmtExprMutator { throw; } - Stmt InjectOpaqueBlock(Stmt body, const Array& alloc_buffers) { + Stmt InjectOpaqueBlock(Stmt body, const ffi::Array& alloc_buffers) { ICHECK(!alloc_buffers.empty()); Block opaque_block(/*iter_vars=*/{}, /*reads=*/{}, @@ -216,7 +216,7 @@ class BufferAllocationLocator : public StmtExprMutator { /*init=*/std::nullopt, /*alloc_buffers=*/alloc_buffers); ObjectPtr n = CopyOnWrite(opaque_block.get()); - Array> access = + ffi::Array> access = GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); n->reads = access[0]; n->writes = access[1]; @@ -224,8 +224,9 @@ class BufferAllocationLocator : public StmtExprMutator { return realize; } - Array RemoveRedundantBufferRegion(const Array& region) const { - Array result; + ffi::Array RemoveRedundantBufferRegion( + const ffi::Array& region) const { + ffi::Array result; for (const BufferRegion& buffer_region : region) { if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { result.push_back(buffer_region); @@ -235,9 +236,9 @@ class BufferAllocationLocator : public StmtExprMutator { } /*! \brief The map from stmt to the buffers to be allocated under it. */ - std::unordered_map> alloc_buffers_; + std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ - Map buffer_data_to_buffer_; + ffi::Map buffer_data_to_buffer_; /*! \brief Buffers that are allocated within a BlockNode, and may be moved. */ std::unordered_set managed_allocations_; }; @@ -258,11 +259,11 @@ Pass PlanAndUpdateBufferAllocationLocation() { return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.PlanAndUpdateBufferAllocationLocation", PlanAndUpdateBufferAllocationLocation); -}); +} } // namespace transform diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index b1f3476eab73..136a92abde31 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -36,7 +36,7 @@ transform::Pass AnnotateEntryFunc() { auto [gvar, base_func] = *mod->functions.begin(); if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { if (auto ptr = base_func.as()) { - mod->Update(gvar, WithAttr(GetRef(ptr), tir::attr::kIsEntryFunc, true)); + mod->Update(gvar, WithAttr(ffi::GetRef(ptr), tir::attr::kIsEntryFunc, true)); } } return mod; @@ -47,11 +47,11 @@ transform::Pass AnnotateEntryFunc() { bool has_external_non_primfuncs = false; IRModule with_annotations; for (const auto& [gvar, base_func] : mod->functions) { - bool is_external = base_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_external = base_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_external) { if (auto ptr = base_func.as()) { - with_annotations->Add(gvar, - WithAttr(GetRef(ptr), tir::attr::kIsEntryFunc, true)); + with_annotations->Add( + gvar, WithAttr(ffi::GetRef(ptr), tir::attr::kIsEntryFunc, true)); } else { has_external_non_primfuncs = true; } @@ -79,12 +79,12 @@ transform::Pass Filter(ffi::TypedFunction fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.transform.AnnotateEntryFunc", AnnotateEntryFunc) .def("tir.transform.Filter", Filter); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/profile_instrumentation.cc b/src/tir/transforms/profile_instrumentation.cc index d7763ee543b8..513f0d730e8c 100644 --- a/src/tir/transforms/profile_instrumentation.cc +++ b/src/tir/transforms/profile_instrumentation.cc @@ -284,10 +284,10 @@ Pass InstrumentProfileIntrinsics() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstrumentProfileIntrinsics", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InstrumentProfileIntrinsics", InstrumentProfileIntrinsics); -}); +} } // namespace transform diff --git a/src/tir/transforms/reduce_branching_through_overcompute.cc b/src/tir/transforms/reduce_branching_through_overcompute.cc index f6bfd8fa4273..9a03b143d0f9 100644 --- a/src/tir/transforms/reduce_branching_through_overcompute.cc +++ b/src/tir/transforms/reduce_branching_through_overcompute.cc @@ -51,18 +51,17 @@ struct ReduceBranchingThroughOvercomputeConfigNode "to statically prove that overcompute is valid.", refl::DefaultValue(false)); } - - static constexpr const char* _type_key = "tir.transform.ReduceBranchingThroughOvercomputeConfig"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ReduceBranchingThroughOvercomputeConfigNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.ReduceBranchingThroughOvercomputeConfig", + ReduceBranchingThroughOvercomputeConfigNode, BaseAttrsNode); }; class ReduceBranchingThroughOvercomputeConfig : public Attrs { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReduceBranchingThroughOvercomputeConfig, Attrs, - ReduceBranchingThroughOvercomputeConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ReduceBranchingThroughOvercomputeConfig, Attrs, + ReduceBranchingThroughOvercomputeConfigNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ ReduceBranchingThroughOvercomputeConfigNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ReduceBranchingThroughOvercomputeConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("tir.ReduceBranchingThroughOvercompute", ReduceBranchingThroughOvercomputeConfig); @@ -176,11 +175,11 @@ Pass ReduceBranchingThroughOvercompute() { return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ReduceBranchingThroughOvercompute", ReduceBranchingThroughOvercompute); -}); +} } // namespace transform diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index 14ad70122798..c7184e07a036 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -70,13 +70,13 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; -PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc func, ffi::Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { tmap[kv.first] = kv.second; } - if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; auto launch_params = opt.value(); // replace the thread axis attribute @@ -97,17 +97,17 @@ PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { namespace transform { -Pass RemapThreadAxis(Map thread_map) { +Pass RemapThreadAxis(ffi::Map thread_map) { auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { return RemapThreadAxis(std::move(f), thread_map); }; return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.RemapThreadAxis", RemapThreadAxis); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remove_assume.cc b/src/tir/transforms/remove_assume.cc index 95d55ed0a3f5..6475befa1cf8 100644 --- a/src/tir/transforms/remove_assume.cc +++ b/src/tir/transforms/remove_assume.cc @@ -62,10 +62,10 @@ Pass RemoveAssume() { return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.RemoveAssume", RemoveAssume); -}); +} } // namespace transform diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 9db6f9f32808..6cc80535085f 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -59,17 +59,16 @@ struct RemoveNoOpConfigNode : public AttrsNodeReflAdapter "For use in debug and testing purposes.", refl::DefaultValue(0)); } - - static constexpr const char* _type_key = "tir.transform.RemoveNoOpConfig"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(RemoveNoOpConfigNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.RemoveNoOpConfig", RemoveNoOpConfigNode, + BaseAttrsNode); }; class RemoveNoOpConfig : public Attrs { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ RemoveNoOpConfigNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { RemoveNoOpConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("tir.RemoveNoOp", RemoveNoOpConfig); @@ -181,20 +180,20 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const EvaluateNode* op) final { if (HasSideEffect(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Evaluate(0); } } Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore store = GetRef(op); + BufferStore store = ffi::GetRef(op); // Helper function that returns a statement containing only the // side effects of evaluating this BufferStore, but not the store // itself. auto only_side_effects = [&]() { - Array statements; + ffi::Array statements; statements.push_back(MakeEvaluate(store->value)); for (const auto& index : store->indices) { statements.push_back(MakeEvaluate(index)); @@ -204,7 +203,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { if (touch_pattern_.has_value()) { // A write that is later overwritten is a no-op. - Stmt context = context_ ? GetRef(context_) : store; + Stmt context = context_ ? ffi::GetRef(context_) : store; if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) { touch_pattern_->RemoveStore(store); return only_side_effects(); @@ -217,7 +216,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; if (touch_pattern_.has_value()) { - Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); + Stmt context_arg = context_ ? ffi::GetRef(context_) : Stmt(store); stores_existing_value = touch_pattern_->SimplifyInContext(stores_existing_value, context_arg, analyzer_); } else { @@ -257,7 +256,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { } private: - bool ArrayValueEqual(const Array& a, const Array& b) { + bool ArrayValueEqual(const ffi::Array& a, const ffi::Array& b) { if (a.size() != b.size()) { return false; } @@ -280,8 +279,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { return Evaluate(0); } } - Stmt MakeEvaluate(const Array& values) { - Array stmts; + Stmt MakeEvaluate(const ffi::Array& values) { + ffi::Array stmts; for (PrimExpr e : values) { if (SideEffect(e) > CallEffectKind::kReadState) { stmts.push_back(Evaluate(e)); @@ -333,10 +332,10 @@ Pass RemoveNoOp() { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.RemoveNoOp", RemoveNoOp); -}); +} } // namespace transform diff --git a/src/tir/transforms/remove_store_undef.cc b/src/tir/transforms/remove_store_undef.cc index 62b4391ef336..93cdd4ed145a 100644 --- a/src/tir/transforms/remove_store_undef.cc +++ b/src/tir/transforms/remove_store_undef.cc @@ -172,10 +172,10 @@ Pass RemoveStoreUndef() { "tir.RemoveStoreUndef"); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.RemoveStoreUndef", RemoveStoreUndef); -}); +} } // namespace transform diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 3c1e12bc3af9..5b2b5704c5c9 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -35,8 +35,9 @@ namespace tir { class RemoveLayoutRewriteBlock : public StmtMutator { public: - static std::tuple, std::unordered_map, - std::unordered_map>> + static std::tuple, + std::unordered_map, + std::unordered_map>> Rewrite(PrimFunc f) { RemoveLayoutRewriteBlock rewriter; @@ -54,7 +55,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { if (it == block->annotations.end() || !is_one(Downcast((*it).second))) { // The block is not a weight layout block // Remove allocates if needed - Array alloc_buffers; + ffi::Array alloc_buffers; for (const Buffer& buffer : block->alloc_buffers) { if (!rewritten_buffers_.count(buffer)) { alloc_buffers.push_back(buffer); @@ -91,7 +92,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { n->reads = {}; n->writes = {}; - Array load_indices; + ffi::Array load_indices; for (auto ind : load->indices) { ICHECK(ind->IsInstance()); load_indices.push_back(Downcast(ind)); @@ -105,14 +106,14 @@ class RemoveLayoutRewriteBlock : public StmtMutator { private: /*! \brief The buffer map from original layout buffer to rewritten buffer */ - Map buf_map_; + ffi::Map buf_map_; /*! \brief The buffer map from original layout buffer to rewritten buffer */ std::unordered_set rewritten_buffers_; /*! \brief Maps a buffer load to an index map associated with the load / store in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ - std::unordered_map> buffer_var_to_rewritten_shape_; + std::unordered_map> buffer_var_to_rewritten_shape_; }; // After RemoveLayoutRewriteBlock, the body of a compute update block references a @@ -149,18 +150,18 @@ class AllocateConstRewrite : public StmtExprMutator { AllocateConstRewrite( const BufferVarMap& buffer_var_map, const std::unordered_map& buffer_var_to_index_map, - const std::unordered_map>& buffer_var_to_rewritten_shape, - bool skip_ndarray_rewrite) + const std::unordered_map>& buffer_var_to_rewritten_shape, + bool skip_tensor_rewrite) : buffer_var_map_(buffer_var_map), buffer_var_to_index_map_(buffer_var_to_index_map), buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape), - skip_ndarray_rewrite_(skip_ndarray_rewrite) {} + skip_tensor_rewrite_(skip_tensor_rewrite) {} private: Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(StmtMutator::VisitStmt_(op)); auto n = CopyOnWrite(block.get()); - Array new_reads; + ffi::Array new_reads; for (auto read_region : op->reads) { if (auto it = new_load_buf_.find(read_region->buffer->data.get()); it != new_load_buf_.end()) { @@ -178,13 +179,13 @@ class AllocateConstRewrite : public StmtExprMutator { it != buffer_var_to_index_map_.end()) { ICHECK(buffer_var_to_rewritten_shape_.count(alloc->buffer_var.get())); auto new_body = StmtMutator::VisitStmt(alloc->body); - auto rewritten_ndarray = RewriteNDArray( + auto rewritten_tensor = RewriteTensor( alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]); - Array rewritten_extents; - for (auto s : rewritten_ndarray.Shape()) { + ffi::Array rewritten_extents; + for (auto s : rewritten_tensor.Shape()) { rewritten_extents.push_back(PrimExpr(static_cast(s))); } - return AllocateConst(alloc->buffer_var, alloc->dtype, rewritten_extents, rewritten_ndarray, + return AllocateConst(alloc->buffer_var, alloc->dtype, rewritten_extents, rewritten_tensor, new_body, alloc->annotations, alloc->span); } return StmtMutator::VisitStmt_(alloc); @@ -193,18 +194,18 @@ class AllocateConstRewrite : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { if (auto it = buffer_var_map_.find(op->buffer->data.get()); it != buffer_var_map_.end()) { auto new_buffer = - Buffer(GetRef(it->second), op->buffer->dtype, op->buffer->shape, op->buffer->strides, - op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, - op->buffer->offset_factor, op->buffer->buffer_type); + Buffer(ffi::GetRef(it->second), op->buffer->dtype, op->buffer->shape, + op->buffer->strides, op->buffer->elem_offset, it->second->name_hint, + op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); new_load_buf_[op->buffer->data.get()] = new_buffer; return BufferLoad(new_buffer, op->indices, op->predicate); } return ExprMutator::VisitExpr_(op); } - runtime::NDArray RewriteNDArray(runtime::NDArray src, const IndexMap& index_map, - const Array& dst_shape) { - if (skip_ndarray_rewrite_) { + runtime::Tensor RewriteTensor(runtime::Tensor src, const IndexMap& index_map, + const ffi::Array& dst_shape) { + if (skip_tensor_rewrite_) { // Only the shape of the destination array needs to be correct. std::vector dst_shape_int; for (auto s : dst_shape) { @@ -213,7 +214,7 @@ class AllocateConstRewrite : public StmtExprMutator { } return src.CreateView(dst_shape_int, src.DataType()); } else { - return index_map->MapNDArray(src); + return index_map->MapTensor(src); } } @@ -223,11 +224,11 @@ class AllocateConstRewrite : public StmtExprMutator { in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ - std::unordered_map> buffer_var_to_rewritten_shape_; + std::unordered_map> buffer_var_to_rewritten_shape_; /*! \brief Maps load buffer variables to newly created buffers */ std::unordered_map new_load_buf_; - /*! \brief Whether or not to skip rewriting of NDArray contents */ - bool skip_ndarray_rewrite_; + /*! \brief Whether or not to skip rewriting of Tensor contents */ + bool skip_tensor_rewrite_; }; class CollectAllocateConstBufferVars : public StmtVisitor { @@ -242,7 +243,7 @@ class CollectAllocateConstBufferVars : public StmtVisitor { class WeightLayoutRewriteBlockRemover : public StmtMutator { public: - static PrimFunc Remove(PrimFunc f, bool skip_ndarray_rewrite) { + static PrimFunc Remove(PrimFunc f, bool skip_tensor_rewrite) { CollectAllocateConstBufferVars collector; collector(f->body); @@ -260,10 +261,10 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator { PrimFuncNode* n = f_.CopyOnWrite(); AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map, - buffer_var_to_rewritten_shape, skip_ndarray_rewrite); + buffer_var_to_rewritten_shape, skip_tensor_rewrite); n->body = rewriter(std::move(n->body)); - Map buffer_map; + ffi::Map buffer_map; for (const auto& [param, buffer] : f_->buffer_map) { auto it = buf_map.find(buffer); if (it != buf_map.end()) { @@ -279,18 +280,18 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator { namespace transform { -Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) { - auto pass_func = [skip_ndarray_rewrite](PrimFunc f, IRModule m, PassContext ctx) { - return WeightLayoutRewriteBlockRemover::Remove(std::move(f), skip_ndarray_rewrite); +Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite) { + auto pass_func = [skip_tensor_rewrite](PrimFunc f, IRModule m, PassContext ctx) { + return WeightLayoutRewriteBlockRemover::Remove(std::move(f), skip_tensor_rewrite); }; return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.RemoveWeightLayoutRewriteBlock", RemoveWeightLayoutRewriteBlock); -}); +} } // namespace transform diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index 167453c04fe0..69002a9e1d78 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -37,7 +37,7 @@ namespace tir { Stmt stmt = StmtExprMutator::VisitStmt_(op); \ op = stmt.as(); \ ICHECK(op != nullptr); \ - auto n = make_object(*op); \ + auto n = ffi::make_object(*op); \ n->FIELD = std::move(new_var); \ return Stmt(n); \ } @@ -47,7 +47,7 @@ class RenewDefMutator : public StmtExprMutator { static PrimFunc Transform(const PrimFunc& func) { RenewDefMutator generator; // Redefine params - Array params; + ffi::Array params; for (const auto& param : func->params) { params.push_back(generator.ReDefineVar(param)); } @@ -56,8 +56,8 @@ class RenewDefMutator : public StmtExprMutator { const Buffer& buffer = func->buffer_map.at(param); for (const PrimExpr& e : buffer->shape) { if (const auto* v = e.as()) { - if (generator.remap_.count(GetRef(v)) == 0) { - generator.ReDefineVar(GetRef(v)); + if (generator.remap_.count(ffi::GetRef(v)) == 0) { + generator.ReDefineVar(ffi::GetRef(v)); } } } @@ -65,7 +65,7 @@ class RenewDefMutator : public StmtExprMutator { } // Redefine buffers in order // TODO(Siyuan Feng): checking var is used after define - Map buffer_map; + ffi::Map buffer_map; for (const auto& param : func->params) { if (param->dtype.is_handle()) { const Buffer& buffer = func->buffer_map.at(param); @@ -105,32 +105,32 @@ class RenewDefMutator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { // Step 0. Re-define Itervars - Array iter_vars = + ffi::Array iter_vars = op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); // Step 1. Re-define buffers allocate under the block - Array alloc_buffers = op->alloc_buffers.Map( + ffi::Array alloc_buffers = op->alloc_buffers.Map( std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true)); // Step 2. Re-define match_buffers - Array match_buffers = op->match_buffers.Map( + ffi::Array match_buffers = op->match_buffers.Map( std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); // Step 3. Visit body - Optional init = std::nullopt; + ffi::Optional init = std::nullopt; if (op->init.defined()) { init = this->VisitStmt(op->init.value()); } Stmt body = this->VisitStmt(op->body); // Step 4. Revisit access region - Array reads = + ffi::Array reads = op->reads.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); - Array writes = + ffi::Array writes = op->writes.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); // Step 5. Regenerate block. Since the defs are changed, we need to create a new block - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->iter_vars = std::move(iter_vars); n->alloc_buffers = std::move(alloc_buffers); n->match_buffers = std::move(match_buffers); @@ -150,7 +150,7 @@ class RenewDefMutator : public StmtExprMutator { if (buffer.same_as(op->buffer)) { return stmt; } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->buffer = std::move(buffer); return BufferStore(n); } @@ -164,7 +164,7 @@ class RenewDefMutator : public StmtExprMutator { if (buffer.same_as(op->buffer)) { return expr; } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->buffer = std::move(buffer); return BufferLoad(n); } @@ -172,7 +172,7 @@ class RenewDefMutator : public StmtExprMutator { private: Var ReDefineVar(const Var& var) { - Var new_var = Var(make_object(*var.get())); + Var new_var = Var(ffi::make_object(*var.get())); this->AddDefRemap(var, new_var); return new_var; } @@ -204,13 +204,13 @@ class RenewDefMutator : public StmtExprMutator { // update data Var data = Downcast(redefine_if_is_var(buffer->data)); // update shape - Array shape = buffer->shape.Map(redefine_if_is_var); + ffi::Array shape = buffer->shape.Map(redefine_if_is_var); // update strides - Array strides = buffer->strides.Map(redefine_if_is_var); + ffi::Array strides = buffer->strides.Map(redefine_if_is_var); // update elem_offset PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset); - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->data = std::move(data); n->shape = std::move(shape); n->strides = std::move(strides); @@ -243,13 +243,13 @@ class RenewDefMutator : public StmtExprMutator { return Downcast((*it).second); } Var data = Downcast(VisitExpr(buffer->data)); - Array shape = + ffi::Array shape = buffer->shape.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); - Array strides = + ffi::Array strides = buffer->strides.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->data = std::move(data); n->shape = std::move(shape); n->strides = std::move(strides); @@ -277,7 +277,7 @@ class RenewDefMutator : public StmtExprMutator { BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { Buffer buffer = VisitBuffer(buffer_region->buffer); - Array region = buffer_region->region.Map( + ffi::Array region = buffer_region->region.Map( std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1)); if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { return buffer_region; @@ -286,15 +286,15 @@ class RenewDefMutator : public StmtExprMutator { } } - Map remap_; + ffi::Map remap_; }; PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.RenewDefs", RenewDefs); -}); +} } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc index bcb143f3323e..04dbcca510e1 100644 --- a/src/tir/transforms/renormalize_split_pattern.cc +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -206,10 +206,10 @@ Pass RenormalizeSplitPattern() { return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.RenormalizeSplitPattern", RenormalizeSplitPattern); -}); +} } // namespace transform diff --git a/src/tir/transforms/replace_global_vars.cc b/src/tir/transforms/replace_global_vars.cc index 3e8437063775..b16926056b7d 100644 --- a/src/tir/transforms/replace_global_vars.cc +++ b/src/tir/transforms/replace_global_vars.cc @@ -35,8 +35,8 @@ namespace { using tvm::transform::GlobalVarReplacer; struct Mutator : StmtExprMutator { - Map replacements; - explicit Mutator(Map replacements) : replacements(replacements) {} + ffi::Map replacements; + explicit Mutator(ffi::Map replacements) : replacements(replacements) {} PrimExpr VisitExpr_(const CallNode* node) override { auto call = Downcast(StmtExprMutator::VisitExpr_(node)); @@ -53,7 +53,7 @@ struct Mutator : StmtExprMutator { TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& obj, - Map replacements) -> BaseFunc { + ffi::Map replacements) -> BaseFunc { Mutator mutator(replacements); auto func = Downcast(obj); auto new_body = mutator(func->body); @@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) // If the function is externally exposed, and is being replaced // by a GlobalVar with a new name, then the function's // kGlobalSymbol must be updated to match. - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { auto name = opt.value(); for (const auto& [before, after] : replacements) { if (before->name_hint == name) { diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 1d311f9bac13..3dfbcb9967d5 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -140,10 +140,10 @@ Pass RewriteUnsafeSelect() { return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.RewriteUnsafeSelect", RewriteUnsafeSelect); -}); +} } // namespace transform diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 2b087c924f58..a3365db9b700 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -77,9 +77,8 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { "branch", refl::DefaultValue(false)); } - - static constexpr const char* _type_key = "tir.transform.SimplifyConfig"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.SimplifyConfig", SimplifyConfigNode, + BaseAttrsNode); RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; @@ -115,7 +114,7 @@ std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt& void VisitBuffer(const Buffer& buf) { // Collect variables that should remain defined - VarUseDefAnalyzer usage(Array{}); + VarUseDefAnalyzer usage(ffi::Array{}); usage(buf->data); for (const auto& dim : buf->shape) { usage(dim); @@ -140,17 +139,17 @@ std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt& class SimplifyConfig : public Attrs { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs, SimplifyConfigNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, - Optional config_opt = std::nullopt) { + ffi::Optional config_opt = std::nullopt) { auto config = config_opt.value_or(AttrsWithDefaultValues()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); @@ -194,7 +193,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt VisitStmt(const Stmt& stmt) override { - Optional cache = this->current_stmt_; + ffi::Optional cache = this->current_stmt_; this->current_stmt_ = stmt; Stmt output = Parent::VisitStmt(stmt); this->current_stmt_ = std::move(cache); @@ -249,7 +248,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (can_inline && !used_in_buffer_def) { return body; } else if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); @@ -259,7 +258,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } Stmt VisitStmt_(const IfThenElseNode* op) override { - if (Optional cond = ProveCondition(op->condition)) { + if (ffi::Optional cond = ProveCondition(op->condition)) { if (cond.value()->value) { return this->VisitStmt(op->then_case); } else if (op->else_case) { @@ -274,7 +273,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode* op) override { if (op->op.same_as(builtin::if_then_else())) { - if (Optional cond = ProveCondition(op->args[0])) { + if (ffi::Optional cond = ProveCondition(op->args[0])) { if (cond.value()->value) { return this->VisitExpr(op->args[1]); } else { @@ -303,7 +302,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } private: - bool ArrayDeepEqual(const Array& lhs, const Array& rhs) { + bool ArrayDeepEqual(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -320,7 +319,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { * Uses more aggressive optimization, such as performing additional * inlining and tracking known buffer values. */ - Optional ProveCondition(PrimExpr condition) const { + ffi::Optional ProveCondition(PrimExpr condition) const { condition = Substitute(condition, non_inlined_bindings_); if (config_->propagate_knowns_to_prove_conditional) { ICHECK(touch_pattern_.has_value()); @@ -338,8 +337,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { SimplifyConfig config_; std::optional touch_pattern_; - Map non_inlined_bindings_; - Optional current_stmt_{std::nullopt}; + ffi::Map non_inlined_bindings_; + ffi::Optional current_stmt_{std::nullopt}; std::unordered_set used_in_buffer_def_; }; @@ -363,10 +362,10 @@ Pass Simplify() { return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.Simplify", Simplify); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index 6a9e62cd1ec7..b2c473c97c96 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -48,10 +48,10 @@ Pass SkipAssert() { return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.SkipAssert", SkipAssert); -}); +} } // namespace transform diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 796514e02762..130cc177f0b1 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -53,7 +53,7 @@ class HostDeviceSplitter : public StmtMutator { private: Stmt SplitDeviceFunc(Stmt body, Target device_target) { - auto [params, buffers_to_declare] = [&]() -> std::tuple, Array> { + auto [params, buffers_to_declare] = [&]() -> std::tuple, ffi::Array> { VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true); use_def(body); @@ -98,7 +98,7 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); - Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); + ffi::Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); if (can_propagate_errors) { Var kernel_error_code("kernel_error_code", success->dtype); @@ -137,14 +137,14 @@ Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { GlobalVarSupply global_var_supply(mod); - IRModule device_mod = IRModule(Map({})); - IRModule updates = IRModule(Map({})); + IRModule device_mod = IRModule(ffi::Map({})); + IRModule updates = IRModule(ffi::Map({})); for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { PrimFunc func = opt.value(); - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); auto kernel_name = name_prefix + "_kernel"; auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar { @@ -166,10 +166,10 @@ Pass SplitHostDevice() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.SplitHostDevice", SplitHostDevice); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 8c7a7035defa..2a38e64cc7e2 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -171,7 +171,7 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { for (AccessEntry& e : s.access) { if (e.buffer.defined()) { ICHECK(e.touched.size()); - Array new_touched; + ffi::Array new_touched; for (const auto& touched : e.touched) { new_touched.push_back(arith::EvalSet(touched, relax_map)); } @@ -250,7 +250,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(GetRef(buffer)); + StorageScope scope = GetScope(ffi::GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index a0e03b35cdaa..10b26f7c2ab2 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -56,7 +56,7 @@ class StorageAccessVisitor : public StmtExprVisitor { /*! \brief An access entry */ struct AccessEntry { /*! \brief The thread index that access this entry */ - Array threads; + ffi::Array threads; /*! \brief The buffer variable, if any */ Var buffer = NullValue(); /*! \brief The access data type */ @@ -65,7 +65,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * * Has one IntSet for each index in the buffer being accessed. */ - Array touched; + ffi::Array touched; /*! \brief The type of access */ AccessType type; /*! \brief The storage scope */ @@ -98,7 +98,7 @@ class StorageAccessVisitor : public StmtExprVisitor { /*! \return whether we are in device environment. */ bool in_device_env() const { return in_device_env_; } /*! \return environment threads */ - const Array& env_threads() const { return env_threads_; } + const ffi::Array& env_threads() const { return env_threads_; } /*! * \brief Whether we need analyze the buffer in current scope. * \param buffer The buffer to be checked @@ -138,7 +138,7 @@ class StorageAccessVisitor : public StmtExprVisitor { // the current free stmt entry. StmtEntry curr_stmt_; // The involving threads - Array env_threads_; + ffi::Array env_threads_; }; } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 7112f62a1088..151f29e5f36d 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -406,7 +406,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (it != alloc_map_.end()) { Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); - Array indices = node->indices; + ffi::Array indices = node->indices; indices.Set(indices.size() - 1, RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second)); @@ -453,7 +453,7 @@ class StoragePlanRewriter : public StmtExprMutator { } return it->second->alloc_var; } else { - return GetRef(op); + return ffi::GetRef(op); } } PrimExpr VisitExpr_(const CallNode* op) final { @@ -473,7 +473,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}); + return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, op->annotations); } else { return StmtExprMutator::VisitExpr_(op); } @@ -510,7 +510,7 @@ class StoragePlanRewriter : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), - op->thread_binding, op->annotations); + op->thread_binding, op->annotations, op->step); } else { return StmtExprMutator::VisitStmt_(op); } @@ -840,7 +840,7 @@ class StoragePlanRewriter : public StmtExprMutator { ICHECK(alloc_info.count(var)); const AllocEntry& entry = alloc_info.at(var); const AllocateNode* alloc = entry.alloc; - auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); + auto storage_scope = StorageScope::Create(GetPtrStorageScope(ffi::GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -1145,7 +1145,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * missing a type annotation, assume that it has the same underlying * type as it is later accessed, with scalar element types. */ - VectorTypeAccessChecker(const Array& params, const Map& buffer_map, + VectorTypeAccessChecker(const ffi::Array& params, + const ffi::Map& buffer_map, bool allow_untyped_pointers = false, bool detect_scalar_read_patterns = true) : allow_untyped_pointers_(allow_untyped_pointers), @@ -1196,7 +1197,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitStmt_(const AllocateNode* op) final { - const Array& extents = op->extents; + const ffi::Array& extents = op->extents; PrimExpr extent = extents[extents.size() - 1]; OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); @@ -1204,7 +1205,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitStmt_(const AllocateConstNode* op) final { - const Array& extents = op->extents; + const ffi::Array& extents = op->extents; PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue(); OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateConstNode); @@ -1271,8 +1272,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * * @param is_buffer_load Whether the access is BufferLoad */ - void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices, - bool is_buffer_load) { + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, + const ffi::Array& indices, bool is_buffer_load) { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; @@ -1471,7 +1472,7 @@ class VectorTypeRewriter : public StmtExprMutator { } const auto& info = it->second; - Array indices = node->indices; + ffi::Array indices = node->indices; const PrimExpr& last_dim_index = indices[indices.size() - 1]; const RampNode* ramp_index = indices[indices.size() - 1].as(); @@ -1536,7 +1537,7 @@ class VectorTypeRewriter : public StmtExprMutator { Stmt body = this->VisitStmt(op->body); Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } return LetStmt(var, value, body); } @@ -1553,7 +1554,7 @@ class VectorTypeRewriter : public StmtExprMutator { if (info_it != rewrite_map_.end()) { auto& info = info_it->second; - Array shape = buf->shape; + ffi::Array shape = buf->shape; PrimExpr last_dim = shape[shape.size() - 1]; shape.Set(shape.size() - 1, last_dim / make_const(last_dim.dtype(), info.factor())); @@ -1591,7 +1592,7 @@ class VectorTypeRewriter : public StmtExprMutator { int factor = info.factor(); extent = extent / make_const(extent.dtype(), factor); index = index / make_const(index.dtype(), factor); - Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; + ffi::Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); } else { @@ -1612,7 +1613,7 @@ class VectorTypeRewriter : public StmtExprMutator { Var new_buffer_var = info.new_buffer_var; - Array extents = op->extents; + ffi::Array extents = op->extents; PrimExpr last_extent = extents[extents.size() - 1]; extents.Set(extents.size() - 1, last_extent / make_const(last_extent.dtype(), info.factor())); return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); @@ -1633,7 +1634,7 @@ class VectorTypeRewriter : public StmtExprMutator { int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); - Array extents = op->extents; + ffi::Array extents = op->extents; extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); return AllocateConst(new_buffer_var, info.new_element_dtype, extents, op->data, op->body); @@ -1652,7 +1653,7 @@ class VectorTypeRewriter : public StmtExprMutator { auto* n = func.CopyOnWrite(); // Remap any remaining references to the old buffer variables - Map var_remap; + ffi::Map var_remap; for (const auto& pair : rewrite_map_) { const auto& info = pair.second; var_remap.Set(info.old_buffer_var, info.new_buffer_var); @@ -1660,7 +1661,7 @@ class VectorTypeRewriter : public StmtExprMutator { n->body = Substitute(n->body, var_remap); // Remap the argument list to use the new buffer variables. - Array new_params; + ffi::Array new_params; for (const auto& old_param : n->params) { auto it = rewrite_map_.find(old_param.get()); if (it == rewrite_map_.end()) { @@ -1674,7 +1675,7 @@ class VectorTypeRewriter : public StmtExprMutator { // Remap the Buffer objects in PrimFunc::buffer_map so that the // buffers use the new buffer variables - Map new_buffer_map; + ffi::Map new_buffer_map; for (const auto& pair : n->buffer_map) { Var key = pair.first; Buffer old_buffer = pair.second; @@ -1742,7 +1743,7 @@ Pass StorageRewrite() { enable_reuse = false; } - Optional target = f->GetAttr("target"); + ffi::Optional target = f->GetAttr("target"); if (target.defined() && (target.value()->kind->name == "vulkan" || target.value()->kind->name == "webgpu")) { // Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU @@ -1763,10 +1764,10 @@ Pass StorageRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.StorageRewrite", StorageRewrite); -}); +} Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -1775,10 +1776,10 @@ Pass PointerValueTypeRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.PointerValueTypeRewrite", PointerValueTypeRewrite); -}); +} } // namespace transform diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 8285ee96279c..7c1b5b05d093 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -59,7 +59,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(k); ICHECK(layout); - std::string scope = GetPtrStorageScope(GetRef(buffer_var)); + std::string scope = GetPtrStorageScope(ffi::GetRef(buffer_var)); if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; @@ -92,7 +92,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(n); ICHECK(k); - std::string scope = GetPtrStorageScope(GetRef(buffer_var)); + std::string scope = GetPtrStorageScope(ffi::GetRef(buffer_var)); if (fragments.count(buffer_var)) { FragmentInfo info = fragments[buffer_var]; ICHECK_EQ(m->value, info.m); @@ -218,10 +218,10 @@ Pass InferFragment() { return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.InferFragment", InferFragment); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index f8b0a83d4d43..d41d474a0864 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -401,7 +401,7 @@ class ThreadSyncInserter : public StmtExprMutator { // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); - Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; + ffi::Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); Stmt body = op->body; for (const auto& kv : rw_stats_) { @@ -463,7 +463,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { namespace transform { -Pass ThreadSync(String storage_scope) { +Pass ThreadSync(ffi::String storage_scope) { auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = ThreadSync(std::move(n->body), storage_scope); @@ -472,10 +472,10 @@ Pass ThreadSync(String storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.ThreadSync", ThreadSync); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index bee45716b17d..60b6ffda3219 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -44,7 +44,7 @@ namespace tir { class MmaBufferLayoutTransformer : public StmtExprMutator { public: Stmt VisitStmt_(const BlockNode* op) { - Block block = GetRef(op); + Block block = ffi::GetRef(op); auto* n = block.CopyOnWrite(); auto fmutate = [this](const Buffer& buffer) { // m16n8k8.matrix[A/B/C] buffers are composed ofseveral small blocks. Assume the block's @@ -164,10 +164,10 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) { - if (buffer_var_map_.count(GetRef(op))) { - return buffer_var_map_[GetRef(op)]; + if (buffer_var_map_.count(ffi::GetRef(op))) { + return buffer_var_map_[ffi::GetRef(op)]; } - return GetRef(op); + return ffi::GetRef(op); } private: @@ -187,10 +187,10 @@ Pass TransformMmaBufferLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.TransformMmaBufferLayout", TransformMmaBufferLayout); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index a9e47055e2a7..502acd5a467e 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -60,14 +60,14 @@ class ThreadBindingUnifier : public StmtExprMutator { if (op->kind != ForKind::kThreadBinding) { return StmtExprMutator::VisitStmt_(op); } - Map annotations = op->annotations; + ffi::Map annotations = op->annotations; Stmt stmt = UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(), Range::FromMinExtent(op->min, op->extent)); if (annotations.empty()) { return stmt; } if (const auto* loop = stmt.as()) { - For new_loop = GetRef(loop); + For new_loop = ffi::GetRef(loop); new_loop.CopyOnWrite()->annotations = std::move(annotations); return new_loop; @@ -79,7 +79,8 @@ class ThreadBindingUnifier : public StmtExprMutator { /*extent=*/IntImm(dtype, 1), // /*kind=*/ForKind::kSerial, stmt, // /*thread_binding=*/std::nullopt, // - /*annotation=*/std::move(annotations)); + /*annotation=*/std::move(annotations), + /*step=*/std::nullopt); } } @@ -88,7 +89,7 @@ class ThreadBindingUnifier : public StmtExprMutator { const Range& dom) { // Step 1. Fetch the thread tag. IterVar new_iter_var{nullptr}; - const String& thread_tag = old_iter_var->thread_tag; + const ffi::String& thread_tag = old_iter_var->thread_tag; // Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the // thread block depth is 0 before the increment, it means we are entering a new kernel, and @@ -107,7 +108,7 @@ class ThreadBindingUnifier : public StmtExprMutator { // Step 3. See if an IterVar for this kind of thread binding was created before. If so, we use // the created IterVar. Otherwise, we create a new IterVar for this thread binding and store the // IterVar in mapping `thread_tag2iter_var_map_`. - Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); + ffi::Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); if (it != thread_tag2iter_var_map_.end()) { new_iter_var = (*it).second; ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); @@ -155,7 +156,8 @@ class ThreadBindingUnifier : public StmtExprMutator { result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent, ForKind::kThreadBinding, result, IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, - thread_binding->thread_tag)); + thread_binding->thread_tag), + {}, std::nullopt); launch_threads_.pop_back(); } return result; @@ -164,22 +166,22 @@ class ThreadBindingUnifier : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* var) final { // If this variable appears as a key in `var_substitution_map_`, we substitute it with its // corresponding value in the mapping. - Map::iterator it = var_substitution_map_.find(GetRef(var)); - return it != var_substitution_map_.end() ? (*it).second : GetRef(var); + ffi::Map::iterator it = var_substitution_map_.find(ffi::GetRef(var)); + return it != var_substitution_map_.end() ? (*it).second : ffi::GetRef(var); } /*! * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all * occurrences of the thread tag */ - Map thread_tag2iter_var_map_; + ffi::Map thread_tag2iter_var_map_; /*! * \brief A list of IterVar corresponding to threads in current kernel. This will be used to * generate for-loops to launch threads. */ - Array launch_threads_; + ffi::Array launch_threads_; /*! \brief A mapping from old variables to new variables, which is used for substitution */ - Map var_substitution_map_; + ffi::Map var_substitution_map_; /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */ int thread_block_depth_ = 0; /*! \brief An analyzer used for equality proof */ @@ -201,10 +203,10 @@ Pass UnifyThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.UnifyThreadBinding", UnifyThreadBinding); -}); +} } // namespace transform diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index fdddf2091141..7b92bad12d34 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -30,9 +30,7 @@ #include #include -#include #include -#include #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" @@ -64,17 +62,16 @@ struct UnrollLoopConfigNode : public AttrsNodeReflAdapter .def_ro("unroll_local_access", &UnrollLoopConfigNode::unroll_local_access, "Whether to always unroll local access", refl::DefaultValue(false)); } - - static constexpr const char* _type_key = "tir.transform.UnrollLoopConfig"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(UnrollLoopConfigNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.UnrollLoopConfig", UnrollLoopConfigNode, + BaseAttrsNode); }; class UnrollLoopConfig : public Attrs { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnrollLoopConfig, Attrs, UnrollLoopConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, Attrs, UnrollLoopConfigNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ UnrollLoopConfigNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); } TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); @@ -83,7 +80,7 @@ class VarLocalAccessMarker : public ExprVisitor { explicit VarLocalAccessMarker(std::unordered_set* var_touched_local) : var_touched_local_(var_touched_local) {} - void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef(op)); } + void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(ffi::GetRef(op)); } private: std::unordered_set* var_touched_local_; @@ -157,8 +154,9 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->kind != ForKind::kUnrolled) { - return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, - op->thread_binding, op->annotations); + auto n = CopyOnWrite(op); + n->kind = ForKind::kUnrolled; + return For(n); } } return stmt; @@ -176,7 +174,7 @@ class LoopUnroller : public StmtExprMutator { } } } - return GetRef(op); + return ffi::GetRef(op); } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -222,8 +220,8 @@ class LoopUnroller : public StmtExprMutator { ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); Stmt body = op->body; - Map vmap; - Array unrolled; + ffi::Map vmap; + ffi::Array unrolled; for (int i = 0; i < value; ++i) { vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); Stmt step = Substitute(body, vmap); @@ -293,10 +291,10 @@ Pass UnrollLoop() { return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.UnrollLoop", UnrollLoop); -}); +} } // namespace transform diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 92f0a6de98e1..bd875eca56de 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -60,7 +60,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { var_remap_->erase(it); } } - Array drop_buffers; + ffi::Array drop_buffers; for (auto kv : *buffer_remap_) { if (opaque_var_access_.count(kv.first->data)) { drop_buffers.push_back(kv.first); @@ -79,7 +79,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { // remap all intermediate constant buffer to promote data types (fp16/fp32) if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) { DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes()); - String storage_scope = "global"; + ffi::String storage_scope = "global"; if (auto* ptr_type = op->buffer_var->type_annotation.as()) { storage_scope = ptr_type->storage_scope; } @@ -106,7 +106,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { StmtExprVisitor::VisitExpr_(op); - Var buffer_var = GetRef(op); + Var buffer_var = ffi::GetRef(op); if (buffer_var.dtype().is_handle()) { opaque_var_access_.insert(buffer_var); } @@ -153,7 +153,7 @@ class FP8ComputeLegalizePlanner : public ComputeLegalizePlanner { PrimExpr origin_b = PromoteToTarget(this->VisitExpr(op->b)); \ \ if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return FUNC(origin_a, origin_b); \ } \ @@ -189,7 +189,7 @@ class ComputeLegalizer : public StmtExprMutator { } if (op_val.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return cast(op->dtype, op_val); } @@ -201,7 +201,7 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr false_value = PromoteToTarget(this->VisitExpr(op->false_value)); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Select(condition, true_value, false_value); } @@ -210,7 +210,7 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr VisitExpr_(const BroadcastNode* op) final { PrimExpr value = PromoteToTarget(this->VisitExpr(op->value)); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Broadcast(value, op->lanes); } @@ -220,7 +220,7 @@ class ComputeLegalizer : public StmtExprMutator { auto fexpr = [this](const PrimExpr& e) { return PromoteToTarget(this->VisitExpr(e)); }; auto vectors = op->vectors.Map(fexpr); if (vectors.same_as(op->vectors)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Shuffle(vectors, op->indices); } @@ -233,14 +233,14 @@ class ComputeLegalizer : public StmtExprMutator { } // update normal computations to return f32 instead. auto fmutate = [this](const PrimExpr& e) { return PromoteToTarget(this->VisitExpr(e)); }; - Array args = op->args.Map(fmutate); + ffi::Array args = op->args.Map(fmutate); if (MatchDType(op->dtype)) { - return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args); + return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args, op->annotations); } if (args.same_as(op->args)) { - return GetRef(op); + return ffi::GetRef(op); } else { - return Call(op->dtype, op->op, args); + return Call(op->dtype, op->op, args, op->annotations); } } @@ -248,11 +248,11 @@ class ComputeLegalizer : public StmtExprMutator { if (MatchDType(op->dtype)) { return FloatImm(promote_dtype_, op->value); } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto itr = var_remap_.find(var); if (itr != var_remap_.end()) { @@ -273,7 +273,7 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr body = VisitExpr(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(var, value, body); } @@ -302,7 +302,7 @@ class ComputeLegalizer : public StmtExprMutator { Stmt body = VisitStmt(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(var, value, body); } @@ -312,12 +312,12 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr value = this->VisitExpr(op->value); auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); Buffer new_buf = GetRemappedBuffer(op->buffer); if (value.same_as(op->value) && indices.same_as(op->indices) && new_buf.same_as(op->buffer)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (MatchDType(new_buf->dtype)) { int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; @@ -329,7 +329,7 @@ class ComputeLegalizer : public StmtExprMutator { // this happens when buffer get rewritten to f32 // but values remain as fp8/bf16 ICHECK(MatchDType(value->dtype)); - value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); + value = DTypeConversion(value, new_buf->dtype.with_lanes(value.dtype().lanes())); } ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " "data type legalizer pass."; @@ -526,7 +526,7 @@ class StorageLegalizer : public StmtExprMutator { private: PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto itr = var_remap_.find(var); if (itr != var_remap_.end()) { return itr->second; @@ -538,7 +538,7 @@ class StorageLegalizer : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { if (MatchDType(op->dtype)) { DataType dtype = GetStorageUIntDType(op->dtype); - String storage_scope = "global"; + ffi::String storage_scope = "global"; if (auto* ptr_type = op->buffer_var->type_annotation.as()) { storage_scope = ptr_type->storage_scope; } @@ -563,7 +563,7 @@ class StorageLegalizer : public StmtExprMutator { } Stmt body = VisitStmt(op->body); if (buf.same_as(op->buffer) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return DeclBuffer(buf, body, op->span); } @@ -575,7 +575,7 @@ class StorageLegalizer : public StmtExprMutator { PrimExpr body = VisitExpr(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(var, value, body); } @@ -587,7 +587,7 @@ class StorageLegalizer : public StmtExprMutator { Stmt body = VisitStmt(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(var, value, body); } @@ -598,7 +598,7 @@ class StorageLegalizer : public StmtExprMutator { Buffer new_buf = GetRemappedBuffer(op->buffer); auto indices = op->indices.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); if (new_buf.same_as(op->buffer) && indices.same_as(op->indices) && value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); @@ -654,7 +654,7 @@ class StorageLegalizer : public StmtExprMutator { return reinterpret(GetStorageUIntDType(op->dtype), value); } if (op->args[0].same_as(value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return reinterpret(op->dtype, value); } @@ -759,10 +759,10 @@ Pass BF16ComputeLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.BF16ComputeLegalize", BF16ComputeLegalize); -}); +} Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -775,26 +775,26 @@ Pass BF16StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize); -}); +} -Pass FP8ComputeLegalize(String promote_dtype_str) { +Pass FP8ComputeLegalize(ffi::String promote_dtype) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto target = f->GetAttr(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } - return FP8ComputeLegalizer(DataType(StringToDLDataType(promote_dtype_str))).Legalize(f); + return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.FP8ComputeLegalize", FP8ComputeLegalize); -}); +} Pass FP8StorageLegalize() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -807,10 +807,10 @@ Pass FP8StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.FP8StorageLegalize", FP8StorageLegalize); -}); +} } // namespace transform } // namespace tir diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 9af990d1e2bf..e12ab9696a99 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -37,7 +37,7 @@ namespace tvm { namespace tir { -Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { +Var WithStorageScope(const VarNode* buffer_var, ffi::String storage_scope) { auto* ptr_type = buffer_var->type_annotation.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), @@ -45,7 +45,7 @@ Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { } UpdatePointerStorageScope::UpdatePointerStorageScope( - const std::unordered_map& new_storage_scopes) { + const std::unordered_map& new_storage_scopes) { for (auto& kv : new_storage_scopes) { new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); } @@ -54,7 +54,7 @@ UpdatePointerStorageScope::UpdatePointerStorageScope( PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { auto it = new_var_remap_.find(op); if (it == new_var_remap_.end()) { - return GetRef(op); + return ffi::GetRef(op); } return it->second; } diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h index 1f1399fba76b..a2f7027ce4f8 100644 --- a/src/tir/transforms/update_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -36,7 +36,7 @@ namespace tir { class UpdatePointerStorageScope : public StmtExprMutator { public: explicit UpdatePointerStorageScope( - const std::unordered_map& new_storage_scopes); + const std::unordered_map& new_storage_scopes); virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const BufferLoadNode*); diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 53509ce49710..21f3dc43ba28 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -119,13 +119,13 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { using Parent::VisitStmt_; // This struct stores all the relevant data related to asssume statement - struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) - PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) - PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding - // bufferload expression (A[i] == 0) - tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 - PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 - Array buffer_indices; // Storing the indices of the buffer Eg : i + struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) + PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) + PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding + // bufferload expression (A[i] == 0) + tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 + PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 + ffi::Array buffer_indices; // Storing the indices of the buffer Eg : i }; // List of conditions in a scope std::vector conditions_; @@ -162,7 +162,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { With analyzer_context; size_t old_num_constraints{0}; size_t new_num_constraints{0}; - Optional assume{std::nullopt}; + ffi::Optional assume{std::nullopt}; // Disable default-generated copy/move assignment and constructors InternalConstraintContext(const InternalConstraintContext&) = delete; @@ -209,7 +209,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { return buf_value; } } - return GetRef(op); + return ffi::GetRef(op); } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -358,7 +358,7 @@ Pass UseAssumeToReduceBranches() { // the primfunc has op_pattern defined and is an elementwise op. // AnnotateTIROpPattern pass will set op_pattern in op attributes of the primfunc. if (n->attrs.GetAttr("op_pattern").defined()) { - Optional opt_pattern = f->GetAttr("op_pattern"); + ffi::Optional opt_pattern = f->GetAttr("op_pattern"); if (opt_pattern.defined()) { relax::OpPatternKind pattern; pattern = static_cast(Downcast(opt_pattern)->value); @@ -382,10 +382,10 @@ Pass UseAssumeToReduceBranches() { return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.UseAssumeToReduceBranches", UseAssumeToReduceBranches); -}); +} } // namespace transform diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 8e350924501e..5c61a0c78e9f 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -75,7 +75,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { bool EnableBufferLevelPredication(Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - Optional enable_buffer_predication = + ffi::Optional enable_buffer_predication = pass_ctx->GetConfig("tir.enable_buffer_level_predication"); if (enable_buffer_predication.defined()) { return enable_buffer_predication.value(); @@ -160,7 +160,7 @@ class TryPredicateBufferAccesses : public StmtExprMutator { num_accesses_analyzed_ += 1; // Do not try to predicate non-vectorized accesses - Array indices = node->indices; + ffi::Array indices = node->indices; if (!indices.size() || !indices[0]->IsInstance()) { return node; } @@ -233,7 +233,7 @@ class VecAllocAccess : public StmtExprMutator { // Extend the least significant dimension by a factor of // var_lanes_. Typically, this will be a 1-d index into a flat // memory space. - Array shape = node->buffer->shape; + ffi::Array shape = node->buffer->shape; shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); // TODO(Lunderberg): Move this pass to be prior to @@ -243,7 +243,7 @@ class VecAllocAccess : public StmtExprMutator { // are updated for consistency. // Update strides if defined. - Array strides; + ffi::Array strides; for (size_t i = 0; i < strides.size(); i++) { PrimExpr stride = strides[i]; if (i != strides.size() - 1) { @@ -262,7 +262,7 @@ class VecAllocAccess : public StmtExprMutator { // Extend the last index by the number of lanes in the vectorized // variable. - Array indices = node->indices; + ffi::Array indices = node->indices; indices.Set(indices.size() - 1, analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); @@ -322,7 +322,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); @@ -369,7 +369,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); if (a.same_as(op->a)) { - return GetRef(op); + return ffi::GetRef(op); } else { return !(a); } @@ -396,7 +396,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor elems; + ffi::Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes)); @@ -408,10 +408,10 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (value.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return ffi::GetRef(op); } if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Broadcast(op->value, op->lanes); } @@ -422,7 +422,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->true_value); PrimExpr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor(); @@ -438,7 +438,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (value.dtype().is_scalable_vector()) { return Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), value); @@ -448,15 +448,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } + PrimExpr VisitExpr_(const FloatImmNode* op) final { return ffi::GetRef(op); } - PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef(op); } + PrimExpr VisitExpr_(const IntImmNode* op) final { return ffi::GetRef(op); } - PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef(op); } + PrimExpr VisitExpr_(const StringImmNode* op) final { return ffi::GetRef(op); } // Variable PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (var.same_as(var_)) { return ramp_; @@ -473,12 +473,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->args[0]); if (cond.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return ffi::GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { - return GetRef(op); + return ffi::GetRef(op); } else { int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor(); @@ -487,9 +487,9 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f}); + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f}, op->annotations); } else { - return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->annotations); } } } @@ -498,17 +498,17 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::reinterpret())); PrimExpr value = this->VisitExpr(op->args[0]); if (value.same_as(op->args[0])) { - return GetRef(op); + return ffi::GetRef(op); } else { int lanes = value.dtype().get_lanes_or_vscale_factor(); if (value.dtype().is_scalable_vector()) { - return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value}); + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value}, op->annotations); } else { int new_lanes = (op->dtype != DataType::Float4E2M1FN() && op->args[0].dtype() != DataType::Float4E2M1FN()) ? (value.dtype().bits() * value.dtype().lanes()) / op->dtype.bits() : value.dtype().lanes(); - return Call(op->dtype.with_lanes(new_lanes), op->op, {value}); + return Call(op->dtype.with_lanes(new_lanes), op->op, {value}, op->annotations); } } } @@ -518,18 +518,18 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::texture2d_load())) { int lane = 0; - Array fcd = MutateArray({op->args.back()}, &lane); + ffi::Array fcd = MutateArray({op->args.back()}, &lane); auto new_args = op->args; new_args.pop_back(); new_args.push_back(fcd[0]); - return Call(op->dtype.with_lanes(4), op->op, new_args); + return Call(op->dtype.with_lanes(4), op->op, new_args, op->annotations); } else if (op->op.same_as(builtin::texture2d_store())) { int lane = 0; // Vectorize the value to store - Array value{op->args.back()}; - Array mutated_value = MutateArray(value, &lane); - Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; - return Call(op->dtype.with_lanes(lane), op->op, new_args); + ffi::Array value{op->args.back()}; + ffi::Array mutated_value = MutateArray(value, &lane); + ffi::Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; + return Call(op->dtype.with_lanes(lane), op->op, new_args, op->annotations); } else if (op->op.same_as(builtin::reinterpret())) { return MutateReinterpretExpr_(op); } @@ -539,32 +539,32 @@ class Vectorizer : public StmtMutator, public ExprFunctor new_args; + ffi::Array new_args; for (auto arg : op->args) { auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return ffi::GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { - return GetRef(op); + return ffi::GetRef(op); } else { - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } } else { int lane = 0; - Array new_args; + ffi::Array new_args; if (op->op.same_as(builtin::call_llvm_pure_intrin())) { // op->args[1], will give us total number of arguments to intrinsic - Array op_expr_args; + ffi::Array op_expr_args; for (size_t i = 1; i < op->args.size(); ++i) { // Collect all intrinsic arguments op_expr_args.push_back(op->args[i]); } // Generate RAMP nodes for intrinsic arguments - Array updated_args = MutateArray(op_expr_args, &lane); + ffi::Array updated_args = MutateArray(op_expr_args, &lane); new_args.push_back(op->args[0]); // Collect updated intrinsic arguments for (size_t i = 0; i < updated_args.size(); ++i) { @@ -575,18 +575,18 @@ class Vectorizer : public StmtMutator, public ExprFunctorargs.same_as(new_args)) { - return GetRef(op); + return ffi::GetRef(op); } else { - return Call(op->dtype.with_lanes(lane), op->op, new_args); + return Call(op->dtype.with_lanes(lane), op->op, new_args, op->annotations); } } } // BufferLoad PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto load = GetRef(op); + auto load = ffi::GetRef(op); auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); if (!indices.same_as(op->indices)) { auto writer = load.CopyOnWrite(); @@ -619,7 +619,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvar] = op->var; PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -631,10 +631,10 @@ class Vectorizer : public StmtMutator, public ExprFunctorvectors.size() << " and the index size is " << op->indices.size(); int lane_vectors = 0; int lane_indices = 0; - Array vectors = MutateArray(op->vectors, &lane_vectors); - Array indices = MutateArray(op->indices, &lane_indices); + ffi::Array vectors = MutateArray(op->vectors, &lane_vectors); + ffi::Array indices = MutateArray(op->indices, &lane_indices); if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } int new_vec_length = Downcast(var_lanes_)->value / op->vectors[0].dtype().lanes(); @@ -689,10 +689,10 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); + auto store = ffi::GetRef(op); auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); PrimExpr value = this->VisitExpr(op->value); @@ -746,14 +746,16 @@ class Vectorizer : public StmtMutator, public ExprFunctorextent.dtype().is_scalable_or_fixed_length_vector()); PrimExpr extent = this->VisitExpr(op->extent); if (extent.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); + return Scalarize(ffi::GetRef(op)); } Stmt body = this->VisitStmt(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, - op->annotations); + auto n = CopyOnWrite(op); + n->extent = extent; + n->body = body; + return For(n); } } // IfThenElse @@ -766,7 +768,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitStmt(op->then_case); - Optional else_case = std::nullopt; + ffi::Optional else_case = std::nullopt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } @@ -782,11 +784,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); + return Scalarize(ffi::GetRef(op)); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { return IfThenElse(condition, then_case, else_case); } @@ -802,7 +804,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); + Scalarize(ffi::GetRef(op)); } ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; @@ -816,7 +818,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvar] = op->var; Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(op->var, value, body); } @@ -828,16 +830,16 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->condition); if (condition.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; - return Scalarize(GetRef(op)); + return Scalarize(ffi::GetRef(op)); } // Mutate the extents - Array extents; + ffi::Array extents; for (const auto& extent : op->extents) { PrimExpr new_ext = this->VisitExpr(extent); if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; - return Scalarize(GetRef(op)); + return Scalarize(ffi::GetRef(op)); } extents.push_back(new_ext); } @@ -887,7 +889,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor MutateArray(Array arr, int* p_lanes) { + ffi::Array MutateArray(ffi::Array arr, int* p_lanes) { if (arr.size() == 0) return arr; int& lanes = *p_lanes; bool changed = false; @@ -907,7 +909,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(new_arr); + return ffi::Array(new_arr); } template PrimExpr BinaryVec(const T* op) { @@ -915,7 +917,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); @@ -929,7 +931,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); @@ -1021,10 +1023,10 @@ Pass VectorizeLoop(bool enable_vectorize) { return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.transform.VectorizeLoop", VectorizeLoop); -}); +} } // namespace transform diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 1ca901c6fbf5..c90b20877101 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -47,12 +47,13 @@ using namespace tvm::runtime; } \ }) -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.broadcast_to", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = broadcast_to(args[0].cast(), args[1].cast>()); + *rv = broadcast_to(args[0].cast(), + args[1].cast>()); }) .TOPI_DEF_BCAST_OP("topi.add", topi::add) .TOPI_DEF_BCAST_OP("topi.subtract", topi::subtract) @@ -79,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .TOPI_DEF_BCAST_OP("topi.not_equal", topi::not_equal) .TOPI_DEF_BCAST_OP("topi.greater_equal", topi::greater_equal) .TOPI_DEF_BCAST_OP("topi.less_equal", topi::less_equal); -}); +} } // namespace topi } // namespace tvm diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 9586b9c5575e..42c8c768d275 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -136,26 +136,26 @@ class EinsumBuilder { * \param equation The Einsum equation * \param input_shapes The shapes of the input tensors */ - EinsumBuilder(EinsumEquation equation, Array> input_shapes) + EinsumBuilder(EinsumEquation equation, ffi::Array> input_shapes) : equation_(equation), input_shapes_(input_shapes) {} /*! * \brief Run the shape inference * \return The inferred shape of the output */ - Array InferShape() { + ffi::Array InferShape() { CHECK_EQ(equation_.inputs.size(), input_shapes_.size()) << "Number of operands does not match the " "equation"; - std::vector> + std::vector> ellipis_shapes; // the sub-shape covered by the ellipsis for each operand // Step 1: Collect the broadcasted extent for each label for (int operand_index = 0; operand_index < static_cast(input_shapes_.size()); ++operand_index) { const EinsumEquation::Subscript subscript = equation_.inputs[operand_index]; - const Array& input_shape = input_shapes_[operand_index]; + const ffi::Array& input_shape = input_shapes_[operand_index]; int current_dim = 0; for (auto label : subscript) { @@ -182,14 +182,16 @@ class EinsumBuilder { // Step 2: Infer the shape of the ellipsis if exists // The ellipsis may cover different number of dimensions for each operand, these sub-shapes // need to be broadcasted to the shape with the maximum number of dimensions - Array ellipsis_shape; + ffi::Array ellipsis_shape; if (ellipis_shapes.size()) { - ellipsis_shape = *std::max_element( - ellipis_shapes.begin(), ellipis_shapes.end(), - [](const Array& a, const Array& b) { return a.size() < b.size(); }); - for (const Array& shape : ellipis_shapes) { + ellipsis_shape = + *std::max_element(ellipis_shapes.begin(), ellipis_shapes.end(), + [](const ffi::Array& a, const ffi::Array& b) { + return a.size() < b.size(); + }); + for (const ffi::Array& shape : ellipis_shapes) { auto common_shape = detail::BroadcastShape(ellipsis_shape, shape).common_shape; - ellipsis_shape = Array(common_shape.begin(), common_shape.end()); + ellipsis_shape = ffi::Array(common_shape.begin(), common_shape.end()); } } @@ -205,10 +207,10 @@ class EinsumBuilder { return output_shape_; } - PrimExpr BuildOutputExpr(const Array inputs, const Array& indices) { + PrimExpr BuildOutputExpr(const ffi::Array inputs, const ffi::Array& indices) { std::unordered_map label_to_index; - Array ellipsis_indices; - Array reduce_axes; + ffi::Array ellipsis_indices; + ffi::Array reduce_axes; PrepareOutputIndicesMapping(indices, &label_to_index, &ellipsis_indices); PrepareReductionIndicesMapping(indices, &label_to_index, &ellipsis_indices, &reduce_axes); @@ -234,14 +236,15 @@ class EinsumBuilder { /*! * \brief Prepare mapping from label (including ellipsis) to the output indices */ - void PrepareOutputIndicesMapping(const Array& indices, + void PrepareOutputIndicesMapping(const ffi::Array& indices, std::unordered_map* label_to_index, - Array* ellipsis_indices) { + ffi::Array* ellipsis_indices) { int i = 0; for (auto label : equation_.output) { if (label == EinsumEquation::kEllipsis) { auto ellipsis_ndim = ellipsis_shape_.value().size(); - *ellipsis_indices = Array(indices.begin() + i, indices.begin() + i + ellipsis_ndim); + *ellipsis_indices = + ffi::Array(indices.begin() + i, indices.begin() + i + ellipsis_ndim); i += ellipsis_ndim; } else { label_to_index->emplace(label, indices[i++]); @@ -255,8 +258,9 @@ class EinsumBuilder { * necessary) to the reduction axes */ void PrepareReductionIndicesMapping( - const Array& indices, std::unordered_map* label_to_index, - Array* ellipsis_indices, Array* reduction_axes) { + const ffi::Array& indices, + std::unordered_map* label_to_index, + ffi::Array* ellipsis_indices, ffi::Array* reduction_axes) { // Collect labels that need to be reduced, which is the union(input_labels) - output_labels std::set reduction_labels; for (const EinsumEquation::Subscript& subscript : equation_.inputs) { @@ -288,18 +292,18 @@ class EinsumBuilder { } } - Array GetIndicesForOperand( + ffi::Array GetIndicesForOperand( int operand_index, const std::unordered_map& label_to_index, - const Array& ellipsis_indices) { + const ffi::Array& ellipsis_indices) { const EinsumEquation::Subscript& subscript = equation_.inputs[operand_index]; - Array indices; // the indices for the operand - const Array input_shape = input_shapes_[operand_index]; + ffi::Array indices; // the indices for the operand + const ffi::Array input_shape = input_shapes_[operand_index]; int i = 0; // index of the operand shape for (char label : subscript) { if (label == EinsumEquation::kEllipsis) { // Ellipsis - Array ellipsis_shape = ellipsis_shape_.value(); + ffi::Array ellipsis_shape = ellipsis_shape_.value(); int ellipsis_ndim = static_cast(input_shape.size()) - static_cast(subscript.size()) + 1; // use last 'ellipsis_ndim' axes @@ -320,24 +324,24 @@ class EinsumBuilder { } EinsumEquation equation_; - Array> input_shapes_; + ffi::Array> input_shapes_; // intermediate results of shape inference // The output shape - Array output_shape_; + ffi::Array output_shape_; // The extent of each label with broadcast rules applied std::unordered_map label_to_extent_; // The shape of the ellipsis if ellipsis is used. The shape covered by the // ellipsis in each operand might be different from this, this is the common // shape among them according to the broadcast rules. - Optional> ellipsis_shape_; + ffi::Optional> ellipsis_shape_; }; -Tensor einsum(const std::string& subscripts_str, const Array inputs, std::string name, +Tensor einsum(const std::string& subscripts_str, const ffi::Array inputs, std::string name, std::string tag) { EinsumEquation equation = EinsumEquation::FromString(subscripts_str); - Array> input_shapes; + ffi::Array> input_shapes; for (const Tensor& input : inputs) { input_shapes.push_back(input->shape); } @@ -345,23 +349,25 @@ Tensor einsum(const std::string& subscripts_str, const Array inputs, std auto output_shape = einsum_builder.InferShape(); return te::compute( output_shape, - [&](const Array& indices) { return einsum_builder.BuildOutputExpr(inputs, indices); }, + [&](const ffi::Array& indices) { + return einsum_builder.BuildOutputExpr(inputs, indices); + }, name, tag); } -Array InferEinsumShape(const std::string& subscripts, - const std::vector>& operands) { +ffi::Array InferEinsumShape(const std::string& subscripts, + const std::vector>& operands) { EinsumEquation equation = EinsumEquation::FromString(subscripts); EinsumBuilder einsum_builder = EinsumBuilder(equation, operands); return einsum_builder.InferShape(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.einsum", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = einsum(args[0].cast(), args[1].cast>()); + *rv = einsum(args[0].cast(), args[1].cast>()); }); -}); +} } // namespace topi } // namespace tvm diff --git a/src/topi/elemwise.cc b/src/topi/elemwise.cc index b60256cea5f5..922c40619908 100644 --- a/src/topi/elemwise.cc +++ b/src/topi/elemwise.cc @@ -31,7 +31,7 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.acos", [](ffi::PackedArgs args, @@ -100,13 +100,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.elemwise_sum", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = elemwise_sum(args[0].cast>()); + *rv = elemwise_sum(args[0].cast>()); }) .def_packed("topi.sign", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = sign(args[0].cast()); }) .def_packed("topi.full", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = full(args[0].cast>(), args[1].cast(), + *rv = full(args[0].cast>(), args[1].cast(), args[2].cast()); }) .def_packed("topi.full_like", @@ -119,7 +119,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("topi.bitwise_not", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = bitwise_not(args[0].cast()); }); -}); +} } // namespace topi } // namespace tvm diff --git a/src/topi/nn.cc b/src/topi/nn.cc index d872bac2ce30..1f8118231fae 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -45,7 +45,7 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed( @@ -62,21 +62,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.nn.pad", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = pad(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast()); + *rv = pad(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast()); }) .def_packed("topi.nn.space_to_batch_nd", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = space_to_batch_nd( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), args[4].cast()); }) .def_packed("topi.nn.batch_to_space_nd", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = batch_to_space_nd( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), args[4].cast()); }) .def_packed("topi.nn.nll_loss", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -84,44 +84,44 @@ TVM_FFI_STATIC_INIT_BLOCK({ nll_loss(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast()); }); -}); +} /* Ops from nn/dense.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.dense", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::dense(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); -}); +} /* Ops from nn/bias_add.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.bias_add", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::bias_add(args[0].cast(), args[1].cast(), args[2].cast()); }); -}); +} /* Ops from nn/dilate.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.dilate", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::dilate(args[0].cast(), args[1].cast>(), + *rv = nn::dilate(args[0].cast(), args[1].cast>(), args[2].cast()); }); -}); +} /* Ops from nn/flatten.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.flatten", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::flatten(args[0].cast()); }); -}); +} /* Ops from nn/mapping.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.nn.scale_shift_nchw", @@ -134,18 +134,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = nn::scale_shift_nhwc(args[0].cast(), args[1].cast(), args[2].cast()); }); -}); +} /* Ops from nn/pooling.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.nn.pool_grad", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool_grad( args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }) @@ -158,53 +158,53 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("topi.nn.adaptive_pool1d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool1d(args[0].cast(), - args[1].cast>(), + args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }) .def_packed("topi.nn.adaptive_pool", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool(args[0].cast(), - args[1].cast>(), + args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }) .def_packed("topi.nn.adaptive_pool3d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool3d(args[0].cast(), - args[1].cast>(), + args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }) .def_packed("topi.nn.pool1d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool1d( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }) .def_packed("topi.nn.pool2d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool2d( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }) .def_packed("topi.nn.pool3d", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::pool3d(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + *rv = nn::pool3d(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); -}); +} /* Ops from nn/softmax.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.nn.softmax", @@ -219,10 +219,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = nn::lrn(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast()); }); -}); +} /* Ops from nn/bnn.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.nn.binarize_pack", @@ -232,46 +232,46 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("topi.nn.binary_dense", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::binary_dense(args[0].cast(), args[1].cast()); }); -}); +} /* Ops from nn/layer_norm.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.layer_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::layer_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), + args[2].cast(), args[3].cast>(), args[4].cast()); }); -}); +} /* Ops from nn/group_norm.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.group_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::group_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), - args[5].cast>(), args[6].cast()); + args[5].cast>(), args[6].cast()); }); -}); +} /* Ops from nn/instance_norm.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.instance_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), - args[4].cast>(), args[5].cast()); + args[4].cast>(), args[5].cast()); }); -}); +} /* Ops from nn/rms_norm.h */ -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.rms_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::rms_norm(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast()); + args[2].cast>(), args[3].cast()); }); -}); +} } // namespace topi } // namespace tvm diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 7b10c7771b32..0f2a7f49fc73 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -32,7 +32,7 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.sum", @@ -76,9 +76,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ args[2].cast()); }) .def_packed("topi.collapse_sum", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); + *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); }); -}); +} } // namespace topi } // namespace tvm diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 433a641ad068..d9545e637405 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -37,7 +37,7 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.expand_dims", @@ -48,13 +48,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("topi.transpose", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = transpose(args[0].cast(), - args[1].cast>>()); + args[1].cast>>()); }) .def_packed("topi.flip", [](ffi::PackedArgs args, ffi::Any* rv) { // pass empty seq_lengths tensor to reverse_sequence - *rv = - reverse_sequence(args[0].cast(), Tensor(), args[1].cast()); + *rv = reverse_sequence(args[0].cast(), te::Tensor(), + args[1].cast()); }) .def_packed("topi.reverse_sequence", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -63,13 +63,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.reshape", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = reshape(args[0].cast(), args[1].cast>()); + *rv = reshape(args[0].cast(), args[1].cast>()); }) .def_packed("topi.sliding_window", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = sliding_window(args[0].cast(), args[1].cast(), - args[2].cast>(), - args[3].cast>()); + args[2].cast>(), + args[3].cast>()); }) .def_packed("topi.squeeze", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -77,19 +77,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.concatenate", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = concatenate(args[0].cast>(), args[1].cast()); + *rv = concatenate(args[0].cast>(), args[1].cast()); }) .def_packed("topi.stack", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = stack(args[0].cast>(), args[1].cast()); + *rv = stack(args[0].cast>(), args[1].cast()); }) .def_packed("topi.shape", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = shape(args[0].cast(), args[1].cast()); }) - .def_packed("topi.ndarray_size", + .def_packed("topi.tensor_size", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = ndarray_size(args[0].cast(), args[1].cast()); + *rv = tensor_size(args[0].cast(), args[1].cast()); }) .def_packed("topi.split", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -97,9 +97,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = split_n_sections(args[0].cast(), args[1].cast(), args[2].cast()); } else { - *rv = - split_indices_array(args[0].cast(), - args[1].cast>(), args[2].cast()); + *rv = split_indices_array(args[0].cast(), + args[1].cast>(), + args[2].cast()); } }) .def_packed("topi.layout_transform", @@ -144,7 +144,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.meshgrid", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = meshgrid(args[0].cast>(), args[1].cast()); + *rv = meshgrid(args[0].cast>(), + args[1].cast()); }) .def_packed("topi.repeat", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -153,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.tile", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = tile(args[0].cast(), args[1].cast>()); + *rv = tile(args[0].cast(), args[1].cast>()); }) .def_packed("topi.gather", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -172,9 +173,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.sparse_to_dense", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = - sparse_to_dense(args[0].cast(), args[1].cast>(), - args[2].cast(), args[3].cast()); + *rv = sparse_to_dense(args[0].cast(), + args[1].cast>(), + args[2].cast(), args[3].cast()); }) .def_packed("topi.matmul", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -202,25 +203,25 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = tensordot(args[0].cast(), args[1].cast(), args[2].cast()); } else { - Array axes = args[3].cast>(); + ffi::Array axes = args[3].cast>(); *rv = tensordot(args[0].cast(), args[1].cast(), - args[2].cast>(), axes); + args[2].cast>(), axes); } }) .def_packed( "topi.strided_slice", [](ffi::PackedArgs args, ffi::Any* rv) { - Tensor x = args[0].cast(); - Array begin = args[1].cast>(); - Array end = args[2].cast>(); - Array strides = args[3].cast>(); - Array axes = args[4].cast>(); + te::Tensor x = args[0].cast(); + ffi::Array begin = args[1].cast>(); + ffi::Array end = args[2].cast>(); + ffi::Array strides = args[3].cast>(); + ffi::Array axes = args[4].cast>(); bool assume_inbound = args[6].cast(); if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && IsConstIntArray(x->shape)) { - Array begin_static = args[1].cast>(); - Array end_static = args[2].cast>(); - Array strides_static = args[3].cast>(); + ffi::Array begin_static = args[1].cast>(); + ffi::Array end_static = args[2].cast>(); + ffi::Array strides_static = args[3].cast>(); auto slice_mode = args[5].cast(); if (axes.size()) { *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, @@ -245,7 +246,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("topi.relax_dynamic_strided_slice", [](te::Tensor x, te::Tensor begin, te::Tensor end, te::Tensor strides, - Array output_shape) { + ffi::Array output_shape) { return relax::dynamic_strided_slice(x, begin, end, strides, output_shape); }) .def_packed("topi.one_hot", @@ -266,8 +267,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ k1, k2, super_diag_right_align, sub_diag_right_align); }) .def("topi.adv_index", - [](te::Tensor x, Array indices) { return adv_index(x, indices); }); -}); + [](te::Tensor x, ffi::Array indices) { return adv_index(x, indices); }); +} } // namespace topi } // namespace tvm diff --git a/src/topi/utils.cc b/src/topi/utils.cc index 6e5c997739d7..6bc1570bd196 100644 --- a/src/topi/utils.cc +++ b/src/topi/utils.cc @@ -28,25 +28,25 @@ namespace tvm { namespace topi { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("topi.utils.is_empty_shape", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::detail::is_empty_shape(args[0].cast>()); + *rv = topi::detail::is_empty_shape(args[0].cast>()); }) .def_packed("topi.utils.bilinear_sample_nchw", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nchw( - args[0].cast(), args[1].cast>(), + args[0].cast(), args[1].cast>(), args[2].cast(), args[3].cast()); }) .def_packed("topi.utils.bilinear_sample_nhwc", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nhwc(args[0].cast(), - args[1].cast>(), + args[1].cast>(), args[2].cast(), args[3].cast()); }); -}); +} } // namespace topi } // namespace tvm diff --git a/src/topi/vision.cc b/src/topi/vision.cc index 8e6a5f4cbc06..7babb0591676 100644 --- a/src/topi/vision.cc +++ b/src/topi/vision.cc @@ -31,12 +31,12 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.vision.reorg", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = vision::reorg(args[0].cast(), args[1].cast()); }); -}); +} } // namespace topi } // namespace tvm diff --git a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc index 6f9c9f0f6f7b..febf484f8161 100644 --- a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc @@ -22,30 +22,31 @@ #include "../src/runtime/hexagon/hexagon_buffer.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; TEST(HexagonBuffer, default_scope) { - Optional scope; + ffi::Optional scope; HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); EXPECT_EQ(hb.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); } TEST(HexagonBuffer, ddr_scope) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); EXPECT_EQ(hb.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); } TEST(HexagonBuffer, vtcm_scope) { - Optional scope(String("global.vtcm")); + ffi::Optional scope(ffi::String("global.vtcm")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); EXPECT_EQ(hb.GetStorageScope(), HexagonBuffer::StorageScope::kVTCM); } TEST(HexagonBuffer, invalid_scope) { - Optional scope(String("invalid")); + ffi::Optional scope(ffi::String("invalid")); EXPECT_THROW(HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope), InternalError); } @@ -268,7 +269,7 @@ TEST(HexagonBuffer, macro_copies_overlapping_regions_merged) { } TEST(HexagonBuffer, copy_from) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -281,7 +282,7 @@ TEST(HexagonBuffer, copy_from) { } TEST(HexagonBuffer, copy_from_invalid_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; // HexagonBuffer too small @@ -290,7 +291,7 @@ TEST(HexagonBuffer, copy_from_invalid_size) { } TEST(HexagonBuffer, copy_from_smaller_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; // HexagonBuffer is big @@ -299,25 +300,25 @@ TEST(HexagonBuffer, copy_from_smaller_size) { } TEST(HexagonBuffer, nd) { - Optional def; + ffi::Optional def; HexagonBuffer hb_default(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, def); EXPECT_EQ(hb_default.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); - Optional global(String("global")); + ffi::Optional global(ffi::String("global")); HexagonBuffer hb_global(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, global); EXPECT_EQ(hb_global.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer hb_vtcm(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, vtcm); EXPECT_EQ(hb_vtcm.GetStorageScope(), HexagonBuffer::StorageScope::kVTCM); - Optional invalid(String("invalid")); + ffi::Optional invalid(ffi::String("invalid")); EXPECT_THROW(HexagonBuffer hb_invalid(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, invalid), InternalError); } TEST(HexagonBuffer, nd_copy_from) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -335,10 +336,10 @@ TEST(HexagonBuffer, nd_copy_from) { } TEST(HexagonBuffer, 1d_copy_from_1d) { - Optional global(String("global")); + ffi::Optional global(ffi::String("global")); HexagonBuffer from(8 /* nbytes */, 8 /* alignment */, global); - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer to(8 /* nbytes */, 8 /* alignment */, vtcm); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -352,10 +353,10 @@ TEST(HexagonBuffer, 1d_copy_from_1d) { } TEST(HexagonBuffer, 2d_copy_from_1d) { - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, vtcm); - Optional global(String("global")); + ffi::Optional global(ffi::String("global")); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, global); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -374,10 +375,10 @@ TEST(HexagonBuffer, 2d_copy_from_1d) { } TEST(HexagonBuffer, 1d_copy_from_2d) { - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, vtcm); - Optional global(String("global.vtcm")); + ffi::Optional global(ffi::String("global.vtcm")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, global); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -391,7 +392,7 @@ TEST(HexagonBuffer, 1d_copy_from_2d) { } TEST(HexagonBuffer, nd_copy_from_nd_invalid_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, scope); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); @@ -405,7 +406,7 @@ TEST(HexagonBuffer, nd_copy_from_nd_invalid_size) { } TEST(HexagonBuffer, nd_copy_from_nd_smaller_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, scope); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); @@ -419,7 +420,7 @@ TEST(HexagonBuffer, nd_copy_from_nd_smaller_size) { } TEST(HexagonBuffer, md_copy_from_nd) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb3d(3 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); HexagonBuffer hb4d(4 /* ndim */, 3 /* nbytes */, 8 /* alignment */, scope); @@ -436,7 +437,7 @@ TEST(HexagonBuffer, md_copy_from_nd) { } TEST(HexagonBuffer, copy_to) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); std::vector data_in{0, 1, 2, 3, 4, 5, 6, 7}; @@ -451,7 +452,7 @@ TEST(HexagonBuffer, copy_to) { } TEST(HexagonBuffer, nd_copy_to) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); std::vector data_in{0, 1, 2, 3, 4, 5, 6, 7}; diff --git a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc index 6211bd63dfbc..9c74521091aa 100644 --- a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc @@ -21,6 +21,7 @@ #include "../src/runtime/hexagon/hexagon_device_api.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; @@ -46,10 +47,10 @@ class HexagonDeviceAPITest : public ::testing::Test { int64_t shape1d[1]{256}; int64_t shape2d[2]{256, 256}; int64_t shape3d[3]{256, 256, 256}; - Optional default_scope; - Optional invalid_scope = String("invalid"); - Optional global_scope = String("global"); - Optional global_vtcm_scope = String("global.vtcm"); + ffi::Optional default_scope; + ffi::Optional invalid_scope = ffi::String("invalid"); + ffi::Optional global_scope = ffi::String("global"); + ffi::Optional global_vtcm_scope = ffi::String("global.vtcm"); }; TEST_F(HexagonDeviceAPITest, global) { CHECK(hexapi != nullptr); } diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc index 2e47473f8a17..dd95a8fb37a7 100644 --- a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -21,6 +21,7 @@ #include "../src/runtime/hexagon/hexagon_device_api.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; @@ -56,8 +57,8 @@ class HexagonUserDMATest : public ::testing::Test { uint32_t length = 0x4000; // 16KB const bool ENABLE_BYPASS = true; const bool DISABLE_BYPASS = false; - Optional global_scope = String("global"); - Optional global_vtcm_scope = String("global.vtcm"); + ffi::Optional global_scope = ffi::String("global"); + ffi::Optional global_vtcm_scope = ffi::String("global.vtcm"); }; TEST_F(HexagonUserDMATest, wait) { diff --git a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc index 3cf008c874ab..baa4035e47fb 100644 --- a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc @@ -21,6 +21,7 @@ #include "../src/runtime/hexagon/hexagon_device_api.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; @@ -256,28 +257,28 @@ TEST_F(HexagonVtcmPoolTest, vtcm_alignment) { void* ptr; // Invalid alignments - EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 128 + 1, String("global")), + EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 128 + 1, ffi::String("global")), InternalError); - EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 2048 + 1, String("global")), + EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 2048 + 1, ffi::String("global")), InternalError); // Valid alignments, sizes need to be adjusted - ptr = test_hexbuffs->AllocateHexagonBuffer(1, 128, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(1, 128, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(127, 128, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(127, 128, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(129, 128, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(129, 128, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(1, 2048, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(1, 2048, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(2047, 2048, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(2047, 2048, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(2049, 2048, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(2049, 2048, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; test_hexbuffs.reset(); diff --git a/tests/cpp-runtime/hexagon/run_all_tests.cc b/tests/cpp-runtime/hexagon/run_all_tests.cc index e6793b530172..6ede0f119281 100644 --- a/tests/cpp-runtime/hexagon/run_all_tests.cc +++ b/tests/cpp-runtime/hexagon/run_all_tests.cc @@ -38,7 +38,7 @@ namespace tvm { namespace runtime { namespace hexagon { -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("hexagon.run_all_tests", [](ffi::PackedArgs args, ffi::Any* rv) { // gtest args are passed into this packed func as a singular string @@ -64,7 +64,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ::testing::InitGoogleTest(&argc, argv.data()); *rv = RUN_ALL_TESTS(); }); -}); +} } // namespace hexagon } // namespace runtime diff --git a/tests/cpp-runtime/hexagon/run_unit_tests.cc b/tests/cpp-runtime/hexagon/run_unit_tests.cc index 03f786b58b07..88b04dd963a1 100644 --- a/tests/cpp-runtime/hexagon/run_unit_tests.cc +++ b/tests/cpp-runtime/hexagon/run_unit_tests.cc @@ -80,7 +80,7 @@ class GtestPrinter : public testing::EmptyTestEventListener { std::string GetOutput() { return gtest_out_.str(); } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("hexagon.run_unit_tests", [](ffi::PackedArgs args, ffi::Any* rv) { // gtest args are passed into this packed func as a singular string @@ -118,7 +118,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = gtest_error_code_and_output.str(); delete gprinter; }); -}); +} } // namespace hexagon } // namespace runtime diff --git a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc index 1097a21128e1..0ab2f5ff6855 100644 --- a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc +++ b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc @@ -194,7 +194,7 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { { OpenCLModuleNode module(m_dataSrc, "cl", m_fmap, std::string()); module.Init(); - module.GetFunction("opencl.SetPreCompiledPrograms").value()(tvm::String(bytes)); + module.GetFunction("opencl.SetPreCompiledPrograms").value()(tvm::ffi::String(bytes)); Timestamp comp_start = std::chrono::high_resolution_clock::now(); for (size_t i = 0; i < m_kernelNames.size(); ++i) { OpenCLModuleNode::KTRefEntry e = {i, 1}; diff --git a/tests/cpp-runtime/opencl/opencl_nativeptr.cc b/tests/cpp-runtime/opencl/opencl_nativeptr.cc index 260effadea0b..1694de418b5c 100644 --- a/tests/cpp-runtime/opencl/opencl_nativeptr.cc +++ b/tests/cpp-runtime/opencl/opencl_nativeptr.cc @@ -32,7 +32,7 @@ using namespace tvm::runtime::cl; TEST(OpenCLNativePtr, access_memory) { OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); - auto A = tvm::runtime::NDArray::Empty({128, 128}, {kDLFloat, 32, 1}, {kDLOpenCL, 0}); + auto A = tvm::runtime::Tensor::Empty({128, 128}, {kDLFloat, 32, 1}, {kDLOpenCL, 0}); void* nptr = workspace->GetNativePtr(A); memset(nptr, 0x0, 128 * 128 * 4); } @@ -40,8 +40,8 @@ TEST(OpenCLNativePtr, access_memory) { TEST(OpenCLNatvePtr, data_loop) { OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); - auto cl_arr = tvm::runtime::NDArray::Empty({1024}, {kDLFloat, 32, 1}, {kDLOpenCL, 0}); - auto cpu_arr = tvm::runtime::NDArray::Empty({1024}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cl_arr = tvm::runtime::Tensor::Empty({1024}, {kDLFloat, 32, 1}, {kDLOpenCL, 0}); + auto cpu_arr = tvm::runtime::Tensor::Empty({1024}, {kDLFloat, 32, 1}, {kDLCPU, 0}); std::random_device rdev; std::mt19937 mt(rdev()); diff --git a/tests/cpp-runtime/opencl/texture_copy_test.cc b/tests/cpp-runtime/opencl/texture_copy_test.cc index 61d9044b6d86..001e65b90126 100644 --- a/tests/cpp-runtime/opencl/texture_copy_test.cc +++ b/tests/cpp-runtime/opencl/texture_copy_test.cc @@ -61,10 +61,10 @@ TEST(TextureCopy, HostDeviceRT) { (void)tvm::runtime::memory::MemoryManager::GetOrCreateAllocator( thr->device, tvm::runtime::memory::AllocatorType::kPooled); std::vector shape{16, 16, 4}; - auto cpu_arr0 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto cpu_arr1 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - String mem_scope = "global.texture"; - auto opencl_txarr0 = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLOpenCL, 0}, mem_scope); + auto cpu_arr0 = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr1 = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + ffi::String mem_scope = "global.texture"; + auto opencl_txarr0 = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLOpenCL, 0}, mem_scope); size_t size = 1; for (size_t i = 0; i < shape.size(); ++i) { @@ -94,19 +94,19 @@ TEST_F(TextureCopyTest, ViewBufferAsBuffer) { using namespace tvm; std::vector shape{1, 16, 16, 8}; std::vector same_shape{1, 8, 16, 16}; - auto cpu_arr = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto cpu_arr_ret = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr_ret = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - String mem_scope = "global"; + ffi::String mem_scope = "global"; DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); auto stor = Storage(buffer, allocator); - auto opencl_memobj = stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, mem_scope); + auto opencl_memobj = stor->AllocTensorScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, mem_scope); auto opencl_memview = - stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, mem_scope); + stor->AllocTensorScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, mem_scope); std::random_device dev; std::mt19937 mt(dev()); @@ -153,17 +153,17 @@ TEST_F(TextureCopyTest, ViewBufferAsImage) { // Shape that doesn't cause padding for image row std::vector shape{1, 16, 16, 8, 4}; std::vector same_shape{1, 8, 16, 16, 4}; - auto cpu_arr = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto cpu_arr_ret = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr_ret = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); auto stor = Storage(buffer, allocator); - auto opencl_buf_obj = stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global"); + auto opencl_buf_obj = stor->AllocTensorScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global"); auto opencl_img_obj = - stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocTensorScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global.texture"); std::random_device dev; std::mt19937 mt(dev()); @@ -210,8 +210,8 @@ TEST_F(TextureCopyTest, ViewImageAsBuffer) { // Shape that doesn't cause padding for image row std::vector shape{1, 16, 16, 8, 4}; std::vector same_shape{1, 8, 16, 16, 4}; - auto cpu_arr = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto cpu_arr_ret = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr_ret = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); @@ -219,9 +219,9 @@ TEST_F(TextureCopyTest, ViewImageAsBuffer) { auto stor = Storage(buffer, allocator); auto opencl_img_obj = - stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocTensorScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global.texture"); auto opencl_buf_obj = - stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global"); + stor->AllocTensorScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global"); std::random_device dev; std::mt19937 mt(dev()); @@ -268,8 +268,8 @@ TEST_F(TextureCopyTest, ViewImageAsImage) { // Shape that doesn't cause padding for image row std::vector shape{1, 16, 16, 8, 4}; std::vector same_shape{1, 8, 16, 16, 4}; - auto cpu_arr = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto cpu_arr_ret = runtime::NDArray::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr_ret = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); @@ -277,9 +277,9 @@ TEST_F(TextureCopyTest, ViewImageAsImage) { auto stor = Storage(buffer, allocator); auto opencl_img_obj_1 = - stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocTensorScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global.texture"); auto opencl_img_obj_2 = - stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocTensorScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global.texture"); std::random_device dev; std::mt19937 mt(dev()); diff --git a/tests/cpp/data_type_rewriter_test.cc b/tests/cpp/data_type_rewriter_test.cc index c5e6d4f75843..1eec334344b3 100644 --- a/tests/cpp/data_type_rewriter_test.cc +++ b/tests/cpp/data_type_rewriter_test.cc @@ -37,7 +37,7 @@ TYPED_TEST_SUITE(DataTypeLegalizerBinaryOp, BinaryOpTypes); TYPED_TEST(DataTypeLegalizerBinaryOp, Basic) { using RefType = TypeParam; using NodeType = typename RefType::ContainerType; - auto node = make_object(); + auto node = ffi::make_object(); node->a = Var("a", DataType::Int(32)); node->b = IntImm(DataType::Int(64), 2); DataTypeLegalizer legalizer; @@ -48,7 +48,7 @@ TYPED_TEST(DataTypeLegalizerBinaryOp, Basic) { } TEST(DataTypeLegalizer, Select) { - auto node = make_object(); + auto node = ffi::make_object(); node->condition = Var("cond", DataType::Bool()); node->true_value = Var("a", DataType::Int(64)); node->false_value = IntImm(DataType::Int(32), 2); @@ -73,8 +73,8 @@ TEST(DataTypeLegalizer, IfThenElse) { } TEST(DataTypeLegalizer, Block) { - auto block_node = make_object(); - auto iter_var_node = make_object(); + auto block_node = ffi::make_object(); + auto iter_var_node = ffi::make_object(); iter_var_node->var = Var("i", DataType::Int(32)); iter_var_node->dom = Range::FromMinExtent(IntImm(DataType::Int(64), 0), IntImm(DataType::Int(64), 10)); @@ -84,12 +84,12 @@ TEST(DataTypeLegalizer, Block) { block_node->writes = {}; block_node->name_hint = "block"; block_node->body = Evaluate(Integer(0)); - auto block_realize_node = make_object(); + auto block_realize_node = ffi::make_object(); auto loop_var = Var("i", DataType::Int(32)); block_realize_node->iter_values = {loop_var}; block_realize_node->predicate = const_true(); block_realize_node->block = Block(block_node); - auto for_node = make_object(); + auto for_node = ffi::make_object(); for_node->loop_var = loop_var; for_node->min = IntImm(DataType::Int(64), 0); for_node->extent = IntImm(DataType::Int(64), 10); @@ -113,7 +113,7 @@ TEST(DataTypeLegalizer, Block) { } TEST(DataTypeLegalizer, For) { - auto node = make_object(); + auto node = ffi::make_object(); node->body = Evaluate(Integer(0)); node->loop_var = Var("i", DataType::Int(32)); node->min = IntImm(DataType::Int(64), 0); @@ -126,7 +126,7 @@ TEST(DataTypeLegalizer, For) { } TEST(DataTypeLegalizer, Ramp) { - auto node = make_object(); + auto node = ffi::make_object(); node->base = IntImm(DataType::Int(64), 0); node->stride = IntImm(DataType::Int(32), 1); int lanes = 4; diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 579479ccc0e5..05fbd5ce548c 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -51,5 +51,5 @@ TEST(ExprNodeRef, Basic) { Var x("x"); PrimExpr z = max(x + 1 + 2, 100); const tir::MaxNode* op = z.as(); - ICHECK(GetRef(op).same_as(z)); + ICHECK(ffi::GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 348792d6ff88..ec7b4111d240 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -215,7 +215,7 @@ TEST(IRF, StmtMutator) { Stmt body2 = Evaluate(1); Stmt bref = body.as()->body; auto* extentptr = body.as()->extents.get(); - Array arr{std::move(body), body2, body2}; + ffi::Array arr{std::move(body), body2, body2}; auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr.get() == arrptr); @@ -228,9 +228,9 @@ TEST(IRF, StmtMutator) { ICHECK(bref.as()->value.as()); } { - Array arr{fmakealloc()}; + ffi::Array arr{fmakealloc()}; // mutate array get reference by another one, triiger copy. - Array arr2 = arr; + ffi::Array arr2 = arr; auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr.get() != arrptr); @@ -242,7 +242,7 @@ TEST(IRF, StmtMutator) { ICHECK(arr2.get() == arr.get()); } { - Array arr{fmakeif()}; + ffi::Array arr{fmakeif()}; arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr[0].as()->else_case.as()->value.same_as(x)); // mutate but no content change. @@ -332,7 +332,7 @@ TEST(IRF, Substitute) { // test substitute buffer var Var y = x.copy_with_suffix("subst"); BufferLoad buffer_load = fmaketest(); - auto f_subst = [&](const Var& var) -> Optional { + auto f_subst = [&](const Var& var) -> ffi::Optional { if (var.same_as(x)) { return y; } @@ -345,7 +345,7 @@ TEST(IRF, Substitute) { { // test identity substitution PrimExpr expr = fmaketest(); - auto f_subst = [&](const Var& var) -> Optional { return var; }; + auto f_subst = [&](const Var& var) -> ffi::Optional { return var; }; PrimExpr new_expr = Substitute(expr, f_subst); // the expression is not changed ICHECK(new_expr.same_as(expr)); diff --git a/tests/cpp/ndarray_test.cc b/tests/cpp/ndarray_test.cc index 57ad3ba90b40..c2452f9146b1 100644 --- a/tests/cpp/ndarray_test.cc +++ b/tests/cpp/ndarray_test.cc @@ -19,12 +19,12 @@ #include #include -#include +#include using namespace tvm; -TEST(NDArrayTest, IsContiguous_ContiguousStride) { - auto array = runtime::NDArray::Empty({5, 10}, DataType::Float(32), {kDLCPU}); +TEST(TensorTest, IsContiguous_ContiguousStride) { + auto array = runtime::Tensor::Empty({5, 10}, DataType::Float(32), {kDLCPU}); DLManagedTensor* managed_tensor = array.ToDLPack(); int64_t strides[] = {10, 1}; @@ -35,8 +35,8 @@ TEST(NDArrayTest, IsContiguous_ContiguousStride) { managed_tensor->deleter(managed_tensor); } -TEST(NDArrayTest, IsContiguous_NullStride) { - auto array = runtime::NDArray::Empty({5, 10}, DataType::Float(32), {kDLCPU}); +TEST(TensorTest, IsContiguous_NullStride) { + auto array = runtime::Tensor::Empty({5, 10}, DataType::Float(32), {kDLCPU}); DLManagedTensor* managed_tensor = array.ToDLPack(); managed_tensor->dl_tensor.strides = nullptr; @@ -46,8 +46,8 @@ TEST(NDArrayTest, IsContiguous_NullStride) { managed_tensor->deleter(managed_tensor); } -TEST(NDArrayTest, IsContiguous_AnyStrideForSingular) { - auto array = runtime::NDArray::Empty({5, 1, 10}, DataType::Float(32), {kDLCPU}); +TEST(TensorTest, IsContiguous_AnyStrideForSingular) { + auto array = runtime::Tensor::Empty({5, 1, 10}, DataType::Float(32), {kDLCPU}); DLManagedTensor* managed_tensor = array.ToDLPack(); int64_t strides[] = {10, 1, 1}; // strides[1] is normalized to 1 because shape[1] == 1. @@ -59,8 +59,8 @@ TEST(NDArrayTest, IsContiguous_AnyStrideForSingular) { managed_tensor->deleter(managed_tensor); } -TEST(NDArrayTest, IsContiguous_UncontiguousStride) { - auto array = runtime::NDArray::Empty({5, 1, 10}, DataType::Float(32), {kDLCPU}); +TEST(TensorTest, IsContiguous_UncontiguousStride) { + auto array = runtime::Tensor::Empty({5, 1, 10}, DataType::Float(32), {kDLCPU}); DLManagedTensor* managed_tensor = array.ToDLPack(); int64_t strides[] = {1, 1, 1}; diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 644a80664fe1..c9628daf0d80 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -138,9 +138,9 @@ TEST(NestedMsg, Equal) { EXPECT_FALSE(Equal(M(std::nullopt), M(x), fequal)); - EXPECT_FALSE(Equal(M(x), M(Array({x})), fequal)); + EXPECT_FALSE(Equal(M(x), M(ffi::Array({x})), fequal)); - EXPECT_FALSE(Equal(M(Array({x})), M(x), fequal)); + EXPECT_FALSE(Equal(M(ffi::Array({x})), M(x), fequal)); } TEST(NestedMsg, MapAndDecompose) { @@ -232,7 +232,7 @@ TEST(NestedMsg, NestedMsgToExpr) { relax::Var x("x", sf0), y("y", sf0), z("z", sf0); NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; - auto expr = NestedMsgToExpr(msg, [&](Optional leaf) { + auto expr = NestedMsgToExpr(msg, [&](ffi::Optional leaf) { ICHECK(leaf.defined()); int value = leaf.value().IntValue(); switch (value) { @@ -251,7 +251,7 @@ TEST(NestedMsg, NestedMsgToExpr) { // test simplified relax::Var t("t", sf1); NestedMsg msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)}; - auto expr1 = NestedMsgToExpr(msg1, [](Optional leaf) { return leaf.value(); }); + auto expr1 = NestedMsgToExpr(msg1, [](ffi::Optional leaf) { return leaf.value(); }); EXPECT_TRUE(StructuralEqual()(expr1, t)); } diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index be69d77ccc73..fc02fb036bcf 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -31,39 +31,36 @@ class ObjBase : public Object { public: // dynamically allocate slow static constexpr const uint32_t _type_child_slots = 1; - static constexpr const char* _type_key = "test.ObjBase"; - TVM_DECLARE_BASE_OBJECT_INFO(ObjBase, Object); + TVM_FFI_DECLARE_OBJECT_INFO("test.ObjBase", ObjBase, Object); }; class ObjA : public ObjBase { public: static constexpr const uint32_t _type_child_slots = 0; - static constexpr const char* _type_key = "test.ObjA"; - TVM_DECLARE_BASE_OBJECT_INFO(ObjA, ObjBase); + TVM_FFI_DECLARE_OBJECT_INFO("test.ObjA", ObjA, ObjBase); }; class ObjB : public ObjBase { public: static constexpr const uint32_t _type_child_slots = 0; - static constexpr const char* _type_key = "test.ObjB"; - TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase); + TVM_FFI_DECLARE_OBJECT_INFO("test.ObjB", ObjB, ObjBase); }; class ObjAA : public ObjA { public: - static constexpr const char* _type_key = "test.ObjAA"; - TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.ObjAA", ObjAA, ObjA); }; } // namespace test } // namespace tvm TEST(ObjectHierachy, Basic) { + using namespace tvm; using namespace tvm::runtime; using namespace tvm::test; using namespace tvm::ffi; - ObjectRef refA(make_object()); + ObjectRef refA(ffi::make_object()); ICHECK_EQ(refA->type_index(), ObjA::RuntimeTypeIndex()); ICHECK(refA.as() != nullptr); ICHECK(refA.as() != nullptr); @@ -71,7 +68,7 @@ TEST(ObjectHierachy, Basic) { ICHECK(refA.as() == nullptr); ICHECK(refA.as() == nullptr); - ObjectRef refAA(make_object()); + ObjectRef refAA(ffi::make_object()); ICHECK_EQ(refAA->type_index(), ObjAA::RuntimeTypeIndex()); ICHECK(refAA.as() != nullptr); ICHECK(refAA.as() != nullptr); @@ -79,7 +76,7 @@ TEST(ObjectHierachy, Basic) { ICHECK(refAA.as() != nullptr); ICHECK(refAA.as() == nullptr); - ObjectRef refB(make_object()); + ObjectRef refB(ffi::make_object()); ICHECK_EQ(refB->type_index(), ObjB::RuntimeTypeIndex()); ICHECK(refB.as() != nullptr); ICHECK(refB.as() != nullptr); diff --git a/tests/cpp/support/scalars_test.cc b/tests/cpp/support/scalars_test.cc index 52bd2dc148c8..12a5145f2145 100644 --- a/tests/cpp/support/scalars_test.cc +++ b/tests/cpp/support/scalars_test.cc @@ -28,17 +28,17 @@ namespace { // Note that functional testing is via test_ir_parser.py and test_ir_text_printer.py. // Here we just check handling which is difficult to test via the standard Python API. -TEST(Scalars, IntImmToNDArray_Unsupported) { - ASSERT_THROW(IntImmToNDArray(IntImm(DataType::Int(15), 42)), runtime::InternalError); +TEST(Scalars, IntImmToTensor_Unsupported) { + ASSERT_THROW(IntImmToTensor(IntImm(DataType::Int(15), 42)), runtime::InternalError); } -TEST(Scalars, FloatImmtoNDArray_Unsupported) { - ASSERT_THROW(FloatImmToNDArray(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); +TEST(Scalars, FloatImmtoTensor_Unsupported) { + ASSERT_THROW(FloatImmToTensor(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); } -TEST(Scalars, NDArrayScalarToString_Unsupported) { - auto ndarray = runtime::NDArray::Empty({}, DataType::Int(8), {DLDeviceType::kDLCPU, 0}); - ASSERT_THROW(NDArrayScalarToString(ndarray), runtime::InternalError); +TEST(Scalars, TensorScalarToString_Unsupported) { + auto ndarray = runtime::Tensor::Empty({}, DataType::Int(8), {DLDeviceType::kDLCPU, 0}); + ASSERT_THROW(TensorScalarToString(ndarray), runtime::InternalError); } TEST(Scalars, IntImmToString_Unsupported) { diff --git a/tests/cpp/target/parsers/aprofile_test.cc b/tests/cpp/target/parsers/aprofile_test.cc index 26f52f4938a8..1e74b3f71599 100644 --- a/tests/cpp/target/parsers/aprofile_test.cc +++ b/tests/cpp/target/parsers/aprofile_test.cc @@ -44,9 +44,9 @@ static bool CheckArchitectureAvailability() { #if TVM_LLVM_VERSION > 120 auto llvm_instance = std::make_unique(); codegen::LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); - Array targets = llvm_backend.GetAllLLVMTargets(); + ffi::Array targets = llvm_backend.GetAllLLVMTargets(); int expected_target_count = 0; - for (String target : targets) { + for (ffi::String target : targets) { if (target == "aarch64" || target == "arm") { expected_target_count += 1; } @@ -74,9 +74,10 @@ class AProfileParser : public ::testing::Test { class AProfileParserTestWithParam : public AProfileParser, public testing::WithParamInterface {}; -static TargetFeatures ParseTargetWithAttrs(String mcpu, String mtriple, Array mattr) { +static TargetFeatures ParseTargetWithAttrs(ffi::String mcpu, ffi::String mtriple, + ffi::Array mattr) { TargetJSON target_json = { - {"kind", String("llvm")}, + {"kind", ffi::String("llvm")}, {"mtriple", mtriple}, {"mattr", mattr}, }; @@ -93,8 +94,8 @@ std::string FloatToStringWithoutTrailingZeros(float value) { } TEST_F(AProfileParser, ParseTargetKeys) { - TargetJSON target = ParseTarget({{"kind", String("llvm")}}); - Array keys = Downcast>(target.at("keys")); + TargetJSON target = ParseTarget({{"kind", ffi::String("llvm")}}); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "arm_cpu"); ASSERT_EQ(keys[1], "cpu"); @@ -102,11 +103,11 @@ TEST_F(AProfileParser, ParseTargetKeys) { TEST_F(AProfileParser, ParseTargetWithExistingKeys) { TargetJSON target = ParseTarget({ - {"kind", String("llvm")}, - {"keys", Array{"cpu"}}, + {"kind", ffi::String("llvm")}, + {"keys", ffi::Array{"cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); @@ -114,18 +115,18 @@ TEST_F(AProfileParser, ParseTargetWithExistingKeys) { TEST_F(AProfileParser, ParseTargetWithDuplicateKey) { TargetJSON target = ParseTarget({ - {"kind", String("llvm")}, - {"keys", Array{"cpu", "arm_cpu"}}, + {"kind", ffi::String("llvm")}, + {"keys", ffi::Array{"cpu", "arm_cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); } TEST_F(AProfileParser, ParseTargetDefaults) { - TargetJSON target = ParseTarget({{"kind", String("llvm")}}); + TargetJSON target = ParseTarget({{"kind", ffi::String("llvm")}}); TargetFeatures features = Downcast(target.at("features")); ASSERT_EQ(Downcast(features.at("is_aarch64")), false); @@ -157,8 +158,8 @@ TEST_F(AProfileParser, IsAArch32Triple) { TEST_F(AProfileParser, IsAArch32BlankCPU) { TargetJSON target = ParseTarget({ - {"kind", String("llvm")}, - {"mtriple", String("arm-unknown-linux-gnu")}, + {"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("arm-unknown-linux-gnu")}, }); TargetFeatures features = Downcast(target.at("features")); ASSERT_EQ(IsArch(target), true); @@ -396,7 +397,7 @@ TEST_F(AProfileParser, UnexpectedTargetKind) { EXPECT_THROW( { try { - ParseTarget({{"kind", String("c")}}); + ParseTarget({{"kind", ffi::String("c")}}); } catch (const tvm::InternalError& e) { EXPECT_THAT(e.what(), HasSubstr("Expected target kind 'llvm', but got 'c'")); throw; @@ -409,7 +410,7 @@ TEST(AProfileParserInvalid, LLVMUnsupportedArchitecture) { if (has_aarch64_and_arm_targets) { GTEST_SKIP() << "LLVM has been compiled for the correct targets."; } - TargetJSON target = ParseTarget({{"kind", String("llvm")}}); + TargetJSON target = ParseTarget({{"kind", ffi::String("llvm")}}); TargetFeatures features = Downcast(target.at("features")); for (auto feature : features) { ASSERT_EQ(Downcast(feature.second), false); diff --git a/tests/cpp/target/parsers/mprofile_test.cc b/tests/cpp/target/parsers/mprofile_test.cc index 97fb227e4190..19baf006d895 100644 --- a/tests/cpp/target/parsers/mprofile_test.cc +++ b/tests/cpp/target/parsers/mprofile_test.cc @@ -37,30 +37,30 @@ class MProfileParserMVECPUs : public testing::TestWithParam {}; class MProfileParserDSPCPUs : public testing::TestWithParam {}; class MProfileParserNoExtensions : public testing::TestWithParam {}; -static TargetFeatures ParseTargetWithAttrs(String mcpu, Array mattr) { +static TargetFeatures ParseTargetWithAttrs(ffi::String mcpu, ffi::Array mattr) { return ParseTarget({{"mcpu", mcpu}, {"mattr", mattr}}); } TEST(MProfileParser, CheckIsNotArch) { - String mcpu = "cake"; + ffi::String mcpu = "cake"; TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), false); } TEST_P(MProfileParserMVECPUs, CheckIsArch) { - String mcpu = GetParam(); + ffi::String mcpu = GetParam(); TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), true); } TEST_P(MProfileParserDSPCPUs, CheckIsArch) { - String mcpu = GetParam(); + ffi::String mcpu = GetParam(); TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), true); } TEST_P(MProfileParserNoExtensions, CheckIsArch) { - String mcpu = GetParam(); + ffi::String mcpu = GetParam(); TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), true); } @@ -68,7 +68,7 @@ TEST_P(MProfileParserNoExtensions, CheckIsArch) { TEST(MProfileParser, ParseTarget) { TargetJSON target = ParseTarget({}); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "arm_cpu"); ASSERT_EQ(keys[1], "cpu"); @@ -79,10 +79,10 @@ TEST(MProfileParser, ParseTarget) { TEST(MProfileParser, ParseTargetWithExistingKeys) { TargetJSON target = ParseTarget({ - {"keys", Array{"cpu"}}, + {"keys", ffi::Array{"cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); @@ -90,10 +90,10 @@ TEST(MProfileParser, ParseTargetWithExistingKeys) { TEST(MProfileParser, ParseTargetWithDuplicateKey) { TargetJSON target = ParseTarget({ - {"keys", Array{"cpu", "arm_cpu"}}, + {"keys", ffi::Array{"cpu", "arm_cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); diff --git a/tests/cpp/target/virtual_device_test.cc b/tests/cpp/target/virtual_device_test.cc index d982a8ae2153..4f4b945cae8f 100644 --- a/tests/cpp/target/virtual_device_test.cc +++ b/tests/cpp/target/virtual_device_test.cc @@ -29,7 +29,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(kDLCUDA, 3); VirtualDevice rhs = VirtualDevice(kDLCUDA, -1, target_a, "global"); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -38,7 +38,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(kDLCUDA, -1, target_a, "global"); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -47,7 +47,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(kDLCUDA); VirtualDevice rhs = VirtualDevice(kDLCUDA, 2, target_a); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 2, target_a); EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -56,7 +56,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "global"); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = rhs; EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -67,25 +67,25 @@ TEST(VirtualDevice, Join_Undefined) { { VirtualDevice lhs = VirtualDevice(kDLCUDA); VirtualDevice rhs = VirtualDevice(kDLCPU); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } { VirtualDevice lhs = VirtualDevice(kDLCUDA, 3); VirtualDevice rhs = VirtualDevice(kDLCUDA, 4); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } { VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda")); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda")); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } { VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "local"); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "global"); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 17e3cae4ad18..ba959672a8ea 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -32,38 +32,37 @@ using namespace tvm; TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .set_attr("Attr1", "Value1") .add_attr_option("my_bool") - .add_attr_option>("your_names") - .add_attr_option>("her_maps"); + .add_attr_option>("your_names") + .add_attr_option>("her_maps"); TargetJSON TestTargetParser(TargetJSON target) { - String mcpu = Downcast(target.at("mcpu")); - target.Set("mcpu", String("super_") + mcpu); - target.Set("keys", Array({"super"})); - target.Set("features", Map{{"test", true}}); + ffi::String mcpu = Downcast(target.at("mcpu")); + target.Set("mcpu", ffi::String("super_") + mcpu); + target.Set("keys", ffi::Array({"super"})); + target.Set("features", ffi::Map{{"test", true}}); return target; } -Map TestAttrsPreProcessor(Map attrs) { - attrs.Set("mattr", String("woof")); +ffi::Map TestAttrsPreProcessor(ffi::Map attrs) { + attrs.Set("mattr", ffi::String("woof")); return attrs; } TVM_REGISTER_TARGET_KIND("TestTargetParser", kDLCPU) - .add_attr_option("mattr") - .add_attr_option("mcpu") + .add_attr_option("mattr") + .add_attr_option("mcpu") .set_default_keys({"cpu"}) .set_target_parser(TestTargetParser); TVM_REGISTER_TARGET_KIND("TestAttrsPreprocessor", kDLCPU) - .add_attr_option("mattr") + .add_attr_option("mattr") .set_default_keys({"cpu"}) - .set_attrs_preprocessor(TestAttrsPreProcessor); + .set_target_parser(TestAttrsPreProcessor); TVM_REGISTER_TARGET_KIND("TestClashingPreprocessor", kDLCPU) - .add_attr_option("mattr") - .add_attr_option("mcpu") + .add_attr_option("mattr") + .add_attr_option("mcpu") .set_default_keys({"cpu"}) - .set_attrs_preprocessor(TestAttrsPreProcessor) .set_target_parser(TestTargetParser); TEST(TargetKind, GetAttrMap) { @@ -74,13 +73,13 @@ TEST(TargetKind, GetAttrMap) { } TEST(TargetCreation, NestedConfig) { - Map config = { + ffi::Map config = { {"my_bool", true}, - {"your_names", Array{"junru", "jian"}}, - {"kind", String("TestTargetKind")}, + {"your_names", ffi::Array{"junru", "jian"}}, + {"kind", ffi::String("TestTargetKind")}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -92,25 +91,27 @@ TEST(TargetCreation, NestedConfig) { ICHECK(target->keys.empty()); bool my_bool = target->GetAttr("my_bool").value(); ICHECK_EQ(my_bool, true); - Array your_names = target->GetAttr>("your_names").value(); + ffi::Array your_names = + target->GetAttr>("your_names").value(); ICHECK_EQ(your_names.size(), 2U); ICHECK_EQ(your_names[0], "junru"); ICHECK_EQ(your_names[1], "jian"); - Map her_maps = target->GetAttr>("her_maps").value(); + ffi::Map her_maps = + target->GetAttr>("her_maps").value(); ICHECK_EQ(her_maps.size(), 2U); ICHECK_EQ(her_maps["a"], 1); ICHECK_EQ(her_maps["b"], 2); } TEST(TargetCreationFail, UnrecognizedConfigOption) { - Map config = { + ffi::Map config = { {"my_bool", true}, - {"your_names", Array{"junru", "jian"}}, - {"kind", String("TestTargetKind")}, + {"your_names", ffi::Array{"junru", "jian"}}, + {"kind", ffi::String("TestTargetKind")}, {"bad", ObjectRef(nullptr)}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -126,13 +127,13 @@ TEST(TargetCreationFail, UnrecognizedConfigOption) { } TEST(TargetCreationFail, TypeMismatch) { - Map config = { - {"my_bool", String("true")}, - {"your_names", Array{"junru", "jian"}}, - {"kind", String("TestTargetKind")}, + ffi::Map config = { + {"my_bool", ffi::String("true")}, + {"your_names", ffi::Array{"junru", "jian"}}, + {"kind", ffi::String("TestTargetKind")}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -148,12 +149,12 @@ TEST(TargetCreationFail, TypeMismatch) { } TEST(TargetCreationFail, TargetKindNotFound) { - Map config = { + ffi::Map config = { {"my_bool", "true"}, - {"your_names", Array{"junru", "jian"}}, + {"your_names", ffi::Array{"junru", "jian"}}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -170,7 +171,7 @@ TEST(TargetCreationFail, TargetKindNotFound) { TEST(TargetCreation, TargetParser) { Target test_target("TestTargetParser -mcpu=woof"); - ASSERT_EQ(test_target->GetAttr("mcpu").value(), "super_woof"); + ASSERT_EQ(test_target->GetAttr("mcpu").value(), "super_woof"); ASSERT_EQ(test_target->keys.size(), 1); ASSERT_EQ(test_target->keys[0], "super"); } @@ -185,10 +186,10 @@ TEST(TargetCreation, TargetFeatures) { } TEST(TargetCreation, TargetFeaturesBeforeParser) { - Map features = {{"test", true}}; - Map config = { - {"kind", String("TestTargetParser")}, - {"mcpu", String("woof")}, + ffi::Map features = {{"test", true}}; + ffi::Map config = { + {"kind", ffi::String("TestTargetParser")}, + {"mcpu", ffi::String("woof")}, {"features", features}, }; EXPECT_THROW(Target test(config), ffi::Error); @@ -196,53 +197,56 @@ TEST(TargetCreation, TargetFeaturesBeforeParser) { TEST(TargetCreation, TargetAttrsPreProcessor) { Target test_target("TestAttrsPreprocessor -mattr=cake"); - ASSERT_EQ(test_target->GetAttr("mattr").value(), "woof"); + ASSERT_EQ(test_target->GetAttr("mattr").value(), "woof"); } -TEST(TargetCreation, ClashingTargetProcessing) { - EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), ffi::Error); +TEST(TargetCreation, TargetParserProcessing) { + Target test_target("TestClashingPreprocessor -mcpu=woof -mattr=cake"); + ASSERT_EQ(test_target->GetAttr("mcpu").value(), "super_woof"); + ASSERT_EQ(test_target->GetAttr("mattr").value(), "cake"); } TVM_REGISTER_TARGET_KIND("TestStringKind", kDLCPU) - .add_attr_option("single") - .add_attr_option>("array") - .add_attr_option>>("nested-array") - .add_attr_option>>>("nested2-array"); + .add_attr_option("single") + .add_attr_option>("array") + .add_attr_option>>("nested-array") + .add_attr_option>>>("nested2-array"); TEST(TargetCreation, ProcessStrings) { Target test_target1("TestStringKind -single='\\'string with single quote'"); - ASSERT_TRUE(test_target1->GetAttr("single")); - String string1 = test_target1->GetAttr("single").value(); + ASSERT_TRUE(test_target1->GetAttr("single")); + ffi::String string1 = test_target1->GetAttr("single").value(); ASSERT_EQ(string1, "'string with single quote"); Target test_target2("TestStringKind -single='\\\'\\\\\\'blah\\\\\\'\\\''"); - ASSERT_TRUE(test_target2->GetAttr("single")); - String string2 = test_target2->GetAttr("single").value(); + ASSERT_TRUE(test_target2->GetAttr("single")); + ffi::String string2 = test_target2->GetAttr("single").value(); ASSERT_EQ(string2, "'\\\'blah\\\''"); Target test_target3("TestStringKind -array=-danny,-sammy=1,-kirby='string with space'"); - ASSERT_TRUE(test_target3->GetAttr>("array")); - Array array3 = test_target3->GetAttr>("array").value(); + ASSERT_TRUE(test_target3->GetAttr>("array")); + ffi::Array array3 = test_target3->GetAttr>("array").value(); ASSERT_EQ(array3[0], "-danny"); ASSERT_EQ(array3[1], "-sammy=1"); ASSERT_EQ(array3[2], "-kirby='string with space'"); Target test_target4("TestStringKind -array='fred, foo, bar',baz"); - ASSERT_TRUE(test_target4->GetAttr>("array")); - Array array4 = test_target4->GetAttr>("array").value(); + ASSERT_TRUE(test_target4->GetAttr>("array")); + ffi::Array array4 = test_target4->GetAttr>("array").value(); ASSERT_EQ(array4[0], "fred, foo, bar"); ASSERT_EQ(array4[1], "baz"); Target test_target5("TestStringKind -array='fr\\'ed','f\\'oo',' bar,baz '"); - ASSERT_TRUE(test_target5->GetAttr>("array")); - Array array5 = test_target5->GetAttr>("array").value(); + ASSERT_TRUE(test_target5->GetAttr>("array")); + ffi::Array array5 = test_target5->GetAttr>("array").value(); ASSERT_EQ(array5[0], "fr'ed"); ASSERT_EQ(array5[1], "f'oo"); ASSERT_EQ(array5[2], "bar,baz"); Target test_target6("TestStringKind -nested-array='foo0,foo1,foo2','bar0,bar1,bar2','baz0,baz1'"); - ASSERT_TRUE(test_target6->GetAttr>>("nested-array")); - Array> array6 = test_target6->GetAttr>>("nested-array").value(); + ASSERT_TRUE(test_target6->GetAttr>>("nested-array")); + ffi::Array> array6 = + test_target6->GetAttr>>("nested-array").value(); ASSERT_EQ(array6[0][0], "foo0"); ASSERT_EQ(array6[0][1], "foo1"); ASSERT_EQ(array6[0][2], "foo2"); @@ -257,9 +261,11 @@ TEST(TargetCreation, ProcessStrings) { "'\\'foo0,foo1\\',\\'bar0,bar1\\',\\'baz0,baz1\\''," "'\\'zing0,zing1\\',\\'fred\\''"); - ASSERT_TRUE(test_target7->GetAttr>>>("nested2-array")); - Array>> array7 = - test_target7->GetAttr>>>("nested2-array").value(); + ASSERT_TRUE( + test_target7->GetAttr>>>("nested2-array")); + ffi::Array>> array7 = + test_target7->GetAttr>>>("nested2-array") + .value(); // { // {foo0, foo1}, // {bar0, bar1}, @@ -449,8 +455,8 @@ TEST(TargetCreation, LLVMCommandLineSaveRestore) { } TEST(TargetCreation, DetectSystemTriple) { - Map config = { - {"kind", String("llvm")}, + ffi::Map config = { + {"kind", ffi::String("llvm")}, }; Target target = Target(config); @@ -461,17 +467,17 @@ TEST(TargetCreation, DetectSystemTriple) { GTEST_SKIP() << "LLVM is not available, skipping test"; } - Optional mtriple = target->GetAttr("mtriple"); - ASSERT_TRUE(mtriple.value() == (*pf)().cast()); + ffi::Optional mtriple = target->GetAttr("mtriple"); + ASSERT_TRUE(mtriple.value() == (*pf)().cast()); } #endif TEST(TargetCreation, DeduplicateKeys) { - Map config = { - {"kind", String("llvm")}, - {"keys", Array{"cpu", "arm_cpu"}}, - {"device", String("arm_cpu")}, + ffi::Map config = { + {"kind", ffi::String("llvm")}, + {"keys", ffi::Array{"cpu", "arm_cpu"}}, + {"device", ffi::String("arm_cpu")}, }; Target target = Target(config); ICHECK_EQ(target->kind, TargetKind::Get("llvm").value()); @@ -480,17 +486,17 @@ TEST(TargetCreation, DeduplicateKeys) { ICHECK_EQ(target->keys[0], "cpu"); ICHECK_EQ(target->keys[1], "arm_cpu"); ICHECK_EQ(target->attrs.size(), 2U); - ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); + ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); } TEST(TargetKindRegistry, ListTargetKinds) { - Array names = TargetKindRegEntry::ListTargetKinds(); + ffi::Array names = TargetKindRegEntry::ListTargetKinds(); ICHECK_EQ(names.empty(), false); ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } TEST(TargetKindRegistry, ListTargetOptions) { TargetKind llvm = TargetKind::Get("llvm").value(); - Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); + ffi::Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); ICHECK_EQ(attrs.empty(), false); } diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 6c42972d9430..6ae6deb50d2e 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -167,8 +167,8 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) { TEST(ScalableDataType, TestScalableBool) { tvm::DataType scalable_type = tvm::DataType::Bool(4, true); - ASSERT_EQ(scalable_type.code(), kDLUInt); - ASSERT_EQ(scalable_type.bits(), 1); + ASSERT_EQ(scalable_type.code(), kDLBool); + ASSERT_EQ(scalable_type.bits(), 8); ASSERT_EQ(scalable_type.vscale_factor(), 4); ASSERT_TRUE(scalable_type.is_scalable_vector()); } diff --git a/tests/lint/cpplint.sh b/tests/lint/cpplint.sh index e49c6801ade7..84065e17b01d 100755 --- a/tests/lint/cpplint.sh +++ b/tests/lint/cpplint.sh @@ -19,7 +19,6 @@ set -e echo "Running 2 cpplints..." -python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp ffi/include ffi/src python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp \ include src \ examples/extension/src examples/graph_executor/src \ diff --git a/tests/lint/flake8.sh b/tests/lint/flake8.sh index 87dc8640d03f..91f057fe20ee 100755 --- a/tests/lint/flake8.sh +++ b/tests/lint/flake8.sh @@ -16,6 +16,40 @@ # specific language governing permissions and limitations # under the License. -set -e +set -euo pipefail -python3 -m flake8 . --count --select=E9,F63,F7 --show-source --statistics --exclude 3rdparty +LINT_ALL_FILES=true +REVISION= + +while (( $# )); do + case "$1" in + --rev) + LINT_ALL_FILES=false + REVISION=$2 + shift 2 + ;; + *) + echo "Usage: tests/lint/flake8.sh [--rev ]" + echo "" + echo "Run flake8 on Python files that changed since or on all files in the repo" + echo "Examples:" + echo "- Compare last one commit: tests/lint/flake8.sh --rev HEAD~1" + echo "- Compare against upstream/main: tests/lint/flake8.sh --rev upstream/main" + exit 1 + ;; + esac +done + +if [[ "$LINT_ALL_FILES" == "true" ]]; then + echo "Running flake8 on all files" + python3 -m flake8 . --count --select=E9,F63,F7 --show-source --statistics --exclude 3rdparty +else + # Get changed Python files, excluding 3rdparty + IFS=$'\n' read -a FILES -d'\n' < <(git diff --name-only --diff-filter=ACMRTUX $REVISION -- "*.py" "*.pyi" | grep -v "^3rdparty/") || true + if [ -z ${FILES+x} ] || [ ${#FILES[@]} -eq 0 ]; then + echo "No changes in Python files" + exit 0 + fi + echo "Running flake8 on changed files: ${FILES[@]}" + python3 -m flake8 ${FILES[@]} --count --select=E9,F63,F7 --show-source --statistics +fi diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh index fdc753ca13b6..d65eba003a2c 100755 --- a/tests/lint/pylint.sh +++ b/tests/lint/pylint.sh @@ -15,6 +15,40 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -set -euxo pipefail +set -euo pipefail -python3 -m pylint python/tvm --rcfile="$(dirname "$0")"/pylintrc +LINT_ALL_FILES=true +REVISION= + +while (( $# )); do + case "$1" in + --rev) + LINT_ALL_FILES=false + REVISION=$2 + shift 2 + ;; + *) + echo "Usage: tests/lint/pylint.sh [--rev ]" + echo "" + echo "Run pylint on Python files that changed since or on all files in python/tvm" + echo "Examples:" + echo "- Compare last one commit: tests/lint/pylint.sh --rev HEAD~1" + echo "- Compare against upstream/main: tests/lint/pylint.sh --rev upstream/main" + exit 1 + ;; + esac +done + +if [[ "$LINT_ALL_FILES" == "true" ]]; then + echo "Running pylint on all files in python/tvm" + python3 -m pylint python/tvm --rcfile="$(dirname "$0")"/pylintrc +else + # Get changed Python files in python/tvm directory + IFS=$'\n' read -a FILES -d'\n' < <(git diff --name-only --diff-filter=ACMRTUX $REVISION -- "python/tvm/*.py" "python/tvm/**/*.py") || true + if [ -z ${FILES+x} ] || [ ${#FILES[@]} -eq 0 ]; then + echo "No changes in Python files under python/tvm" + exit 0 + fi + echo "Running pylint on changed files: ${FILES[@]}" + python3 -m pylint ${FILES[@]} --rcfile="$(dirname "$0")"/pylintrc +fi diff --git a/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py b/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py index 4767c24b693a..a9dbf74269e7 100644 --- a/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py +++ b/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py @@ -50,9 +50,9 @@ def check_llvm(): dev = tvm.cpu(0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) diff --git a/tests/python/all-platform-minimal-test/test_runtime_ndarray.py b/tests/python/all-platform-minimal-test/test_runtime_ndarray.py index 29867c3ed8ee..7e00ba64fac4 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_ndarray.py +++ b/tests/python/all-platform-minimal-test/test_runtime_ndarray.py @@ -31,11 +31,11 @@ def test_nd_create(target, dev, dtype): x = np.random.randint(0, 10, size=(3, 4)) x = np.array(x, dtype=dtype) - y = tvm.nd.array(x, device=dev) + y = tvm.runtime.tensor(x, device=dev) z = y.copyto(dev) assert y.dtype == x.dtype assert y.shape == x.shape - assert isinstance(y, tvm.nd.NDArray) + assert isinstance(y, tvm.runtime.Tensor) np.testing.assert_equal(x, y.numpy()) np.testing.assert_equal(x, z.numpy()) @@ -48,7 +48,7 @@ def test_memory_usage(target, dev, dtype): if available_memory_before is None: pytest.skip(reason=f"Target '{target}' does not support queries of available memory") - arr = tvm.nd.empty([1024, 1024], dtype=dtype, device=dev) + arr = tvm.runtime.empty([1024, 1024], dtype=dtype, device=dev) available_memory_after = dev.available_global_memory num_elements = math.prod(arr.shape) @@ -61,8 +61,8 @@ def test_memory_usage(target, dev, dtype): # available memory may decrease by more than the requested amount. assert available_memory_after <= expected_memory_after - # TVM's NDArray type is a reference-counted handle to the - # underlying reference. After the last reference to an NDArray is + # TVM's Tensor type is a reference-counted handle to the + # underlying reference. After the last reference to an Tensor is # cleared, the backing allocation will be freed. del arr diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index f315b8f3c210..7b868007a6b0 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -29,7 +29,7 @@ def test_get_global(): targs = (10, 10.0, "hello") # register into global function table - @tvm.register_func + @tvm.register_global_func def my_packed_func(*args): assert tuple(args) == targs return 10 @@ -50,7 +50,7 @@ def test(y): f2 = tvm.runtime.convert(test) # register into global function table - @tvm.register_func + @tvm.register_global_func def my_callback_with_node(y, f): assert y == x return f(y) @@ -112,7 +112,7 @@ def test_device_func(dev): x = test_device_func(tvm.cuda(7)) assert x == tvm.cpu(0) x = tvm.opencl(10) - x = tvm.testing.device_test(x, x.device_type, x.device_id) + x = tvm.testing.device_test(x, x.dlpack_device_type(), x.index) assert x == tvm.opencl(10) @@ -121,13 +121,12 @@ def test_numpy_scalar(): assert tvm.testing.echo(np.int64(maxint)) == maxint -def test_ndarray_args(): +def test_tensor_args(): def check(arr): - assert not arr.is_view assert tvm.testing.object_use_count(arr) == 2 fcheck = tvm.runtime.convert(check) - x = tvm.nd.array([1, 2, 3]) + x = tvm.runtime.tensor([1, 2, 3]) fcheck(x) assert tvm.testing.object_use_count(x) == 1 @@ -145,7 +144,7 @@ def test_dict_function_value_type(): if __name__ == "__main__": - test_ndarray_args() + test_tensor_args() test_numpy_scalar() test_rvalue_ref() test_empty_array() diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index 14bfec2328f2..8728df7e3f3a 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -298,5 +298,17 @@ class TestRampBound(BaseCompare): ) +class TestModularSetBound(BaseCompare): + analyzer = tvm.arith.Analyzer() + tx = tvm.te.var("tx", dtype="int32") + bx = tvm.te.var("bx", dtype="int32") + + expr = (bx * 2048 + tx * 16) % 7168 + + test_case = tvm.testing.parameter( + TestCase(expr, (0, 7152), {bx: (0, 3584), tx: (0, 128)}), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/arith/test_arith_intset.py b/tests/python/arith/test_arith_intset.py index 18865a73df45..04014ca30095 100644 --- a/tests/python/arith/test_arith_intset.py +++ b/tests/python/arith/test_arith_intset.py @@ -387,5 +387,15 @@ def test_union_lower_bound(): assert result.max_value.same_as(pos_inf) +def test_modular_set(): + ck = IntSetChecker() + x = tvm.te.var("x", dtype="int32") + y = tvm.te.var("y", dtype="int32") + expr = (x * 2048 + y * 16) % 7168 + ck.verify( + expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152) + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 6954cf4e1d5c..ab11bbf3c1f6 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -93,7 +93,7 @@ class TestVector(BaseCompare): x, y, z = te.var("x"), te.var("y"), te.var("z") x64 = te.var("x", dtype="int64") vx = te.var("vx", dtype="int32x2") - vc = te.var("vc", dtype="uint1") + vc = te.var("vc", dtype="bool") test_case = tvm.testing.parameter( # Add rules TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4)), @@ -285,22 +285,22 @@ class TestVector(BaseCompare): tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")), ), ## Logical rules - TestCase(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")), + TestCase(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("boolx2")), TestCase( tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))), - (tvm.tir.NE(y, x)).astype("uint1x2"), + (tvm.tir.NE(y, x)).astype("boolx2"), ), - TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")), - TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2")), - TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2")), - TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")), + TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("boolx2")), + TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("boolx2")), + TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("boolx2")), + TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("boolx2")), TestCase( - tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), - (tvm.tir.And(y <= x, vc)).astype("uint1x2"), + tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")), + (tvm.tir.And(y <= x, vc)).astype("boolx2"), ), TestCase( - tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), - (tvm.tir.Or(y <= x, vc)).astype("uint1x2"), + tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")), + (tvm.tir.Or(y <= x, vc)).astype("boolx2"), ), ) @@ -941,6 +941,10 @@ class TestComparisons(BaseCompare): TestCase(x * 3 < y * 3, x < y), TestCase(x * (-3) < y * (-3), y < x), TestCase(x * 3 >= y * 3, y <= x), + # Eliminate bounded offset when comparing aligned values. + TestCase(x * 4 + y < z, x * 4 < z, [y >= 0, y < 4, flm(z, 4) == 0]), + TestCase(x * 4 + y >= z, z <= x * 4, [y >= 0, y < 4, flm(z, 4) == 0]), + TestCase(z < x * 4 + y, z <= x * 4, [y >= 1, y < 4, flm(z, 4) == 0]), TestCase(x * 4 >= 2, tvm.tir.LE(1, x)), TestCase(x * 2 >= 50, tvm.tir.LE(25, x)), TestCase(x * 4 <= 2, x <= 0), diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 5a61cb8a52a9..161548a7a14b 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -21,6 +21,7 @@ import tvm.testing from tvm import tir from tvm.script import tir as T +import tvm.ir def test_simplify_reshape_flattened_index(): @@ -144,5 +145,16 @@ def test_simplify_floor_mod_with_linear_offset(): assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0) +def test_simplify_float_division(): + # Test for the discussion: + # https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615 + ana = tvm.arith.Analyzer() + x = tir.Var("x", "float32") + ry = x / 27 + # in old version, the division will be rewritten into x * T.float32(1 / 27) + sy = ana.rewrite_simplify(ry) + tvm.ir.assert_structural_equal(ry, sy) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_gpu_codegen_allreduce.py b/tests/python/codegen/test_gpu_codegen_allreduce.py index aa56411cc9e0..09c9fa13386e 100644 --- a/tests/python/codegen/test_gpu_codegen_allreduce.py +++ b/tests/python/codegen/test_gpu_codegen_allreduce.py @@ -76,8 +76,8 @@ def test_allreduce_sum(dims, target, dev): # prepare input and output array a_np = np.random.rand(1, d1, d2, d3).astype("float32") b_np = a_np.sum(axis=-1).astype("float32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros_like(b_np), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(np.zeros_like(b_np), dev) # launch kernel f(a, b) @@ -94,7 +94,7 @@ def optional_metal_compile_callback(define_metal_compile_callback): if define_metal_compile_callback: - @tvm.register_func(name, override=True) + @tvm.register_global_func(name, override=True) def compile_metal(src, target): return tvm.contrib.xcode.compile_metal(src, sdk="macosx") @@ -102,9 +102,9 @@ def compile_metal(src, target): if define_metal_compile_callback: if cached is None: - tvm.ffi.registry.remove_global_func(name) + tvm_ffi.registry.remove_global_func(name) else: - tvm.register_func(name, cached, override=True) + tvm.register_global_func(name, cached, override=True) @tvm.testing.requires_metal(support_required="compile-only") @@ -143,8 +143,8 @@ def test_allreduce_max(dims, target, dev): # prepare input and output array a_np = -np.random.rand(1, d1, d2, d3).astype("float32") b_np = a_np.max(axis=-1).astype("float32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros_like(b_np), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(np.zeros_like(b_np), dev) # launch kernel f(a, b) diff --git a/tests/python/codegen/test_inject_ptx_ldg32.py b/tests/python/codegen/test_inject_ptx_ldg32.py index a45e8f57f38f..fd2f598c924e 100644 --- a/tests/python/codegen/test_inject_ptx_ldg32.py +++ b/tests/python/codegen/test_inject_ptx_ldg32.py @@ -50,8 +50,8 @@ def test_inject_ptx_intrin(): A_np = np.random.rand(16).astype("float32") B_np = np.zeros((32)).astype("float32") dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) C_np = np.zeros((32)).astype("float32") diff --git a/tests/python/codegen/test_target_codegen.py b/tests/python/codegen/test_target_codegen.py index 3332d015a818..329dfac35d45 100644 --- a/tests/python/codegen/test_target_codegen.py +++ b/tests/python/codegen/test_target_codegen.py @@ -16,7 +16,7 @@ # under the License. import pytest - +import numpy as np import tvm from tvm.script import tir as T @@ -88,5 +88,47 @@ def func(a: T.handle, b: T.handle): tvm.compile(func) +@tvm.testing.parametrize_targets("c", "llvm") +def test_codegen_loop_step(target): + @T.prim_func + def test_loop_step( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + for i in T.serial(3, 1024, step=96): + C[i] = A[i] + B[i] + + with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]): + lib = tvm.compile(test_loop_step, target=target) + + src = lib.mod.inspect_source() + if target == "c": + assert src.find("for (int32_t i = 3; i < 1024; i += 96)") >= 0 + + dev = tvm.device(target, 0) + a_np = np.random.rand(1024).astype("float32") + b_np = np.random.rand(1024).astype("float32") + c_np = np.zeros(1024, dtype="float32") + a_tvm = tvm.runtime.tensor(a_np, dev) + b_tvm = tvm.runtime.tensor(b_np, dev) + c_tvm = tvm.runtime.tensor(c_np, dev) + + lib(a_tvm, b_tvm, c_tvm) + + c_result = c_tvm.numpy() + + # Check that the loop executes at positions 3, 99, 195, 291, 387, 483, 579, 675, 771, 867, 963 + for i in range(3, 1024, 96): + tvm.testing.assert_allclose(c_result[i], a_np[i] + b_np[i], rtol=1e-5) + + # Assert non-touched positions remain zero + for i in range(0, 3): + assert c_result[i] == 0.0 + for i in range(4, 1024): + if (i - 3) % 96 != 0: + assert c_result[i] == 0.0 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_blob.py b/tests/python/codegen/test_target_codegen_blob.py index 39373c4d840c..d57297ee6e22 100644 --- a/tests/python/codegen/test_target_codegen_blob.py +++ b/tests/python/codegen/test_target_codegen_blob.py @@ -77,8 +77,8 @@ def popen_check(): # Load the system wide library dev = tvm.cuda() a_np = np.random.uniform(size=12).astype("float32") - a_nd = tvm.nd.array(a_np, dev) - b_nd = tvm.nd.array(a_np, dev) + a_nd = tvm.runtime.tensor(a_np, dev) + b_nd = tvm.runtime.tensor(a_np, dev) syslibA = tvm.runtime.system_lib("modA_") syslibB = tvm.runtime.system_lib("modB_") # reload same lib twice diff --git a/tests/python/codegen/test_target_codegen_bool.py b/tests/python/codegen/test_target_codegen_bool.py index 96bd21329c93..d4524ac1d5fe 100644 --- a/tests/python/codegen/test_target_codegen_bool.py +++ b/tests/python/codegen/test_target_codegen_bool.py @@ -56,9 +56,9 @@ def test_cmp_load_store(target, dev, arr_size, compute, get_module): a_np = np.random.uniform(size=arr_size).astype(A.dtype) b_np = np.random.uniform(size=arr_size).astype(B.dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - d = tvm.nd.array(np.zeros(arr_size, dtype=D.dtype), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) + d = tvm.runtime.tensor(np.zeros(arr_size, dtype=D.dtype), dev) f(a, b, d) np.testing.assert_equal( d.numpy(), diff --git a/tests/python/codegen/test_target_codegen_c_host.py b/tests/python/codegen/test_target_codegen_c_host.py index 3c80cfbeb0b4..e95108aeac17 100644 --- a/tests/python/codegen/test_target_codegen_c_host.py +++ b/tests/python/codegen/test_target_codegen_c_host.py @@ -47,9 +47,9 @@ def check_c(): dev = tvm.cpu(0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) fadd(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) @@ -78,8 +78,8 @@ def check_c(): fadd = m["test_reinterpret"] dev = tvm.cpu(0) n = nn - a = tvm.nd.array(np.random.randint(-(2**30), 2**30, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.randint(-(2**30), 2**30, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) fadd(a, b) tvm.testing.assert_allclose(b.numpy(), (2 + a.numpy()).view("float32")) @@ -106,8 +106,8 @@ def check_c(): fceil = m["test_ceil"] dev = tvm.cpu(0) n = nn - a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.rand(n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) fceil(a, b) tvm.testing.assert_allclose(b.numpy(), (np.ceil(a.numpy()).view("float32"))) @@ -134,8 +134,8 @@ def check_c(): ffloor = m["test_floor"] dev = tvm.cpu(0) n = nn - a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.rand(n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) ffloor(a, b) tvm.testing.assert_allclose(b.numpy(), (np.floor(a.numpy()).view("float32"))) @@ -162,8 +162,8 @@ def check_c(): fround = m["test_round"] dev = tvm.cpu(0) n = nn - a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.rand(n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) fround(a, b) tvm.testing.assert_allclose(b.numpy(), (np.round(a.numpy()).view("float32"))) @@ -184,17 +184,9 @@ def subroutine(A_data: T.handle("float32")): built = tvm.tir.build(mod, target="c") - func_names = list(built["get_func_names"]()) - assert ( - "main" in func_names - ), "Externally exposed functions should be listed in available functions." - assert ( - "subroutine" not in func_names - ), "Internal function should not be listed in available functions." - source = built.inspect_source() assert ( - source.count("main(void*") == 2 + source.count("__tvm_ffi_main(void*") == 2 ), "Expected two occurrences, for forward-declaration and definition" assert ( source.count("subroutine(float*") == 2 diff --git a/tests/python/codegen/test_target_codegen_cross_llvm.py b/tests/python/codegen/test_target_codegen_cross_llvm.py index 9ae516c7de30..3cb8c3037254 100644 --- a/tests/python/codegen/test_target_codegen_cross_llvm.py +++ b/tests/python/codegen/test_target_codegen_cross_llvm.py @@ -81,9 +81,9 @@ def build_arm(): farm = remote.load_module("myadd.o") dev = remote.cpu(0) n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) farm(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) print("Verification finish on remote..") diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index fb9c47410fea..1b31e64414b1 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -49,8 +49,8 @@ def check_cuda(dtype, n, lanes): fun = tvm.compile(sch.mod, target="cuda") dev = tvm.cuda(0) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) - c = tvm.nd.empty((n,), B.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) + c = tvm.runtime.empty((n,), B.dtype, dev) fun(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) @@ -105,10 +105,10 @@ def check_cuda(n, lanes): dev = tvm.cuda(0) np_a = np.random.uniform(size=(n, lanes)).astype("float32") np_a = np_bf162np_float(np_float2np_bf16(np_a)) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_float2np_bf16(np_a)) - c = tvm.nd.empty((n,), B.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np_float2np_bf16(np_a)) + c = tvm.runtime.empty((n,), B.dtype, dev) fun(a, c) - c = tvm.nd.empty((n, lanes), "uint16", dev).copyfrom(c) + c = tvm.runtime.empty((n, lanes), "uint16", dev).copyfrom(c) tvm.testing.assert_allclose(c.numpy(), np_float2np_bf16(np_a + 1)) check_cuda(64, 2) @@ -143,10 +143,10 @@ def check_cuda(dtype, n, lanes): np_c = np.random.randint(low=0, high=127, size=(n,)) np_d = [sum(x * y) + z for x, y, z in zip(np_a, np_b, np_c)] dev = tvm.cuda(0) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a) - b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np_b) - c = tvm.nd.empty((n,), C.dtype, dev).copyfrom(np_c) - d = tvm.nd.empty((n,), D.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np_a) + b = tvm.runtime.empty((n,), B.dtype, dev).copyfrom(np_b) + c = tvm.runtime.empty((n,), C.dtype, dev).copyfrom(np_c) + d = tvm.runtime.empty((n,), D.dtype, dev) fun(a, b, c, d) tvm.testing.assert_allclose(d.numpy(), np_d) @@ -170,8 +170,8 @@ def check_cuda(dtype, n, lanes): fun = tvm.compile(sch.mod, target="cuda") np_a = np.random.randint(low=-128, high=127, size=(n, lanes)) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a) - b = tvm.nd.empty((n,), B.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np_a) + b = tvm.runtime.empty((n,), B.dtype, dev) fun(a, b) tvm.testing.assert_allclose(a.numpy(), b.numpy()) @@ -197,7 +197,7 @@ def check_cuda(n, value, lanes): fun = tvm.compile(sch.mod, target="cuda") np_a = np.full((n, lanes), value, dtype=dtype) - a = tvm.nd.empty(np_a.shape, dtype, dev) + a = tvm.runtime.empty(np_a.shape, dtype, dev) fun(a) np.testing.assert_equal(a.numpy(), np_a) @@ -228,8 +228,8 @@ def check_inf_nan(dev, n, value, dtype): sch.bind(xi, "threadIdx.x") fun = tvm.compile(sch.mod, target="cuda") - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -267,8 +267,8 @@ def verify(nthd): vals = [nthd - 1, nthd, nthd + 1] for kk in [x for x in vals]: size = (nn, kk) - a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=size).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(nn, dtype=B.dtype), dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=1), rtol=1e-3) @@ -306,8 +306,8 @@ def verify(nthdx, nthdy): vy = [nthdy - 1, nthdy, nthdy + 1] for kk0, kk1 in [(x, y) for x in vx for y in vy]: size = (nn, kk0, kk1) - a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=size).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(nn, dtype=B.dtype), dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=(1, 2)), rtol=1e-3) @@ -352,8 +352,8 @@ def test_cuda_const_float_to_half(): dev = tvm.cuda(0) a_np = np.random.uniform(size=shape).astype(a.dtype) c_np = np.zeros(shape=shape, dtype=c.dtype) - a = tvm.nd.array(a_np, dev) - c = tvm.nd.array(c_np, dev) + a = tvm.runtime.tensor(a_np, dev) + c = tvm.runtime.tensor(c_np, dev) func(a, c) np.testing.assert_equal(c.numpy(), a_np > b.value) @@ -379,8 +379,8 @@ def test_cuda_floordiv_with_vectorization(): dev = tvm.cuda(0) a_np = np.random.uniform(size=(n,)).astype(A.dtype) b_np = np.array([a_np[i // k] for i in range(0, n)]) - a_nd = tvm.nd.array(a_np, dev) - b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) + a_nd = tvm.runtime.tensor(a_np, dev) + b_nd = tvm.runtime.tensor(np.zeros(b_np.shape, dtype=b_np.dtype), dev) func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) @@ -405,8 +405,8 @@ def test_cuda_floormod_with_vectorization(): dev = tvm.cuda(0) a_np = np.random.uniform(size=(n,)).astype(A.dtype) b_np = np.array([a_np[i % k] for i in range(0, n)]) - a_nd = tvm.nd.array(a_np, dev) - b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) + a_nd = tvm.runtime.tensor(a_np, dev) + b_nd = tvm.runtime.tensor(np.zeros(b_np.shape, dtype=b_np.dtype), dev) func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) @@ -438,9 +438,9 @@ def check(t0, t1, factor): a_np = np.random.randint(low, high, size=n).astype(A.dtype) b_np = np.random.randint(low, high, size=n).astype(B.dtype) c_np = (a_np + b_np).astype(A.dtype) - a_nd = tvm.nd.array(a_np, dev) - b_nd = tvm.nd.array(b_np, dev) - c_nd = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np.dtype), dev) + a_nd = tvm.runtime.tensor(a_np, dev) + b_nd = tvm.runtime.tensor(b_np, dev) + c_nd = tvm.runtime.tensor(np.zeros(c_np.shape, dtype=c_np.dtype), dev) func(a_nd, b_nd, c_nd) tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-3) @@ -535,8 +535,8 @@ def run_test(tvm_intrin, np_func, dtype): B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B") f = sched(A, B) dev = tvm.cuda(0) - a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(A.dtype), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) @@ -560,8 +560,8 @@ def run_test(tvm_intrin, np_func): B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B") f = sched(A, B) dev = tvm.cuda(0) - a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(A.dtype), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) @@ -585,8 +585,8 @@ def run_test(dtype): B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B") f = sched(A, B) dev = tvm.cuda(0) - a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev) + a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(B.dtype), dev) f(a, b) ref = np.vectorize(ref_popcount)(a.numpy()) tvm.testing.assert_allclose(b.numpy(), ref) @@ -623,8 +623,8 @@ def check_cuda(dtype, n, l, padding, lanes): fun = tvm.compile(sch.mod, target="cuda") np_a = np.random.randint(low=-128, high=127, size=(n, l)).astype(A.dtype) - a = tvm.nd.empty((n, l), A.dtype, dev).copyfrom(np_a) - b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, dev) + a = tvm.runtime.empty((n, l), A.dtype, dev).copyfrom(np_a) + b = tvm.runtime.empty((n // lanes, l + padding * 2, lanes), B.dtype, dev) fun(a, b) np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1) ref = np.pad( @@ -666,8 +666,8 @@ def build(A, C, N, C_N): kernel_source = f.imports[0].inspect_source() dev = tvm.cuda() a_data = np.arange(0, N).astype(A.dtype) - a = tvm.nd.array(a_data, dev) - c = tvm.nd.array(np.zeros(C_N, dtype=C.dtype), dev) + a = tvm.runtime.tensor(a_data, dev) + c = tvm.runtime.tensor(np.zeros(C_N, dtype=C.dtype), dev) f(a, c) return a_data, c.numpy(), kernel_source @@ -801,6 +801,25 @@ def main( assert 'extern "C" __device__ float add(float a, float b) {\n return (a + b);\n}' in cuda_code +@tvm.testing.requires_cuda +def test_cuda_float_const_hex_format(): + """Test that float constants are emitted in hexadecimal format for precision""" + + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((1024, 1024), "float32"), + ): + for bx in T.thread_binding(1024, "blockIdx.x"): + for tx in T.thread_binding(1024, "threadIdx.x"): + A[bx, tx] = T.float32(1 / 27) + + lib = tvm.compile(Module, target="cuda") + cuda_code = lib.mod.imports[0].inspect_source() + assert "0x1.2f684bda12f68p-5f" in cuda_code + + @tvm.testing.requires_cuda def test_device_host_call_same_func(): @I.ir_module @@ -834,9 +853,9 @@ def main( dev = tvm.cuda(0) a_np = np.random.randint(0, 10, (128, 128), dtype="int32") b_np = np.random.randint(0, 10, (128, 128), dtype="int32") - a_tvm = tvm.nd.array(a_np, device=dev) - b_tvm = tvm.nd.array(b_np, device=dev) - c_tvm = tvm.nd.empty((128, 128), dtype="int32", device=dev) + a_tvm = tvm.runtime.tensor(a_np, device=dev) + b_tvm = tvm.runtime.tensor(b_np, device=dev) + c_tvm = tvm.runtime.empty((128, 128), dtype="int32", device=dev) lib["main"](a_tvm, b_tvm, c_tvm) tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np) @@ -858,5 +877,37 @@ def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): assert "return;" in cuda_code +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_cuda_loop_step(): + @T.prim_func + def cuda_loop_step( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + # Each thread computes a strided subset of the i loop: start = tx*3, step = 96 (3 * 32 threads) + for bx in T.thread_binding(1, "blockIdx.x"): + for tx in T.thread_binding(96, "threadIdx.x"): + for i in T.serial(tx, 1024, step=96): + C[i] = A[i] + B[i] + + target = tvm.target.Target({"kind": "cuda"}) + with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]): + lib = tvm.compile(cuda_loop_step, target=target) + + cuda_src = lib.mod.imports[0].inspect_source() + assert "i += 96" in cuda_src + dev = tvm.cuda(0) + a_np = np.random.uniform(1, 100, (1024,)).astype("float32") + b_np = np.random.uniform(1, 100, (1024,)).astype("float32") + c_np = np.zeros((1024,), dtype="float32") + a_nd = tvm.runtime.tensor(a_np, dev) + b_nd = tvm.runtime.tensor(b_np, dev) + c_nd = tvm.runtime.tensor(c_np, dev) + lib["main"](a_nd, b_nd, c_nd) + tvm.testing.assert_allclose(c_nd.numpy(), a_np + b_np) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index 364f9461c2f9..ef425dbf73e0 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -25,9 +25,11 @@ from tvm.script import tir as T try: - import ml_dtypes + from ml_dtypes import float4_e2m1fn + + ML_DTYPES_AVAILABLE = True except ImportError: - ml_dtypes = None + ML_DTYPES_AVAILABLE = False @pytest.mark.parametrize("promoted_dtype", ["float32x2", "float16x2"]) @@ -63,7 +65,6 @@ def add( fadd = tvm.compile(sch.mod, target=target) dev = tvm.device(target, 0) - numpytype = "float4_e2m1fn" if "x" in native_dtype: lanes = int(native_dtype.split("x")[-1]) else: @@ -75,18 +76,39 @@ def add( promoted_base_dtype = promoted_dtype np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,) - a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) - a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + + # Create test data - either using ml_dtypes if available, or using int8 with valid FP4 values + if ML_DTYPES_AVAILABLE: + a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(float4_e2m1fn) + b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(float4_e2m1fn) + else: + # float4_e2m1fn possible values: [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # We will create int8 arrays with valid FP4 bit patterns + valid_fp4_values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] # 4-bit values + a_np = np.random.choice(valid_fp4_values, size=np_shape).astype(np.int8) + b_np = np.random.choice(valid_fp4_values, size=np_shape).astype(np.int8) + + a = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) a.copyfrom(a_np) - b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) - b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + b = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) b.copyfrom(b_np) - c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + c = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) fadd(a, b, c) - tvm.testing.assert_allclose( - c.numpy().astype(promoted_base_dtype), (a_np + b_np).astype(promoted_base_dtype) - ) + # For the comparison, we will convert result to the promoted dtype and compare + # Note: When ml_dtypes is not available, we skip the numpy-level computation comparison + # and just verify that the CUDA kernel compiles and executes without error + c_result = c.numpy().astype(promoted_base_dtype) + + if ML_DTYPES_AVAILABLE: + # Full comparison when ml_dtypes is available + expected = (a_np + b_np).astype(promoted_base_dtype) + tvm.testing.assert_allclose(c_result, expected) + else: + # When ml_dtypes is not available, we just verify the comparison ran successfully + # by checking that we got a result with the expected shape and dtype + assert c_result.shape == np_shape + assert c_result.dtype == promoted_base_dtype @tvm.testing.requires_cuda_compute_version(10) diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index c0b6130bcb80..4ea938cad8ad 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -76,9 +76,9 @@ def add( dev = tvm.device(target, 0) - a = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(dtype), dev) - b = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(dtype), dev) - c = tvm.nd.array(np.zeros(64, dtype=dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(low=0, high=5, size=64).astype(dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(low=0, high=5, size=64).astype(dtype), dev) + c = tvm.runtime.tensor(np.zeros(64, dtype=dtype), dev) fadd(a, b, c) tvm.testing.assert_allclose( @@ -135,9 +135,9 @@ def add( np_shape = (length, vector_length) a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(dtype) - a = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev) - r = tvm.nd.empty(shape=(length,), dtype=packed_dtype, device=dev) - b = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev) + a = tvm.runtime.empty(shape=(length,), dtype=native_dtype, device=dev) + r = tvm.runtime.empty(shape=(length,), dtype=packed_dtype, device=dev) + b = tvm.runtime.empty(shape=(length,), dtype=native_dtype, device=dev) a.copyfrom(a_np) f(a, r, b) tvm.testing.assert_allclose(a.numpy().astype("float16"), b.numpy().astype("float16")) @@ -205,12 +205,12 @@ def add( np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,) a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) - a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + a = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) a.copyfrom(a_np) b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) - b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + b = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) b.copyfrom(b_np) - c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + c = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) fadd(a, b, c) tvm.testing.assert_allclose( @@ -243,8 +243,8 @@ def vector_broadcast(a: T.Buffer((), dtype), vec: T.Buffer((bcast_length,), dtyp dev = tvm.device(target, 0) a_np = np.random.uniform(low=0, high=4, size=()).astype(dtype) - a = tvm.nd.array(a_np, device=dev) - b = tvm.nd.empty((bcast_length,), dtype=dtype, device=dev) + a = tvm.runtime.tensor(a_np, device=dev) + b = tvm.runtime.empty((bcast_length,), dtype=dtype, device=dev) func(a, b) @@ -276,9 +276,9 @@ def vector_load( dev = tvm.device(target, 0) a_np = np.random.uniform(low=0, high=1, size=(length,)).astype(dtype) - a = tvm.nd.array(a_np, device=dev) + a = tvm.runtime.tensor(a_np, device=dev) - b = tvm.nd.empty((length // vector_length,), dtype=vec_dtype, device=dev) + b = tvm.runtime.empty((length // vector_length,), dtype=vec_dtype, device=dev) f(a, b) @@ -325,12 +325,12 @@ def add( dev = tvm.device(target, 0) a_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype) - a = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev) + a = tvm.runtime.empty(shape=(length,), dtype=vec_dtype, device=dev) a.copyfrom(a_np) b_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype) - b = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev) + b = tvm.runtime.empty(shape=(length,), dtype=vec_dtype, device=dev) b.copyfrom(b_np) - c = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev) + c = tvm.runtime.empty(shape=(length,), dtype=vec_dtype, device=dev) fadd(a, b, c) c_expected = a_np + b_np @@ -805,7 +805,7 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): dev = tvm.device(target_str, 0) weight_np = np.random.uniform(-100, 100, weight_shape).astype(model_dtype) - weight = tvm.nd.array(weight_np, device=dev) + weight = tvm.runtime.tensor(weight_np, device=dev) quant_weight, scales = quant(weight) quant_weight_np, scales_np = quant_weight.numpy(), scales.numpy() @@ -955,16 +955,16 @@ def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: dev = tvm.cuda(0) x_data = np.zeros((1, reduce_size), dtype=np.float16) - x = tvm.nd.array(x_data, device=dev) + x = tvm.runtime.tensor(x_data, device=dev) indptr_data = np.zeros((1, 2), dtype=np.int32) - indptr = tvm.nd.array(indptr_data, device=dev) + indptr = tvm.runtime.tensor(indptr_data, device=dev) weight_data = np.zeros((num_experts, spatial_size, reduce_size), dtype="float8_e4m3fn") - weight = tvm.nd.array(weight_data, device=dev) + weight = tvm.runtime.tensor(weight_data, device=dev) scale_data = np.zeros((1,), dtype=np.float32) - scale = tvm.nd.array(scale_data, device=dev) + scale = tvm.runtime.tensor(scale_data, device=dev) vm = relax.VirtualMachine(rt_mod, dev) # Ensure this runs without failure. Utilizing dlight thread extents TS, TR = 4, 64 @@ -1000,12 +1000,12 @@ def func_vectorize( a_np = np.random.rand(128).astype("float8_e4m3fn") b_np = np.random.rand(128).astype(dtype) c_np = (a_np.astype(dtype) * b_np) + 3 - a_tvm = tvm.nd.array(a_np, device=device) - b_tvm = tvm.nd.array(b_np, device=device) - c_tvm = tvm.nd.empty((128,), dtype=dtype, device=device) + a_tvm = tvm.runtime.tensor(a_np, device=device) + b_tvm = tvm.runtime.tensor(b_np, device=device) + c_tvm = tvm.runtime.empty((128,), dtype=dtype, device=device) f(a_tvm, b_tvm, c_tvm) c_tvm = c_tvm.numpy() - np.testing.assert_allclose( + tvm.testing.assert_allclose( c_tvm.astype(np.float32), c_np.astype(np.float32), atol=5e-1, rtol=1e-2 ) diff --git a/tests/python/codegen/test_target_codegen_device.py b/tests/python/codegen/test_target_codegen_device.py index 4dad03d7004c..b897d50b41c7 100644 --- a/tests/python/codegen/test_target_codegen_device.py +++ b/tests/python/codegen/test_target_codegen_device.py @@ -50,7 +50,7 @@ def check_target(device): dev = tvm.device(device, 0) f = tvm.compile(sch.mod, target=device) # launch the kernel. - a = tvm.nd.empty((n,), dtype=A.dtype, device=dev) + a = tvm.runtime.empty((n,), dtype=A.dtype, device=dev) f(a) assert a.numpy()[0] == value + 3 @@ -95,12 +95,12 @@ def check_target(device, host): dev = tvm.device(device, 0) target = tvm.target.Target(device, host) mhost = tvm.tir.build(sch.mod, target=target) - f = mhost.entry_func + f = mhost.main # launch the kernel. n = 1027 - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=()).astype(B.dtype), dev) - d = tvm.nd.array(np.zeros(n, dtype=D.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=()).astype(B.dtype), dev) + d = tvm.runtime.tensor(np.zeros(n, dtype=D.dtype), dev) f(a, b, d) tvm.testing.assert_allclose(d.numpy(), a.numpy() + b.numpy() + 1) diff --git a/tests/python/codegen/test_target_codegen_extern.py b/tests/python/codegen/test_target_codegen_extern.py index 35227baaff5b..06e0926005bf 100644 --- a/tests/python/codegen/test_target_codegen_extern.py +++ b/tests/python/codegen/test_target_codegen_extern.py @@ -73,8 +73,8 @@ def check_target(target): dev = tvm.device(target, 0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) @@ -97,7 +97,7 @@ def extern_generator(ins, outs): # Create IRModule directly mod = tvm.IRModule.from_expr(te.create_prim_func([A, C])) - @tvm.register_func + @tvm.register_global_func def my_extern_array_func1(aa, bb): aa.copyto(bb) @@ -109,8 +109,8 @@ def check_target(target): dev = tvm.cpu(0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy()) @@ -140,10 +140,10 @@ def check_target(target): dev = tvm.cpu(0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) - @tvm.register_func + @tvm.register_global_func def my_extern_array_func2(aa, bb): assert aa.shape == a.shape tvm.testing.assert_allclose(aa.numpy(), a.numpy() + 1) diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py b/tests/python/codegen/test_target_codegen_gpu_common.py index 08f43a114084..b115fddb57f7 100644 --- a/tests/python/codegen/test_target_codegen_gpu_common.py +++ b/tests/python/codegen/test_target_codegen_gpu_common.py @@ -41,8 +41,8 @@ def run_test(tvm_intrin, np_func, dtype): (x,) = sch.get_loops(sch.get_block("B")) sch.bind(x, "threadIdx.x") f = tvm.compile(sch.mod, target=target) - a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev) + a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(B.dtype), dev) f(a, b) ref = np.vectorize(partial(np_func, dtype=dtype))(a.numpy()) tvm.testing.assert_allclose(b.numpy(), ref) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 953adf78b342..88b791d1aa52 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -118,7 +118,7 @@ def check_llvm(): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.empty((), dtype=A.dtype, device=dev) + a = tvm.runtime.empty((), dtype=A.dtype, device=dev) f(a) assert a.numpy() == value + 3 @@ -160,8 +160,8 @@ def check_llvm(): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), np.sqrt(a.numpy() + 1) * 2 + 2, rtol=1e-5) @@ -193,8 +193,8 @@ def check_llvm(nn, base): dev = tvm.cpu(0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(n + base)).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy()[::-1][:n]) @@ -226,9 +226,9 @@ def test_llvm_vadd_pipeline(): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) n = 128 - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) @@ -258,8 +258,8 @@ def check_llvm(nn, base, stride): dev = tvm.cpu(0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(n + base, stride)).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros((n, stride), dtype=C.dtype), dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy()[base:] + 1) @@ -288,8 +288,8 @@ def check_llvm(): dev = tvm.cpu(0) # launch the kernel. n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1 + 1) @@ -320,9 +320,9 @@ def test_multiple_func(): f = tvm.compile(mod, target="llvm") dev = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) # Test both functions f["fadd1"](a, b, c) @@ -345,8 +345,8 @@ def check_llvm(n, offset): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.tensor(np.random.uniform(size=(n,)).astype(A.dtype), dev) + c = tvm.runtime.empty((n,), A.dtype, dev) f(a, c) c_np = a.numpy() c_np[:offset] = 0 @@ -369,8 +369,8 @@ def check_llvm(n): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - c = tvm.nd.empty((n,), C.dtype, dev) + a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) + c = tvm.runtime.empty((n,), C.dtype, dev) f(a, c) c_np = a.numpy() == 1 tvm.testing.assert_allclose(c.numpy(), c_np) @@ -395,9 +395,9 @@ def check_llvm(n): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - sc = tvm.nd.array(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) - d = tvm.nd.empty((), D.dtype, dev) + a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) + sc = tvm.runtime.tensor(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) + d = tvm.runtime.empty((), D.dtype, dev) f(a, sc, d) d_np = np.sum(a.numpy()) * sc.numpy() + 1 tvm.testing.assert_allclose(d.numpy(), d_np) @@ -423,9 +423,9 @@ def check_llvm(n): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - sc = tvm.nd.array(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) - d = tvm.nd.empty((), D.dtype, dev) + a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) + sc = tvm.runtime.tensor(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) + d = tvm.runtime.empty((), D.dtype, dev) f(a, sc, d) d_np = np.sum(a.numpy()) * sc.numpy() + 1 tvm.testing.assert_allclose(d.numpy(), d_np) @@ -531,16 +531,16 @@ def clipb(x): f = tvm.compile(sch.mod, target="llvm") # Fill input arrays with values - A_arr = tvm.nd.empty((end - start + 1,), dtype) - B_arr = tvm.nd.empty((dend - dstart + 1,), dtype) + A_arr = tvm.runtime.empty((end - start + 1,), dtype) + B_arr = tvm.runtime.empty((dend - dstart + 1,), dtype) A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype)) B_np = np.arange(dstart, dend + 1, dtype=dtype) # If the range of the divisor contains 0, replace it with 1 to avoid division by zero if dend >= 0 and dstart <= 0: B_np[-dstart] = 1 B_arr.copyfrom(B_np) - D_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype) - M_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype) + D_arr = tvm.runtime.empty((end - start + 1, dend - dstart + 1), dtype) + M_arr = tvm.runtime.empty((end - start + 1, dend - dstart + 1), dtype) # Run the function and convert the results to numpy f(A_arr, B_arr, D_arr, M_arr) @@ -636,8 +636,8 @@ def check_llvm_reciprocal(n): # Build from scheduled TIR f = tvm.compile(sch.mod, target="llvm") - a = tvm.nd.array(np.full((n,), 100, "float32")) - b = tvm.nd.empty((n,), "float32") + a = tvm.runtime.tensor(np.full((n,), 100, "float32")) + b = tvm.runtime.empty((n,), "float32") f(a, b) tvm.testing.assert_allclose(b.numpy(), np.zeros((n,), "float32")) @@ -656,8 +656,8 @@ def check_llvm_sigmoid(n): # Build from scheduled TIR f = tvm.compile(sch.mod, target="llvm") - a = tvm.nd.array(np.full((n,), -1000, "float32")) - b = tvm.nd.empty((n,), "float32") + a = tvm.runtime.tensor(np.full((n,), -1000, "float32")) + b = tvm.runtime.empty((n,), "float32") f(a, b) tvm.testing.assert_allclose(b.numpy(), np.zeros((n,), "float32")) @@ -780,9 +780,9 @@ def dotest(do_vectorize): npa = np.random.rand(32).astype("bfloat16") npb = np.random.rand(32).astype("bfloat16") res = npa + npb - a_ = tvm.nd.array(npa) - b_ = tvm.nd.array(npb) - c_ = tvm.nd.empty((32,), "bfloat16") + a_ = tvm.runtime.tensor(npa) + b_ = tvm.runtime.tensor(npb) + c_ = tvm.runtime.empty((32,), "bfloat16") module(a_, b_, c_) # Note: directly compare without casting to float32 should work with the # latest numpy version. @@ -868,8 +868,8 @@ def check_llvm(use_file): f = tvm.compile(sch.mod, target="llvm") dev = tvm.cpu(0) # launch the kernel. - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), a.numpy() + 1.0) @@ -953,7 +953,10 @@ def test_llvm_target_attributes(): assert re.match('.*"target-cpu"="skylake".*', attribute_definitions[k]) assert re.match('.*"target-features"=".*[+]avx512f.*".*', attribute_definitions[k]) - expected_functions = ["test_func", "test_func_compute_", "__tvm_parallel_lambda"] + expected_functions = [ + "__tvm_ffi_test_func", + "__tvm_parallel_lambda", + ] for n in expected_functions: assert n in functions_with_target @@ -1024,7 +1027,7 @@ def subroutine(A_data: T.handle("float32")): built = tvm.compile(mod) - arr = tvm.nd.array(np.zeros([1], "float32"), device=dev) + arr = tvm.runtime.tensor(np.zeros([1], "float32"), device=dev) built["main"](arr) assert arr.numpy()[0] == 42.0 @@ -1188,10 +1191,10 @@ def func(a0: T.bool, a1: T.Buffer([10], "float32")) -> T.int32: built(1, 1) with pytest.raises(RuntimeError): - built(1, tvm.nd.empty([10], "int32")) + built(1, tvm.runtime.empty([10], "int32")) with pytest.raises(RuntimeError): - built(False, tvm.nd.empty([11], "float32")) + built(False, tvm.runtime.empty([11], "float32")) if __name__ == "__main__": diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index 6b413d532371..b969f0e0b911 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -37,8 +37,8 @@ def check_inf_nan(dev, n, value, dtype): (x,) = sch.get_loops(sch.get_block("C")) sch.bind(x, "threadIdx.x") fun = tvm.compile(sch.mod, target=target) - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -70,11 +70,11 @@ def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")): dev = tvm.metal() a = (np.arange(6).reshape(2, 3)).astype("float32") - a_nd = tvm.nd.array(a, dev) - b_nd = tvm.nd.empty((6,), "float32", dev) + a_nd = tvm.runtime.tensor(a, dev) + b_nd = tvm.runtime.empty((6,), "float32", dev) f = tvm.compile(IRModule, target=target) f(a_nd, b_nd) - np.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5) @tvm.testing.requires_gpu @@ -90,8 +90,8 @@ def check_erf(dev, n, dtype): (x,) = sch.get_loops(sch.get_block("C")) sch.bind(x, "threadIdx.x") fun = tvm.compile(sch.mod, target=target) - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -119,7 +119,7 @@ def main(A: T.Buffer((1, 2), "int32")): f = tvm.compile(IRModule, target=target) dev = tvm.metal() - a_nd = tvm.nd.empty((1, 2), "int32", dev) + a_nd = tvm.runtime.empty((1, 2), "int32", dev) f(a_nd) assert tuple(a_nd.numpy()[0, :]) == (0, 3) @@ -141,12 +141,12 @@ def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")): target = "metal" dev = tvm.metal() a = np.arange(6).astype("float32") - a_nd = tvm.nd.array(a, dev) - b_nd = tvm.nd.empty((6,), "float32", dev) + a_nd = tvm.runtime.tensor(a, dev) + b_nd = tvm.runtime.empty((6,), "float32", dev) f = tvm.compile(IRModule, target=target) f(a_nd, b_nd) a.reshape(3, 2)[:, 1] = 0 - np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5) @tvm.testing.requires_gpu @@ -162,11 +162,11 @@ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")): dev = tvm.metal() a = np.arange(16).astype("uint8") - a_nd = tvm.nd.array(a, dev) - b_nd = tvm.nd.empty((16,), "float32", dev) + a_nd = tvm.runtime.tensor(a, dev) + b_nd = tvm.runtime.empty((16,), "float32", dev) f = tvm.compile(func, target="metal") f(a_nd, b_nd) - np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5) @tvm.testing.requires_metal(support_required="compile-only") @@ -180,7 +180,7 @@ def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float3 vi = T.axis.spatial(16, i) B[vi] = A[vi] + x - @tvm.register_func("tvm_callback_metal_compile") + @tvm.register_global_func("tvm_callback_metal_compile") def compile_metal(src, target): return xcode.compile_metal(src) diff --git a/tests/python/codegen/test_target_codegen_opencl.py b/tests/python/codegen/test_target_codegen_opencl.py index 4eb96747bcee..3e0fe7e31e50 100644 --- a/tests/python/codegen/test_target_codegen_opencl.py +++ b/tests/python/codegen/test_target_codegen_opencl.py @@ -39,8 +39,8 @@ def check_if_then_else(dev, n, dtype): (x,) = sch.get_loops(sch.get_block("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -57,8 +57,8 @@ def check_select(dev, n, dtype): sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -86,8 +86,8 @@ def check_inf_nan(dev, n, value, dtype): (x,) = sch.get_loops(sch.get_block("C")) sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -115,8 +115,8 @@ def check_max(dev, n, dtype): sch.bind(x, "threadIdx.x") fun = tvm.tir.build(sch.mod, target=target) - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -179,7 +179,7 @@ def check_type_casting(ctx, n, dtype): sch.vectorize(vx) fun = tvm.tir.build(sch.mod, target=target) - c = tvm.nd.empty((n,), dtype, ctx) + c = tvm.runtime.empty((n,), dtype, ctx) assembly = fun.imports[0].inspect_source() lcond = "convert_int4(((convert_uint4(((uint4)(((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3)))))" rcond = "(convert_uint4(((((int4)(((convert_int(get_local_id(0))))+(1*0), ((convert_int(get_local_id(0))))+(1*1), ((convert_int(get_local_id(0))))+(1*2), ((convert_int(get_local_id(0))))+(1*3))) % ((int4)(3, 3, 3, 3))) == ((int4)(1, 1, 1, 1))))))))" diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index 1a30ab203f04..9e2d18e109f9 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm import tvm.testing from tvm.script import tir as T @@ -46,5 +47,31 @@ def load_vec(A: T.Buffer((N,), "int8")): check_rvv_presence(16, 32) +@tvm.testing.requires_llvm_minimum_version(14) +@tvm.testing.parametrize_targets( + "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_rvv_vscale_llvm_dbginfo(target): + # fmt: off + @T.prim_func + def rvv_with_vscale(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle): + A = T.match_buffer(A_handle, (8,), dtype="float32", align=4, offset_factor=1) + B = T.match_buffer(B_handle, (4, 8), dtype="float32", align=4, offset_factor=1, strides=[8, 1]) + C = T.match_buffer(C_handle, (4,), dtype="float32", align=4, offset_factor=1) + with T.block("root"): + T.reads(A[0:8], B[0:4, 0:8]) + zero = T.call_llvm_intrin("float32xvscalex2", "llvm.riscv.vfmv.v.f", T.Broadcast(T.float32(0.0), T.vscale() * 2), C[0], T.uint64(1)) + vec_A = T.call_llvm_intrin("float32xvscalex4", "llvm.riscv.vle", T.Broadcast(T.float32(0.0), T.vscale() * 4), T.tvm_access_ptr(T.type_annotation("float32"), A.data, 0, 8, 1), T.int64(8)) + vec_B = T.call_llvm_intrin("float32xvscalex4", "llvm.riscv.vle", T.Broadcast(T.float32(0.0), T.vscale() * 4), T.tvm_access_ptr(T.type_annotation("float32"), B.data, 0 * 8, 8, 1), T.int64(8)) + prod = T.call_llvm_intrin("float32xvscalex4", "llvm.riscv.vfmul", T.Broadcast(T.float32(0.0), T.vscale() * 4), vec_A, vec_B, T.uint64(7), T.uint64(8)) + redsum = T.call_llvm_intrin("float32xvscalex2", "llvm.riscv.vfredusum", T.Broadcast(T.float32(0.0), T.vscale() * 2), prod, zero, T.uint64(7), T.uint64(8)) + # fmt: on + + # tvm.error.InternalError: Can't fetch the lanes of a scalable vector at a compile time. + with tvm.target.Target(target): + f = tvm.tir.build(rvv_with_vscale, target) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index a89d71f2be48..cdd84fc57ae1 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -32,8 +32,8 @@ def check_inf_nan(dev, n, value, dtype): sch.bind(xo, "blockIdx.x") sch.bind(xi, "threadIdx.x") fun = tvm.compile(sch.mod, "rocm") - a = tvm.nd.empty((n,), A.dtype, dev) - c = tvm.nd.empty((n,), A.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev) + c = tvm.runtime.empty((n,), A.dtype, dev) # Only need to test compiling here fun(a, c) @@ -53,7 +53,7 @@ def check_rocm(dtype, n): A = te.placeholder((n,), name="A", dtype=dtype) dev = tvm.rocm(0) a_np = np.random.uniform(size=(n,)).astype(A.dtype) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(a_np) + a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(a_np) b_np = a.numpy() tvm.testing.assert_allclose(a_np, b_np) tvm.testing.assert_allclose(a_np, a.numpy()) @@ -79,8 +79,8 @@ def check_rocm(dtype, n, lanes): fun = tvm.compile(sch.mod, target="rocm") dev = tvm.rocm(0) - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) - c = tvm.nd.empty((n,), B.dtype, dev) + a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) + c = tvm.runtime.empty((n,), B.dtype, dev) fun(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) @@ -109,7 +109,7 @@ def func( mod = tvm.compile(func, target="rocm") dev = tvm.rocm(0) - a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(32,)).astype("float32"), dev) mod(a) tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0]) @@ -132,7 +132,7 @@ def func( mod = tvm.compile(func, target="rocm") dev = tvm.rocm(0) - a = tvm.nd.array(np.ones((4,)).astype("float32"), dev) - b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev) + a = tvm.runtime.tensor(np.ones((4,)).astype("float32"), dev) + b = tvm.runtime.tensor(np.zeros((4,)).astype("float32"), dev) mod(a, b) tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy())) diff --git a/tests/python/codegen/test_target_codegen_static_init.py b/tests/python/codegen/test_target_codegen_static_init.py index 4d993e5d6b7b..30161913360a 100644 --- a/tests/python/codegen/test_target_codegen_static_init.py +++ b/tests/python/codegen/test_target_codegen_static_init.py @@ -36,7 +36,7 @@ def test_static_callback(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")) f = tvm.driver.build(mod, target="llvm") - a = tvm.nd.array(np.zeros(10, dtype=dtype)) + a = tvm.runtime.tensor(np.zeros(10, dtype=dtype)) f(a) f(a) np.testing.assert_equal(a.numpy(), np.ones(a.shape[0])) @@ -51,7 +51,7 @@ def test_static_init(): handle = tvm.tir.call_intrin("handle", "tir.tvm_static_handle") ib.emit(tvm.tir.call_packed("test_static_callback", handle, Ab)) - @tvm.register_func("test_static_callback") + @tvm.register_global_func("test_static_callback") def test_cb(sh, A): assert isinstance(sh, ctypes.c_void_p) return sh @@ -59,7 +59,7 @@ def test_cb(sh, A): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")) f = tvm.driver.build(mod, target="llvm") - a = tvm.nd.array(np.zeros(10, dtype=dtype)) + a = tvm.runtime.tensor(np.zeros(10, dtype=dtype)) f(a) diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index a523ae037794..cf7b46692661 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -99,7 +99,7 @@ def test_array_copy(dev, dtype, fuzz_seed): log_arr_size = np.random.uniform(low=np.log(1), high=np.log(32768)) arr_size = np.exp(log_arr_size).astype(int) a_np = np.random.uniform(size=(arr_size,)).astype(dtype) - a = tvm.nd.empty((arr_size,), dtype, dev).copyfrom(a_np) + a = tvm.runtime.empty((arr_size,), dtype, dev).copyfrom(a_np) b_np = a.numpy() tvm.testing.assert_allclose(a_np, b_np) tvm.testing.assert_allclose(a_np, a.numpy()) @@ -123,8 +123,10 @@ def test_array_vectorize_add(target, dev, dtype): sch.bind(xi, "threadIdx.x") f = tvm.compile(sch.mod, target=target) - a = tvm.nd.empty((arr_size,), A.dtype, dev).copyfrom(np.random.uniform(size=(arr_size, lanes))) - c = tvm.nd.empty((arr_size,), B.dtype, dev) + a = tvm.runtime.empty((arr_size,), A.dtype, dev).copyfrom( + np.random.uniform(size=(arr_size, lanes)) + ) + c = tvm.runtime.empty((arr_size,), B.dtype, dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) @@ -146,8 +148,8 @@ def test_vulkan_bool_load(target, dev): a_np = np.random.uniform(size=arr_size) > 0.5 b_np = np.zeros((arr_size,), dtype="int32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) f(a, b) ref = a_np.astype(np.int32) tvm.testing.assert_allclose(b.numpy(), ref) @@ -198,8 +200,8 @@ def test_vulkan_constant_passing(target, dev, vulkan_parameter_impl, vulkan_para n = 1024 scalars = np.array([1 for _ in scalars]).astype(dtype) - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) f_add(*scalars, a, b) tvm.testing.assert_allclose(a.numpy() + sum(scalars), b.numpy()) @@ -244,13 +246,13 @@ def do_compute(A, B, n): # Build func = tvm.compile(sch.mod, target=target) - a = tvm.nd.array(np.array([5], dtype=A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.array([5], dtype=A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=A.dtype), dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), [55]) - a = tvm.nd.array(np.array([-5], dtype=A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.array([-5], dtype=A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=A.dtype), dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), [210]) @@ -295,8 +297,8 @@ def do_compute(A, B, n): n = 32 a_np = np.arange(n).astype(dtype=A.dtype) b_np = np.zeros((n,), dtype="int32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), a_np) @@ -386,9 +388,9 @@ def test_ramp_broadcast_index(self, target, dev, mod, ref_data): f = tvm.compile(mod, target=target) a_np, reorder_np, b_np = ref_data - a = tvm.nd.array(a_np, dev) - r = tvm.nd.array(reorder_np, dev) - b = tvm.nd.array(np.zeros(shape=b_np.shape, dtype="int32"), dev) + a = tvm.runtime.tensor(a_np, dev) + r = tvm.runtime.tensor(reorder_np, dev) + b = tvm.runtime.tensor(np.zeros(shape=b_np.shape, dtype="int32"), dev) f(a, r, b) tvm.testing.assert_allclose(b.numpy(), b_np) @@ -426,7 +428,7 @@ def func(A: T.Buffer((N, 2), "int32")): built = tvm.compile(func, target=target) - a_dev = tvm.nd.empty([N, 2], "int32", dev) + a_dev = tvm.runtime.empty([N, 2], "int32", dev) built(a_dev) a = a_dev.numpy() @@ -538,9 +540,9 @@ def tensorize_load(block, dim): dev = tvm.device(target, 0) - A = tvm.nd.array(np.random.randn(M, K).astype("float16"), dev) - B = tvm.nd.array(np.random.randn(K, N).astype("float16"), dev) - C = tvm.nd.array(np.random.randn(M, N).astype(out_dtype), dev) + A = tvm.runtime.tensor(np.random.randn(M, K).astype("float16"), dev) + B = tvm.runtime.tensor(np.random.randn(K, N).astype("float16"), dev) + C = tvm.runtime.tensor(np.random.randn(M, N).astype(out_dtype), dev) f(A, B, C) @@ -614,8 +616,8 @@ def run_test(tvm_intrin, np_func): else: data = np.random.uniform(0.1, 0.9, size=n) - a = tvm.nd.array(data.astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev) + a = tvm.runtime.tensor(data.astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=A.dtype), dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index c0e1553ea782..e2a15cc60b10 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -71,9 +71,15 @@ def verify(target="llvm"): ) if target == "c": f = compiling(f, name) - matrix_input1 = tvm.nd.array(np.random.uniform(size=ashape).astype(input1_data.dtype), dev) - matrix_input2 = tvm.nd.array(np.random.uniform(size=bshape).astype(input2_data.dtype), dev) - matrix_result = tvm.nd.array(np.zeros((matrix_n, matrix_m), dtype=final_result.dtype), dev) + matrix_input1 = tvm.runtime.tensor( + np.random.uniform(size=ashape).astype(input1_data.dtype), dev + ) + matrix_input2 = tvm.runtime.tensor( + np.random.uniform(size=bshape).astype(input2_data.dtype), dev + ) + matrix_result = tvm.runtime.tensor( + np.zeros((matrix_n, matrix_m), dtype=final_result.dtype), dev + ) matrix_bias = 10.0 f(matrix_input1, matrix_input2, matrix_result, matrix_bias) tvm.testing.assert_allclose( @@ -149,13 +155,15 @@ def verify(target="llvm"): f = tvm.compile( te.create_prim_func([input1_data, input2_data, final_result, bias]), target=target ) - matrix_input1 = tvm.nd.array( + matrix_input1 = tvm.runtime.tensor( np.random.randint(low=0, high=50, size=ashape).astype(input1_data.dtype), dev ) - matrix_input2 = tvm.nd.array( + matrix_input2 = tvm.runtime.tensor( np.random.randint(low=0, high=50, size=bshape).astype(input2_data.dtype), dev ) - matrix_result = tvm.nd.array(np.zeros((matrix_n, matrix_m), dtype=final_result.dtype), dev) + matrix_result = tvm.runtime.tensor( + np.zeros((matrix_n, matrix_m), dtype=final_result.dtype), dev + ) matrix_bias = 10 f(matrix_input1, matrix_input2, matrix_result, matrix_bias) tvm.testing.assert_allclose( @@ -235,9 +243,13 @@ def verify(target="llvm"): ) if target == "c": f = compiling(f, name) - matrix_input1 = tvm.nd.array(np.random.uniform(size=ashape).astype(input1_data.dtype), dev) - matrix_input2 = tvm.nd.array(np.random.uniform(size=bshape).astype(input2_data.dtype), dev) - matrix_result = tvm.nd.array( + matrix_input1 = tvm.runtime.tensor( + np.random.uniform(size=ashape).astype(input1_data.dtype), dev + ) + matrix_input2 = tvm.runtime.tensor( + np.random.uniform(size=bshape).astype(input2_data.dtype), dev + ) + matrix_result = tvm.runtime.tensor( np.zeros((batch, matrix_n, matrix_m), dtype=final_result.dtype), dev ) f(matrix_input1, matrix_input2, matrix_result) diff --git a/tests/python/contrib/test_coreml_runtime.py b/tests/python/contrib/test_coreml_runtime.py index c2284dbe64f6..014a57b28787 100644 --- a/tests/python/contrib/test_coreml_runtime.py +++ b/tests/python/contrib/test_coreml_runtime.py @@ -73,7 +73,7 @@ def verify(coreml_model, model_path, dev): # inference via tvm coreml runtime runtime = coreml_runtime.create("main", model_path, dev) for name in inputs: - runtime.set_input(name, tvm.nd.array(inputs[name], dev)) + runtime.set_input(name, tvm.runtime.tensor(inputs[name], dev)) runtime.invoke() tvm_outputs = [runtime.get_output(i).numpy() for i in range(runtime.get_num_outputs())] diff --git a/tests/python/contrib/test_cutlass_gemm.py b/tests/python/contrib/test_cutlass_gemm.py index 33f7ef1160a1..951085e8530c 100644 --- a/tests/python/contrib/test_cutlass_gemm.py +++ b/tests/python/contrib/test_cutlass_gemm.py @@ -24,7 +24,7 @@ from tvm.contrib.pickle_memoize import memoize -def get_random_ndarray(shape, dtype): +def get_random_tensor(shape, dtype): if dtype == "int8": return np.random.randint(-128, 128, shape).astype(dtype) elif dtype == "uint8": @@ -44,8 +44,8 @@ def verify_group_gemm( def get_ref_data(): assert M % num_groups == 0 M_per_group = M // num_groups - a_np = get_random_ndarray((M, K), x_dtype) - b_np = get_random_ndarray((num_groups, N, K), weight_dtype) + a_np = get_random_tensor((M, K), x_dtype) + b_np = get_random_tensor((num_groups, N, K), weight_dtype) indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group c_np = np.concatenate( [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i in range(num_groups)], @@ -59,13 +59,13 @@ def to_numpy_dtype(dtype): a_np, b_np, indptr_np, c_np = get_ref_data() dev = tvm.cuda(0) - a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev) - b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev) - c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev) - indptr_nd = tvm.nd.array(indptr_np, device=dev) - workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev) + a_nd = tvm.runtime.tensor(a_np.astype(to_numpy_dtype(x_dtype)), device=dev) + b_nd = tvm.runtime.tensor(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev) + c_nd = tvm.runtime.empty(c_np.shape, dtype=out_dtype, device=dev) + indptr_nd = tvm.runtime.tensor(indptr_np, device=dev) + workspace = tvm.runtime.empty((4096 * 1024,), dtype="uint8", device=dev) if use_scale: - scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev) + scale = tvm.runtime.tensor(np.array([1.0], dtype="float32"), device=dev) group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd) else: group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd) @@ -319,12 +319,12 @@ def test_fp8_e4m3_groupwise_scaled_gemm(): x_np, x_scale_np = rowwise_quant_fp8_e4m3((M, K), block_size, dtype) w_np, w_scale_np = blockwise_quant_fp8_e4m3((N, K), block_size, dtype) o_np = blockwise_matmul(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype) - x_tvm = tvm.nd.array(x_np, device=device) - x_scale_tvm = tvm.nd.array(x_scale_np.T, device=device) - w_tvm = tvm.nd.array(w_np, device=device) - w_scale_tvm = tvm.nd.array(w_scale_np, device=device) - workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device) - o_tvm = tvm.nd.empty((M, N), dtype=dtype, device=device) + x_tvm = tvm.runtime.tensor(x_np, device=device) + x_scale_tvm = tvm.runtime.tensor(x_scale_np.T, device=device) + w_tvm = tvm.runtime.tensor(w_np, device=device) + w_scale_tvm = tvm.runtime.tensor(w_scale_np, device=device) + workspace = tvm.runtime.empty((4096 * 1024,), dtype="uint8", device=device) + o_tvm = tvm.runtime.empty((M, N), dtype=dtype, device=device) gemm_func( x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], block_size[1], o_tvm ) @@ -353,12 +353,12 @@ def test_fp8_e4m3_groupwise_scaled_bmm(): x_np, x_scale_np = rowwise_quant_fp8_e4m3((B, M, K), block_size, dtype) w_np, w_scale_np = blockwise_quant_fp8_e4m3((B, N, K), block_size, dtype) o_np = blockwise_bmm(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype) - x_tvm = tvm.nd.array(x_np, device=device) - x_scale_tvm = tvm.nd.array(x_scale_np.transpose(0, 2, 1), device=device) - w_tvm = tvm.nd.array(w_np, device=device) - w_scale_tvm = tvm.nd.array(w_scale_np, device=device) - workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device) - o_tvm = tvm.nd.empty((B, M, N), dtype=dtype, device=device) + x_tvm = tvm.runtime.tensor(x_np, device=device) + x_scale_tvm = tvm.runtime.tensor(x_scale_np.transpose(0, 2, 1), device=device) + w_tvm = tvm.runtime.tensor(w_np, device=device) + w_scale_tvm = tvm.runtime.tensor(w_scale_np, device=device) + workspace = tvm.runtime.empty((4096 * 1024,), dtype="uint8", device=device) + o_tvm = tvm.runtime.empty((B, M, N), dtype=dtype, device=device) gemm_func( x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], block_size[1], o_tvm ) diff --git a/tests/python/contrib/test_dlpack.py b/tests/python/contrib/test_dlpack.py index 421853899979..f0632f3ac7db 100644 --- a/tests/python/contrib/test_dlpack.py +++ b/tests/python/contrib/test_dlpack.py @@ -23,21 +23,19 @@ def verify_torch_dlpack(): a = np.random.randn(1337) - tvm_a = tvm.nd.array(a) - np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).numpy(), a) + tvm_a = tvm.runtime.tensor(a) + np.testing.assert_equal(tvm.runtime.from_dlpack(tvm_a).numpy(), a) try: import torch import torch.utils.dlpack x = torch.rand(56, 56) - tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + tvm_x = tvm.runtime.from_dlpack(torch.utils.dlpack.to_dlpack(x)) np.testing.assert_equal(x.numpy(), tvm_x.numpy()) - y = tvm.nd.from_dlpack(tvm_x) + y = tvm.runtime.from_dlpack(tvm_x) np.testing.assert_equal(y.numpy(), tvm_x.numpy()) - np.testing.assert_equal( - torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.numpy() - ) + np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y).numpy(), tvm_x.numpy()) n = tvm.runtime.convert(137) xx = torch.rand(137, 137) diff --git a/tests/python/contrib/test_edgetpu_runtime.py b/tests/python/contrib/test_edgetpu_runtime.py index 2bf58106dfdc..6fdd1799a1eb 100644 --- a/tests/python/contrib/test_edgetpu_runtime.py +++ b/tests/python/contrib/test_edgetpu_runtime.py @@ -76,7 +76,7 @@ def check_remote(server, target_edgetpu=False): with open(tflite_model_path, "rb") as model_fin: runtime = tflite_runtime.create(model_fin.read(), dev, runtime_target) - runtime.set_input(0, tvm.nd.array(tflite_input, dev)) + runtime.set_input(0, tvm.runtime.tensor(tflite_input, dev)) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.numpy(), tflite_output) diff --git a/tests/python/contrib/test_hexagon/README_RPC.md b/tests/python/contrib/test_hexagon/README_RPC.md index 8d185fcbebeb..c70aa1e99087 100644 --- a/tests/python/contrib/test_hexagon/README_RPC.md +++ b/tests/python/contrib/test_hexagon/README_RPC.md @@ -80,7 +80,7 @@ Which eventually jumps to the following line in C++, which creates a RPC client [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129) ```cpp -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("rpc.Connect", [](ffi::PackedArgs args, ffi::Any* rv) { auto url = args[0].cast(); @@ -89,7 +89,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = RPCClientConnect(url, port, key, ffi::PackedArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); }); -}); +} ``` `tvm.contrib.hexagon.create_hexagon_session` is defined here. It establishes a link between android and hexagon, this code runs on android. @@ -98,7 +98,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ```cpp -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -111,7 +111,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto sess = CreateClientSession(ep); *rv = CreateRPCSessionModule(sess); }); -}); +} ``` `HexagonTransportChannel` is the one that actually knows how to talk to Hexagon. It uses functions such as `hexagon_rpc_send`, `hexagon_rpc_receive` defined in @@ -125,23 +125,23 @@ TVM_FFI_STATIC_INIT_BLOCK({ [https://github.com/apache/tvm/blob/b2757817af7ba3aefe16ea3ccb6d4982dd7fd531/python/tvm/runtime/ndarray.py#L183](https://github.com/apache/tvm/blob/b2757817af7ba3aefe16ea3ccb6d4982dd7fd531/python/tvm/runtime/ndarray.py#L183) ```python -check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes)) +check_call(_LIB.TVMTensorCopyFromBytes(self.handle, data, nbytes)) ``` -[https://github.com/apache/tvm/blob/37cd9837ff302e4490696ca57a9fbba6404c7046/src/runtime/ndarray.cc#L322](https://github.com/apache/tvm/blob/37cd9837ff302e4490696ca57a9fbba6404c7046/src/runtime/ndarray.cc#L322) +[https://github.com/apache/tvm/blob/37cd9837ff302e4490696ca57a9fbba6404c7046/src/runtime/tensor.cc#L322](https://github.com/apache/tvm/blob/37cd9837ff302e4490696ca57a9fbba6404c7046/src/runtime/tensor.cc#L322) ```cpp -int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { +int TVMTensorCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); - ArrayCopyFromBytes(handle, data, nbytes); + TensorCopyFromBytes(handle, data, nbytes); API_END(); } ``` -Now we come to `ArrayCopyFromBytes` function. The first non-obvious question is, which `DeviceAPI` is selected by `DeviceAPI::Get(handle->device)`? +Now we come to `TensorCopyFromBytes` function. The first non-obvious question is, which `DeviceAPI` is selected by `DeviceAPI::Get(handle->device)`? ```cpp -void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { +void TensorCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { ... DLTensor from; ... diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 376cc8c7da12..4718fa7e0671 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -100,9 +100,9 @@ def build_and_run(inputs, func, target: str, target_host: str, *args, **kwargs): dev = tvm.device(target) tensors = [] for tensor in inputs: - tensors.append(tvm.nd.array(tensor, dev)) + tensors.append(tvm.runtime.tensor(tensor, dev)) tensors.append( - tvm.nd.array( + tvm.runtime.tensor( numpy.zeros([i.value for i in placeholders[-1].shape], dtype=placeholders[-1].dtype), dev, ) diff --git a/tests/python/contrib/test_hexagon/pytest_util.py b/tests/python/contrib/test_hexagon/pytest_util.py index c078edf7a934..925c29282b18 100644 --- a/tests/python/contrib/test_hexagon/pytest_util.py +++ b/tests/python/contrib/test_hexagon/pytest_util.py @@ -140,7 +140,7 @@ def get_numpy_dtype_info(dtype) -> Union[np.finfo, np.iinfo]: TensorContentDtypeMax = collections.namedtuple("TensorContentDtypeMax", []) -def create_populated_numpy_ndarray( +def create_populated_numpy_tensor( input_shape: Union[list, tuple], dtype: str, input_tensor_populator ) -> np.ndarray: """ diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 965795d29e02..e5fc783510ac 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test different strategies for loading data into vtcm before running HVX workloads. """ +"""Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np import pytest @@ -281,19 +281,15 @@ def evaluate( ) module = hexagon_session.load_module(func_tir) - a_hexagon = tvm.runtime.ndarray.array(a_data, device=hexagon_session.device) - b_hexagon = tvm.runtime.ndarray.array(b_data, device=hexagon_session.device) - c_hexagon = tvm.runtime.ndarray.array(c_data, device=hexagon_session.device) + a_hexagon = tvm.runtime.tensor(a_data, device=hexagon_session.device) + b_hexagon = tvm.runtime.tensor(b_data, device=hexagon_session.device) + c_hexagon = tvm.runtime.tensor(c_data, device=hexagon_session.device) if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) else: - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) time = timer(a_hexagon, b_hexagon, c_hexagon) if expected_output is not None: diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py index d3adbc12c922..dc77b7ad39a4 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py @@ -242,9 +242,9 @@ def _benchmark_hexagon_elementwise_add_kernel( ) # Create the target-side tensors to hold the primfunc's inputs and outputs... - input1_data = tvm.nd.empty(shape, dtype, hexagon_session.device, mem_scope) - input2_data = tvm.nd.empty(shape, dtype, hexagon_session.device, mem_scope) - output_data = tvm.nd.empty(shape, dtype, hexagon_session.device, mem_scope) + input1_data = tvm.runtime.empty(shape, dtype, hexagon_session.device, mem_scope) + input2_data = tvm.runtime.empty(shape, dtype, hexagon_session.device, mem_scope) + output_data = tvm.runtime.empty(shape, dtype, hexagon_session.device, mem_scope) # Populate the primfunc's input tensors... input1_data.copyfrom(host_numpy_input1_data) diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py b/tests/python/contrib/test_hexagon/test_dma_builtin.py index 479b680065e1..1592bd020fd6 100644 --- a/tests/python/contrib/test_hexagon/test_dma_builtin.py +++ b/tests/python/contrib/test_hexagon/test_dma_builtin.py @@ -164,8 +164,8 @@ def test_vtcm_alloc_compute(self, hexagon_launcher, mode, module): vm_rt = relax.VirtualMachine( vm_mod, dev, "naive" ) # Use naive allocator to exercise VTCM allocation in relax - data0 = tvm.nd.array(input_arg0_data, dev) - data1 = tvm.nd.array(input_arg1_data, dev) + data0 = tvm.runtime.tensor(input_arg0_data, dev) + data1 = tvm.runtime.tensor(input_arg1_data, dev) vm_rt.set_input("main", data0, data1) vm_rt.invoke_stateful("main") hexagon_output = vm_rt.get_outputs("main").numpy() diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index c7f9d2a00fed..5d9f4128d172 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -174,9 +174,9 @@ def verify_dense(sch, target, m_size, n_size, k_size, hexagon_session): k_output * 4 + t_idx ] - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(pack_width, dev) - c = tvm.nd.array(np.zeros((m_size, n_size), dtype="int32"), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(pack_width, dev) + c = tvm.runtime.tensor(np.zeros((m_size, n_size), dtype="int32"), dev) mod(a, b, c) np.testing.assert_equal(c.numpy(), c_np) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 6e1b7db4d5c5..6abfa812175f 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -148,17 +148,15 @@ def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch): b = np.random.randint(0, 16, b_shape, dtype=b_dtype) c = np.zeros(c_shape, dtype=c_dtype) - a_hexagon = tvm.runtime.ndarray.array(a, device=hexagon_session.device) - b_hexagon = tvm.runtime.ndarray.array(b, device=hexagon_session.device) - c_hexagon = tvm.runtime.ndarray.array(c, device=hexagon_session.device) + a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device) + b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device) + c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device) # These are reduced for CI but number=100 and repeat=10 does a good job of removing noise. number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b)) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index a0b94d89cfa6..ceabc6355732 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test different strategies for loading data into vtcm before running HVX workloads. """ +"""Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np import tvm @@ -318,17 +318,15 @@ def setup_and_run(hexagon_session, sch, a, b, c, operations, mem_scope="global") func_tir = tvm.compile(sch.mod["main"], target=get_hexagon_target("v69")) module = hexagon_session.load_module(func_tir) - a_hexagon = tvm.runtime.ndarray.array(a, device=hexagon_session.device, mem_scope=mem_scope) - b_hexagon = tvm.runtime.ndarray.array(b, device=hexagon_session.device, mem_scope=mem_scope) - c_hexagon = tvm.runtime.ndarray.array(c, device=hexagon_session.device, mem_scope=mem_scope) + a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device, mem_scope=mem_scope) + b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device, mem_scope=mem_scope) + c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device, mem_scope=mem_scope) # These are reduced for CI but number=100 and repeat=10 does a good job of removing noise. number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() @@ -343,16 +341,16 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations): b_vtcm = np.zeros((b.size), dtype="uint8") c_vtcm = np.zeros((c.size), dtype="int32") - a_hexagon = tvm.runtime.ndarray.array(a, device=hexagon_session.device, mem_scope="global") - b_hexagon = tvm.runtime.ndarray.array(b, device=hexagon_session.device, mem_scope="global") - c_hexagon = tvm.runtime.ndarray.array(c, device=hexagon_session.device, mem_scope="global") - a_vtcm_hexagon = tvm.runtime.ndarray.array( + a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device, mem_scope="global") + b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device, mem_scope="global") + c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device, mem_scope="global") + a_vtcm_hexagon = tvm.runtime.tensor( a_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" ) - b_vtcm_hexagon = tvm.runtime.ndarray.array( + b_vtcm_hexagon = tvm.runtime.tensor( b_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" ) - c_vtcm_hexagon = tvm.runtime.ndarray.array( + c_vtcm_hexagon = tvm.runtime.tensor( c_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" ) @@ -360,9 +358,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, b_vtcm_hexagon, c_vtcm_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index dd765178dc32..60731a8febe0 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test parallelism for multiple different scalar workloads. """ +"""Test parallelism for multiple different scalar workloads.""" import numpy as np @@ -96,17 +96,15 @@ def evaluate(hexagon_session, operations, expected, sch): b = np.random.random(shape).astype(dtype) c = np.zeros(shape, dtype=dtype) - a_hexagon = tvm.runtime.ndarray.array(a, device=hexagon_session.device) - b_hexagon = tvm.runtime.ndarray.array(b, device=hexagon_session.device) - c_hexagon = tvm.runtime.ndarray.array(c, device=hexagon_session.device) + a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device) + b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device) + c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device) # These are reduced for CI but number=100 and repeat=10 does a good job of removing noise. number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected(a, b)) diff --git a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py index 42038b97f90e..8a56e91581cb 100644 --- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py +++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py @@ -86,7 +86,7 @@ def test_alloc_storage_with_scope_global(hexagon_launcher): vm_mod = session.get_executor_from_factory(lib) # This is the important line which tests nd allocator vm_rt = relax.VirtualMachine(vm_mod, dev, memory_cfg="naive") - x = tvm.nd.array(arg0, dev) + x = tvm.runtime.tensor(arg0, dev) vm_rt.set_input("main", x) vm_rt.invoke_stateful("main") hexagon_output = vm_rt.get_outputs("main").numpy() diff --git a/tests/python/contrib/test_hexagon/test_relax_integration.py b/tests/python/contrib/test_hexagon/test_relax_integration.py index 5e1bfac3625e..4a3d122ce0fb 100644 --- a/tests/python/contrib/test_hexagon/test_relax_integration.py +++ b/tests/python/contrib/test_hexagon/test_relax_integration.py @@ -57,7 +57,7 @@ def test_mobilenet_onnx(hexagon_session: Session): vm_mod = hexagon_session.get_executor_from_factory(exe) vm_rt = relax.VirtualMachine(vm_mod, dev) - data = tvm.nd.array(data_np, dev) + data = tvm.runtime.tensor(data_np, dev) vm_rt.set_input("main", data) vm_rt.invoke_stateful("main") hexagon_res = vm_rt.get_outputs("main") @@ -67,7 +67,7 @@ def test_mobilenet_onnx(hexagon_session: Session): exe = tvm.compile(relax_mod, "llvm") dev = tvm.cpu() vm_rt = relax.VirtualMachine(exe, dev) - data = tvm.nd.array(data_np, dev) + data = tvm.runtime.tensor(data_np, dev) llvm_res = vm_rt["main"](data) tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) @@ -91,7 +91,7 @@ def test_mobilenet(hexagon_session: Session): vm_mod = hexagon_session.get_executor_from_factory(exe) vm_rt = relax.VirtualMachine(vm_mod, dev) - data = tvm.nd.array(data_np, dev) + data = tvm.runtime.tensor(data_np, dev) vm_rt.set_input("main", data) vm_rt.invoke_stateful("main") hexagon_res = vm_rt.get_outputs("main") @@ -101,7 +101,7 @@ def test_mobilenet(hexagon_session: Session): exe = tvm.compile(relax_mod, "llvm") dev = tvm.cpu() vm_rt = relax.VirtualMachine(exe, dev) - data = tvm.nd.array(data_np, dev) + data = tvm.runtime.tensor(data_np, dev) llvm_res = vm_rt["main"](data) tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 3be9683a7deb..b4d2aed433b9 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -89,7 +89,7 @@ def check(out, ref): if "int" in dtype: np.testing.assert_equal(out.numpy(), ref) else: - np.testing.assert_allclose(out.numpy(), ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-3, atol=1e-3) return check @@ -188,12 +188,12 @@ def test_async_software_pipeline( with hexagon_launcher.create_session() as hexagon_session: dev = hexagon_session.device mod = hexagon_session.load_module(func) - out = tvm.nd.array(out_np, device=dev) - a = tvm.nd.array(a_np, device=dev) + out = tvm.runtime.tensor(out_np, device=dev) + a = tvm.runtime.tensor(a_np, device=dev) if comp_type == "single_input": mod(a, out) else: - b = tvm.nd.array(b_np, device=dev) + b = tvm.runtime.tensor(b_np, device=dev) mod(a, b, out) verify(out, ref) diff --git a/tests/python/contrib/test_hexagon/test_take.py b/tests/python/contrib/test_hexagon/test_take.py index 15058e17af5a..4f6169b48ca7 100644 --- a/tests/python/contrib/test_hexagon/test_take.py +++ b/tests/python/contrib/test_hexagon/test_take.py @@ -322,7 +322,7 @@ def abs( # Quantizing input : scale is returned as float64 and zp is returned as int32 inp_quant, inp_scale, inp_zero_point = quantize_np(data, dtype) -inp_quant = tvm.nd.array(inp_quant.astype(np.uint8)) +inp_quant = tvm.runtime.tensor(inp_quant.astype(np.uint8)) # Test the implementations value output with numpy data. First the IR is runn through pass diff --git a/tests/python/contrib/test_hexagon/test_thread_pool.py b/tests/python/contrib/test_hexagon/test_thread_pool.py index 2dc426749680..f61a2560cfad 100644 --- a/tests/python/contrib/test_hexagon/test_thread_pool.py +++ b/tests/python/contrib/test_hexagon/test_thread_pool.py @@ -60,9 +60,9 @@ def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32): def generate_add_test_data(hexagon_session: Session, n=128 * 1024): - a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device) - b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device) - c = tvm.nd.array(np.zeros(n, dtype="float32"), hexagon_session.device) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), hexagon_session.device) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), hexagon_session.device) + c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), hexagon_session.device) return (a, b, c, n) diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index 265f2bf5fd2d..42fca9c153aa 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -101,20 +101,16 @@ def evaluate(hexagon_session, sch, size): a = np.random.randint(-128, 127, a_shape, dtype="int8") a_vtcm = np.zeros(a_shape, dtype="int8") - a_hexagon = tvm.runtime.ndarray.array(a, device=hexagon_session.device, mem_scope="global") - a_vtcm_hexagon = tvm.runtime.ndarray.array( + a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device, mem_scope="global") + a_vtcm_hexagon = tvm.runtime.tensor( a_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" ) if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) else: - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) runtime = timer(a_hexagon, a_vtcm_hexagon) diff --git a/tests/python/contrib/test_hipblas.py b/tests/python/contrib/test_hipblas.py index 33187fa4efba..d285dd45491d 100644 --- a/tests/python/contrib/test_hipblas.py +++ b/tests/python/contrib/test_hipblas.py @@ -36,9 +36,9 @@ def verify(target="rocm"): return dev = tvm.rocm(0) f = tvm.compile(te.create_prim_func([A, B, C]), target=target) - a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros((n, m), dtype=C.dtype), dev) f(a, b, c) tvm.testing.assert_allclose( c.numpy(), np.dot(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)), rtol=rtol @@ -60,13 +60,13 @@ def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): f = tvm.compile(te.create_prim_func([A, B, C]), target="rocm") if "int" in in_dtype: - a = tvm.nd.array(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev) - b = tvm.nd.array(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev) else: - a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=Ashape).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=Bshape).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev) + c = tvm.runtime.tensor(np.zeros(Cshape, dtype=C.dtype), dev) f(a, b, c) tvm.testing.assert_allclose( c.numpy(), diff --git a/tests/python/contrib/test_mps.py b/tests/python/contrib/test_mps.py index 41847f3b8fea..cc459e81f51d 100644 --- a/tests/python/contrib/test_mps.py +++ b/tests/python/contrib/test_mps.py @@ -36,9 +36,9 @@ def verify(A, B, C): return dev = tvm.metal(0) f = tvm.compile(te.create_prim_func([A, B, C]), target="metal") - a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(n, l)).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=(l, m)).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros((n, m), dtype=C.dtype), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), np.dot(a.numpy(), b.numpy()), rtol=1e-5) @@ -65,9 +65,9 @@ def verify(A, B, C, target="llvm"): return dev = tvm.metal(0) f = tvm.compile(te.create_prim_func([A, B, C]), target="metal") - a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), dev) f(a, b, c) verify(A, B, C, s1) diff --git a/tests/python/contrib/test_msc/test_plugin.py b/tests/python/contrib/test_msc/test_plugin.py index 3cacb8a646ba..1feeed2a7c84 100644 --- a/tests/python/contrib/test_msc/test_plugin.py +++ b/tests/python/contrib/test_msc/test_plugin.py @@ -241,7 +241,7 @@ def _get_tvm_model(tvm_manager): data = block_builder.emit_output(data) block_builder.emit_func_output(data) mod = block_builder.finalize() - return BindParams("main", {"weight": tvm.nd.array(weights)})(mod) + return BindParams("main", {"weight": tvm.runtime.tensor(weights)})(mod) def _build_plugin(frameworks, plugin_root): @@ -264,7 +264,7 @@ def _run_relax(relax_mod, target_name, data): with tvm.transform.PassContext(opt_level=3): relax_exec = tvm.compile(relax_mod, target) runnable = tvm.relax.VirtualMachine(relax_exec, device) - data = tvm.nd.array(data, device) + data = tvm.runtime.tensor(data, device) return runnable["main"](data).numpy() diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 41e8f0e44e64..0a8be3df11a0 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -40,7 +40,7 @@ def verify_model(torch_model, input_info, opt_config=None): args = [msc_utils.random_data(i, MSCFramework.TVM) for i in input_info] def _tvm_runtime_to_np(obj): - if isinstance(obj, tvm.runtime.NDArray): + if isinstance(obj, tvm.runtime.Tensor): return obj.numpy() elif isinstance(obj, tvm.runtime.ShapeTuple): return np.array(obj, dtype="int64") diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index a3eaae09afbc..66b56210c233 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -47,7 +47,7 @@ def build_and_run(mod, inputs): rt_mod = tvm.compile(mod, target) runnable = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) res = runnable["main"](*inputs) - if isinstance(res, tvm.runtime.NDArray): + if isinstance(res, tvm.runtime.Tensor): return [res.numpy()] return [e.numpy() for e in res] @@ -104,7 +104,7 @@ def verify_model(torch_model, input_info, **trans_config): output_folder = msc_utils.msc_dir() # tranalte to tensorrt mod = codegen.to_tensorrt(mod, graphs, weights, output_folder=output_folder) - tvm_datas = [tvm.nd.array(i, device=tvm.cuda()) for i in datas] + tvm_datas = [tvm.runtime.tensor(i, device=tvm.cuda()) for i in datas] results = build_and_run(mod, tvm_datas) for gol, res in zip(golden, results): tvm.testing.assert_allclose(gol, res, atol=1e-3, rtol=1e-3) diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index c8c8054dfb6b..10091cb9adff 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -40,7 +40,7 @@ def verify(target="llvm"): return dev = tvm.cpu(0) f = tvm.compile(te.create_prim_func([A]), target=target) - a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.zeros((m, n), dtype=A.dtype), dev) f(a) na = a.numpy() assert abs(np.mean(na)) < 0.3 @@ -65,7 +65,7 @@ def verify(target="llvm"): return dev = tvm.cpu(0) f = tvm.compile(te.create_prim_func([A]), target=target) - a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.zeros((m, n), dtype=A.dtype), dev) f(a) na = a.numpy() assert abs(np.mean(na) - 0.5) < 1e-1 @@ -90,7 +90,7 @@ def verify(target="llvm"): return dev = tvm.cpu(0) f = tvm.compile(te.create_prim_func([A]), target=target) - a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.zeros((m, n), dtype=A.dtype), dev) f(a) na = a.numpy() assert abs(np.mean(na) - 3) < 1e-1 @@ -107,7 +107,7 @@ def test_local(dev, dtype): if not tvm.get_global_func("tvm.contrib.random.random_fill", True): print("skip because extern function is not available") return - value = tvm.nd.empty((512, 512), dtype, dev) + value = tvm.runtime.empty((512, 512), dtype, dev) random_fill = tvm.get_global_func("tvm.contrib.random.random_fill") random_fill(value) @@ -126,7 +126,7 @@ def test_rpc(dtype): def check_remote(server): remote = rpc.connect(server.host, server.port) - value = tvm.nd.empty((512, 512), dtype, remote.cpu()) + value = tvm.runtime.empty((512, 512), dtype, remote.cpu()) random_fill = remote.get_function("tvm.contrib.random.random_fill") random_fill(value) @@ -170,7 +170,7 @@ def test_body(): configure_threads = tvm.get_global_func("runtime.config_threadpool") configure_threads(1, num_thread_used) - test_input = tvm.runtime.ndarray.empty((10, 10)) + test_input = tvm.runtime.empty((10, 10)) random_fill = tvm.get_global_func("tvm.contrib.random.random_fill_for_measure") random_fill(test_input) except: # pylint: disable=bare-except diff --git a/tests/python/contrib/test_rocblas.py b/tests/python/contrib/test_rocblas.py index a715a5bb4a74..6b57395ce847 100644 --- a/tests/python/contrib/test_rocblas.py +++ b/tests/python/contrib/test_rocblas.py @@ -40,9 +40,9 @@ def verify(target="rocm"): return dev = tvm.rocm(0) f = tvm.compile(te.create_prim_func([A, B, C]), target=target) - a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(n, l)).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=(l, m)).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros((n, m), dtype=C.dtype), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), np.dot(a.numpy(), b.numpy()), rtol=1e-5) @@ -73,9 +73,9 @@ def verify(target="rocm"): return dev = tvm.rocm(0) f = tvm.compile(te.create_prim_func([A, B, C]), target=target) - a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((batch, m, n), dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=ashape).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=bshape).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros((batch, m, n), dtype=C.dtype), dev) f(a, b, c) tvm.testing.assert_allclose( c.numpy(), get_numpy(a.numpy(), b.numpy(), transa, transb), rtol=1e-5 diff --git a/tests/python/contrib/test_rpc_tracker.py b/tests/python/contrib/test_rpc_tracker.py index f6918db4e286..8dbc1c700412 100644 --- a/tests/python/contrib/test_rpc_tracker.py +++ b/tests/python/contrib/test_rpc_tracker.py @@ -31,7 +31,7 @@ def check_server_drop(): # pylint: disable=import-outside-toplevel from tvm.rpc.base import TrackerCode - @tvm.register_func("rpc.test2.addone") + @tvm.register_global_func("rpc.test2.addone") def addone(x): return x + 1 diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index a853df569498..aa80cf484823 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -53,9 +53,9 @@ def test_sort(): dev = tvm.cpu(0) target = "llvm" f = tvm.compile(te.create_prim_func([data, sort_num, out]), target=target) - a = tvm.nd.array(np.array(input_data).astype(data.dtype), dev) - b = tvm.nd.array(np.array(sort_num_input).astype(sort_num.dtype), dev) - c = tvm.nd.array(np.zeros(a.shape, dtype=out.dtype), dev) + a = tvm.runtime.tensor(np.array(input_data).astype(data.dtype), dev) + b = tvm.runtime.tensor(np.array(sort_num_input).astype(sort_num.dtype), dev) + c = tvm.runtime.tensor(np.zeros(a.shape, dtype=out.dtype), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), np.array(sorted_index).astype(out.dtype), rtol=1e-5) @@ -85,9 +85,9 @@ def test_sort_np(): np_data = np.random.uniform(size=dshape) np_out = np.argsort(np_data, axis=axis) sort_num_input = np.full(reduced_shape, dshape[axis]) - a = tvm.nd.array(np.array(np_data).astype(data.dtype), dev) - b = tvm.nd.array(np.array(sort_num_input).astype(sort_num.dtype), dev) - c = tvm.nd.array(np.zeros(a.shape, dtype=out.dtype), dev) + a = tvm.runtime.tensor(np.array(np_data).astype(data.dtype), dev) + b = tvm.runtime.tensor(np.array(sort_num_input).astype(sort_num.dtype), dev) + c = tvm.runtime.tensor(np.zeros(a.shape, dtype=out.dtype), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), np_out, rtol=1e-5) diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 9938f85cd563..f75156fa0467 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -92,7 +92,7 @@ def test_local(): # inference via tvm tflite runtime with open(tflite_model_path, "rb") as model_fin: runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.set_input(0, tvm.runtime.tensor(tflite_input)) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.numpy(), tflite_output) @@ -138,7 +138,7 @@ def check_remote(server): with open(tflite_model_path, "rb") as model_fin: runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.set_input(0, tvm.runtime.tensor(tflite_input, remote.cpu(0))) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.numpy(), tflite_output) diff --git a/tests/python/contrib/test_tir_triton_integration.py b/tests/python/contrib/test_tir_triton_integration.py index b349d2fabce5..95ccf28fbddb 100644 --- a/tests/python/contrib/test_tir_triton_integration.py +++ b/tests/python/contrib/test_tir_triton_integration.py @@ -110,8 +110,8 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle): assert len(Module.get_attr("external_mods")) == 1 device = tvm.cuda(0) - x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) - y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + x_nd = tvm.runtime.tensor(np.random.rand(256).astype(np.float32), device) + y_nd = tvm.runtime.tensor(np.random.rand(256).astype(np.float32), device) output_np = x_nd.numpy() + y_nd.numpy() with tvm.target.Target("cuda"): diff --git a/tests/python/contrib/test_tvmjs.py b/tests/python/contrib/test_tvmjs.py index 22742ec224ef..4de1b6c9850c 100644 --- a/tests/python/contrib/test_tvmjs.py +++ b/tests/python/contrib/test_tvmjs.py @@ -52,8 +52,8 @@ def test_save_load_float8(dtype): arr = np.arange(16, dtype=np_dtype) with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: - tvmjs.dump_ndarray_cache({"arr": arr}, temp_dir) - cache, _ = tvmjs.load_ndarray_cache(temp_dir, tvm.cpu()) + tvmjs.dump_tensor_cache({"arr": arr}, temp_dir) + cache, _ = tvmjs.load_tensor_cache(temp_dir, tvm.cpu()) after_roundtrip = cache["arr"].numpy() diff --git a/tests/python/contrib/test_util.py b/tests/python/contrib/test_util.py index d22ce14b291e..10360422e93a 100644 --- a/tests/python/contrib/test_util.py +++ b/tests/python/contrib/test_util.py @@ -17,8 +17,10 @@ """Tests for functions in tvm/python/tvm/contrib/util.py.""" import datetime +import multiprocessing as mp import os import shutil +import tempfile from tvm.contrib import utils @@ -32,6 +34,17 @@ def validate_debug_dir_path(temp_dir, expected_basename): assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60) +def _create_debug_tempdir(root_dir): + from tvm.contrib import utils as worker_utils + + worker_utils.TempDirectory._DEBUG_PARENT_DIR = None + worker_utils.TempDirectory._NUM_TEMPDIR_CREATED = 0 + worker_utils.tempfile.gettempdir = lambda: root_dir + + temp_dir = worker_utils.tempdir(keep_for_debug=True) + return temp_dir.temp_dir + + def test_tempdir(): """Tests for temporary dir""" assert utils.TempDirectory._KEEP_FOR_DEBUG is False, "don't submit with KEEP_FOR_DEBUG == True" @@ -85,5 +98,24 @@ def test_tempdir(): utils.TempDirectory.TEMPDIRS = old_tempdirs +def test_tempdir_debug_parent_dir_is_multiprocess_safe(): + root_dir = tempfile.mkdtemp(prefix="tvm-util-tempdir-") + try: + ctx = mp.get_context("spawn") + with ctx.Pool(8) as pool: + temp_dirs = pool.map(_create_debug_tempdir, [root_dir] * 32) + assert len(temp_dirs) == 32 + assert len(set(temp_dirs)) == 32 + assert os.path.isdir(os.path.join(root_dir, "tvm-debug-mode-tempdirs")) + finally: + shutil.rmtree(root_dir, ignore_errors=True) + + +def test_tempdir_remove_tolerates_partial_initialization(): + temp_dir = object.__new__(utils.TempDirectory) + temp_dir.remove() + temp_dir.__del__() + + if __name__ == "__main__": test_tempdir() diff --git a/tests/python/disco/test_callback.py b/tests/python/disco/test_callback.py index d0defa15b869..8e78058331a5 100644 --- a/tests/python/disco/test_callback.py +++ b/tests/python/disco/test_callback.py @@ -91,7 +91,7 @@ def transform_params( params = transform_params(worker_id, fget_item) # Worker 0 is the same PID as the controlling scope, so - # `debug_get_from_remote(0)` returns the NDArray containing + # `debug_get_from_remote(0)` returns the Tensor containing # the output. params_gpu0 = params.debug_get_from_remote(0) assert params_gpu0[0].device == tvm.cuda(0) @@ -109,7 +109,7 @@ def transform_params( ) # Worker 1 is a different PID altogether, so - # `debug_get_from_remote(1)` returns a new NDArray within the + # `debug_get_from_remote(1)` returns a new Tensor within the # calling scope's PID. params_gpu1 = params.debug_get_from_remote(1) assert params_gpu1[0].device == tvm.cpu() diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index 649b865b6c3b..8a1518765fb2 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -491,9 +491,9 @@ def relax_build(mod, target): W1 = np.random.randn(128, 128).astype("float32") W2 = np.random.randn(128, 128).astype("float32") Y_expected = VirtualMachine(relax_build(MLP, target), device=dev)["main"]( - tvm.nd.array(X, device=dev), - tvm.nd.array(W1, device=dev), - tvm.nd.array(W2, device=dev), + tvm.runtime.tensor(X, device=dev), + tvm.runtime.tensor(W1, device=dev), + tvm.runtime.tensor(W2, device=dev), ).numpy() with tempfile.TemporaryDirectory() as tmpdir: @@ -512,12 +512,12 @@ def relax_build(mod, target): d_W2.debug_copy_from(0, W2[:64, :]) d_W2.debug_copy_from(1, W2[64:, :]) d_Y = mod["main"](d_X, d_W1, d_W2) - Y_result = tvm.nd.empty((128, 128), "float32", device=dev) + Y_result = tvm.runtime.empty((128, 128), "float32", device=dev) sess.copy_from_worker_0(Y_result, d_Y) sess.sync_worker_0() Y_result = Y_result.numpy() # pylint: enable=invalid-name - np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4) @pytest.mark.parametrize("session_kind", _all_session_kinds) @@ -632,11 +632,11 @@ def relax_build(mod, target): Wv = np.random.randn(128, 512).astype("float32") Wo = np.random.randn(512, 128).astype("float32") Y_expected = VirtualMachine(relax_build(Attention, target), device=dev)["main"]( - tvm.nd.array(X, device=dev), - tvm.nd.array(Wq, device=dev), - tvm.nd.array(Wk, device=dev), - tvm.nd.array(Wv, device=dev), - tvm.nd.array(Wo, device=dev), + tvm.runtime.tensor(X, device=dev), + tvm.runtime.tensor(Wq, device=dev), + tvm.runtime.tensor(Wk, device=dev), + tvm.runtime.tensor(Wv, device=dev), + tvm.runtime.tensor(Wo, device=dev), ).numpy() with tempfile.TemporaryDirectory() as tmpdir: @@ -661,12 +661,12 @@ def relax_build(mod, target): d_Wo.debug_copy_from(0, Wo[:256, :]) d_Wo.debug_copy_from(1, Wo[256:, :]) d_Y = mod["main"](d_X, d_Wq, d_Wk, d_Wv, d_Wo) - Y_result = tvm.nd.empty((1, 10, 128), "float32", device=dev) + Y_result = tvm.runtime.empty((1, 10, 128), "float32", device=dev) sess.copy_from_worker_0(Y_result, d_Y) sess.sync_worker_0() Y_result = Y_result.numpy() # pylint: enable=invalid-name - np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(Y_result, Y_expected, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py index 5089336f09d3..a68f53917603 100644 --- a/tests/python/disco/test_loader.py +++ b/tests/python/disco/test_loader.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import dlight as dl from tvm import relax as rx -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.contrib import tvmjs from tvm.runtime import ShapeTuple from tvm.runtime import disco as di @@ -35,19 +35,19 @@ from tvm.contrib import tvmjs -@register_func("tests.disco.shard_dim_0", override=True) +@register_global_func("tests.disco.shard_dim_0", override=True) def _shard_dim_0(src, num_shards, tgt): s_0, s_1 = src.shape tgt.copyfrom(src.numpy().reshape(num_shards, s_0 // num_shards, s_1)) -@register_func("tests.disco.shard_dim_1", override=True) +@register_global_func("tests.disco.shard_dim_1", override=True) def _shard_dim_1(src, num_shards, tgt): s_0, s_1 = src.shape tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards).transpose(1, 0, 2)) -@register_func("tests.disco.shard_qkv_0", override=True) +@register_global_func("tests.disco.shard_qkv_0", override=True) def _shard_qkv_0(src, num_shards, q_heads, kv_heads, tgt): total_dim, hidden_size = src.shape head_dim = total_dim // (q_heads + kv_heads + kv_heads) @@ -75,19 +75,19 @@ def _shard_qkv_0(src, num_shards, q_heads, kv_heads, tgt): tgt.copyfrom(w_qkv) -@register_func("tests.disco.shard_qkv_1", override=True) +@register_global_func("tests.disco.shard_qkv_1", override=True) def _shard_qkv_1(src, tgt): s, _, _, h = src.shape # pylint: disable=invalid-name tgt.copyfrom(src.numpy().reshape(s, -1, h)) def _create_loader(sess, path, param_dict, shard_info): - path_ndarray_cache = path + "/ndarray-cache.json" - tvmjs.dump_ndarray_cache(param_dict, path, encode_format="raw") - with open(path_ndarray_cache, "r", encoding="utf-8") as i_f: - ndarray_cache = i_f.read() + path_tensor_cache = path + "/tensor-cache.json" + tvmjs.dump_tensor_cache(param_dict, path, encode_format="raw") + with open(path_tensor_cache, "r", encoding="utf-8") as i_f: + tensor_cache = i_f.read() loader_create = sess.get_global_func("runtime.disco.ShardLoader") - loader = loader_create(path_ndarray_cache, ndarray_cache, json.dumps(shard_info), None) + loader = loader_create(path_tensor_cache, tensor_cache, json.dumps(shard_info), None) return loader @@ -100,7 +100,8 @@ def _simulate_presharded_weights(base_path, param_dict, num_shards, shard_info): assert key in shard_info, f"ShardInfo lacks shard info about param: {key}" shard_dim = shard_info[key] sharded_params[key] = [ - tvm.nd.array(np_shard) for np_shard in np.split(ndarray, num_shards, axis=shard_dim) + tvm.runtime.tensor(np_shard) + for np_shard in np.split(ndarray, num_shards, axis=shard_dim) ] # Re-order so that the parameter order is sorted first by shard, @@ -113,7 +114,7 @@ def _simulate_presharded_weights(base_path, param_dict, num_shards, shard_info): for key, shards in sharded_params.items() } - tvmjs.dump_ndarray_cache( + tvmjs.dump_tensor_cache( sharded_params, base_path, encode_format="raw", @@ -169,11 +170,11 @@ def test_load_shard(): def _create_presharded_loader(sess, path): - path_ndarray_cache = path + "/ndarray-cache.json" - with open(path_ndarray_cache, "r", encoding="utf-8") as i_f: - ndarray_cache = i_f.read() + path_tensor_cache = path + "/tensor-cache.json" + with open(path_tensor_cache, "r", encoding="utf-8") as i_f: + tensor_cache = i_f.read() loader_create = sess.get_global_func("runtime.disco.ShardLoader") - loader = loader_create(path_ndarray_cache, ndarray_cache, json.dumps({}), None) + loader = loader_create(path_tensor_cache, tensor_cache, json.dumps({}), None) return loader diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index db357c54397b..721115947480 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -37,13 +37,13 @@ def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): x_array = sess.empty(np_array.shape, "float32", device=device) - host_array = tvm.nd.array(np_array, device=device) + host_array = tvm.runtime.tensor(np_array, device=device) sess.copy_to_worker_0(host_array, x_array) return x_array def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype): - host_array = tvm.nd.empty(shape, dtype, device=tvm.cpu()) + host_array = tvm.runtime.empty(shape, dtype, device=tvm.cpu()) sess.copy_from_worker_0(host_array, remote_array) sess.sync_worker_0() return host_array.numpy() @@ -142,14 +142,14 @@ def test_float(session_kind): @pytest.mark.parametrize("session_kind", _all_session_kinds) -def test_ndarray(session_kind): +def test_tensor(session_kind): num_workers = 4 sess = session_kind(num_workers=num_workers) device = tvm.cpu(0) x_np = np.arange(6).astype("float32").reshape([2, 3]) y_np = np.arange(6).astype("float32").reshape([2, 3]) + 1 x_disc = _numpy_to_worker_0(sess, x_np, device=device) - y_disc = sess.get_global_func("tests.disco.add_one_ndarray")(x_disc) + y_disc = sess.get_global_func("tests.disco.add_one_tensor")(x_disc) y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) np.testing.assert_equal(y_nd, y_np) diff --git a/tests/python/driver/test_compile.py b/tests/python/driver/test_compile.py index 1ed4fc67ca6a..25c71b16dd6f 100644 --- a/tests/python/driver/test_compile.py +++ b/tests/python/driver/test_compile.py @@ -47,14 +47,14 @@ def test_compile_tir(): dev = tvm.cpu(0) a_np = np.random.uniform(size=10).astype(np.float32) b_np = np.random.uniform(size=10).astype(np.float32) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(10, dtype=np.float32), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) + c = tvm.runtime.tensor(np.zeros(10, dtype=np.float32), dev) exec_prim(a, b, c) - np.testing.assert_allclose(c.numpy(), a_np + b_np) + tvm.testing.assert_allclose(c.numpy(), a_np + b_np) exec_mod(a, b, c) - np.testing.assert_allclose(c.numpy(), a_np + b_np) + tvm.testing.assert_allclose(c.numpy(), a_np + b_np) def test_compile_relax(): @@ -77,12 +77,12 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")) -> R.Te dev = tvm.cpu(0) x_np = np.random.uniform(size=(3, 4)).astype(np.float32) y_np = np.random.uniform(size=(3, 4)).astype(np.float32) - x = tvm.nd.array(x_np, dev) - y = tvm.nd.array(y_np, dev) + x = tvm.runtime.tensor(x_np, dev) + y = tvm.runtime.tensor(y_np, dev) vm = relax.VirtualMachine(exec_relax, dev) z = vm["main"](x, y) - np.testing.assert_allclose(z.numpy(), x_np + y_np) + tvm.testing.assert_allclose(z.numpy(), x_np + y_np) @tvm.testing.skip_if_32bit(reason="skipping test for i386.") @@ -107,15 +107,15 @@ def main(x: R.Tensor((4,), "float32")): assert isinstance(ex, Executable) dev = tvm.cpu(0) - x = tvm.nd.array(np.array([1, 2, 3, 4], dtype=np.float32), dev) - y = tvm.nd.array(np.zeros(4, dtype=np.float32), dev) + x = tvm.runtime.tensor(np.array([1, 2, 3, 4], dtype=np.float32), dev) + y = tvm.runtime.tensor(np.zeros(4, dtype=np.float32), dev) # For tir function, we can directly call the function ex["add_one"](x, y) - np.testing.assert_allclose(y.numpy(), x.numpy() + 1) + tvm.testing.assert_allclose(y.numpy(), x.numpy() + 1) # For relax function, we need to use the vm to call the function vm = relax.VirtualMachine(ex, dev) z = vm["main"](x) - np.testing.assert_allclose(z.numpy(), x.numpy() + 1) + tvm.testing.assert_allclose(z.numpy(), x.numpy() + 1) if __name__ == "__main__": diff --git a/tests/python/ffi/test_access_path.py b/tests/python/ffi/test_access_path.py deleted file mode 100644 index 06fbb64ff217..000000000000 --- a/tests/python/ffi/test_access_path.py +++ /dev/null @@ -1,133 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -from tvm.ffi.access_path import AccessPath, AccessKind - - -def test_root_path(): - root = AccessPath.root() - assert isinstance(root, AccessPath) - steps = root.to_steps() - assert len(steps) == 0 - assert root == AccessPath.root() - - -def test_path_attr(): - path = AccessPath.root().attr("foo") - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.ATTR - assert steps[0].key == "foo" - assert path.parent == AccessPath.root() - - -def test_path_array_item(): - path = AccessPath.root().array_item(2) - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.ARRAY_ITEM - assert steps[0].key == 2 - assert path.parent == AccessPath.root() - - -def test_path_missing_array_element(): - path = AccessPath.root().array_item_missing(2) - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.ARRAY_ITEM_MISSING - assert steps[0].key == 2 - assert path.parent == AccessPath.root() - - -def test_path_map_item(): - path = AccessPath.root().map_item("foo") - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.MAP_ITEM - assert steps[0].key == "foo" - assert path.parent == AccessPath.root() - - -def test_path_missing_map_item(): - path = AccessPath.root().map_item_missing("foo") - assert isinstance(path, AccessPath) - steps = path.to_steps() - assert len(steps) == 1 - assert steps[0].kind == AccessKind.MAP_ITEM_MISSING - assert steps[0].key == "foo" - assert path.parent == AccessPath.root() - - -def test_path_is_prefix_of(): - # Root is prefix of root - assert AccessPath.root().is_prefix_of(AccessPath.root()) - - # Root is prefix of any path - assert AccessPath.root().is_prefix_of(AccessPath.root().attr("foo")) - - # Non-root is not prefix of root - assert not AccessPath.root().attr("foo").is_prefix_of(AccessPath.root()) - - # Path is prefix of itself - assert AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo")) - - # Different attrs are not prefixes of each other - assert not AccessPath.root().attr("bar").is_prefix_of(AccessPath.root().attr("foo")) - - # Shorter path is prefix of longer path with same start - assert AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("foo").array_item(2)) - - # Longer path is not prefix of shorter path - assert ( - not AccessPath.root().attr("foo").array_item(2).is_prefix_of(AccessPath.root().attr("foo")) - ) - - # Different paths are not prefixes - assert ( - not AccessPath.root().attr("foo").is_prefix_of(AccessPath.root().attr("bar").array_item(2)) - ) - - -def test_path_equal(): - # Root equals root - assert AccessPath.root() == AccessPath.root() - - # Root does not equal non-root paths - assert not (AccessPath.root() == AccessPath.root().attr("foo")) - - # Non-root does not equal root - assert not (AccessPath.root().attr("foo") == AccessPath.root()) - - # Path equals itself - assert AccessPath.root().attr("foo") == AccessPath.root().attr("foo") - - # Different attrs are not equal - assert not (AccessPath.root().attr("bar") == AccessPath.root().attr("foo")) - - # Shorter path does not equal longer path - assert not (AccessPath.root().attr("foo") == AccessPath.root().attr("foo").array_item(2)) - - # Longer path does not equal shorter path - assert not (AccessPath.root().attr("foo").array_item(2) == AccessPath.root().attr("foo")) - - # Different paths are not equal - assert not (AccessPath.root().attr("foo") == AccessPath.root().attr("bar").array_item(2)) diff --git a/tests/python/ffi/test_container.py b/tests/python/ffi/test_container.py deleted file mode 100644 index 25468f452acc..000000000000 --- a/tests/python/ffi/test_container.py +++ /dev/null @@ -1,102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest -import pickle -import tvm.ffi as tvm_ffi - - -def test_array(): - a = tvm_ffi.convert([1, 2, 3]) - assert isinstance(a, tvm_ffi.Array) - assert len(a) == 3 - assert a[-1] == 3 - a_slice = a[-3:-1] - assert (a_slice[0], a_slice[1]) == (1, 2) - - -def test_bad_constructor_init_state(): - """Test when error is raised before __init_handle_by_constructor - - This case we need the FFI binding to gracefully handle both repr - and dealloc by ensuring the chandle is initialized and there is - proper repr code - """ - with pytest.raises(TypeError): - tvm_ffi.Array(1) - - with pytest.raises(AttributeError): - tvm_ffi.Map(1) - - -def test_array_of_array_map(): - a = tvm_ffi.convert([[1, 2, 3], {"A": 5, "B": 6}]) - assert isinstance(a, tvm_ffi.Array) - assert len(a) == 2 - assert isinstance(a[0], tvm_ffi.Array) - assert isinstance(a[1], tvm_ffi.Map) - assert tuple(a[0]) == (1, 2, 3) - assert a[1]["A"] == 5 - assert a[1]["B"] == 6 - - -def test_int_map(): - amap = tvm_ffi.convert({3: 2, 4: 3}) - assert 3 in amap - assert len(amap) == 2 - dd = dict(amap.items()) - assert 3 in dd - assert 4 in dd - assert 5 not in amap - assert tuple(amap.items()) == ((3, 2), (4, 3)) - assert tuple(amap.keys()) == (3, 4) - assert tuple(amap.values()) == (2, 3) - - -def test_str_map(): - data = [] - for i in reversed(range(10)): - data.append((f"a{i}", i)) - amap = tvm_ffi.convert({k: v for k, v in data}) - assert tuple(amap.items()) == tuple(data) - for k, v in data: - assert k in amap - assert amap[k] == v - assert amap.get(k) == v - - assert tuple(k for k in amap) == tuple(k for k, _ in data) - - -def test_key_not_found(): - amap = tvm_ffi.convert({3: 2, 4: 3}) - with pytest.raises(KeyError): - amap[5] - - -def test_repr(): - a = tvm_ffi.convert([1, 2, 3]) - assert str(a) == "[1, 2, 3]" - amap = tvm_ffi.convert({3: 2, 4: 3}) - assert str(amap) == "{3: 2, 4: 3}" - - smap = tvm_ffi.convert({"a": 1, "b": 2}) - assert str(smap) == "{'a': 1, 'b': 2}" - - -def test_serialization(): - a = tvm_ffi.convert([1, 2, 3]) - b = pickle.loads(pickle.dumps(a)) - assert str(b) == "[1, 2, 3]" diff --git a/tests/python/ffi/test_device.py b/tests/python/ffi/test_device.py deleted file mode 100644 index 5800a0c44178..000000000000 --- a/tests/python/ffi/test_device.py +++ /dev/null @@ -1,94 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import pickle -from tvm.ffi import Device -from tvm import ffi as tvm_ffi - - -def test_device(): - device = tvm_ffi.Device("cuda", 0) - assert device.device_type == tvm_ffi.Device.kDLCUDA - assert device.device_id == 0 - assert str(device) == "cuda:0" - assert device.__repr__() == "device(type='cuda', index=0)" - - -def test_device_from_str(): - device = tvm_ffi.device("ext_dev:0") - assert device.device_type == tvm_ffi.Device.kDLExtDev - assert device.device_id == 0 - assert str(device) == "ext_dev:0" - assert device.__repr__() == "device(type='ext_dev', index=0)" - - -@pytest.mark.parametrize( - "dev_str, expected_device_type, expect_device_id", - [ - ("cpu", Device.kDLCPU, 0), - ("cuda", Device.kDLCUDA, 0), - ("cuda:0", Device.kDLCUDA, 0), - ("cuda:3", Device.kDLCUDA, 3), - ("metal:2", Device.kDLMetal, 2), - ], -) -def test_device(dev_str, expected_device_type, expect_device_id): - dev = tvm_ffi.device(dev_str) - assert dev.device_type == expected_device_type - assert dev.device_id == expect_device_id - - -@pytest.mark.parametrize( - "dev_type, dev_id, expected_device_type, expect_device_id", - [ - ("cpu", 0, Device.kDLCPU, 0), - ("cuda", 0, Device.kDLCUDA, 0), - (Device.kDLCUDA, 0, Device.kDLCUDA, 0), - ("cuda", 3, Device.kDLCUDA, 3), - (Device.kDLMetal, 2, Device.kDLMetal, 2), - ], -) -def test_device_with_dev_id(dev_type, dev_id, expected_device_type, expect_device_id): - dev = tvm_ffi.device(dev_type=dev_type, dev_id=dev_id) - assert dev.device_type == expected_device_type - assert dev.device_id == expect_device_id - - -@pytest.mark.parametrize( - "dev_type, dev_id", - [ - ("cpu:0:0", None), - ("cpu:?", None), - ("cpu:", None), - ], -) -def test_deive_type_error(dev_type, dev_id): - with pytest.raises(ValueError): - dev = tvm_ffi.device(dev_type=dev_type, dev_id=dev_id) - - -def test_deive_id_error(): - with pytest.raises(TypeError): - dev = tvm_ffi.device(dev_type="cpu", dev_id="?") - - -def test_device_pickle(): - device = tvm_ffi.device("cuda", 0) - device_pickled = pickle.loads(pickle.dumps(device)) - assert device_pickled.device_type == device.device_type - assert device_pickled.device_id == device.device_id diff --git a/tests/python/ffi/test_dtype.py b/tests/python/ffi/test_dtype.py deleted file mode 100644 index 332d0e1827d8..000000000000 --- a/tests/python/ffi/test_dtype.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import pickle -import numpy as np -import tvm -import tvm.testing -from tvm import ffi as tvm_ffi - - -def test_dtype(): - float32 = tvm_ffi.dtype("float32") - assert float32.__repr__() == "dtype('float32')" - assert type(float32) == tvm_ffi.dtype - x = np.array([1, 2, 3], dtype=float32) - assert x.dtype == float32 - - -@pytest.mark.parametrize( - "dtype_str, expected_size", - [ - ("float32", 4), - ("float32x4", 16), - ("float8_e5m2x4", 4), - ("float6_e2m3fnx4", 3), - ("float4_e2m1fnx4", 2), - ("uint8", 1), - ("bool", 1), - ], -) -def test_dtype_itemsize(dtype_str, expected_size): - dtype = tvm_ffi.dtype(dtype_str) - assert dtype.itemsize == expected_size - - -@pytest.mark.parametrize("dtype_str", ["int32xvscalex4"]) -def test_dtype_itemmize_error(dtype_str): - with pytest.raises(ValueError): - tvm_ffi.dtype(dtype_str).itemsize - - -@pytest.mark.parametrize( - "dtype_str", - [ - "float32", - "float32x4", - "float8_e5m2x4", - "float6_e2m3fnx4", - "float4_e2m1fnx4", - "uint8", - "bool", - ], -) -def test_dtype_pickle(dtype_str): - dtype = tvm_ffi.dtype(dtype_str) - dtype_pickled = pickle.loads(pickle.dumps(dtype)) - assert dtype_pickled.type_code == dtype.type_code - assert dtype_pickled.bits == dtype.bits - assert dtype_pickled.lanes == dtype.lanes - - -@pytest.mark.parametrize("dtype_str", ["float32", "bool"]) -def test_dtype_with_lanes(dtype_str): - dtype = tvm_ffi.dtype(dtype_str) - dtype_with_lanes = dtype.with_lanes(4) - assert dtype_with_lanes.type_code == dtype.type_code - assert dtype_with_lanes.bits == dtype.bits - assert dtype_with_lanes.lanes == 4 - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/ffi/test_error.py b/tests/python/ffi/test_error.py deleted file mode 100644 index e3d02234b580..000000000000 --- a/tests/python/ffi/test_error.py +++ /dev/null @@ -1,113 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import platform -from tvm import ffi as tvm_ffi - - -def test_parse_traceback(): - traceback = """ - File "test.py", line 1, in - File "test.py", line 3, in run_test - """ - parsed = tvm_ffi.error._parse_traceback(traceback) - assert len(parsed) == 2 - assert parsed[0] == ("test.py", 1, "") - assert parsed[1] == ("test.py", 3, "run_test") - - -def test_error_from_cxx(): - test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error") - - try: - test_raise_error("ValueError", "error XYZ") - except ValueError as e: - assert e.__tvm_ffi_error__.kind == "ValueError" - assert e.__tvm_ffi_error__.message == "error XYZ" - assert e.__tvm_ffi_error__.traceback.find("TestRaiseError") != -1 - - fapply = tvm_ffi.convert(lambda f, *args: f(*args)) - - with pytest.raises(TypeError): - fapply(test_raise_error, "TypeError", "error XYZ") - - # wrong number of arguments - with pytest.raises(TypeError): - tvm_ffi.convert(lambda x: x)() - - -@pytest.mark.skipif( - "32bit" in platform.architecture(), - reason="libbacktrace file name support is not available in i386 yet", -) -def test_error_from_nested_pyfunc(): - fapply = tvm_ffi.convert(lambda f, *args: f(*args)) - cxx_test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error") - cxx_test_apply = tvm_ffi.get_global_func("testing.apply") - - record_object = [] - - def raise_error(): - try: - fapply(cxx_test_raise_error, "ValueError", "error XYZ") - except ValueError as e: - assert e.__tvm_ffi_error__.kind == "ValueError" - assert e.__tvm_ffi_error__.message == "error XYZ" - assert e.__tvm_ffi_error__.traceback.find("TestRaiseError") != -1 - record_object.append(e.__tvm_ffi_error__) - raise e - - try: - cxx_test_apply(raise_error) - except ValueError as e: - traceback = e.__tvm_ffi_error__.traceback - assert e.__tvm_ffi_error__.same_as(record_object[0]) - assert traceback.count("TestRaiseError") == 1 - assert traceback.count("TestApply") == 1 - assert traceback.count("") == 1 - pos_cxx_raise = traceback.find("TestRaiseError") - pos_cxx_apply = traceback.find("TestApply") - pos_lambda = traceback.find("") - assert pos_cxx_raise > pos_lambda - assert pos_lambda > pos_cxx_apply - - -def test_error_traceback_update(): - fecho = tvm_ffi.get_global_func("testing.echo") - - def raise_error(): - raise ValueError("error XYZ") - - try: - raise_error() - except ValueError as e: - ffi_error = tvm_ffi.convert(e) - assert ffi_error.traceback.find("raise_error") != -1 - - def raise_cxx_error(): - cxx_test_raise_error = tvm_ffi.get_global_func("testing.test_raise_error") - cxx_test_raise_error("ValueError", "error XYZ") - - try: - raise_cxx_error() - except ValueError as e: - assert e.__tvm_ffi_error__.traceback.find("raise_cxx_error") == -1 - ffi_error1 = tvm_ffi.convert(e) - ffi_error2 = fecho(e) - assert ffi_error1.traceback.find("raise_cxx_error") != -1 - assert ffi_error2.traceback.find("raise_cxx_error") != -1 diff --git a/tests/python/ffi/test_function.py b/tests/python/ffi/test_function.py deleted file mode 100644 index 5a8b4acb1f4e..000000000000 --- a/tests/python/ffi/test_function.py +++ /dev/null @@ -1,163 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import gc -import ctypes -import numpy as np -from tvm import ffi as tvm_ffi - - -def test_echo(): - fecho = tvm_ffi.get_global_func("testing.echo") - assert isinstance(fecho, tvm_ffi.Function) - # test each type - assert fecho(None) is None - - # test bool - bool_result = fecho(True) - assert isinstance(bool_result, bool) - assert bool_result is True - bool_result = fecho(False) - assert isinstance(bool_result, bool) - assert bool_result is False - - # test int/float - assert fecho(1) == 1 - assert fecho(1.2) == 1.2 - - # test str - str_result = fecho("hello") - assert isinstance(str_result, str) - assert str_result == "hello" - - # test bytes - bytes_result = fecho(b"abc") - assert isinstance(bytes_result, bytes) - assert bytes_result == b"abc" - - # test dtype - dtype_result = fecho(tvm_ffi.dtype("float32")) - assert isinstance(dtype_result, tvm_ffi.dtype) - assert dtype_result == tvm_ffi.dtype("float32") - - # test device - device_result = fecho(tvm_ffi.device("cuda:1")) - assert isinstance(device_result, tvm_ffi.Device) - assert device_result.device_type == tvm_ffi.Device.kDLCUDA - assert device_result.device_id == 1 - assert str(device_result) == "cuda:1" - assert device_result.__repr__() == "device(type='cuda', index=1)" - - # test c_void_p - c_void_p_result = fecho(ctypes.c_void_p(0x12345678)) - assert isinstance(c_void_p_result, ctypes.c_void_p) - assert c_void_p_result.value == 0x12345678 - - # test function: aka object - fadd = tvm_ffi.convert(lambda a, b: a + b) - fadd1 = fecho(fadd) - assert fadd1(1, 2) == 3 - assert fadd1.same_as(fadd) - - def check_ndarray(): - np_data = np.arange(10, dtype="int32") - if not hasattr(np_data, "__dlpack__"): - return - # test NDArray - x = tvm_ffi.from_dlpack(np_data) - assert isinstance(x, tvm_ffi.NDArray) - nd_result = fecho(x) - assert isinstance(nd_result, tvm_ffi.NDArray) - assert nd_result.shape == (10,) - assert nd_result.dtype == tvm_ffi.dtype("int32") - assert nd_result.device.device_type == tvm_ffi.Device.kDLCPU - assert nd_result.device.device_id == 0 - - check_ndarray() - - -def test_return_raw_str_bytes(): - assert tvm_ffi.convert(lambda: "hello")() == "hello" - assert tvm_ffi.convert(lambda: b"hello")() == b"hello" - assert tvm_ffi.convert(lambda: bytearray(b"hello"))() == b"hello" - - -def test_pyfunc_convert(): - def add(a, b): - return a + b - - fadd = tvm_ffi.convert(add) - assert isinstance(fadd, tvm_ffi.Function) - assert fadd(1, 2) == 3 - - def fapply(f, *args): - return f(*args) - - fapply = tvm_ffi.convert(fapply) - assert fapply(add, 1, 3.3) == 4.3 - - -def test_global_func(): - @tvm_ffi.register_func("mytest.echo") - def echo(x): - return x - - f = tvm_ffi.get_global_func("mytest.echo") - assert f.same_as(echo) - assert f(1) == 1 - - assert "mytest.echo" in tvm_ffi.registry.list_global_func_names() - - tvm_ffi.registry.remove_global_func("mytest.echo") - assert "mytest.echo" not in tvm_ffi.registry.list_global_func_names() - assert tvm_ffi.get_global_func("mytest.echo", allow_missing=True) is None - - -def test_rvalue_ref(): - use_count = tvm_ffi.get_global_func("testing.object_use_count") - - def callback(x, expected_count): - # The use count of TVM FFI objects is decremented as part of - # `ObjectRef.__del__`, which runs when the Python object is - # destructed. However, Python object destruction is not - # deterministic, and even CPython's reference-counting is - # considered an implementation detail. Therefore, to ensure - # correct results from this test, `gc.collect()` must be - # explicitly called. - gc.collect() - assert expected_count == use_count(x) - return x._move() - - f = tvm_ffi.convert(callback) - - def check0(): - x = tvm_ffi.convert([1, 2]) - assert use_count(x) == 1 - f(x, 2) - y = f(x._move(), 1) - assert x.__ctypes_handle__().value == None - - def check1(): - x = tvm_ffi.convert([1, 2]) - assert use_count(x) == 1 - y = f(x, 2) - z = f(x._move(), 2) - assert x.__ctypes_handle__().value == None - assert y.__ctypes_handle__().value is not None - - check0() - check1() diff --git a/tests/python/ffi/test_ndarray.py b/tests/python/ffi/test_ndarray.py deleted file mode 100644 index 5b75171b55bb..000000000000 --- a/tests/python/ffi/test_ndarray.py +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest - -try: - import torch -except ImportError: - torch = None - -from tvm import ffi as tvm_ffi -import numpy as np - - -def test_ndarray_attributes(): - data = np.zeros((10, 8, 4, 2), dtype="int16") - if not hasattr(data, "__dlpack__"): - return - x = tvm_ffi.from_dlpack(data) - assert isinstance(x, tvm_ffi.NDArray) - assert x.shape == (10, 8, 4, 2) - assert x.dtype == tvm_ffi.dtype("int16") - assert x.device.device_type == tvm_ffi.Device.kDLCPU - assert x.device.device_id == 0 - x2 = np.from_dlpack(x) - np.testing.assert_equal(x2, data) - - -def test_shape_object(): - shape = tvm_ffi.Shape((10, 8, 4, 2)) - assert isinstance(shape, tvm_ffi.Shape) - assert shape == (10, 8, 4, 2) - - fecho = tvm_ffi.convert(lambda x: x) - shape2 = fecho(shape) - assert shape2.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__) - assert isinstance(shape2, tvm_ffi.Shape) - assert isinstance(shape2, tuple) - - shape3 = tvm_ffi.convert(shape) - assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__) - assert isinstance(shape3, tvm_ffi.Shape) - - -@pytest.mark.skipif(torch is None, reason="Torch is not installed") -def test_ndarray_auto_dlpack(): - def check(x, y): - assert isinstance(y, tvm_ffi.NDArray) - assert y.shape == (128,) - assert y.dtype == tvm_ffi.dtype("int64") - assert y.device.device_type == tvm_ffi.Device.kDLCPU - assert y.device.device_id == 0 - x2 = torch.from_dlpack(y) - np.testing.assert_equal(x2.numpy(), x.numpy()) - - x = torch.arange(128) - fecho = tvm_ffi.get_global_func("testing.echo") - y = fecho(x) - check(x, y) - - # pass in list of tensors - y = fecho([x]) - check(x, y[0]) diff --git a/tests/python/ffi/test_object.py b/tests/python/ffi/test_object.py deleted file mode 100644 index d333cbca089c..000000000000 --- a/tests/python/ffi/test_object.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import pytest - -from tvm import ffi as tvm_ffi - - -def test_make_object(): - # with default values - obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase") - assert obj0.v_i64 == 10 - assert obj0.v_f64 == 10.0 - assert obj0.v_str == "hello" - - -def test_method(): - obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12) - assert obj0.add_i64(1) == 13 - assert type(obj0).add_i64.__doc__ == "add_i64 method" - assert type(obj0).v_i64.__doc__ == "i64 field" - - -def test_setter(): - # test setter - obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, v_str="hello") - assert obj0.v_i64 == 10 - obj0.v_i64 = 11 - assert obj0.v_i64 == 11 - obj0.v_str = "world" - assert obj0.v_str == "world" - - with pytest.raises(TypeError): - obj0.v_str = 1 - - with pytest.raises(TypeError): - obj0.v_i64 = "hello" - - -def test_derived_object(): - with pytest.raises(TypeError): - obj0 = tvm_ffi.testing.create_object("testing.TestObjectDerived") - - v_map = tvm_ffi.convert({"a": 1}) - v_array = tvm_ffi.convert([1, 2, 3]) - - obj0 = tvm_ffi.testing.create_object( - "testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array - ) - assert obj0.v_map.same_as(v_map) - assert obj0.v_array.same_as(v_array) - assert obj0.v_i64 == 20 - assert obj0.v_f64 == 10.0 - assert obj0.v_str == "hello" - - obj0.v_i64 = 21 - assert obj0.v_i64 == 21 diff --git a/tests/python/ffi/test_string.py b/tests/python/ffi/test_string.py deleted file mode 100644 index 85fed5670c72..000000000000 --- a/tests/python/ffi/test_string.py +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pickle -from tvm import ffi as tvm_ffi - - -def test_string(): - fecho = tvm_ffi.get_global_func("testing.echo") - s = tvm_ffi.String("hello") - s2 = fecho(s) - assert s2 == "hello" - s3 = tvm_ffi.convert("hello") - assert isinstance(s3, str) - - x = "hello long string" - assert fecho(x) == x - - s4 = pickle.loads(pickle.dumps(s)) - assert s4 == "hello" - - -def test_bytes(): - fecho = tvm_ffi.get_global_func("testing.echo") - b = tvm_ffi.Bytes(b"hello") - assert isinstance(b, tvm_ffi.Bytes) - b2 = fecho(b) - assert b2 == b"hello" - - b3 = tvm_ffi.convert(b"hello") - assert isinstance(b3, tvm_ffi.Bytes) - assert isinstance(b3, bytes) - - b4 = tvm_ffi.convert(bytearray(b"hello")) - assert isinstance(b4, tvm_ffi.Bytes) - assert isinstance(b4, bytes) - - b5 = pickle.loads(pickle.dumps(b)) - assert b5 == b"hello" - assert isinstance(b5, tvm_ffi.Bytes) diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 251b33f910e7..957d0946ed00 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm.ffi.access_path import AccessPath +from tvm_ffi.access_path import AccessPath from tvm.ir.base import get_first_structural_mismatch diff --git a/tests/python/ir/test_datatype_nv_fp4.py b/tests/python/ir/test_datatype_nv_fp4.py index 85047fc4a5fd..d237176e6c55 100644 --- a/tests/python/ir/test_datatype_nv_fp4.py +++ b/tests/python/ir/test_datatype_nv_fp4.py @@ -36,7 +36,7 @@ def test_create_nv_fp4_nd_array(np_dtype, dtype_str): """Skip test if ml_dtypes is not installed""" return x = np.random.rand(128, 128).astype(np_dtype) - x_nd = tvm.nd.array(x) + x_nd = tvm.runtime.tensor(x) assert x_nd.dtype == dtype_str np.testing.assert_equal(x_nd.numpy(), x) diff --git a/tests/python/ir/test_datatype_nv_fp8.py b/tests/python/ir/test_datatype_nv_fp8.py index d27cc0314328..0c17e844757f 100644 --- a/tests/python/ir/test_datatype_nv_fp8.py +++ b/tests/python/ir/test_datatype_nv_fp8.py @@ -85,7 +85,7 @@ def test_create_nv_fp8_nd_array(np_dtype, dtype_str): """Skip test if ml_dtypes is not installed""" return x = np.random.rand(128, 128).astype(np_dtype) - x_nd = tvm.nd.array(x) + x_nd = tvm.runtime.tensor(x) assert x_nd.dtype == dtype_str np.testing.assert_equal(x_nd.numpy(), x) @@ -110,7 +110,7 @@ def test_fp8_unary_op(np_dtype, dtype_str): a_fp32 = np.zeros(128).astype(np.float32) a_roundtrip = np.zeros(128).astype(np_dtype) args = list( - map(lambda _: tvm.nd.array(_), [a, b, a_add_b, a_sub_b, a_mul_b, a_fp32, a_roundtrip]) + map(lambda _: tvm.runtime.tensor(_), [a, b, a_add_b, a_sub_b, a_mul_b, a_fp32, a_roundtrip]) ) f(*args) expected_a_fp32 = a.astype(np.float32) diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index 1004bad702f6..12502b6e6c7e 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +import tvm_ffi import tvm from tvm import te import numpy as np @@ -90,7 +91,7 @@ def test_getattr_map(): a = te.var("a") b = te.var("b") amap = tvm.runtime.convert({a: 2, b: 3}) - assert isinstance(amap, tvm.ffi.Map) + assert isinstance(amap, tvm_ffi.Map) def test_in_container(): @@ -100,12 +101,12 @@ def test_in_container(): assert "d" not in arr -def test_ndarray_container(): - x = tvm.nd.array([1, 2, 3]) +def test_tensor_container(): + x = tvm.runtime.tensor([1, 2, 3]) arr = tvm.runtime.convert([x, x]) assert arr[0].same_as(x) assert arr[1].same_as(x) - assert isinstance(arr[0], tvm.nd.NDArray) + assert isinstance(arr[0], tvm.runtime.Tensor) def test_return_variant_type(): diff --git a/tests/python/ir/test_node_reflection.py b/tests/python/ir/test_node_reflection.py index be00bc3a4777..52b2a29f59c0 100644 --- a/tests/python/ir/test_node_reflection.py +++ b/tests/python/ir/test_node_reflection.py @@ -94,7 +94,7 @@ def test_make_sum(): def test_env_func(): - @tvm.register_func("test.env_func") + @tvm.register_global_func("test.env_func") def test(x): return x + 1 @@ -163,19 +163,19 @@ def test_dict(): assert set(dir(x.__class__)) <= set(dir(x)) -def test_ndarray(): +def test_tensor(): dev = tvm.cpu(0) - tvm_arr = tvm.nd.array(np.random.rand(4), device=dev) + tvm_arr = tvm.runtime.tensor(np.random.rand(4), device=dev) tvm_arr2 = tvm.ir.load_json(tvm.ir.save_json(tvm_arr)) tvm.ir.assert_structural_equal(tvm_arr, tvm_arr2) np.testing.assert_array_equal(tvm_arr.numpy(), tvm_arr2.numpy()) -def test_ndarray_dict(): +def test_tensor_dict(): dev = tvm.cpu(0) m1 = { - "key1": tvm.nd.array(np.random.rand(4), device=dev), - "key2": tvm.nd.array(np.random.rand(4), device=dev), + "key1": tvm.runtime.tensor(np.random.rand(4), device=dev), + "key2": tvm.runtime.tensor(np.random.rand(4), device=dev), } m2 = tvm.ir.load_json(tvm.ir.save_json(m1)) tvm.ir.assert_structural_equal(m1, m2) @@ -196,7 +196,7 @@ def test_alloc_const(): shape = (16,) buf = tvm.tir.decl_buffer(shape, dtype) np_data = np.random.rand(*shape).astype(dtype) - data = tvm.nd.array(np_data, device=dev) + data = tvm.runtime.tensor(np_data, device=dev) body = tvm.tir.Evaluate(0) alloc_const = tvm.tir.AllocateConst(buf.data, dtype, shape, data, body) alloc_const2 = tvm.ir.load_json(tvm.ir.save_json(alloc_const)) diff --git a/tests/python/meta_schedule/test_meta_schedule_builder.py b/tests/python/meta_schedule/test_meta_schedule_builder.py index 090a393fbeeb..6da0a089180c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_builder.py +++ b/tests/python/meta_schedule/test_meta_schedule_builder.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import script -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.meta_schedule.builder import ( BuilderInput, BuilderResult, @@ -163,7 +163,7 @@ def test_meta_schedule_error_handle_build_func(): """Test the error handing during building""" def initializer(): - @register_func("meta_schedule.builder.test_build") + @register_global_func("meta_schedule.builder.test_build") def test_build(mod: Module, target: Target, _) -> None: # pylint: disable=unused-variable raise ValueError("Builder intended Test Error (build func).") @@ -182,7 +182,7 @@ def test_meta_schedule_error_handle_export_func(): """Test the error handing during building""" def initializer(): - @register_func("meta_schedule.builder.test_export") + @register_global_func("meta_schedule.builder.test_export") def test_build(mod: Module) -> str: # pylint: disable=unused-variable raise ValueError("Builder intended Test Error (export func).") @@ -201,7 +201,7 @@ def test_meta_schedule_error_handle_time_out(): """Test the error handing time out during building""" def initializer(): - @register_func("meta_schedule.builder.test_time_out") + @register_global_func("meta_schedule.builder.test_time_out") def timeout_build(mod, target, _): # pylint: disable=unused-argument, unused-variable time.sleep(2) diff --git a/tests/python/meta_schedule/test_meta_schedule_database.py b/tests/python/meta_schedule/test_meta_schedule_database.py index 84ec862f0ef8..f8b2354c33bf 100644 --- a/tests/python/meta_schedule/test_meta_schedule_database.py +++ b/tests/python/meta_schedule/test_meta_schedule_database.py @@ -587,7 +587,7 @@ def MatmulPrimFunc() -> IRModule: @pytest.mark.parametrize("f_mod", [MatmulPrimFunc]) -@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"]) +@pytest.mark.parametrize("mod_eq", ["structural", "ignore-tensor", "anchor-block"]) def test_json_database_commit_workload(f_mod, mod_eq): mod: IRModule = f_mod() with tempfile.TemporaryDirectory() as tmpdir: @@ -596,7 +596,7 @@ def test_json_database_commit_workload(f_mod, mod_eq): @pytest.mark.parametrize("f_mod", [MatmulPrimFunc]) -@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"]) +@pytest.mark.parametrize("mod_eq", ["structural", "ignore-tensor", "anchor-block"]) def test_memory_database_commit_workload(f_mod, mod_eq): mod: IRModule = f_mod() database = ms.database.MemoryDatabase(module_equality=mod_eq) diff --git a/tests/python/meta_schedule/test_meta_schedule_feature_extractor.py b/tests/python/meta_schedule/test_meta_schedule_feature_extractor.py index 84d07dbf6e11..8b718f86a104 100644 --- a/tests/python/meta_schedule/test_meta_schedule_feature_extractor.py +++ b/tests/python/meta_schedule/test_meta_schedule_feature_extractor.py @@ -19,11 +19,11 @@ from typing import List import numpy as np +import tvm.runtime from tvm.meta_schedule import TuneContext from tvm.meta_schedule.feature_extractor import PyFeatureExtractor from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.utils import derived_object -from tvm.runtime.ndarray import array def test_meta_schedule_feature_extractor(): @@ -34,7 +34,7 @@ def extract_from( context: TuneContext, # pylint: disable = unused-argument candidates: List[MeasureCandidate], # pylint: disable = unused-argument ) -> List[np.ndarray]: - return [array(np.random.rand(4, 5))] + return [tvm.runtime.tensor(np.random.rand(4, 5))] extractor = FancyFeatureExtractor() features = extractor.extract_from(TuneContext(), []) diff --git a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py index 057cd0e9f7ae..b901c3ce1372 100644 --- a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py @@ -846,21 +846,21 @@ def _create_schedule(): 1.0, 0.0, 0.0, - 25.000000042995662, - 20.000001375860553, - 23.00000017198264, - 14.000088052430122, + 25.00000004, + 19.99718086, + 23.00000017, + 13.99726771, 1.0, 0.0, 0.0, - 18.00000550343433, - 20.00562591970089, - 2.321928094887362, - 23.00000017198264, - 18.00000550343433, - 21.000000687930438, - 12.0003521774803, - 12.0003521774803, + 18.0000055, + 20.00000138, + 2.32192809, + 23.00000017, + 17.997185, + 21.00000069, + 11.99753235, + 12.00035218, ], rtol=1e-5, atol=1e-5, @@ -872,21 +872,21 @@ def _create_schedule(): 0.0, 1.0, 0.0, - 25.000000042995662, - 12.0003521774803, - 23.00000017198264, - 9.002815015607053, + 25.00000004, + 11.00070427, + 23.00000017, + 5.04439412, 1.0, 0.0, 0.0, - 6.022367813028454, - 11.98049663618346, - 8.005624549193879, - 17.000011006847668, - 4.087462841250339, - 15.000044026886828, - 1.584962500721156, - 4.087462841250339, + 6.02236781, + 11.98049664, + 8.00562455, + 17.00001101, + 3.169925, + 15.00004403, + 0.169925, + 4.08746284, ], rtol=1e-5, atol=1e-5, @@ -1052,21 +1052,21 @@ def _create_schedule(): 1.0, 0.0, 0.0, - 22.00000034396526, - 20.000001375860553, - 20.000001375860553, - 14.000088052430122, + 22.00000034, + 19.85798251, + 20.00000138, + 13.85807816, 1.0, 0.0, 0.0, - 15.000044026886828, - 20.17555076886471, - 2.321928094887362, - 20.000001375860553, - 18.00000550343433, - 18.00000550343433, - 12.0003521774803, - 4.087462841250339, + 15.00004403, + 20.04456622, + 2.32192809, + 20.00000138, + 17.85798707, + 18.0000055, + 11.8583696, + 4.08746284, ], rtol=1e-5, atol=1e-5, @@ -1078,20 +1078,20 @@ def _create_schedule(): 0.0, 1.0, 0.0, - 22.00000034396526, - 9.002815015607053, - 20.000001375860553, - 3.169925001442312, + 22.00000034, + 7.01122726, + 20.00000138, + 4.08746284, 1.0, 0.0, 0.0, 3.169925001442312, - 9.61654884377899, + 4.08746284, 8.005624549193879, 14.000088052430122, - 1.584962500721156, - 12.0003521774803, - 0.044394119358453436, + 0.5849625, + 12.00035218, + 0.08746284, 4.087462841250339, ], rtol=1e-5, diff --git a/tests/python/meta_schedule/test_meta_schedule_mma_tensorize.py b/tests/python/meta_schedule/test_meta_schedule_mma_tensorize.py new file mode 100644 index 000000000000..a318ea35158f --- /dev/null +++ b/tests/python/meta_schedule/test_meta_schedule_mma_tensorize.py @@ -0,0 +1,338 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm.script import tir as T +from tvm.tir.schedule import Schedule +import tvm.tir.tensor_intrin # pylint: disable=unused-import +import tvm.testing + +import pytest + +torch = pytest.importorskip("torch") + +M, N, K = 4096, 4096, 4096 +np.random.seed(0) + + +@tvm.script.ir_module +class Gemm_F16F16F16: + # fmt: off + @T.prim_func + def main( + A: T.Buffer((M, K), "float16"), # type: ignore + B: T.Buffer((K, N), "float16"), # type: ignore + C: T.Buffer((M, N), "float16"), # type: ignore + ): + for i, j, k in T.grid(M, N, K): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class Gemm_F16F16F32: + # fmt: off + @T.prim_func + def main( + A: T.Buffer((M, K), "float16"), # type: ignore + B: T.Buffer((K, N), "float16"), # type: ignore + C: T.Buffer((M, N), "float32"), # type: ignore + ): + for i, j, k in T.grid(M, N, K): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + T.cast(A[vi, vk], "float32") * T.cast(B[vk, vj], "float32") + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_run_target(mod=None, tgt_str=None, in_dtype="float16", out_dtype="float16"): + if mod is None: + return + tgt_str = tgt_str or "cuda" + target = tvm.target.Target(target=tgt_str) + with tvm.transform.PassContext(opt_level=3): + lib: tvm.runtime.Module = tvm.compile(mod, target=target) + + dev = tvm.device(tgt_str, 0) + a_np = np.random.rand(M, K).astype(in_dtype) + b_np = np.random.rand(K, N).astype(in_dtype) + c_np = np.ones((M, N), dtype=out_dtype) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) + c = tvm.runtime.tensor(c_np, dev) + + f = lib["main"] + f(a, b, c) + + c_th = torch.matmul(torch.tensor(a_np).to(tgt_str), torch.tensor(b_np).to(tgt_str)).to( + torch.float32 if out_dtype == "float32" else torch.float16 + ) + c_f = torch.tensor(c.numpy()).to(tgt_str) + torch.allclose(c_th, c_f, rtol=0.05, atol=0.05) + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_f16f16f16_mma_gemm(): + # fmt: off + mod = Gemm_F16F16F16 + sch = Schedule(mod) + b0 = sch.get_block(name="C", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + b2 = sch.reindex(block=b0, buffer=("write", 0)) + b3 = sch.reindex(block=b0, buffer=("read", 0)) + b4 = sch.reindex(block=b0, buffer=("read", 1)) + sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda vi, vk: (vi, vk,), pad_value=None, assume_injective_transform=True) + sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda vj, vk: (vk, vj,), pad_value=None, assume_injective_transform=True) + sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda vi, vj: (vi, vj,), pad_value=None, assume_injective_transform=True) + sch.transform_block_layout(block=b2, index_map=lambda vi, vj: (vi, vj,)) + sch.transform_block_layout(block=b3, index_map=lambda vi, vk: (vi, vk,)) + sch.transform_block_layout(block=b4, index_map=lambda vj, vk: (vk, vj,)) + sch.transform_block_layout(block=b0, index_map=lambda vi, vj, vk: (vi, vj, vk,)) + l5, l6, l7 = sch.get_loops(block=b0) + l8, l9 = sch.split(loop=l7, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l10, l11 = sch.split(loop=l6, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True, disable_predication=False) + l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0) + sch.reorder(l16, l18, l13, l11, l9) + b20 = sch.blockize(target=l13, preserve_unit_iters=True) + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="mma_sync_m16n8k8_f16f16f16") + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="mma_init_m16n8k8_f16") + sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1) + l21, l22, l23 = sch.get_loops(block=b20) + v24, v25, v26, v27, v28 = sch.sample_partitioned_tile(loop=l21, n=5, partition_pos=3, innerpart_factor=2, decision=[2, 16, 4, 1, 2]) + l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True, disable_predication=False) + v34, v35, v36, v37, v38 = sch.sample_partitioned_tile(loop=l22, n=5, partition_pos=3, innerpart_factor=4, decision=[2, 16, 4, 1, 4]) + l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True, disable_predication=False) + v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4, decision=[128, 1, 4]) + l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True, disable_predication=False) + sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43) + l50 = sch.fuse(l29, l39, preserve_unit_iters=True) + sch.bind(loop=l50, thread_axis="blockIdx.y") + l51 = sch.fuse(l30, l40, preserve_unit_iters=True) + sch.bind(loop=l51, thread_axis="blockIdx.x") + l52 = sch.fuse(l31, l41, preserve_unit_iters=True) + sch.bind(loop=l52, thread_axis="threadIdx.y") + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b53 = sch.write_at(loop=l52, block=b20, write_buffer_index=0, storage_scope="m16n8k8.matrixC") + sch.reverse_compute_inline(block=b2) + b54 = sch.read_at(loop=l47, block=b20, read_buffer_index=0, storage_scope="shared.dyn") + sch.annotate(block_or_loop=b54, ann_key="permuted_layout", ann_val="g2s_A") + b55 = sch.read_at(loop=l47, block=b20, read_buffer_index=1, storage_scope="shared.dyn") + sch.annotate(block_or_loop=b55, ann_key="permuted_layout", ann_val="g2s_B") + b56 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="m16n8k8.matrixA") + sch.compute_at(block=b56, loop=l48, preserve_unit_loops=True, index=-1) + l57, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b56) + l64, l65 = sch.split(loop=l63, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l66, l67 = sch.split(loop=l62, factors=[None, 32], preserve_unit_iters=True, disable_predication=False) + l68, l69, l70, l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b56) + sch.reorder(l75, l67, l65) + b77 = sch.blockize(target=l67, preserve_unit_iters=True) + sch.annotate(block_or_loop=b77, ann_key="meta_schedule.auto_tensorize", ann_val="mma_load_m16n8k8_f16_A_shared_dyn") + sch.annotate(block_or_loop=b77, ann_key="permuted_layout", ann_val="s2l_A") + b78 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="m16n8k8.matrixB") + sch.compute_at(block=b78, loop=l48, preserve_unit_loops=True, index=-1) + l79, l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b78) + l86, l87 = sch.split(loop=l85, factors=[None, 32], preserve_unit_iters=True, disable_predication=False) + l88, l89 = sch.split(loop=l84, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b78) + sch.reorder(l97, l89, l87) + b99 = sch.blockize(target=l89, preserve_unit_iters=True) + sch.annotate(block_or_loop=b99, ann_key="meta_schedule.auto_tensorize", ann_val="mma_load_m16n8k8_f16_B_shared_dyn") + sch.annotate(block_or_loop=b99, ann_key="permuted_layout", ann_val="s2l_B") + b100, = sch.get_producers(block=b54) + sch.compute_inline(block=b100) + sch.storage_align(block=b54, buffer_index=0, axis=-2, factor=32, offset=8) + b101, = sch.get_producers(block=b55) + sch.compute_inline(block=b101) + sch.storage_align(block=b55, buffer_index=0, axis=-2, factor=32, offset=8) + sch.annotate(block_or_loop=b54, ann_key="vector_bytes", ann_val=16) + sch.annotate(block_or_loop=b55, ann_key="vector_bytes", ann_val=16) + sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", ann_val=[0, 0, 1]) + sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + sch.annotate(block_or_loop=l47, ann_key="software_pipeline_async_stages", ann_val=[0]) + sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", ann_val=[0, 0, 1, 2, 2]) + sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", ann_val=[0, 1, 3, 2, 4]) + v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v102) + sch.enter_postproc() + b103 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b103, ann_key="meta_schedule.unroll_explicit") + b104, b105, b106, b107, b108, b109 = sch.get_child_blocks(b103) + l110, l111, l112, l113 = sch.get_loops(block=b104) + l114, l115, l116, l117 = sch.get_loops(block=b105) + l118, l119, l120, l121, l122, l123, l124 = sch.get_loops(block=b106) + l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b107) + l132, l133, l134, l135, l136, l137, l138, l139, l140, l141 = sch.get_loops(block=b108) + l142, l143, l144 = sch.get_loops(block=b109) + b145 = sch.get_block(name="C_o", func_name="main") + l146, l147, l148, l149, l150, l151, l152, l153, l154, l155 = sch.get_loops(block=b145) + b156 = sch.decompose_reduction(block=b145, loop=l149) + sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize") + sch.annotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize", ann_val="mma_init_m16n8k8_f16") + sch.unannotate(block_or_loop=b145, ann_key="meta_schedule.auto_tensorize_init") + sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize_init") + b157 = sch.get_block(name="C_o_init", func_name="main") + sch.unannotate(block_or_loop=b157, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b157, tensor_intrin="mma_init_m16n8k8_f16", preserve_unit_iters=True) + b158 = sch.get_block(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", func_name="main") + sch.unannotate(block_or_loop=b158, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b158, tensor_intrin="mma_load_m16n8k8_f16_A_shared_dyn", preserve_unit_iters=True) + b159 = sch.get_block(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", func_name="main") + sch.unannotate(block_or_loop=b159, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b159, tensor_intrin="mma_load_m16n8k8_f16_B_shared_dyn", preserve_unit_iters=True) + b160 = sch.get_block(name="C_o_update", func_name="main") + sch.unannotate(block_or_loop=b160, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b160, tensor_intrin="mma_sync_m16n8k8_f16f16f16", preserve_unit_iters=True) + mod = sch.mod + test_run_target(mod) + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_f16f16f32_mma_gemm(): + mod = Gemm_F16F16F32 + sch = Schedule(mod) + # fmt: off + sch = Schedule(mod) + b0 = sch.get_block(name="C", func_name="main") + b1 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + b2 = sch.reindex(block=b0, buffer=("write", 0)) + b3 = sch.reindex(block=b0, buffer=("read", 0)) + b4 = sch.reindex(block=b0, buffer=("read", 1)) + sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda vi, vk: (vi, vk,), pad_value=None, assume_injective_transform=True) + sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda vj, vk: (vk, vj,), pad_value=None, assume_injective_transform=True) + sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda vi, vj: (vi, vj,), pad_value=None, assume_injective_transform=True) + sch.transform_block_layout(block=b2, index_map=lambda vi, vj: (vi, vj,)) + sch.transform_block_layout(block=b3, index_map=lambda vi, vk: (vi, vk,)) + sch.transform_block_layout(block=b4, index_map=lambda vj, vk: (vk, vj,)) + sch.transform_block_layout(block=b0, index_map=lambda vi, vj, vk: (vi, vj, vk,)) + l5, l6, l7 = sch.get_loops(block=b0) + l8, l9 = sch.split(loop=l7, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l10, l11 = sch.split(loop=l6, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True, disable_predication=False) + l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0) + sch.reorder(l16, l18, l13, l11, l9) + b20 = sch.blockize(target=l13, preserve_unit_iters=True) + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="mma_sync_m16n8k8_f16f16f32") + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="mma_init_m16n8k8_f32") + sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1) + l21, l22, l23 = sch.get_loops(block=b20) + v24, v25, v26, v27, v28 = sch.sample_partitioned_tile(loop=l21, n=5, partition_pos=3, innerpart_factor=2, decision=[1, 16, 2, 2, 4]) + l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True, disable_predication=False) + v34, v35, v36, v37, v38 = sch.sample_partitioned_tile(loop=l22, n=5, partition_pos=3, innerpart_factor=4, decision=[2, 16, 2, 4, 2]) + l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True, disable_predication=False) + v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4, decision=[128, 1, 4]) + l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True, disable_predication=False) + sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43) + l50 = sch.fuse(l29, l39, preserve_unit_iters=True) + sch.bind(loop=l50, thread_axis="blockIdx.y") + l51 = sch.fuse(l30, l40, preserve_unit_iters=True) + sch.bind(loop=l51, thread_axis="blockIdx.x") + l52 = sch.fuse(l31, l41, preserve_unit_iters=True) + sch.bind(loop=l52, thread_axis="threadIdx.y") + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b20, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) + b53 = sch.write_at(loop=l52, block=b20, write_buffer_index=0, storage_scope="m16n8k8.matrixC") + sch.reverse_compute_inline(block=b2) + b54 = sch.read_at(loop=l47, block=b20, read_buffer_index=0, storage_scope="shared.dyn") + sch.annotate(block_or_loop=b54, ann_key="permuted_layout", ann_val="g2s_A") + b55 = sch.read_at(loop=l47, block=b20, read_buffer_index=1, storage_scope="shared.dyn") + sch.annotate(block_or_loop=b55, ann_key="permuted_layout", ann_val="g2s_B") + b56 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="m16n8k8.matrixA") + sch.compute_at(block=b56, loop=l48, preserve_unit_loops=True, index=-1) + l57, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b56) + l64, l65 = sch.split(loop=l63, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l66, l67 = sch.split(loop=l62, factors=[None, 32], preserve_unit_iters=True, disable_predication=False) + l68, l69, l70, l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b56) + sch.reorder(l75, l67, l65) + b77 = sch.blockize(target=l67, preserve_unit_iters=True) + sch.annotate(block_or_loop=b77, ann_key="meta_schedule.auto_tensorize", ann_val="mma_load_m16n8k8_f16_A_shared_dyn") + sch.annotate(block_or_loop=b77, ann_key="permuted_layout", ann_val="s2l_A") + b78 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="m16n8k8.matrixB") + sch.compute_at(block=b78, loop=l48, preserve_unit_loops=True, index=-1) + l79, l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b78) + l86, l87 = sch.split(loop=l85, factors=[None, 32], preserve_unit_iters=True, disable_predication=False) + l88, l89 = sch.split(loop=l84, factors=[None, 8], preserve_unit_iters=True, disable_predication=False) + l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b78) + sch.reorder(l97, l89, l87) + b99 = sch.blockize(target=l89, preserve_unit_iters=True) + sch.annotate(block_or_loop=b99, ann_key="meta_schedule.auto_tensorize", ann_val="mma_load_m16n8k8_f16_B_shared_dyn") + sch.annotate(block_or_loop=b99, ann_key="permuted_layout", ann_val="s2l_B") + b100, = sch.get_producers(block=b54) + sch.compute_inline(block=b100) + sch.storage_align(block=b54, buffer_index=0, axis=-2, factor=32, offset=8) + b101, = sch.get_producers(block=b55) + sch.compute_inline(block=b101) + sch.storage_align(block=b55, buffer_index=0, axis=-2, factor=32, offset=8) + sch.annotate(block_or_loop=b54, ann_key="vector_bytes", ann_val=16) + sch.annotate(block_or_loop=b55, ann_key="vector_bytes", ann_val=16) + sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", ann_val=[0, 0, 1]) + sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + sch.annotate(block_or_loop=l47, ann_key="software_pipeline_async_stages", ann_val=[0]) + sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", ann_val=[0, 0, 1, 2, 2]) + sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", ann_val=[0, 1, 3, 2, 4]) + v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) + sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v102) + sch.enter_postproc() + b103 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b103, ann_key="meta_schedule.unroll_explicit") + b104, b105, b106, b107, b108, b109 = sch.get_child_blocks(b103) + l110, l111, l112, l113 = sch.get_loops(block=b104) + l114, l115, l116, l117 = sch.get_loops(block=b105) + l118, l119, l120, l121, l122, l123, l124 = sch.get_loops(block=b106) + l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b107) + l132, l133, l134, l135, l136, l137, l138, l139, l140, l141 = sch.get_loops(block=b108) + sch.annotate(block_or_loop=l132, ann_key="pragma_auto_unroll_max_step", ann_val=0) + sch.annotate(block_or_loop=l132, ann_key="pragma_unroll_explicit", ann_val=1) + l142, l143, l144 = sch.get_loops(block=b109) + b145 = sch.get_block(name="C_o", func_name="main") + l146, l147, l148, l149, l150, l151, l152, l153, l154, l155 = sch.get_loops(block=b145) + b156 = sch.decompose_reduction(block=b145, loop=l149) + sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize") + sch.annotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize", ann_val="mma_init_m16n8k8_f32") + sch.unannotate(block_or_loop=b145, ann_key="meta_schedule.auto_tensorize_init") + sch.unannotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize_init") + b157 = sch.get_block(name="C_o_init", func_name="main") + sch.unannotate(block_or_loop=b157, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b157, tensor_intrin="mma_init_m16n8k8_f32", preserve_unit_iters=True) + b158 = sch.get_block(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", func_name="main") + sch.unannotate(block_or_loop=b158, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b158, tensor_intrin="mma_load_m16n8k8_f16_A_shared_dyn", preserve_unit_iters=True) + b159 = sch.get_block(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", func_name="main") + sch.unannotate(block_or_loop=b159, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b159, tensor_intrin="mma_load_m16n8k8_f16_B_shared_dyn", preserve_unit_iters=True) + b160 = sch.get_block(name="C_o_update", func_name="main") + sch.unannotate(block_or_loop=b160, ann_key="meta_schedule.auto_tensorize") + sch.tensorize(block_or_loop=b160, tensor_intrin="mma_sync_m16n8k8_f16f16f32", preserve_unit_iters=True) + mod = sch.mod + test_run_target(mod, out_dtype="float32") + + +if __name__ == """__main__""": + test_f16f16f16_mma_gemm() + test_f16f16f32_mma_gemm() diff --git a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py index 57d9d0961088..61888ed1a70e 100644 --- a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import te from tvm.ir.module import IRModule -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.error import TVMError from tvm.meta_schedule import TuneContext from tvm.meta_schedule.schedule_rule import PyScheduleRule diff --git a/tests/python/meta_schedule/test_meta_schedule_runner.py b/tests/python/meta_schedule/test_meta_schedule_runner.py index e5deefe7507c..5b4f6944df91 100644 --- a/tests/python/meta_schedule/test_meta_schedule_runner.py +++ b/tests/python/meta_schedule/test_meta_schedule_runner.py @@ -25,7 +25,7 @@ import pytest import tvm import tvm.testing -from tvm.ffi import register_func +from tvm_ffi import register_global_func from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput, LocalBuilder from tvm.meta_schedule.runner import ( @@ -454,7 +454,7 @@ def test_meta_schedule_local_runner_time_out(): ) def initializer(): - @register_func("meta_schedule.runner.test_time_out") + @register_global_func("meta_schedule.runner.test_time_out") def timeout_session_creator( # pylint: disable=unused-variable device: Device, # pylint: disable=unused-argument args_info: T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument @@ -492,7 +492,7 @@ def test_meta_schedule_rpc_runner_exception(): """Test meta schedule RPC Runner exception""" def initializer(): - @register_func("meta_schedule.runner.test_exception") + @register_global_func("meta_schedule.runner.test_exception") def exception_session_creator( # pylint: disable=unused-variable rpc_config: RPCConfig, # pylint: disable=unused-argument ) -> RPCSession: @@ -556,7 +556,7 @@ def test_meta_schedule_local_runner_exception(): ) def initializer(): - @register_func("meta_schedule.runner.test_exception") + @register_global_func("meta_schedule.runner.test_exception") def timeout_session_creator( # pylint: disable=unused-variable device: Device, # pylint: disable=unused-argument args_info: T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py index 7222c4d64972..332bebd79d31 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py @@ -42,7 +42,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -@tvm.register_func("meta_schedule.cpu.test_apply_custom_rule") +@tvm.register_global_func("meta_schedule.cpu.test_apply_custom_rule") def sch_fn(sch: tvm.tir.Schedule, block: tvm.tir.Block) -> List[tvm.tir.Schedule]: raise ValueError("Intended for meta_schedule.cpu.test_apply_custom_rule") diff --git a/tests/python/meta_schedule/test_meta_schedule_search_strategy.py b/tests/python/meta_schedule/test_meta_schedule_search_strategy.py index 29c20ced0488..04a6e187a6a7 100644 --- a/tests/python/meta_schedule/test_meta_schedule_search_strategy.py +++ b/tests/python/meta_schedule/test_meta_schedule_search_strategy.py @@ -306,9 +306,40 @@ def __str__(self) -> str: assert candidates is None +def test_search_strategy_abstract_class_instantiation(): + """Test that directly instantiating abstract SearchStrategy raises TypeError instead of segfault.""" + from tvm.meta_schedule import SearchStrategy + from tvm.target import Target + from tvm.meta_schedule import TuneContext + + # Test that direct instantiation raises TypeError + # This prevents segfault when SearchStrategy() is called directly + with pytest.raises(TypeError, match="Cannot instantiate abstract class SearchStrategy"): + SearchStrategy() + + # Test that TuneContext with SearchStrategy() raises TypeError + # The error should occur when trying to create SearchStrategy() instance in the function call + # Since SearchStrategy() fails in __new__, it will fail before TuneContext.__init__ is called + with pytest.raises(TypeError, match="Cannot instantiate abstract class SearchStrategy"): + # This will fail when evaluating SearchStrategy() as an argument + TuneContext( + mod=Matmul, # Use the existing Matmul module from the test file + target=Target("llvm"), + search_strategy=SearchStrategy(), # This should fail in __new__ before reaching TuneContext + ) + + # Test that SearchStrategy.create() works correctly + strategy = SearchStrategy.create("evolutionary") + assert strategy is not None + assert isinstance(strategy, SearchStrategy) + # Verify it's not the abstract class itself + assert type(strategy) is not SearchStrategy + + if __name__ == "__main__": test_meta_schedule_replay_func(ms.search_strategy.ReplayFunc) test_meta_schedule_replay_func(ms.search_strategy.ReplayTrace) test_meta_schedule_evolutionary_search() test_meta_schedule_evolutionary_search_early_stop() test_meta_schedule_evolutionary_search_fail_init_population() + test_search_strategy_abstract_class_instantiation() diff --git a/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py index 3f0964cfa8ed..64898ecdbaa5 100644 --- a/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py +++ b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py @@ -47,8 +47,8 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) vm = relax.VirtualMachine(ex, dev) - gpu_data = tvm.nd.array(raw_data_for_tvm, dev) - gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_data = tvm.runtime.tensor(raw_data_for_tvm, dev) + gpu_params = [tvm.runtime.tensor(p, dev) for p in tvm_params["main"]] gpu_out = vm["main"](gpu_data, *gpu_params) pytorch_out = torch_module(torch_data) @@ -57,11 +57,11 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar for i in range(len(pytorch_out)): actual = gpu_out[i].numpy() desired = pytorch_out[i].detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) else: actual = gpu_out[0].numpy() desired = pytorch_out.detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) @tvm.testing.parametrize_targets("cuda") diff --git a/tests/python/nightly/test_nnapi/test_network.py b/tests/python/nightly/test_nnapi/test_network.py index 2f9863eb4ee8..82094cb74c29 100644 --- a/tests/python/nightly/test_nnapi/test_network.py +++ b/tests/python/nightly/test_nnapi/test_network.py @@ -125,7 +125,7 @@ def test_network(name, dtype): for _name, (shape, _dtype) in inputs.items(): input_data[_name] = np.random.uniform(-1.0, 1.0, shape).astype(_dtype) - inputs_tvm: List[tvm.nd.NDArray] = [tvm.nd.array(v) for k, v in input_data.items()] + inputs_tvm: List[tvm.runtime.Tensor] = [tvm.runtime.tensor(v) for k, v in input_data.items()] outputs = _build_and_run_network(remote_obj, tracker, mod, inputs_tvm) nnapi_out = outputs[0] expected_out = outputs[1] diff --git a/tests/python/nightly/test_nnapi/test_ops.py b/tests/python/nightly/test_nnapi/test_ops.py index a6837d2ce5c1..fc10e9b169c0 100644 --- a/tests/python/nightly/test_nnapi/test_ops.py +++ b/tests/python/nightly/test_nnapi/test_ops.py @@ -255,7 +255,7 @@ def main( tracker, mod, inputs=[ - tvm.nd.array(np.random.uniform(size=(8, 10, 15)).astype("float32")), + tvm.runtime.tensor(np.random.uniform(size=(8, 10, 15)).astype("float32")), ], ) @@ -284,7 +284,7 @@ def main( tracker, mod, inputs=[ - tvm.nd.array(np.random.uniform(size=(1, 10, 15)).astype("float32")), + tvm.runtime.tensor(np.random.uniform(size=(1, 10, 15)).astype("float32")), ], ) @@ -351,7 +351,7 @@ def main( def verify(remote_obj, tracker, mod, inputs): - inputs_tvm: List[tvm.nd.NDArray] = [tvm.nd.array(v) for v in inputs] + inputs_tvm: List[tvm.runtime.Tensor] = [tvm.runtime.tensor(v) for v in inputs] outputs = _build_and_run_network(remote_obj, tracker, mod, inputs_tvm) nnapi_out = outputs[0] expected_out = outputs[1] diff --git a/tests/python/relax/adreno/test_transform_annotate_custom_scope.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py new file mode 100644 index 000000000000..24b4cf66b888 --- /dev/null +++ b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py @@ -0,0 +1,1204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + + +@visitor +class ValidateScope(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, scope_info: dict) -> None: + self.scope_info = scope_info + self.matched = True + + def visit(self, mod: IRModule) -> None: + """Entry point""" + for _, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + return self.matched + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op.name == "relax.call_tir": + # if call.args[0].name_hint in self.scope_info: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + call_mem_scope = ( + "global" if not arg_sinfo.vdevice else arg_sinfo.vdevice.memory_scope + ) + assert ( + call_mem_scope == self.scope_info[call.args[0].name_hint][0][idx] + ), f"Scope mismatched for argument {idx} in {call.args[0].name_hint}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + call_mem_scope = ( + "global" + if not call.sinfo_args[0].vdevice + else call.sinfo_args[0].vdevice.memory_scope + ) + assert ( + call_mem_scope == self.scope_info[call.args[0].name_hint][1][0] + ), f"Scope mismatched for return scope: {call.args[0].name_hint}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + call_mem_scope = "global" if not sinfo.vdevice else sinfo.vdevice.memory_scope + assert ( + call_mem_scope == self.scope_info[call.args[0].name_hint][1][idx] + ), f"Scope mismatched for return scope for {idx} in {call.args[0].name_hint}" + + +def verify(mod, expected): + tgt = tvm.target.Target("opencl --device=adreno", host="llvm") + skip_ops = [ + "relax.nn.conv2d", + "relax.nn.max_pool2d", + "relax.nn.adaptive_avg_pool2d", + # "relax.nn.layer_norm", + ] + with tgt: + mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) + mod = tvm.relax.transform.DecomposeOpsForInference()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} + mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) + mod = tvm.relax.transform.Normalize()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.backend.adreno.transform.AnnotateCustomMemoryScope(tgt)(mod) + # There is a possibility of some skipped ops above might not use 5D layouts. + mod = tvm.relax.transform.LegalizeOps()(mod) + mod = tvm.relax.transform.LegalizeOps( + {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, + )(mod) + # Lets get pattern info for newly legalized ops + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.FuseOps()(mod) + mod = tvm.relax.transform.FuseTIR()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) + mod = tvm.relax.transform.Normalize()(mod) + + ValidateScope(expected).visit(mod) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 64, 56, 56), "float32"), w: R.Tensor((32, 64, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 32, 54, 54), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-nhwc"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-nhwc", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_NCHW_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d( + x, + w, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_NHWC_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), "float32"), w: R.Tensor((4, 3, 3, 16), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 26, 26, 4), "float32") = R.nn.conv2d( + x, + w, + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def _test_conv2d_symbolic_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor("float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + gv: R.Tensor( + (N, T.int64(4), H + T.int64(1) - Hw, W + T.int64(1) - Ww), "float32" + ) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_relu": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_relu_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + "relu": (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_tanh_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_relu_tir_tanh": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_add_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_fma_relu_conv2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "relu": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform4": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_keepdims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_reduce_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 26), "float32") = R.sum(gv, axis=[1, 2]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_transpose_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "transpose": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_expand_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "expand_dims": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_squeeze_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=3): + with R.dataflow(): + gv: R.Tensor((1, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((4, 26, 26), "float32") = R.squeeze(gv, axis=[0]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "squeeze": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_strided_slice_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "strided_slice": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_transpose_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + gv5: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[0], axes=[3, 2, 1, 0]) + gv6: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[1], axes=[3, 2, 1, 0]) + gv7: R.Tensor((26, 26, 8, 2), "float32") = R.concat((gv5, gv6), axis=2) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), + "fused_transpose_transpose_concatenate1": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_maxpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d_opencl": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_avgpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "adaptive_avg_pool2d_opencl": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_softmax_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "softmax": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_layernorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "layer_norm": (["global", "global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_broadcast_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "add": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_ewise_scalar_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, R.const(1, "float32")) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_residual_block(): + """ + - some kind of residual block followed by convolution to have texture after residual block + - scalar data type verification which should be mapped to global memory scope + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (1) <- to get textures as output + / \ + conv2d (2) | + \ / + add <- add should be fused into conv2d (2) + multiply to scalar <- buffer to the input of multiply scalar value + relu + | <- texture in intermediate tensor + conv2d (3) + relu + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 2, 2), "float32"), + w2: R.Tensor((32, 32, 1, 1), "float32"), + w3: R.Tensor((32, 32, 2, 2), "float32"), + bias: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[1, 1], out_dtype="float32") + bias_1 = R.multiply(bias, R.const(0.15, "float32")) + gv4 = R.add(gv3, bias_1) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv5, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.nn.relu(gv6) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "multiply": (["global"], ["global"]), + "fused_conv2d_NCHWc_OIHWo1_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo2_opencl_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform4": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_fallback_to_buffer_conv2d(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform4": (["global"], ["global"]), + "conv2d": (["global", "global"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + "concatenate": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "concatenate": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d_opencl": (["global.texture-weight"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo1_opencl_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo3_opencl_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_injective_inputs1(): + """ + Input + / \ + / | + | / + conv2d (1) / + | / + conv2d (2) mean / + / \ / + | | \ / + | | (3) add + | | | + | \ / + \ mul + \ / + add + + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv1) + gv = R.add(ad3, ad2) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( + [ + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + ], + ["global"], + ), + } + verify(Input, Expected) + + +def test_injective_nwo_inputs2(): + """ + Input + / \ + | \ + conv2d \ + | / + conv2d mean / + / \ / + add | \ | + | | \ | + | | \ / + | | (3) add + | | | + | \ / + | \ / + \ mul + \ / + add + + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv2) + gv = R.add(ad2, ad3) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( + [ + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + ], + ["global"], + ), + } + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py new file mode 100644 index 000000000000..b461f39dd744 --- /dev/null +++ b/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.ir.module import IRModule + + +def verify(input, expected): + mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(input) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_maxpool2d_scope_folding(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv2 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5: R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = R.to_vdevice(lv2, dst_vdevice="opencl:1:global") + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Expected + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" + ), + ) + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/backend/clml/utils.py b/tests/python/relax/backend/clml/utils.py index dd7e269f5535..d32a2df38ffd 100644 --- a/tests/python/relax/backend/clml/utils.py +++ b/tests/python/relax/backend/clml/utils.py @@ -56,7 +56,7 @@ def build_and_run( vm = relax.VirtualMachine(ex, dev) f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] vm.set_input("main", *inputs) vm.invoke_stateful("main") tvm_output = vm.get_outputs("main") diff --git a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py index 81acf5ee863d..5c994028ac88 100644 --- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py +++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py @@ -170,7 +170,7 @@ def set_global_func(head_dim, dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f.main) ( ftranspose_append, @@ -212,7 +212,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): rope_scale, rope_theta, None, # rope_ext_factors - tvm.nd.empty((), dtype, device=device), + tvm.runtime.empty((), dtype, device=device), ftranspose_append, None, # f_transpose_append_mla ["tir", fattn_prefill_ragged], @@ -262,8 +262,8 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): values_expected = expected_v[seq_id] assert keys_expected.shape == values_expected.shape seq_length = expected_k[seq_id].shape[1] - keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) - values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) + keys = tvm.runtime.empty(keys_expected.shape, dtype=dtype, device=device) + values = tvm.runtime.empty(values_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) torch.testing.assert_close( torch.from_numpy(keys.numpy()).to(device_torch), keys_expected, rtol=1e-3, atol=1e-3 @@ -460,8 +460,10 @@ def apply_attention( queries_np = global_new_q[layer_id] keys_np = global_new_k[layer_id] values_np = global_new_v[layer_id] - qkv = tvm.nd.array(torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device) - outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) + qkv = tvm.runtime.tensor( + torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device + ) + outputs = tvm.runtime.empty(queries_np.shape, dtype, device=device) if not only_update_host: fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) diff --git a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py index b0b41c8e92b4..0bdf63b6d547 100644 --- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py +++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py @@ -62,12 +62,12 @@ def test_kv_transfer_without_disco(): k_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) v_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) if rank == 0: - k = tvm.nd.array(k_np, dev) - v = tvm.nd.array(v_np, dev) + k = tvm.runtime.tensor(k_np, dev) + v = tvm.runtime.tensor(v_np, dev) remote_position_map_np = np.array(position_map_array, dtype=np.int32) - remote_position_map = tvm.nd.array(remote_position_map_np, dev) + remote_position_map = tvm.runtime.tensor(remote_position_map_np, dev) remote_tp_group_pe_offset_np = np.array([1] * len(position_map_array), dtype=np.int32) - remote_tp_group_pe_offset = tvm.nd.array(remote_tp_group_pe_offset_np, dev) + remote_tp_group_pe_offset = tvm.runtime.tensor(remote_tp_group_pe_offset_np, dev) transfer_func = tvm.get_global_func("nvshmem.KVTransfer") layer_view = pages._create_view( [num_pages, 2, num_kv_heads, page_size, head_dim], @@ -85,10 +85,10 @@ def test_kv_transfer_without_disco(): offset_in_page = position % page_size original_k = k_np[i] transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page, :] - np.testing.assert_allclose(original_k, transferred_k) + tvm.testing.assert_allclose(original_k, transferred_k) original_v = v_np[i] transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page, :] - np.testing.assert_allclose(original_v, transferred_v) + tvm.testing.assert_allclose(original_v, transferred_v) finalize_func = tvm.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") finalize_func() comm.Barrier() @@ -120,13 +120,13 @@ def test_kv_transfer_page_to_page_without_disco(): if rank == 0: pages.copyfrom(pages_np) remote_position_map_np = np.array(rank_1_position_map_array, dtype=np.int32) - remote_position_map = tvm.nd.array(remote_position_map_np, dev) + remote_position_map = tvm.runtime.tensor(remote_position_map_np, dev) local_position_map_np = np.array(rank_0_position_map_array, dtype=np.int32) - local_position_map = tvm.nd.array(local_position_map_np, dev) + local_position_map = tvm.runtime.tensor(local_position_map_np, dev) remote_tp_group_pe_offset_np = np.array( [1] * len(rank_0_position_map_array), dtype=np.int32 ) - remote_tp_group_pe_offset = tvm.nd.array(remote_tp_group_pe_offset_np, dev) + remote_tp_group_pe_offset = tvm.runtime.tensor(remote_tp_group_pe_offset_np, dev) transfer_func = tvm.get_global_func("nvshmem.KVTransferPageToPage") layer_view = pages._create_view( [num_pages, 2, num_kv_heads, page_size, head_dim], @@ -154,7 +154,7 @@ def test_kv_transfer_page_to_page_without_disco(): rank_0_offset_in_page = rank_0_position % page_size rank_0_entry = pages_np[layer_id, rank_0_page_id, :, :, rank_0_offset_in_page, :] transferred_entry = new_pages_np[layer_id, page_id, :, :, offset_in_page, :] - np.testing.assert_allclose(rank_0_entry, transferred_entry) + tvm.testing.assert_allclose(rank_0_entry, transferred_entry) finalize_func = tvm.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") finalize_func() comm.Barrier() @@ -197,7 +197,7 @@ def test_kv_transfer_with_disco(): remote_position_map = sess.empty((len(position_map_array),), "int32") remote_tp_group_pe_offset_np = np.array([2] * len(position_map_array), dtype=np.int32) remote_tp_group_pe_offset = sess.empty((len(remote_tp_group_pe_offset_np),), "int32") - f_view_func = sess.get_global_func("runtime.TVMArrayCreateView") + f_view_func = sess.get_global_func("runtime.TVMTensorCreateView") layer_view = f_view_func( pages, ShapeTuple([num_pages, 2, num_kv_heads, page_size, head_dim]), @@ -223,20 +223,20 @@ def test_kv_transfer_with_disco(): offset_in_page = position % page_size original_k = k_np_0[i] transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page, :] - np.testing.assert_allclose(original_k, transferred_k) + tvm.testing.assert_allclose(original_k, transferred_k) original_v = v_np_0[i] transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page, :] - np.testing.assert_allclose(original_v, transferred_v) + tvm.testing.assert_allclose(original_v, transferred_v) pages_np = pages.debug_get_from_remote(1).numpy() for i, position in enumerate(position_map_array): page_id = position // page_size offset_in_page = position % page_size original_k = k_np_1[i] transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page, :] - np.testing.assert_allclose(original_k, transferred_k) + tvm.testing.assert_allclose(original_k, transferred_k) original_v = v_np_1[i] transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page, :] - np.testing.assert_allclose(original_v, transferred_v) + tvm.testing.assert_allclose(original_v, transferred_v) finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") finalize_dfunc() for i in range(2): diff --git a/tests/python/relax/test_backend_dispatch_sampling.py b/tests/python/relax/test_backend_dispatch_sampling.py index de31efc3fa96..fb36f877758b 100644 --- a/tests/python/relax/test_backend_dispatch_sampling.py +++ b/tests/python/relax/test_backend_dispatch_sampling.py @@ -103,7 +103,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl u: T.float32 = uniform_samples[bx, 0] aggregate[()] = T.Cast("float32", 0) step_iter[()] = 0 - while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)): + while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < T.Cast("int64", (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512))): with T.block(""): T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()]) T.writes(sample_id_local[()], aggregate[()]) diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index 004050aaf892..d48227fc6277 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -428,7 +428,7 @@ def main(x: R.Tensor(("m", "n"), "int32")): mod = DispatchSortScan()(Module) ex = tvm.compile(mod, target) vm = tvm.relax.VirtualMachine(ex, dev) - tvm_data = tvm.nd.array(np_data, dev) + tvm_data = tvm.runtime.tensor(np_data, dev) cumsum = vm["main"](tvm_data) tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum) diff --git a/tests/python/relax/test_base_py_module.py b/tests/python/relax/test_base_py_module.py new file mode 100644 index 000000000000..1f888991be1b --- /dev/null +++ b/tests/python/relax/test_base_py_module.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Test BasePyModule core functionality. + +This test verifies: +1. BasePyModule instantiation and basic methods +2. TIR function compilation and execution +3. Python function integration +4. DLPack conversion between PyTorch and TVM +""" + +import pytest +import torch +import tvm +from tvm import relax, tir +from tvm.script import relax as R, tir as T +from tvm.relax import BasePyModule +import numpy as np + + +class TestBasePyModule: + """Test BasePyModule core functionality.""" + + def test_base_py_module_instantiation(self): + @T.prim_func + def simple_func(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")): + for i in T.grid(10): + B[i] = A[i] * 2.0 + + ir_mod = tvm.IRModule({"simple_func": simple_func}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + assert isinstance(py_mod, BasePyModule) + assert hasattr(py_mod, "call_tir") + assert hasattr(py_mod, "call_dps_packed") + assert hasattr(py_mod, "compiled_tir_funcs") + + def test_base_py_module_instantiation_gpu(self): + @T.prim_func + def simple_func(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")): + for i in T.grid(10): + B[i] = A[i] * 2.0 + + ir_mod = tvm.IRModule({"simple_func": simple_func}) + + if tvm.cuda().exist: + device = tvm.cuda(0) + py_mod = BasePyModule(ir_mod, device) + + assert isinstance(py_mod, BasePyModule) + assert hasattr(py_mod, "call_tir") + assert hasattr(py_mod, "call_dps_packed") + assert hasattr(py_mod, "compiled_tir_funcs") + # Check if target contains "cuda" instead of exact match + assert "cuda" in str(py_mod.target) + else: + pytest.skip("CUDA not available") + + def test_tir_function_compilation(self): + @T.prim_func + def add_func( + A: T.Buffer((5,), "float32"), B: T.Buffer((5,), "float32"), C: T.Buffer((5,), "float32") + ): + for i in T.grid(5): + C[i] = A[i] + B[i] + + ir_mod = tvm.IRModule({"add_func": add_func}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + assert "add_func" in py_mod.tir_func_names + assert "add_func" in py_mod.compiled_tir_funcs + + def test_call_tir_with_pytorch_tensors(self): + @T.prim_func + def scale_func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + for i in T.grid(4): + B[i] = A[i] * T.float32(2.5) + + ir_mod = tvm.IRModule({"scale_func": scale_func}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + scale_value = 2.5 + + result = py_mod.call_tir(scale_func, [input_tensor], R.Tensor((4,), "float32")) + + assert isinstance(result, torch.Tensor) + assert result.shape == (4,) + expected = input_tensor * scale_value + assert torch.allclose(result, expected, atol=1e-5) + + def test_call_tir_with_pytorch_tensors_gpu(self): + if tvm.cuda().exist: + # Create a simple IRModule without TIR functions for GPU testing + ir_mod = tvm.IRModule({}) + device = tvm.cuda(0) + py_mod = BasePyModule(ir_mod, device) + + # Test basic GPU functionality without TIR compilation issues + assert isinstance(py_mod, BasePyModule) + assert hasattr(py_mod, "call_tir") + assert hasattr(py_mod, "call_dps_packed") + assert "cuda" in str(py_mod.target) + + # Test that we can create GPU tensors and they work + input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device="cuda") + assert input_tensor.device.type == "cuda" + assert input_tensor.shape == (4,) + else: + pytest.skip("CUDA not available") + + def test_dlpack_conversion_pytorch_to_tvm(self): + @T.prim_func + def identity_func(A: T.Buffer((3,), "float32"), B: T.Buffer((3,), "float32")): + for i in T.grid(3): + B[i] = A[i] + + ir_mod = tvm.IRModule({"identity_func": identity_func}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + input_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + + result = py_mod.call_tir(identity_func, [input_tensor], R.Tensor((3,), "float32")) + + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, input_tensor, atol=1e-5) + + def test_dlpack_conversion_tvm_to_pytorch(self): + @T.prim_func + def constant_func(B: T.Buffer((2,), "float32")): + for i in T.grid(2): + B[i] = T.float32(5.0) + + ir_mod = tvm.IRModule({"constant_func": constant_func}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + result = py_mod.call_tir(constant_func, [], R.Tensor((2,), "float32")) + + assert isinstance(result, torch.Tensor) + assert result.shape == (2,) + expected = torch.tensor([5.0, 5.0], dtype=torch.float32) + assert torch.allclose(result, expected, atol=1e-5) + + def test_add_python_function(self): + ir_mod = tvm.IRModule({}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + def custom_activation(x): + return torch.tanh(x) + + py_mod.add_python_function("custom_activation", custom_activation) + + assert hasattr(py_mod, "custom_activation") + assert "custom_activation" in py_mod.pyfuncs + + input_tensor = torch.tensor([1.0, -1.0, 0.0], dtype=torch.float32) + result = py_mod.custom_activation(input_tensor) + + assert isinstance(result, torch.Tensor) + expected = torch.tanh(input_tensor) + assert torch.allclose(result, expected, atol=1e-5) + + def test_call_dps_packed_with_python_function(self): + ir_mod = tvm.IRModule({}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + def my_softmax(tensor, dim): + return torch.softmax(tensor, dim=dim) + + py_mod.add_python_function("my_softmax", my_softmax) + + input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + + result = py_mod.call_dps_packed( + "my_softmax", [input_tensor, 1], R.Tensor((2, 2), "float32") + ) + + assert isinstance(result, torch.Tensor) + expected = torch.softmax(input_tensor, dim=1) + assert torch.allclose(result, expected, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py new file mode 100644 index 000000000000..0b5b97b0c323 --- /dev/null +++ b/tests/python/relax/test_base_py_module_printer.py @@ -0,0 +1,807 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name, unused-argument + +import pytest +import tvm +from tvm.relax.base_py_module import BasePyModule +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R + + +@I.ir_module +class SimplePyFuncModule(BasePyModule): + """Test simple Python functions with basic operations.""" + + @I.pyfunc + def add(self, x, y): + """Simple addition function.""" + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + result = self.call_tir(self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32")) + return self._convert_tvm_to_pytorch(result) + + @I.pyfunc + def multiply(self, x, y): + """Simple multiplication function.""" + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + result = self.call_tir( + self.multiply_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32") + ) + return self._convert_tvm_to_pytorch(result) + + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + + for i in range(5): + out[i] = x[i] + y[i] + + @T.prim_func + def multiply_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + + for i in range(5): + out[i] = x[i] * y[i] + + @R.function + def main_relax( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.add(x, y) + + +@I.ir_module +class ComplexPyFuncModule(BasePyModule): + """Test complex Python logic with ML pipeline and error handling.""" + + @I.pyfunc + def ml_pipeline(self, input_data, model_params): + """Complex ML pipeline with data validation and error handling.""" + # Data validation + if input_data is None or model_params is None: + raise ValueError("Inputs cannot be None") + + try: + # Convert to TVM format + tvm_data = self._convert_pytorch_to_tvm(input_data) + tvm_params = self._convert_pytorch_to_tvm(model_params) + + # Run ML inference + features = self.call_tir( + self.extract_features, [tvm_data], out_sinfo=R.Tensor((10,), "float32") + ) + + predictions = self.call_tir( + self.ml_inference, [features, tvm_params], out_sinfo=R.Tensor((5,), "float32") + ) + + # Post-process results + final_result = self.call_tir( + self.post_process, [predictions], out_sinfo=R.Tensor((5,), "float32") + ) + + return self._convert_tvm_to_pytorch(final_result) + + except Exception as e: + self._log_error(f"ML pipeline failed: {e}") + return self._get_default_value() + + @I.pyfunc + def data_preprocessing(self, raw_data): + """Data preprocessing with conditional logic.""" + if hasattr(raw_data, "numpy"): + # Vectorized path for numpy-compatible data + data_np = raw_data.numpy() + processed = self._vectorized_preprocess(data_np) + else: + # Fallback path for other data types + processed = self._elementwise_preprocess(raw_data) + + # Convert and return + tvm_processed = self._convert_pytorch_to_tvm(processed) + result = self.call_tir( + self.normalize_data, [tvm_processed], out_sinfo=R.Tensor((10,), "float32") + ) + return self._convert_tvm_to_pytorch(result) + + @T.prim_func + def extract_features(data: T.handle, features: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (10,), "float32") + Features = T.match_buffer(features, (10,), "float32") + + for i in range(10): + Features[i] = T.sqrt(Data[i]) + + @T.prim_func + def ml_inference(features: T.handle, params: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Features = T.match_buffer(features, (10,), "float32") + Params = T.match_buffer(params, (10,), "float32") + Output = T.match_buffer(output, (5,), "float32") + + for i in range(5): + Output[i] = Features[i] * Params[i] + Features[i + 5] * Params[i + 5] + + @T.prim_func + def post_process(predictions: T.handle, final: T.handle): + T.func_attr({"tir.noalias": True}) + Predictions = T.match_buffer(predictions, (5,), "float32") + Final = T.match_buffer(final, (5,), "float32") + + for i in range(5): + Final[i] = T.max(Predictions[i], 0.0) + + @T.prim_func + def normalize_data(data: T.handle, normalized: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (10,), "float32") + Normalized = T.match_buffer(normalized, (10,), "float32") + + for i in range(10): + Normalized[i] = Data[i] / 255.0 + + +@I.ir_module +class EdgeCasePyFuncModule(BasePyModule): + """Test edge cases and boundary conditions.""" + + @I.pyfunc + def empty_func(self): + """Empty function with no operations.""" + pass + + @I.pyfunc + def single_return(self, x): + """Function with immediate return.""" + return x + + @I.pyfunc + def nested_conditionals(self, data, threshold): + """Function with complex nested conditional logic.""" + if data is None: + return None + + if hasattr(data, "shape"): + if len(data.shape) == 1: + if data.shape[0] > threshold: + return self._process_large_data(data) + else: + return self._process_small_data(data) + elif len(data.shape) == 2: + return self._process_2d_data(data) + else: + return self._process_nd_data(data) + else: + return self._process_scalar_data(data) + + @I.pyfunc + def loop_with_break(self, data, max_iter): + """Function with loop and break statement.""" + result = [] + for i, item in enumerate(data): + if i >= max_iter: + break + if item > 0: + result.append(item * 2) + else: + result.append(0) + return result + + @T.prim_func + def dummy_tir(data: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (1,), "float32") + Output = T.match_buffer(output, (1,), "float32") + Output[0] = Data[0] + + +@I.ir_module +class PerformancePyFuncModule(BasePyModule): + """Test performance optimization patterns.""" + + @I.pyfunc + def vectorized_operation(self, x, y): + """Vectorized operation with numpy fallback.""" + try: + # Try vectorized operation first + if hasattr(x, "numpy") and hasattr(y, "numpy"): + x_np = x.numpy() + y_np = y.numpy() + result_np = x_np + y_np + return self._convert_numpy_to_pytorch(result_np) + except Exception: + pass + + # Fallback to TVM processing + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + result = self.call_tir( + self.vectorized_add, [x_tvm, y_tvm], out_sinfo=R.Tensor((10,), "float32") + ) + return self._convert_tvm_to_pytorch(result) + + @I.pyfunc + def batch_processing(self, batch_data): + """Batch processing with memory optimization.""" + batch_size = len(batch_data) + results = [] + + # Process in chunks to optimize memory usage + chunk_size = min(batch_size, 100) + for i in range(0, batch_size, chunk_size): + chunk = batch_data[i : i + chunk_size] + chunk_result = self._process_chunk(chunk) + results.extend(chunk_result) + + return results + + @I.pyfunc + def memory_efficient_transform(self, large_tensor): + """Memory-efficient tensor transformation.""" + # Use in-place operations when possible + if hasattr(large_tensor, "requires_grad") and not large_tensor.requires_grad: + # In-place operation for efficiency + large_tensor.add_(1.0) + return large_tensor + else: + # Create new tensor if gradients are needed + return large_tensor + 1.0 + + @T.prim_func + def vectorized_add(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"tir.noalias": True}) + A = T.match_buffer(a, (10,), "float32") + B = T.match_buffer(b, (10,), "float32") + C = T.match_buffer(c, (10,), "float32") + + for i in range(10): + C[i] = A[i] + B[i] + + +@I.ir_module +class IntegrationPyFuncModule(BasePyModule): + """Test integration with external libraries and complex workflows.""" + + @I.pyfunc + def sklearn_integration(self, input_data, scaler_params): + """Integration with scikit-learn preprocessing.""" + try: + # Import sklearn components + from sklearn.preprocessing import StandardScaler + from sklearn.decomposition import PCA + + # Create and fit scaler + scaler = StandardScaler() + if scaler_params is not None: + scaler.mean_ = scaler_params["mean"] + scaler.scale_ = scaler_params["scale"] + else: + scaler.fit(input_data) + + # Transform data + scaled_data = scaler.transform(input_data) + + # Apply PCA if needed + if input_data.shape[1] > 10: + pca = PCA(n_components=10) + reduced_data = pca.fit_transform(scaled_data) + else: + reduced_data = scaled_data + + # Convert to TVM and process + tvm_data = self._convert_pytorch_to_tvm(reduced_data) + result = self.call_tir( + self.final_transform, + [tvm_data], + out_sinfo=R.Tensor((reduced_data.shape[0], 10), "float32"), + ) + + return self._convert_tvm_to_pytorch(result) + + except ImportError: + # Fallback if sklearn is not available + return self._fallback_preprocessing(input_data) + + @I.pyfunc + def multi_stage_pipeline(self, raw_input): + """Multi-stage processing pipeline.""" + # Stage 1: Data cleaning + cleaned = self._clean_data(raw_input) + + # Stage 2: Feature extraction + features = self._extract_features(cleaned) + + # Stage 3: Model inference + predictions = self._run_inference(features) + + # Stage 4: Post-processing + final_result = self._post_process_output(predictions) + + return final_result + + @T.prim_func + def final_transform(data: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (10, 10), "float32") + Output = T.match_buffer(output, (10, 10), "float32") + + for i in range(10): + for j in range(10): + Output[i, j] = T.tanh(Data[i, j]) + + +@I.ir_module +class ErrorHandlingPyFuncModule(BasePyModule): + """Test comprehensive error handling and validation.""" + + @I.pyfunc + def robust_data_processing(self, input_data, config): + """Robust data processing with comprehensive error handling.""" + try: + # Validate inputs + if not self._validate_inputs(input_data, config): + raise ValueError("Invalid input data or configuration") + + # Check data types + if not self._check_data_types(input_data): + raise TypeError("Unsupported data types") + + # Process data with retry logic + max_retries = config.get("max_retries", 3) + for attempt in range(max_retries): + try: + result = self._process_with_validation(input_data, config) + if self._validate_output(result): + return result + else: + raise RuntimeError("Output validation failed") + except Exception as e: + if attempt == max_retries - 1: + raise + self._log_warning(f"Attempt {attempt + 1} failed: {e}") + continue + + except Exception as e: + self._log_error(f"Data processing failed: {e}") + return self._get_safe_fallback(input_data, config) + + @I.pyfunc + def graceful_degradation(self, primary_input, fallback_input): + """Function that gracefully degrades when primary path fails.""" + try: + # Try primary processing path + result = self._primary_processing(primary_input) + return result + except Exception as e: + self._log_warning(f"Primary processing failed: {e}") + + try: + # Try fallback path + result = self._fallback_processing(fallback_input) + return result + except Exception as e2: + self._log_error(f"Fallback processing also failed: {e2}") + # Return safe default + return self._get_safe_default() + + @T.prim_func + def safe_transform(data: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (5,), "float32") + Output = T.match_buffer(output, (5,), "float32") + + for i in range(5): + # Safe operation that handles edge cases + if Data[i] > 0: + Output[i] = T.sqrt(Data[i]) + else: + Output[i] = 0.0 + + +# Pytest test functions to verify the classes work correctly +def test_simple_pyfunc_module_creation(): + """Test that SimplePyFuncModule can be created.""" + # Get the IRModule instance from the TVMScript decorated class + ir_mod = SimplePyFuncModule + device = tvm.cpu() + + # Create BasePyModule instance + module = BasePyModule(ir_mod, device) + assert isinstance(module, BasePyModule) + + # Note: Python functions are stored in pyfuncs, not as direct attributes + # We need to check if they exist in the IRModule's pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "add" in ir_mod.pyfuncs + assert "multiply" in ir_mod.pyfuncs + + # Check that TIR functions exist + assert hasattr(module, "add_tir") + assert hasattr(module, "multiply_tir") + + # Note: This particular TVMScript is for testing purpose only, and cannot compile + # Relax functions may not be available due to TVMScript compilation issues + print("Note: This TVMScript is for testing purpose only, and cannot compile") + + +def test_complex_pyfunc_module_creation(): + """Test that ComplexPyFuncModule can be created.""" + ir_mod = ComplexPyFuncModule + device = tvm.cpu() + + module = BasePyModule(ir_mod, device) + assert isinstance(module, BasePyModule) + + # Check Python functions in pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "ml_pipeline" in ir_mod.pyfuncs + assert "data_preprocessing" in ir_mod.pyfuncs + + # Check TIR functions + assert hasattr(module, "extract_features") + assert hasattr(module, "ml_inference") + assert hasattr(module, "post_process") + assert hasattr(module, "normalize_data") + + +def test_edge_case_pyfunc_module_creation(): + """Test that EdgeCasePyFuncModule can be created.""" + ir_mod = EdgeCasePyFuncModule + device = tvm.cpu() + + module = BasePyModule(ir_mod, device) + assert isinstance(module, BasePyModule) + + # Check Python functions in pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "empty_func" in ir_mod.pyfuncs + assert "single_return" in ir_mod.pyfuncs + assert "nested_conditionals" in ir_mod.pyfuncs + assert "loop_with_break" in ir_mod.pyfuncs + + # Check TIR function + assert hasattr(module, "dummy_tir") + + +def test_performance_pyfunc_module_creation(): + """Test that PerformancePyFuncModule can be created.""" + ir_mod = PerformancePyFuncModule + device = tvm.cpu() + + module = BasePyModule(ir_mod, device) + assert isinstance(module, BasePyModule) + + # Check Python functions in pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "vectorized_operation" in ir_mod.pyfuncs + assert "batch_processing" in ir_mod.pyfuncs + assert "memory_efficient_transform" in ir_mod.pyfuncs + + # Check TIR function + assert hasattr(module, "vectorized_add") + + +def test_integration_pyfunc_module_creation(): + """Test that IntegrationPyFuncModule can be created.""" + ir_mod = IntegrationPyFuncModule + device = tvm.cpu() + + module = BasePyModule(ir_mod, device) + assert isinstance(module, BasePyModule) + + # Check Python functions in pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "sklearn_integration" in ir_mod.pyfuncs + assert "multi_stage_pipeline" in ir_mod.pyfuncs + + # Check TIR function + assert hasattr(module, "final_transform") + + +def test_error_handling_pyfunc_module_creation(): + """Test that ErrorHandlingPyFuncModule can be created.""" + ir_mod = ErrorHandlingPyFuncModule + device = tvm.cpu() + + module = BasePyModule(ir_mod, device) + assert isinstance(module, BasePyModule) + + # Check Python functions in pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "robust_data_processing" in ir_mod.pyfuncs + assert "graceful_degradation" in ir_mod.pyfuncs + + # Check TIR function + assert hasattr(module, "safe_transform") + + +def test_all_modules_inherit_from_base(): + """Test that all modules properly inherit from BasePyModule.""" + modules = [ + SimplePyFuncModule, + ComplexPyFuncModule, + EdgeCasePyFuncModule, + PerformancePyFuncModule, + IntegrationPyFuncModule, + ErrorHandlingPyFuncModule, + ] + + device = tvm.cpu() + for ir_mod in modules: + module = BasePyModule(ir_mod, device) + assert isinstance(module, BasePyModule) + assert hasattr(module, "script") + assert hasattr(module, "show") + + +def test_pyfunc_decorators(): + """Test that all @I.pyfunc decorated functions are present.""" + ir_mod = SimplePyFuncModule + device = tvm.cpu() + module = BasePyModule(ir_mod, device) + + # Check that the functions exist in pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "add" in ir_mod.pyfuncs + assert "multiply" in ir_mod.pyfuncs + + # Get the actual function objects + add_func = ir_mod.pyfuncs["add"] + multiply_func = ir_mod.pyfuncs["multiply"] + + # Check that they are callable + assert callable(add_func) + assert callable(multiply_func) + + # Check function signatures + import inspect + + add_sig = inspect.signature(add_func) + assert len(add_sig.parameters) == 3 # self, x, y + + multiply_sig = inspect.signature(multiply_func) + assert len(multiply_sig.parameters) == 3 # self, x, y + + +def test_tir_functions(): + """Test that TIR functions are properly defined.""" + ir_mod = SimplePyFuncModule + device = tvm.cpu() + module = BasePyModule(ir_mod, device) + + # Check TIR function attributes + assert hasattr(module, "add_tir") + assert hasattr(module, "multiply_tir") + + # These should be callable (though they're TIR functions) + assert callable(module.add_tir) + assert callable(module.multiply_tir) + + +def test_relax_functions(): + """Test that Relax functions are properly defined.""" + ir_mod = SimplePyFuncModule + device = tvm.cpu() + module = BasePyModule(ir_mod, device) + + # Note: This particular TVMScript is for testing purpose only, and cannot compile + # Relax functions may not be available due to TVMScript compilation issues + print("Note: This TVMScript is for testing purpose only, and cannot compile") + + # We can still check that the module was created successfully + assert isinstance(module, BasePyModule) + assert hasattr(module, "script") + assert hasattr(module, "show") + + +def test_module_docstrings(): + """Test that all modules have proper docstrings.""" + modules = [ + SimplePyFuncModule, + ComplexPyFuncModule, + EdgeCasePyFuncModule, + PerformancePyFuncModule, + IntegrationPyFuncModule, + ErrorHandlingPyFuncModule, + ] + + for module_class in modules: + # TVMScript decorator changes the class, so we check that it's callable + # and can create instances instead of checking docstrings + assert callable(module_class) + # We can't directly instantiate TVMScript decorated classes + # but we can create BasePyModule instances with them + device = tvm.cpu() + instance = BasePyModule(module_class, device) + assert isinstance(instance, BasePyModule) + + +def test_python_function_complexity(): + """Test that complex Python functions have the expected structure.""" + ir_mod = ComplexPyFuncModule + device = tvm.cpu() + module = BasePyModule(ir_mod, device) + + # Check that complex functions exist in pyfuncs + if hasattr(ir_mod, "pyfuncs"): + assert "ml_pipeline" in ir_mod.pyfuncs + assert "data_preprocessing" in ir_mod.pyfuncs + + # Get the actual function objects + ml_func = ir_mod.pyfuncs["ml_pipeline"] + preprocess_func = ir_mod.pyfuncs["data_preprocessing"] + + # These should be callable + assert callable(ml_func) + assert callable(preprocess_func) + + # Check function signatures + import inspect + + ml_sig = inspect.signature(ml_func) + assert len(ml_sig.parameters) == 3 # self, input_data, model_params + + preprocess_sig = inspect.signature(preprocess_func) + assert len(preprocess_sig.parameters) == 2 # self, raw_data + + +def test_script_and_show_methods(): + """Test that script() and show() methods work correctly.""" + ir_mod = SimplePyFuncModule + device = tvm.cpu() + module = BasePyModule(ir_mod, device) + + # Test script() method + script_output = module.script() + assert isinstance(script_output, str) + assert len(script_output) > 0 + + # Test show() method + try: + module.show() + # If we get here, show() worked + assert True + except Exception as e: + # If show() fails, the feature is not working properly + pytest.fail(f"show() method failed: {e}") + + +def test_python_functions_in_irmodule(): + """Test that Python functions are properly stored in IRModule pyfuncs.""" + ir_mod = SimplePyFuncModule + device = tvm.cpu() + module = BasePyModule(ir_mod, device) + + # Check that pyfuncs attribute exists and contains our functions + if hasattr(ir_mod, "pyfuncs"): + pyfuncs = ir_mod.pyfuncs + assert isinstance(pyfuncs, dict) + assert "add" in pyfuncs + assert "multiply" in pyfuncs + + # Check that the functions are callable + assert callable(pyfuncs["add"]) + assert callable(pyfuncs["multiply"]) + + # Check function names + assert pyfuncs["add"].__name__ == "add" + assert pyfuncs["multiply"].__name__ == "multiply" + else: + pytest.fail("pyfuncs attribute not found in IRModule") + + +def test_call_py_func_with_base_py_module(): + """Test R.call_py_func with BasePyModule.""" + import torch + import numpy as np + from tvm.relax.op import call_py_func + from tvm.relax.expr import StringImm + from tvm.relax import Var, TensorStructInfo + + # Test 1: Operator creation and basic properties + x = Var("x", TensorStructInfo((5,), "float32")) + y = Var("y", TensorStructInfo((5,), "float32")) + + call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) + + assert call_expr.op.name == "relax.call_py_func" + assert call_expr.args[0].value == "test_func" + assert len(call_expr.args) == 2 + + # Test 2: Compilation validation + try: + call_py_func( + "invalid", + (Var("x", TensorStructInfo((5,), "float32")),), + out_sinfo=R.Tensor((5,), "float32"), + ) + assert False, "Should raise type error" + except Exception as e: + assert "Mismatched type" in str(e) or "Expected" in str(e) + + # Test 3: Validation and error handling + @I.ir_module + class ValidationTestModule(BasePyModule): + @R.function + def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) + return result + + device = tvm.cpu() + module = ValidationTestModule(device) + + x = torch.randn(5, dtype=torch.float32) + + with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"): + module.call_py_func("non_existent_func", [x]) + + # Test 4: Using call_py_func within Relax functions + @I.ir_module + class RelaxCallPyFuncModule(BasePyModule): + @I.pyfunc + def torch_relu(self, x): + """PyTorch ReLU implementation.""" + return torch.relu(x) + + @I.pyfunc + def torch_softmax(self, x, dim=0): + """PyTorch softmax implementation.""" + return torch.softmax(x, dim=dim) + + @R.function + def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): + relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) + final_result = R.call_py_func( + "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") + ) + return final_result + + device = tvm.cpu() + module = RelaxCallPyFuncModule(device) + + x = torch.randn(10, dtype=torch.float32) + + expected = torch.softmax(torch.relu(x), dim=0) + + relu_result = module.call_py_func("torch_relu", [x]) + final_result = module.call_py_func("torch_softmax", [relu_result]) + + # Convert to numpy for comparison + if isinstance(final_result, tvm.runtime.Tensor): + final_result_np = final_result.numpy() + else: + final_result_np = final_result + + if isinstance(expected, torch.Tensor): + expected_np = expected.numpy() + else: + expected_np = expected + + # Use numpy for comparison since we have numpy arrays + tvm.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_base_py_module_symbolic_shape.py b/tests/python/relax/test_base_py_module_symbolic_shape.py new file mode 100644 index 000000000000..3179c8f51eed --- /dev/null +++ b/tests/python/relax/test_base_py_module_symbolic_shape.py @@ -0,0 +1,367 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +from tvm.ir import IRModule +from tvm.relax.base_py_module import BasePyModule +from tvm import tir, relax +from tvm.script import ir as I, tir as T, relax as R + + +def _make_module(): + return IRModule({}) + + +def test_infer_concrete_shape_from_numpy_input(): + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + x = np.zeros((3, 4), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x]) + assert inferred == [3, 4] + + +def test_infer_concrete_shape_all_concrete_dims(): + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + shape = [tir.IntImm("int32", 5), 6] + inferred = bpm._infer_concrete_shape_from_args(shape, in_args=[]) + assert inferred == [5, 6] + + +def test_infer_concrete_shape_error_when_uninferrable(): + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + k = tir.Var("k", "int64") + with pytest.raises(ValueError): + bpm._infer_concrete_shape_from_args([k, 8], in_args=[]) + + +@I.ir_module +class AddModuleSymbolic(BasePyModule): + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + T.func_attr({"global_symbol": "add_tir"}) + n = T.int64() + x = T.match_buffer(var_x, (n,), dtype="float32") + y = T.match_buffer(var_y, (n,), dtype="float32") + out = T.match_buffer(var_out, (n,), dtype="float32") + + for i in T.serial(n): + out[i] = x[i] + y[i] + + @R.function + def main_relax( + x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32") + ) -> R.Tensor(("n",), "float32"): + return R.add(x, y) + + +def test_base_py_module_relax_symbolic_end_to_end(): + bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm") + + a = np.random.randn(5).astype("float32") + b = np.random.randn(5).astype("float32") + out = bpm.main_relax(a, b) + assert isinstance(out, np.ndarray) or hasattr(out, "numpy") + out_np = out if isinstance(out, np.ndarray) else out.numpy() + tvm.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) + + a7 = np.random.randn(7).astype("float32") + b7 = np.random.randn(7).astype("float32") + out2 = bpm.main_relax(a7, b7) + out2_np = out2 if isinstance(out2, np.ndarray) else out2.numpy() + tvm.testing.assert_allclose(out2_np, a7 + b7, rtol=1e-6, atol=1e-6) + + +def test_base_py_module_tir_symbolic_end_to_end(): + bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm") + + a = np.random.randn(5).astype("float32") + b = np.random.randn(5).astype("float32") + + n = tir.Var("n", "int64") + out_sinfo = relax.TensorStructInfo((n,), "float32") + + out = bpm.call_tir("add_tir", [a, b], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + tvm.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) + + +def test_infer_concrete_shape_multiple_symbolic_dims(): + """Test shape inference with multiple symbolic dimensions.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + k = tir.Var("k", "int64") + sym_shape = [n, m, k] + + x = np.zeros((2, 3, 4), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x]) + assert inferred == [2, 3, 4] + + +def test_infer_concrete_shape_mixed_concrete_symbolic(): + """Test shape inference with mixed concrete and symbolic dimensions.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + sym_shape = [n, 5, 10] # First dim is symbolic, others are concrete + + x = np.zeros((3, 5, 10), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x]) + assert inferred == [3, 5, 10] + + +def test_infer_concrete_shape_from_tvm_tensors(): + """Test shape inference from TVM tensors.""" + try: + # Try to create TVM tensor using new API + x_np = np.zeros((3, 4), dtype="float32") + x_tvm = tvm.runtime.tensor(x_np) + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x_tvm]) + assert inferred == [3, 4] + except AttributeError: + # Skip if tvm.runtime.tensor is not available + pytest.skip("tvm.runtime.tensor not available") + + +def test_infer_concrete_shape_multiple_inputs(): + """Test shape inference when multiple inputs are available.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + # Multiple inputs with different shapes - should use first matching one + x1 = np.zeros((2, 3), dtype="float32") + x2 = np.zeros((4, 5), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x1, x2]) + assert inferred == [2, 3] # Should use first input + + +def test_infer_concrete_shape_wrong_ndim(): + """Test shape inference when input has wrong number of dimensions.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] # 2D + + x = np.zeros((3,), dtype="float32") # 1D - wrong ndim + with pytest.raises(ValueError, match="Cannot infer concrete output shape"): + bpm._infer_concrete_shape_from_args(sym_shape, [x]) + + +@I.ir_module +class MatrixModuleSymbolic(BasePyModule): + @T.prim_func + def matmul_tir(var_a: T.handle, var_b: T.handle, var_c: T.handle): + T.func_attr({"global_symbol": "matmul_tir"}) + m = T.int64() + n = T.int64() + k = T.int64() + a = T.match_buffer(var_a, (m, k), dtype="float32") + b = T.match_buffer(var_b, (k, n), dtype="float32") + c = T.match_buffer(var_c, (m, n), dtype="float32") + + for i in T.serial(m): + for j in T.serial(n): + c[i, j] = 0.0 + for l in T.serial(k): + c[i, j] = c[i, j] + a[i, l] * b[l, j] + + @R.function + def matmul_relax( + a: R.Tensor(("m", "k"), "float32"), b: R.Tensor(("k", "n"), "float32") + ) -> R.Tensor(("m", "n"), "float32"): + return R.matmul(a, b) + + +def test_base_py_module_multiple_symbolic_dims(): + """Test BasePyModule with multiple symbolic dimensions.""" + bpm = MatrixModuleSymbolic(device=tvm.cpu(0), target="llvm") + + # Test Relax function with multiple symbolic dims + a = np.random.randn(2, 3).astype("float32") + b = np.random.randn(3, 4).astype("float32") + out = bpm.matmul_relax(a, b) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = np.matmul(a, b) + tvm.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + # Test TIR function with multiple symbolic dims + # Use concrete shapes for TIR function to avoid constraint issues + out_sinfo = relax.TensorStructInfo((2, 4), "float32") + out_tir = bpm.call_tir("matmul_tir", [a, b], out_sinfo) + out_tir_np = out_tir if isinstance(out_tir, np.ndarray) else out_tir.numpy() + tvm.testing.assert_allclose(out_tir_np, expected, rtol=1e-6, atol=1e-6) + + +def test_base_py_module_call_dps_packed_symbolic(): + """Test call_dps_packed with symbolic shapes.""" + try: + # Register a simple test function + @tvm.register_global_func("test_add_packed") + def test_add_packed(a, b, out): + """Add two tensors element-wise.""" + a_np = a.numpy() + b_np = b.numpy() + result = a_np + b_np + out[:] = result + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + a = np.random.randn(5).astype("float32") + b = np.random.randn(5).astype("float32") + + n = tir.Var("n", "int64") + out_sinfo = relax.TensorStructInfo((n,), "float32") + + out = bpm.call_dps_packed("test_add_packed", [a, b], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + tvm.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) + + except AttributeError as e: + pytest.skip(f"call_dps_packed test requires register_global_func: {e}") + + +def test_base_py_module_call_dps_packed_multiple_args(): + """Test call_dps_packed with multiple arguments and symbolic shapes.""" + try: + # Register a function that takes multiple arguments + @tvm.register_global_func("test_matmul_packed") + def test_matmul_packed(a, b, out): + """Matrix multiplication.""" + a_np = a.numpy() + b_np = b.numpy() + result = np.matmul(a_np, b_np) + out[:] = result + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + a = np.random.randn(2, 3).astype("float32") + b = np.random.randn(3, 4).astype("float32") + + out_sinfo = relax.TensorStructInfo((2, 4), "float32") + + out = bpm.call_dps_packed("test_matmul_packed", [a, b], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = np.matmul(a, b) + tvm.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + except AttributeError as e: + pytest.skip(f"call_dps_packed test requires register_global_func: {e}") + + +def test_base_py_module_call_dps_packed_scalar_args(): + """Test call_dps_packed with scalar arguments and symbolic shapes.""" + try: + # Register a function that takes scalar arguments + @tvm.register_global_func("test_add_scalar_packed") + def test_add_scalar_packed(x, scalar, out): + """Add scalar to tensor.""" + x_np = x.numpy() + if hasattr(scalar, "numpy"): + scalar_val = scalar.numpy() + else: + scalar_val = scalar + result = x_np + scalar_val + out[:] = result + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + x = np.random.randn(4).astype("float32") + scalar = 2.5 + + n = tir.Var("n", "int64") + out_sinfo = relax.TensorStructInfo((n,), "float32") + + out = bpm.call_dps_packed("test_add_scalar_packed", [x, scalar], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = x + scalar + tvm.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + except AttributeError as e: + pytest.skip(f"call_dps_packed test requires register_global_func: {e}") + + +def test_infer_concrete_shape_from_pytorch_tensors(): + """Test shape inference from PyTorch tensors (if available).""" + try: + import torch + except ImportError: + pytest.skip("PyTorch not available") + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + x_torch = torch.zeros((3, 4), dtype=torch.float32) + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x_torch]) + assert inferred == [3, 4] + + +def test_base_py_module_relax_with_pytorch_tensors(): + """Test Relax functions with PyTorch tensors and symbolic shapes.""" + try: + import torch + except ImportError: + pytest.skip("PyTorch not available") + + bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm") + + a_torch = torch.randn(5, dtype=torch.float32) + b_torch = torch.randn(5, dtype=torch.float32) + + out = bpm.main_relax(a_torch, b_torch) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = a_torch.numpy() + b_torch.numpy() + tvm.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py index be60524e8475..56372a63e576 100644 --- a/tests/python/relax/test_blockbuilder_core.py +++ b/tests/python/relax/test_blockbuilder_core.py @@ -31,7 +31,7 @@ @pytest.fixture(scope="module") def register_nop(): - @tvm.register_func("test.blockbuilder.nop") + @tvm.register_global_func("test.blockbuilder.nop") def nop(): pass diff --git a/tests/python/relax/test_codegen_coreml.py b/tests/python/relax/test_codegen_coreml.py index 7b9c22b8b9d8..b07271e8949a 100644 --- a/tests/python/relax/test_codegen_coreml.py +++ b/tests/python/relax/test_codegen_coreml.py @@ -75,8 +75,8 @@ def test_add(): gv = bb.emit_output(lv0) bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) - y_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) + y_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data, y_data]) @@ -90,7 +90,7 @@ def test_add_const(): gv = bb.emit_output(lv0) bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data]) @@ -105,14 +105,14 @@ def test_multiply(): bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) - y_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) + y_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data, y_data]) def test_matmul(): x = relax.Var("x", relax.TensorStructInfo([8, 10], "float32")) - y = relax.Constant(tvm.nd.array(np.random.rand(10, 8).astype("float32"), dev)) + y = relax.Constant(tvm.runtime.tensor(np.random.rand(10, 8).astype("float32"), dev)) bb = relax.BlockBuilder() with bb.function("main", [x]): with bb.dataflow(): @@ -121,7 +121,7 @@ def test_matmul(): bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(8, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(8, 10).astype("float32"), dev) verify(mod, [x_data]) x = relax.Var("x", relax.TensorStructInfo([8, 10], "float32")) @@ -134,8 +134,8 @@ def test_matmul(): bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(8, 10).astype("float32"), dev) - y_data = tvm.nd.array(np.random.rand(10, 8).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(8, 10).astype("float32"), dev) + y_data = tvm.runtime.tensor(np.random.rand(10, 8).astype("float32"), dev) verify(mod, [x_data, y_data]) @@ -150,7 +150,7 @@ def test_clip(): bb.emit_func_output(gv0) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data]) x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32")) @@ -164,7 +164,7 @@ def test_clip(): gv1 = bb.emit_output(lv1) bb.emit_func_output([gv0, gv1]) - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data]) @@ -179,7 +179,7 @@ def get_mod(axis): bb.emit_func_output(gv) return bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(get_mod(axis=0), [x_data]) verify(get_mod(axis=1), [x_data]) @@ -194,7 +194,7 @@ def test_relu(): bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data]) @@ -209,7 +209,7 @@ def test_batch_flatten(): bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10, 10).astype("float32"), dev) verify(mod, [x_data]) @@ -224,7 +224,7 @@ def test_softmax(): bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data]) @@ -238,7 +238,7 @@ def test_conv2d(): gv = bb.emit_output(lv0) bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(1, 3, 224, 224).astype("float32"), dev) verify(mod, [x_data]) @@ -251,7 +251,7 @@ def test_global_avg_pool2d(): gv = bb.emit_output(lv0) bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(1, 1, 10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(1, 1, 10, 10).astype("float32"), dev) verify(mod, [x_data]) @@ -266,8 +266,8 @@ def test_subgraph1(): gv = bb.emit_output(lv1) bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) - y_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) + y_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data, y_data]) @@ -287,8 +287,8 @@ def test_subgraph2(): gv = bb.emit_output(lv3) bb.emit_func_output(gv) mod = bb.get() - x_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) - y_data = tvm.nd.array(np.random.rand(10, 10).astype("float32"), dev) + x_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) + y_data = tvm.runtime.tensor(np.random.rand(10, 10).astype("float32"), dev) verify(mod, [x_data, y_data]) diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 152f04fc3ce7..32666ebd1d8c 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -52,7 +52,7 @@ def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): ex = tvm.compile(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] # For cuda graph, run the compiled function twice to make sure that we can launch the cached # graph on the second run. diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py index 990f21138619..b92e2fee40ed 100644 --- a/tests/python/relax/test_codegen_cudnn.py +++ b/tests/python/relax/test_codegen_cudnn.py @@ -113,7 +113,7 @@ def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): ex = tvm.compile(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] # For cuda graph, run the compiled function twice to make sure that we can launch the cached # graph on the second run. @@ -193,9 +193,13 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, with_bias, activation): out = get_result_with_relax_cudnn_offload(mod, args) ref = build_and_run(mod, args, "llvm", legalize=True) if dtype == "float16": - tvm.testing.assert_allclose(out, ref, rtol=1e-1, atol=1e-1) + # FIXME(lei): currently raise into 3e-1 to prevent flaky test + # see https://github.com/apache/tvm/pull/18319 + tvm.testing.assert_allclose(out, ref, rtol=3e-1, atol=3e-1) else: - tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + # Increased tolerance to 2.5e-2 to prevent flaky test due to numerical + # differences between cuDNN and LLVM implementations + tvm.testing.assert_allclose(out, ref, rtol=2.5e-2, atol=2.5e-2) @pytest.mark.skip(reason="flaky test") diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 6528e1c93c0c..c645dce96bd4 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -94,7 +94,7 @@ def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False): dev = tvm.device(target, 0) vm = relax.VirtualMachine(ex, dev) f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] # For cuda graph, run the compiled function twice to make sure that we can launch the cached # graph on the second run. @@ -1481,15 +1481,15 @@ def main_residual( vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) packed_weight, scales, bias_trans = vm[transform_func_name]( - (tvm.nd.array(y), tvm.nd.array(bias)) + (tvm.runtime.tensor(y), tvm.runtime.tensor(bias)) ) dev = tvm.device("cuda", 0) ex = tvm.compile(mod_deploy, target="cuda") vm = relax.vm.VirtualMachine(ex, dev) - x_nd = tvm.nd.array(x, dev) - residual_nd = tvm.nd.array(residual, dev) + x_nd = tvm.runtime.tensor(x, dev) + residual_nd = tvm.runtime.tensor(residual, dev) params = [packed_weight.copyto(dev), scales.copyto(dev), bias_trans.copyto(dev)] for f_name in ["main_bias", "main_cast_bias", "main_residual"]: @@ -1634,14 +1634,14 @@ def main( vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) packed_weight, scales, bias_trans = vm[transform_func_name]( - (tvm.nd.array(y), tvm.nd.array(bias)) + (tvm.runtime.tensor(y), tvm.runtime.tensor(bias)) ) dev = tvm.device("cuda", 0) ex = tvm.compile(mod_deploy, target="cuda") vm = relax.vm.VirtualMachine(ex, dev) - x_nd = tvm.nd.array(x, dev) + x_nd = tvm.runtime.tensor(x, dev) inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev), bias_trans.copyto(dev)] out = vm["main"](*inp).numpy() @@ -1909,13 +1909,13 @@ def main( ex = tvm.compile(mod_transform, target="llvm") vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) - packed_weight, scales = vm[transform_func_name]((tvm.nd.array(y),)) + packed_weight, scales = vm[transform_func_name]((tvm.runtime.tensor(y),)) dev = tvm.device("cuda", 0) ex = tvm.compile(mod_deploy, target="cuda") vm = relax.vm.VirtualMachine(ex, dev) - x_nd = tvm.nd.array(x, dev) + x_nd = tvm.runtime.tensor(x, dev) inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev)] out = vm["main"](*inp).numpy() ref = np.dot(x, y.transpose()) @@ -2064,13 +2064,13 @@ def main( ex = tvm.compile(mod_transform, target="llvm") vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) - packed_weight, scales = vm[transform_func_name]((tvm.nd.array(y),)) + packed_weight, scales = vm[transform_func_name]((tvm.runtime.tensor(y),)) dev = tvm.device("cuda", 0) ex = tvm.compile(mod_deploy, target="cuda") vm = relax.vm.VirtualMachine(ex, dev) - x_nd = tvm.nd.array(x, dev) + x_nd = tvm.runtime.tensor(x, dev) inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev)] out = vm["main"](*inp).numpy() ref = np.dot(x, y.transpose()) diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py index 370c5f03a486..f386f8f2f8d0 100644 --- a/tests/python/relax/test_codegen_dnnl.py +++ b/tests/python/relax/test_codegen_dnnl.py @@ -54,7 +54,7 @@ def main( def build_and_run(mod, inputs, legalize=False): target = tvm.target.Target("llvm") dev = tvm.cpu() - inputs = [tvm.nd.array(inp, dev) for inp in inputs] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs] with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): ex = tvm.compile(mod, target) diff --git a/tests/python/relax/test_codegen_hipblas.py b/tests/python/relax/test_codegen_hipblas.py index 004e70e4e60e..286acc44f1f1 100644 --- a/tests/python/relax/test_codegen_hipblas.py +++ b/tests/python/relax/test_codegen_hipblas.py @@ -45,7 +45,7 @@ def build_and_run(mod, inputs_np, target, legalize=False): ex = tvm.compile(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] return f(*inputs).numpy() diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py index 746f4eba6028..84467a67a9c4 100644 --- a/tests/python/relax/test_codegen_tensorrt.py +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -67,7 +67,7 @@ def build_and_run(mod, inputs_np, target, legalize=False): ex = tvm.compile(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] return f(*inputs).numpy() diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py index 0a8d338a455e..fade620dfea4 100644 --- a/tests/python/relax/test_contrib_vllm.py +++ b/tests/python/relax/test_contrib_vllm.py @@ -48,7 +48,7 @@ def build_and_run(mod, inputs_np, target, legalize=True): dev = tvm.device(target, 0) vm = relax.VirtualMachine(ex, dev) f = vm["main"] - inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] out = f(*inputs) @@ -752,17 +752,21 @@ def test_reconstruct_from_cache(): dev = tvm.device("cuda", 0) - key = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev) - value = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev) - slot_mapping = tvm.nd.array(np.arange(num_tokens).astype("int32"), dev) + key = tvm.runtime.tensor( + np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev + ) + value = tvm.runtime.tensor( + np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev + ) + slot_mapping = tvm.runtime.tensor(np.arange(num_tokens).astype("int32"), dev) - k_cache = tvm.nd.array( + k_cache = tvm.runtime.tensor( np.random.randn(num_blocks, num_heads, head_dim // vec_size, block_size, vec_size).astype( "float16" ), dev, ) - v_cache = tvm.nd.array( + v_cache = tvm.runtime.tensor( np.random.randn(num_blocks, num_heads, head_dim, block_size).astype("float16"), dev ) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index f6413c1d8206..00805152b499 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -526,8 +526,8 @@ def main( new_mod = transform_pass(EndToEndTest) tvm.ir.assert_structural_equal(new_mod, Expected) - x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) - y = tvm.nd.array(np.random.rand(1, 3).astype("float32")) + x = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) + y = tvm.runtime.tensor(np.random.rand(1, 3).astype("float32")) expected = np.zeros((2, 3), dtype="float32") target = tvm.target.Target("llvm") @@ -609,8 +609,8 @@ def main( return s tvm.ir.assert_structural_equal(new_mod, Expected, map_free_vars=True) - x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) - y = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + x = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) + y = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) expected = np.zeros((2, 3), dtype="float32") target = tvm.target.Target("llvm") diff --git a/tests/python/relax/test_dlpack_integration.py b/tests/python/relax/test_dlpack_integration.py new file mode 100644 index 000000000000..b212f710b200 --- /dev/null +++ b/tests/python/relax/test_dlpack_integration.py @@ -0,0 +1,296 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Test DLPack integration between PyTorch and TVM. + +This test verifies: +1. DLPack conversion from PyTorch to TVM +2. DLPack conversion from TVM to PyTorch +3. Data integrity preservation during conversion +4. Functionality equivalence between DLPack and numpy fallback +5. Error handling for unsupported data types +""" + +import pytest +import torch +import tvm +from tvm import relax, tir +from tvm.script import relax as R, tir as T +from tvm.relax import BasePyModule +import numpy as np + + +class TestDLPackIntegration: + def test_dlpack_pytorch_to_tvm_conversion(self): + pytorch_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor) + + assert isinstance(tvm_tensor, tvm.runtime.Tensor) + assert tvm_tensor.shape == pytorch_tensor.shape + assert str(tvm_tensor.dtype) == str(pytorch_tensor.dtype).replace("torch.", "") + + tvm_numpy = tvm_tensor.numpy() + pytorch_numpy = pytorch_tensor.numpy() + tvm.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5) + + def test_dlpack_pytorch_to_tvm_conversion_gpu(self): + if tvm.cuda().exist: + pytorch_tensor = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32, device="cuda" + ) + + tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor) + + assert isinstance(tvm_tensor, tvm.runtime.Tensor) + assert tvm_tensor.shape == pytorch_tensor.shape + assert str(tvm_tensor.dtype) == str(pytorch_tensor.dtype).replace("torch.", "") + assert str(tvm_tensor.device) == "cuda:0" + + # Move to CPU for numpy conversion + tvm_numpy = tvm_tensor.numpy() + pytorch_numpy = pytorch_tensor.cpu().numpy() + tvm.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5) + else: + pytest.skip("CUDA not available") + + def test_dlpack_tvm_to_pytorch_conversion(self): + import numpy as np + + data = np.array([1.0, 2.0, 3.0, 5.0], dtype="float32") + tvm_tensor = tvm.runtime.tensor(data) + + pytorch_tensor = torch.from_dlpack(tvm_tensor) + + assert isinstance(pytorch_tensor, torch.Tensor) + assert pytorch_tensor.shape == tvm_tensor.shape + assert pytorch_tensor.dtype == torch.float32 + + tvm_numpy = tvm_tensor.numpy() + pytorch_numpy = pytorch_tensor.numpy() + tvm.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5) + + def test_dlpack_tvm_to_pytorch_conversion_gpu(self): + if tvm.cuda().exist: + import numpy as np + + data = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype="float32") + tvm_tensor = tvm.runtime.tensor(data, device=tvm.cuda(0)) + + pytorch_tensor = torch.from_dlpack(tvm_tensor) + + assert isinstance(pytorch_tensor, torch.Tensor) + assert pytorch_tensor.shape == tvm_tensor.shape + assert pytorch_tensor.dtype == torch.float32 + assert pytorch_tensor.device.type == "cuda" + + tvm_numpy = tvm_tensor.numpy() + pytorch_numpy = pytorch_tensor.cpu().numpy() + tvm.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5) + else: + pytest.skip("CUDA not available") + + def test_dlpack_roundtrip_conversion(self): + """Test roundtrip conversion: PyTorch -> TVM -> PyTorch.""" + # Create PyTorch tensor + original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + # Convert to TVM + tvm_tensor = tvm.runtime.from_dlpack(original_tensor) + + # Convert back to PyTorch + result_tensor = torch.from_dlpack(tvm_tensor) + + # Verify roundtrip integrity + assert torch.allclose(original_tensor, result_tensor, atol=1e-5) + assert original_tensor.dtype == result_tensor.dtype + assert original_tensor.shape == result_tensor.shape + + def test_dlpack_different_data_types(self): + """Test DLPack conversion with different data types.""" + test_types = [ + (torch.float32, "float32"), + (torch.float64, "float64"), + (torch.int32, "int32"), + (torch.int64, "int64"), + ] + + for torch_dtype, tvm_dtype in test_types: + # Create PyTorch tensor + pytorch_tensor = torch.tensor([1, 2, 3], dtype=torch_dtype) + + # Convert to TVM + tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor) + + # Convert back to PyTorch + result_tensor = torch.from_dlpack(tvm_tensor) + + # Verify conversion + assert torch.allclose(pytorch_tensor, result_tensor, atol=1e-5) + assert pytorch_tensor.dtype == result_tensor.dtype + + def test_dlpack_different_shapes(self): + """Test DLPack conversion with different tensor shapes.""" + test_shapes = [ + (1,), + (2, 3), + (4, 5, 6), + (1, 1, 1, 1), + ] + + for shape in test_shapes: + # Create PyTorch tensor + pytorch_tensor = torch.randn(shape, dtype=torch.float32) + + # Convert to TVM + tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor) + + # Convert back to PyTorch + result_tensor = torch.from_dlpack(tvm_tensor) + + # Verify conversion + assert torch.allclose(pytorch_tensor, result_tensor, atol=1e-5) + assert pytorch_tensor.shape == result_tensor.shape + + def test_dlpack_functionality_verification(self): + """Test that DLPack and numpy conversions produce identical results.""" + # Create large PyTorch tensor + size = 1000000 + pytorch_tensor = torch.randn(size, dtype=torch.float32) + + # Test DLPack conversion + tvm_tensor_dlpack = tvm.runtime.from_dlpack(pytorch_tensor) + + # Test numpy conversion + numpy_array = pytorch_tensor.detach().cpu().numpy() + tvm_tensor_numpy = tvm.runtime.tensor(numpy_array) + + # Verify both methods produce same result + result_dlpack = torch.from_dlpack(tvm_tensor_dlpack) + result_numpy = torch.from_numpy(tvm_tensor_numpy.numpy()) + assert torch.allclose(result_dlpack, result_numpy, atol=1e-5) + + # Verify data integrity + assert torch.allclose(result_dlpack, pytorch_tensor, atol=1e-5) + assert result_dlpack.shape == pytorch_tensor.shape + assert result_dlpack.dtype == pytorch_tensor.dtype + + def test_dlpack_error_handling(self): + """Test DLPack error handling for unsupported operations.""" + # Test with non-contiguous tensor + pytorch_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + non_contiguous = pytorch_tensor[::2] # Create non-contiguous view + + # This should work (PyTorch handles non-contiguous tensors) + try: + tvm_tensor = tvm.runtime.from_dlpack(non_contiguous) + result_tensor = torch.from_dlpack(tvm_tensor) + assert torch.allclose(non_contiguous, result_tensor, atol=1e-5) + except Exception as e: + # If it fails, that's also acceptable + pass + + def test_dlpack_with_base_py_module(self): + """Test DLPack conversion within BasePyModule context.""" + # Create a simple IRModule + @T.prim_func + def identity_func(A: T.Buffer((3,), "float32"), B: T.Buffer((3,), "float32")): + for i in T.grid(3): + B[i] = A[i] + + ir_mod = tvm.IRModule({"identity_func": identity_func}) + device = tvm.cpu(0) + py_mod = BasePyModule(ir_mod, device) + + # Create PyTorch tensor + input_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + + # Call TIR function (this will trigger DLPack conversion) + result = py_mod.call_tir(identity_func, [input_tensor], R.Tensor((3,), "float32")) + + # Verify result + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, input_tensor, atol=1e-5) + + def test_dlpack_device_consistency(self): + """Test DLPack conversion maintains device consistency.""" + # Test CPU tensor + cpu_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + cpu_tvm = tvm.runtime.from_dlpack(cpu_tensor) + cpu_result = torch.from_dlpack(cpu_tvm) + + assert cpu_result.device.type == "cpu" + assert torch.allclose(cpu_tensor, cpu_result, atol=1e-5) + + # Note: GPU testing would require CUDA/OpenCL setup + # This is a basic test that CPU works correctly + + def test_dlpack_memory_sharing(self): + """Test that DLPack conversion shares memory when possible.""" + # Create PyTorch tensor + pytorch_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + # Convert to TVM + tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor) + + # Modify the original tensor + pytorch_tensor[0] = 10.0 + + # Convert back to PyTorch + result_tensor = torch.from_dlpack(tvm_tensor) + + # The result should reflect the modification (memory sharing) + assert result_tensor[0] == 10.0 + assert torch.allclose(pytorch_tensor, result_tensor, atol=1e-5) + + def test_dlpack_batch_operations(self): + """Test DLPack conversion with batch operations.""" + # Create batch of tensors + batch_size = 10 + pytorch_tensors = [torch.randn(5, dtype=torch.float32) for _ in range(batch_size)] + + # Convert all to TVM + tvm_tensors = [tvm.runtime.from_dlpack(t) for t in pytorch_tensors] + + # Convert all back to PyTorch + result_tensors = [torch.from_dlpack(t) for t in tvm_tensors] + + # Verify all conversions + for i in range(batch_size): + assert torch.allclose(pytorch_tensors[i], result_tensors[i], atol=1e-5) + + def test_dlpack_edge_cases(self): + """Test DLPack conversion with edge cases.""" + # Empty tensor + empty_tensor = torch.tensor([], dtype=torch.float32) + empty_tvm = tvm.runtime.from_dlpack(empty_tensor) + empty_result = torch.from_dlpack(empty_tvm) + + assert empty_result.shape == empty_tensor.shape + assert empty_result.dtype == empty_tensor.dtype + + # Single element tensor + single_tensor = torch.tensor([42.0], dtype=torch.float32) + single_tvm = tvm.runtime.from_dlpack(single_tensor) + single_result = torch.from_dlpack(single_tvm) + + assert single_result.shape == single_tensor.shape + assert single_result[0] == 42.0 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_e2e_op_dynamic.py b/tests/python/relax/test_e2e_op_dynamic.py index 9179802360b3..ea1f3a778e47 100644 --- a/tests/python/relax/test_e2e_op_dynamic.py +++ b/tests/python/relax/test_e2e_op_dynamic.py @@ -52,10 +52,10 @@ def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: R.Tensor((4,),"int64"), vm = build(DynamicStridedSlice) x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) - data_nd = tvm.nd.array(x_np, dev) - begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev) - end_nd = tvm.nd.array(np.array(end).astype("int64"), dev) - strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev) + data_nd = tvm.runtime.tensor(x_np, dev) + begin_nd = tvm.runtime.tensor(np.array(begin).astype("int64"), dev) + end_nd = tvm.runtime.tensor(np.array(end).astype("int64"), dev) + strides_nd = tvm.runtime.tensor(np.array(strides).astype("int64"), dev) # Reference implementation out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) @@ -85,10 +85,10 @@ def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64 vm = build(DynamicStridedSlice) x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) - data_nd = tvm.nd.array(x_np, dev) - begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev) - end_nd = tvm.nd.array(np.array(end).astype("int64"), dev) - strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev) + data_nd = tvm.runtime.tensor(x_np, dev) + begin_nd = tvm.runtime.tensor(np.array(begin).astype("int64"), dev) + end_nd = tvm.runtime.tensor(np.array(end).astype("int64"), dev) + strides_nd = tvm.runtime.tensor(np.array(strides).astype("int64"), dev) # Reference implementation out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) diff --git a/tests/python/relax/test_frontend_common.py b/tests/python/relax/test_frontend_common.py index 39f9af103134..85424df2f602 100644 --- a/tests/python/relax/test_frontend_common.py +++ b/tests/python/relax/test_frontend_common.py @@ -16,7 +16,11 @@ # under the License. import tvm import tvm.testing +from tvm import relax from tvm.relax.frontend import detach_params +from tvm.relax.frontend.common import autopad +from tvm.script import ir as I +from tvm.script import tir as T from tvm.script.parser import relax as R @@ -25,7 +29,7 @@ def test_detach_params(): def func(x: R.Tensor((2, 3), "float32")): return x - param = tvm.nd.empty((3,), "float32") + param = tvm.runtime.empty((3,), "float32") mod = tvm.IRModule({"func": func.with_attr("params", [param])}) detached_mod, detached_params = detach_params(mod) @@ -37,5 +41,175 @@ def func(x: R.Tensor((2, 3), "float32")): tvm.testing.assert_allclose(detached_params["func"][0].numpy(), param.numpy()) +class TestAutopad: + def _test_autopad(self, pad_type, expected): + bb = relax.BlockBuilder() + input_shape = (1, 1, 4, 4) + x = relax.Var("x", relax.TensorStructInfo(input_shape, "float32")) + + with bb.function("main", [x]): + with bb.dataflow(): + result = autopad( + bb, + x, + strides=[2, 2], + kernel_shape=[3, 3], + dilations=(1, 1), + pad_type=pad_type, + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, + ) + out = bb.emit_output(result) + bb.emit_func_output(out) + + tvm.ir.assert_structural_equal(bb.get(), expected) + + def test_constant(self): + @I.ir_module + class expected: + @T.prim_func(private=True) + def pad( + x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)), "float32"), + PadInput: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, v_i2, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + T.int64(0) <= v_i2 + and v_i2 < T.int64(4) + and T.int64(0) <= v_i3 + and v_i3 < T.int64(4), + x[v_i0, v_i1, v_i2, v_i3], + T.float32(0.0), + ) + + @R.function + def main( + x: R.Tensor((1, 1, 4, 4), dtype="float32") + ) -> R.Tensor((1, 1, 5, 5), dtype="float32"): + cls = expected + with R.dataflow(): + lv = R.call_tir( + cls.pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + ) + gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + self._test_autopad("constant", expected) + + def test_edge(self): + @I.ir_module + class expected: + @T.prim_func(private=True) + def replicate_pad( + x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)), "float32"), + ReplicatePadInput: T.Buffer( + (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32" + ), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): + with T.block("ReplicatePadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads( + x[ + T.int64(0), + T.int64(0), + T.int64(0) : T.int64(4), + T.int64(0) : T.int64(4), + ] + ) + T.writes(ReplicatePadInput[v_i0, v_i1, v_i2, v_i3]) + ReplicatePadInput[v_i0, v_i1, v_i2, v_i3] = x[ + T.if_then_else( + v_i0 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(1) <= v_i0, T.int64(0), v_i0), + ), + T.if_then_else( + v_i1 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(1) <= v_i1, T.int64(0), v_i1), + ), + T.if_then_else( + v_i2 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(4) <= v_i2, T.int64(3), v_i2), + ), + T.if_then_else( + v_i3 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(4) <= v_i3, T.int64(3), v_i3), + ), + ] + + @R.function + def main( + x: R.Tensor((1, 1, 4, 4), dtype="float32") + ) -> R.Tensor((1, 1, 5, 5), dtype="float32"): + cls = expected + with R.dataflow(): + lv = R.call_tir( + cls.replicate_pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + ) + gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + self._test_autopad("edge", expected) + + def test_reflect(self): + @I.ir_module + class expected: + @T.prim_func(private=True) + def mirror_pad( + x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)), "float32"), + MirrorPadInput: T.Buffer( + (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32" + ), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): + with T.block("MirrorPadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, T.int64(0) : T.int64(4), T.int64(0) : T.int64(4)]) + T.writes(MirrorPadInput[v_i0, v_i1, v_i2, v_i3]) + MirrorPadInput[v_i0, v_i1, v_i2, v_i3] = x[ + v_i0, + v_i1, + T.if_then_else( + T.int64(4) <= v_i2, + T.int64(6) - v_i2, + T.if_then_else(v_i2 < T.int64(0), v_i2 * T.int64(-1), v_i2), + ), + T.if_then_else( + T.int64(4) <= v_i3, + T.int64(6) - v_i3, + T.if_then_else(v_i3 < T.int64(0), v_i3 * T.int64(-1), v_i3), + ), + ] + + @R.function + def main( + x: R.Tensor((1, 1, 4, 4), dtype="float32") + ) -> R.Tensor((1, 1, 5, 5), dtype="float32"): + cls = expected + with R.dataflow(): + lv = R.call_tir( + cls.mirror_pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + ) + gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + self._test_autopad("reflect", expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index fb1544be68a8..70619714dd10 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -275,7 +275,7 @@ def verify_dynamo_model(torch_model, input_info, binding, expected): args.append(torch.zeros(*info[0], dtype=_convert_data_type(info[1]))) graph_model = dynamo.export(torch_model)(*args)[0] mod = from_fx(graph_model, input_info, unwrap_unit_return_tuple=True) - binding = {k: tvm.nd.array(v) for k, v in binding.items()} + binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) tvm.ir.assert_structural_equal(mod, expected) @@ -285,14 +285,34 @@ def _convert_data_type(input_type): import torch # type: ignore input_type = input_type.lower() if isinstance(input_type, str) else input_type - if input_type == "float32": - return torch.float32 - elif input_type == "float16": + # Float types + if input_type == "float16": return torch.float16 - elif input_type == "int64": - return torch.int64 + elif input_type == "float32": + return torch.float32 + elif input_type == "float64": + return torch.float64 + elif input_type == "bfloat16": + return torch.bfloat16 + # Signed integer types + elif input_type == "int8": + return torch.int8 + elif input_type == "int16": + return torch.int16 elif input_type == "int32": return torch.int32 + elif input_type == "int64": + return torch.int64 + # Unsigned integer types + elif input_type == "uint8": + return torch.uint8 + elif input_type == "uint16": + return torch.uint16 + elif input_type == "uint32": + return torch.uint32 + elif input_type == "uint64": + return torch.uint64 + # Boolean elif input_type == "bool": return torch.bool else: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 406a5d9a1c70..01e16e7564ac 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -17,6 +17,7 @@ import operator import pytest import torch +import numpy as np from torch import nn from torch.nn import Module from torch.export import export @@ -30,13 +31,63 @@ from tvm.relax.frontend.torch import from_exported_program -def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=None): +def verify_model( + torch_model, + example_args, + binding, + expected, + dynamic_shapes=None, + run_ep_decomposition=True, + keep_params_as_input=False, + unwrap_unit_return_tuple=False, + no_bind_return_tuple=False, + map_free_vars=False, + custom_convert_map=None, +): exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program) + mod = from_exported_program( + exported_program, + run_ep_decomposition=run_ep_decomposition, + keep_params_as_input=keep_params_as_input, + unwrap_unit_return_tuple=unwrap_unit_return_tuple, + no_bind_return_tuple=no_bind_return_tuple, + custom_convert_map=custom_convert_map, + ) - binding = {k: tvm.nd.array(v) for k, v in binding.items()} + binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) - tvm.ir.assert_structural_equal(mod, expected) + tvm.ir.assert_structural_equal(mod, expected, map_free_vars=map_free_vars) + + +def verify_model_numerically(torch_model, example_args, rtol=1e-7, atol=1e-7): + """Verify model by comparing numerical outputs between PyTorch and TVM.""" + with torch.no_grad(): + pytorch_output = torch_model(*example_args) + + exported_program = export(torch_model, args=example_args) + mod = from_exported_program(exported_program) + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args] + tvm_output = vm["main"](*tvm_args) + + if hasattr(tvm_output, "numpy"): + tvm_output_np = tvm_output.numpy() + else: + tvm_output_np = tvm_output[0].numpy() + + pytorch_output_np = ( + pytorch_output.numpy() + if isinstance(pytorch_output, torch.Tensor) + else pytorch_output[0].numpy() + ) + + assert ( + pytorch_output_np.shape == tvm_output_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output_np.shape} vs TVM {tvm_output_np.shape}" + tvm.testing.assert_allclose(pytorch_output_np, tvm_output_np, rtol=rtol, atol=atol) operator_basic_unary = [ @@ -58,18 +109,13 @@ def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=No (torch.log, R.log), (torch.neg, R.negative), (torch.relu, R.nn.relu), - (torch.relu_, R.nn.relu), (torch.round, R.round), (torch.rsqrt, R.rsqrt), - (torch.selu, R.nn.selu), (torch.sigmoid, R.sigmoid), - (torch.ops.aten.silu, R.nn.silu), - (torch.ops.aten.silu_, R.nn.silu), (torch.sin, R.sin), (torch.sinh, R.sinh), (torch.sign, R.sign), (torch.sqrt, R.sqrt), - (torch.square, R.square), (torch.tan, R.tan), (torch.tanh, R.tanh), (torch.trunc, R.trunc), @@ -100,7 +146,6 @@ def main( operator_bool_unary = [ - (torch.isfinite, R.isfinite), (torch.isinf, R.isinf), (torch.isnan, R.isnan), ] @@ -129,6 +174,47 @@ def main( verify_model(UnaryOp(), example_args, {}, expected) +def test_sqrt_integer_input(): + """Test that sqrt operation works with integer tensors by auto-converting to float.""" + example_args = (torch.tensor([[4, 9, 16, 25]], dtype=torch.int64),) + + class SqrtIntModel(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected_int64: + @R.function + def main( + input_1: R.Tensor((1, 4), dtype="int64") + ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, dtype="float32") + lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv) + gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(SqrtIntModel(), example_args, {}, expected_int64) + + example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),) + + @tvm.script.ir_module + class expected_int32: + @R.function + def main( + input_1: R.Tensor((1, 3), dtype="int32") + ) -> R.Tuple(R.Tensor((1, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3), dtype="float32") = R.astype(input_1, dtype="float32") + lv1: R.Tensor((1, 3), dtype="float32") = R.sqrt(lv) + gv: R.Tuple(R.Tensor((1, 3), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32) + + def test_extended_unary_ops(): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) @@ -154,21 +240,14 @@ def main( ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( lv, R.const(1.0, "float32") ) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - lv_div, R.const(1.0, "float32") - ) - lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum( - R.const(0.0, "float32"), lv_sub - ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.0, "float32"), lv_min + lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + input_1, R.const(0.0, "float32") ) - lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv2, input_1, lv1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -265,20 +344,44 @@ def forward(self, input): return torch.ops.aten.dropout_(input, 0.5, train=True) @tvm.script.ir_module - class expected_dropout: + class expected_dropout_for_1_2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input_1,) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_dropout_for_3: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros( + R.shape([1, 3, 10, 10]), dtype="float32" + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv, R.const(0.5, "float32") + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv1) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv2, lv2) R.output(gv) return gv - verify_model(Dropout1(), example_args, {}, expected_dropout) - verify_model(Dropout2(), example_args, {}, expected_dropout) - verify_model(Dropout3(), example_args, {}, expected_dropout) + verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2) + verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2) + verify_model(Dropout3(), example_args, {}, expected_dropout_for_3) # elu class Elu(Module): @@ -297,23 +400,27 @@ def forward(self, input): class expected_elu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - R.const(1.0, dtype="float32"), lv_exp + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + input, R.const(0.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + input, R.const(1.0, "float32") ) - lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu( - lv_one_minus_exp + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + input, R.const(1.0, "float32") ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + lv3, R.const(1.0, "float32") ) - lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + lv4, R.const(1.0, "float32") + ) + lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, lv1, lv5) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,) R.output(gv) return gv @@ -340,12 +447,19 @@ def main( inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( - lv1, R.const(6, "float32") + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + inp_0, R.const(3.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv, R.prim_value(0), R.prim_value(T.float64("inf")) + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv1, R.prim_value(T.float64("-inf")), R.prim_value(6) + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv2, R.const(6.0, "float32") ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -370,25 +484,85 @@ def forward(self, input): return torch.ops.aten.hardswish_(input) @tvm.script.ir_module - class expected1: + class expected_hardswish_for_1_2: @R.function def main( inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( - lv1, R.const(6, "float32") + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + inp_0, R.const(3.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv, R.prim_value(0), R.prim_value(T.float64("inf")) + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv1, R.prim_value(T.float64("-inf")), R.prim_value(6) ) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv3, R.const(6.0, "float32") + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_hardswish_for_3: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + input, R.const(3.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv, R.prim_value(0), R.prim_value(T.float64("inf")) + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv1, R.prim_value(T.float64("-inf")), R.prim_value(6) + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv3, R.const(6.0, "float32") + ) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv4, lv4) + R.output(gv) + return gv + + verify_model(Hardswish(), example_args, {}, expected_hardswish_for_1_2) + verify_model(Hardswish2(), example_args, {}, expected_hardswish_for_1_2) + verify_model(Hardswish3(), example_args, {}, expected_hardswish_for_3) + + # isfinite + class IsFinite(Module): + def forward(self, input): + return torch.isfinite(input) + + @tvm.script.ir_module + class expected_isfinite: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input) + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal( + lv, R.const(float("inf"), "float32") + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input, input) + lv3: R.Tensor((1, 3, 10, 10), dtype="bool") = R.multiply(lv2, lv1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv3,) R.output(gv) return gv - verify_model(Hardswish(), example_args, {}, expected1) - verify_model(Hardswish2(), example_args, {}, expected1) - verify_model(Hardswish3(), example_args, {}, expected1) + verify_model(IsFinite(), example_args, {}, expected_isfinite) # log2 class Log2(Module): @@ -557,9 +731,151 @@ def main( R.output(gv) return gv + @tvm.script.ir_module + class expected_relu6_3: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + x, R.prim_value(0), R.prim_value(6) + ) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv, lv) + R.output(gv) + return gv + verify_model(ReLU6_1(), example_args, {}, expected_relu6_1) verify_model(ReLU6_2(), example_args, {}, expected_relu6_2) - verify_model(ReLU6_3(), example_args, {}, expected_relu6_2) + verify_model(ReLU6_3(), example_args, {}, expected_relu6_3) + + # selu + class SELU(Module): + def forward(self, input): + return torch.nn.functional.selu(input) + + @tvm.script.ir_module + class expected_selu: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + input, R.const(0.0, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + input, R.const(1.0507010221481323, "float32") + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + input, R.const(1.0, "float32") + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + lv3, R.const(1.0, "float32") + ) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + lv4, R.const(1.7580993175506592, "float32") + ) + lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, lv1, lv5) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,) + R.output(gv) + return gv + + verify_model(SELU(), example_args, {}, expected_selu) + + # silu + class SiLU(Module): + def forward(self, input): + return torch.nn.functional.silu(input) + + @tvm.script.ir_module + class expected_silu: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(SiLU(), example_args, {}, expected_silu) + + # silu_ + class SiLU_(Module): + def forward(self, input): + return torch.ops.aten.silu_(input) + + @tvm.script.ir_module + class expected_silu_: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = ( + lv1, + lv1, + ) + R.output(gv) + return gv + + verify_model(SiLU_(), example_args, {}, expected_silu_) + + # square + class Square(Module): + def forward(self, input): + return torch.square(input) + + @tvm.script.ir_module + class expected_square: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power( + input, R.const(2.0, "float32") + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Square(), example_args, {}, expected_square) + + # relu_ + class ReLU_(Module): + def forward(self, input): + return torch.relu_(input.clone()) + + @tvm.script.ir_module + class expected_relu_: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(ReLU_(), example_args, {}, expected_relu_) def test_hardtanh(): @@ -580,7 +896,7 @@ def forward(self, input): return torch.ops.aten.hardtanh_(input) @tvm.script.ir_module - class expected1: + class expected_for_1_2: @R.function def main( inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -593,10 +909,29 @@ def main( R.output(gv) return gv + @tvm.script.ir_module + class expected_hardtanh_for_3: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0)) + ) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv, lv) + R.output(gv) + return gv + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Hardtanh(), example_args, {}, expected1) - verify_model(Hardtanh2(), example_args, {}, expected1) - verify_model(Hardtanh3(), example_args, {}, expected1) + verify_model(Hardtanh(), example_args, {}, expected_for_1_2) + verify_model(Hardtanh2(), example_args, {}, expected_for_1_2) + verify_model(Hardtanh3(), example_args, {}, expected_hardtanh_for_3) def test_softplus(): @@ -624,10 +959,20 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus( - x, beta=1.0, threshold=20.0 + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + x, R.const(1.0, "float32") ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv1, R.const(1.0, "float32")) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv3, R.const(1.0, "float32") + ) + lv5: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + lv, R.const(20.0, "float32") + ) + lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv5, x, lv4) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,) R.output(gv) return gv @@ -659,22 +1004,40 @@ def forward(self, input): return torch.ops.aten.leaky_relu_(input, 0.02) @tvm.script.ir_module - class expected: + class expected_for_1_2: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input_1, 0.02) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input_1, alpha=0.02) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv + @tvm.script.ir_module + class expected_for_3: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32") + ): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input, alpha=0.02) + gv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv, lv) + R.output(gv) + return gv + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(LeakyReLU0(), example_args, {}, expected) - verify_model(LeakyReLU1(), example_args, {}, expected) - verify_model(LeakyReLU2(), example_args, {}, expected) + verify_model(LeakyReLU0(), example_args, {}, expected_for_1_2) + verify_model(LeakyReLU1(), example_args, {}, expected_for_1_2) + verify_model(LeakyReLU2(), example_args, {}, expected_for_3) def test_logaddexp(): @@ -686,13 +1049,32 @@ def forward(self, input1, input2): class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), - input_2: R.Tensor((1, 3, 10, 10), dtype="float32"), + input1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input2: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log_add_exp(input_1, input_2) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater_equal(input1, input2) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, input1, input2) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, input2, input1) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input1) + lv4: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal( + lv3, R.const(float("inf"), "float32") + ) + lv5: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input1, input1) + lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.multiply(lv5, lv4) + lv7: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv6) + lv8: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input1, input2) + lv9: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_and(lv7, lv8) + lv10: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lv2, lv1) + lv11: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv10) + lv12: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + lv11, R.const(1.0, "float32") + ) + lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(lv12) + lv14: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv1, lv13) + lv15: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv9, input1, lv14) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv15,) R.output(gv) return gv @@ -758,10 +1140,13 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu( - x, R.const([0.25], dtype="float32"), axis=1 + lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.reshape( + R.const([0.25], dtype="float32"), R.shape([1, 1, 1, 1]) ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(x, R.const(0.0, "float32")) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, x) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1, x, lv2) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -853,26 +1238,18 @@ def main( input: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - input, R.const(0.5, "float32") - ) - lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( - input, R.const(0.5, "float32") + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input) + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lv, R.const(0.5, "float32")) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sign(input) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + lv2, R.const(0.5, "float32") ) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32") - lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2) - - lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( - input, R.const(0.5, "float32") + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(input, lv3) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + input, R.const(0.0, "float32") ) - lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32")) - lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5) - lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32") - lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7) - - lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8) - - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,) + lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1, lv4, lv5) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,) R.output(gv) return gv @@ -892,12 +1269,23 @@ def forward(self, input): class expected_tril: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + lv: R.Tensor((10,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, axis=[-2]) + lv2: R.Tensor((10,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" + ) + lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, axis=[-1]) + lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3) + lv5: R.Tensor((10, 10), dtype="bool") = R.less_equal(lv4, R.const(1, "int64")) + lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, lv6) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,) R.output(gv) return gv @@ -911,12 +1299,23 @@ def forward(self, input): class expected_triu: @R.function def main( - input_1: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + lv: R.Tensor((10,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, axis=[-2]) + lv2: R.Tensor((10,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" + ) + lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, axis=[-1]) + lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3) + lv5: R.Tensor((10, 10), dtype="bool") = R.greater_equal(lv4, R.const(1, "int64")) + lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, lv6) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,) R.output(gv) return gv @@ -971,6 +1370,21 @@ def main( R.output(gv) return gv + @tvm.script.ir_module + class expected_binary1_inplace: + @R.function + def main( + lhs: R.Tensor((10, 10), dtype="float32"), + rhs: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs) + gv: R.Tuple( + R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32") + ) = (lv, lv) + R.output(gv) + return gv + class Binary2(Module): def __init__(self, op): super().__init__() @@ -991,10 +1405,132 @@ def main( R.output(gv) return gv - verify_model(Binary1(op), example_args1, {}, expected_binary1) - verify_model(Binary2(op), example_args2, {}, expected_binary2) - - + @tvm.script.ir_module + class expected_binary2_inplace: + @R.function + def main( + lhs: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, R.const(1.0)) + gv: R.Tuple( + R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32") + ) = (lv, lv) + R.output(gv) + return gv + + inplace_ops = [ + torch.ops.aten.add_, + torch.ops.aten.bitwise_or_, + torch.ops.aten.mul_, + ] + + expected1 = expected_binary1_inplace if op in inplace_ops else expected_binary1 + expected2 = expected_binary2_inplace if op in inplace_ops else expected_binary2 + verify_model(Binary1(op), example_args1, {}, expected1) + verify_model(Binary2(op), example_args2, {}, expected2) + + +operator_binary_scalar = [ + (torch.ops.aten.add.Scalar, R.add), + (torch.ops.aten.bitwise_and.Scalar, R.bitwise_and), + (torch.ops.aten.bitwise_or.Scalar, R.bitwise_or), + (torch.ops.aten.bitwise_xor.Scalar, R.bitwise_xor), + (torch.ops.aten.div.Scalar, R.divide), + (torch.ops.aten.sub.Scalar, R.subtract), + (torch.ops.aten.mul.Scalar, R.multiply), + (torch.ops.aten.remainder.Scalar, R.floor_mod), +] + + +@pytest.mark.parametrize("op, relax_op", operator_binary_scalar) +def test_binary_scalar(op, relax_op): + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + class BinaryScalar(Module): + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, lhs): + return self.op(lhs, 1.0) + + @tvm.script.ir_module + class expected_binary_scalar: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(lhs, R.const(1.0)) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(BinaryScalar(op), example_args, {}, expected_binary_scalar) + + +operator_binary_promote = [ + (operator.add, R.add), + (operator.sub, R.subtract), + (operator.mul, R.multiply), + (operator.truediv, R.divide), + (operator.pow, R.power), + (operator.mod, R.floor_mod), +] + + +@pytest.mark.parametrize("op, relax_op", operator_binary_promote) +def test_binary_dtype_promotion(op, relax_op): + """Ensure binary ops promote differing dtypes following PyTorch rules.""" + + class BinaryPromoteLHS(Module): + def forward(self, x): + arange_val = torch.arange(x.shape[1]) # int64 by default + return op(x, arange_val) + + @tvm.script.ir_module + class expected_promote_lhs: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((3,), dtype="float32") = R.astype(lv, dtype="float32") + lv2: R.Tensor((2, 3), dtype="float32") = relax_op(x, lv1) + gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,) + R.output(gv) + return gv + + class BinaryPromoteRHS(Module): + def forward(self, x): + arange_val = torch.arange(x.shape[1]) # int64 by default + return op(arange_val, x) + + @tvm.script.ir_module + class expected_promote_rhs: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((3,), dtype="float32") = R.astype(lv, dtype="float32") + lv2: R.Tensor((2, 3), dtype="float32") = relax_op(lv1, x) + gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32),) + verify_model(BinaryPromoteLHS(), example_args, {}, expected_promote_lhs) + verify_model(BinaryPromoteRHS(), example_args, {}, expected_promote_rhs) + + operator_binary_2 = [ (operator.eq, R.equal), (operator.ne, R.not_equal), @@ -1157,11 +1693,11 @@ def main( x: R.Tensor((10, 10), dtype="float32"), test_elements: R.Tensor((8,), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): with R.dataflow(): - lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x, axis=[-1]) - lv1: R.Tensor((8,), dtype="float32") = R.reshape(test_elements, R.shape([8])) - lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1) - lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], keepdims=False) - lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, R.const(0.0, "float32")) + lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x, R.shape([10, 10, 1])) + lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, test_elements) + lv2: R.Tensor((10, 10, 8), dtype="int8") = R.astype(lv1, dtype="int8") + lv3: R.Tensor((10, 10), dtype="int8") = R.max(lv2, axis=[-1], keepdims=False) + lv4: R.Tensor((10, 10), dtype="bool") = R.astype(lv3, dtype="bool") gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,) R.output(gv) return gv @@ -1238,7 +1774,7 @@ def main( def test_batchnorm2d(): - class BatchNorm2d(Module): + class BatchNorm2d1(Module): def __init__(self): super().__init__() self.bn = torch.nn.BatchNorm2d(3) @@ -1272,6 +1808,49 @@ def main( epsilon=1e-05, center=True, scale=True, + momentum=0.1, + training=False, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + class BatchNorm2dCustom(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3, eps=0.001, momentum=0.01) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=0.001, + center=True, + scale=True, + momentum=0.01, + training=False, ) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) @@ -1280,14 +1859,99 @@ def main( example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - model = BatchNorm2d().eval() - binding = { - "w1": model.bn.weight.detach().numpy(), - "w2": model.bn.bias.detach().numpy(), - "w3": model.bn.running_mean.detach().numpy(), - "w4": model.bn.running_var.detach().numpy(), + model_1 = BatchNorm2d1().eval() + binding_1 = { + "w1": model_1.bn.weight.detach().numpy(), + "w2": model_1.bn.bias.detach().numpy(), + "w3": model_1.bn.running_mean.detach().numpy(), + "w4": model_1.bn.running_var.detach().numpy(), } - verify_model(model, example_args, binding, expected1) + verify_model(model_1, example_args, binding_1, expected1) + + model_2 = BatchNorm2dCustom().eval() + binding_2 = { + "w1": model_2.bn.weight.detach().numpy(), + "w2": model_2.bn.bias.detach().numpy(), + "w3": model_2.bn.running_mean.detach().numpy(), + "w4": model_2.bn.running_var.detach().numpy(), + } + verify_model(model_2, example_args, binding_2, expected2) + + class BatchNorm2dTraining(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3, track_running_stats=True) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((2, 3, 4, 4), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((), dtype="int64"), + R.Tensor((2, 3, 4, 4), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64")) + lv1: R.Tuple( + R.Tensor((2, 3, 4, 4), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=0.1, + center=True, + scale=True, + momentum=1.0, + training=True, + ) + lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0] + lv3: R.Tensor((3,), dtype="float32") = lv1[1] + lv4: R.Tensor((3,), dtype="float32") = R.zeros(R.shape([3]), dtype="float32") + lv5: R.Tuple( + R.Tensor((2, 3, 4, 4), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = (lv2, lv3, lv4, lv4, lv4) + lv6: R.Tensor((2, 3, 4, 4), dtype="float32") = lv5[0] + lv7: R.Tensor((3,), dtype="float32") = lv5[3] + lv8: R.Tensor((3,), dtype="float32") = lv5[4] + gv: R.Tuple( + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((), dtype="int64"), + R.Tensor((2, 3, 4, 4), dtype="float32"), + ) = (lv7, lv8, lv, lv6) + R.output(gv) + return gv + + example_args_train = (torch.randn(2, 3, 4, 4, dtype=torch.float32),) + + model_3 = BatchNorm2dTraining() + model_3.train() # Set to training mode + binding_3 = { + "w1": model_3.bn.weight.detach().numpy(), + "w2": model_3.bn.bias.detach().numpy(), + "w3": model_3.bn.running_mean.detach().numpy(), + "w4": model_3.bn.running_var.detach().numpy(), + } + verify_model(model_3, example_args_train, binding_3, expected3) def test_adaptive_avgpool1d(): @@ -1310,10 +1974,12 @@ def main( input_1: R.Tensor((1, 3, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.adaptive_avg_pool1d( - input_1, output_size=[5], layout="NCW" + lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input_1, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = R.nn.adaptive_avg_pool2d( + lv, output_size=[1, 5], layout="NCHW" ) - gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,) + lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -1446,6 +2112,63 @@ def main( verify_model(Addmm2(), example_args, {}, expected2) +def test_sparse_addmm(): + class SparseAddmm1(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.sparse.addmm(x1, x2, x3) + + class SparseAddmm2(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.sparse.addmm(x1, x2, x3, beta=0.8, alpha=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, R.const(0.5, "float32")) + lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, R.const(0.8, "float32")) + lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + + verify_model(SparseAddmm1(), example_args, {}, expected1) + verify_model(SparseAddmm2(), example_args, {}, expected2) + + def test_avg_pool1d(): class AvgPool1d1(Module): def __init__(self): @@ -1459,21 +2182,23 @@ def forward(self, input): class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10), dtype="float32") + input: R.Tensor((1, 3, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.avg_pool1d( - input_1, - pool_size=[1], - strides=[1], - dilation=[1], - padding=[0, 0], + lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 10), dtype="float32") = R.nn.avg_pool2d( + lv, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], ceil_mode=False, count_include_pad=True, - layout="NCW", - out_layout="NCW", + layout="NCHW", + out_layout="NCHW", ) - gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv,) + lv2: R.Tensor((1, 3, 10), dtype="float32") = R.squeeze(lv1, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -1494,20 +2219,24 @@ def forward(self, input): @tvm.script.ir_module class expected2: @R.function - def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + def main( + input: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 6), dtype="float32")): with R.dataflow(): - lv = R.nn.avg_pool1d( - input_1, - pool_size=[3], - strides=[2], - dilation=[1], - padding=[1, 1], + lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 6), dtype="float32") = R.nn.avg_pool2d( + lv, + pool_size=[1, 3], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 1, 0, 1], ceil_mode=True, count_include_pad=True, - layout="NCW", - out_layout="NCW", + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 6), dtype="float32") = R.squeeze(lv1, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 6), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -1518,20 +2247,24 @@ def forward(self, input): @tvm.script.ir_module class expected3: @R.function - def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + def main( + input: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")): with R.dataflow(): - lv = R.nn.avg_pool1d( - input_1, - pool_size=[2], - strides=[2], - dilation=[1], - padding=[0, 0], + lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = R.nn.avg_pool2d( + lv, + pool_size=[1, 2], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], ceil_mode=False, count_include_pad=True, - layout="NCW", - out_layout="NCW", + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -1565,6 +2298,7 @@ def main( strides=[1, 1], dilation=[1, 1], padding=[0, 0, 0, 0], + count_include_pad=True, layout="NCHW", out_layout="NCHW", ) @@ -1598,6 +2332,7 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=True, + count_include_pad=True, layout="NCHW", out_layout="NCHW", ) @@ -1610,7 +2345,7 @@ def forward(self, input): return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2) @tvm.script.ir_module - class expected3: + class expected4: @R.function def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): @@ -1621,6 +2356,7 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): dilation=[1, 1], padding=[0, 0, 0, 0], ceil_mode=False, + count_include_pad=True, layout="NCHW", out_layout="NCHW", ) @@ -1628,21 +2364,48 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): R.output(gv) return gv - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(AvgPool2d1(), example_args, {}, expected1) - verify_model(AvgPool2d2(), example_args, {}, expected2) - verify_model(AvgPool2d3(), example_args, {}, expected2) - verify_model(AvgPool2d4(), example_args, {}, expected3) - - -def test_avg_pool3d(): - class AvgPool3d1(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool3d(kernel_size=1) - + class AvgPool2d5(Module): def forward(self, input): - return self.pool(input) + return torch.nn.functional.avg_pool2d( + input, kernel_size=[2, 1], divisor_override=2, count_include_pad=False + ) + + @tvm.script.ir_module + class expected5: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[2, 1], + strides=[2, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + ceil_mode=False, + count_include_pad=False, + layout="NCHW", + out_layout="NCHW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AvgPool2d1(), example_args, {}, expected1) + verify_model(AvgPool2d2(), example_args, {}, expected2) + verify_model(AvgPool2d3(), example_args, {}, expected2) + verify_model(AvgPool2d4(), example_args, {}, expected4) + verify_model(AvgPool2d5(), example_args, {}, expected5) + + +def test_avg_pool3d(): + class AvgPool3d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size=1) + + def forward(self, input): + return self.pool(input) @tvm.script.ir_module class expected1: @@ -1748,8 +2511,10 @@ def main( inp_2: R.Tensor((4, 256, 512), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) - lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, inp_0) + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + inp_1, inp_2, out_dtype="float32" + ) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(inp_0, lv) gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -1770,7 +2535,9 @@ def main( inp_2: R.Tensor((4, 256, 512), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + inp_1, inp_2, out_dtype="float32" + ) lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( lv, R.const(2, "float32") ) @@ -1794,14 +2561,16 @@ def main( inp_2: R.Tensor((4, 256, 512), dtype="float32"), ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + inp_1, inp_2, out_dtype="float32" + ) lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( lv, R.const(2, "float32") ) lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( inp_0, R.const(3, "float32") ) - lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2) + lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv2, lv1) gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -1816,6 +2585,7 @@ def main( example_args, {}, Expected1, + run_ep_decomposition=True, ) verify_model( @@ -1823,6 +2593,7 @@ def main( example_args, {}, Expected2, + run_ep_decomposition=True, ) verify_model( @@ -1830,6 +2601,7 @@ def main( example_args, {}, Expected3, + run_ep_decomposition=True, ) @@ -1866,6 +2638,7 @@ def main( example_args, {}, Expected, + run_ep_decomposition=True, ) @@ -2371,13 +3144,25 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( - x, - pad_width=[0, 0, 0, 0, 2, 2, 1, 1], - pad_mode="reflect", - pad_value=0.0, + lv: R.Tensor((14,), dtype="int64") = R.arange( + R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64" ) - gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + lv1: R.Tensor((14,), dtype="int64") = R.abs(lv) + lv2: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv1) + lv3: R.Tensor((14,), dtype="int64") = R.abs(lv2) + lv4: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv3) + lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv4, axis=2, mode="fast") + lv6: R.Tensor((12,), dtype="int64") = R.arange( + R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64" + ) + lv7: R.Tensor((12,), dtype="int64") = R.abs(lv6) + lv8: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv7) + lv9: R.Tensor((12,), dtype="int64") = R.abs(lv8) + lv10: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv9) + lv11: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take( + lv5, lv10, axis=3, mode="fast" + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv11,) R.output(gv) return gv @@ -2388,13 +3173,19 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( - x, - pad_width=[0, 0, 0, 0, 2, 2, 1, 1], - pad_mode="replicate", - pad_value=0.0, + lv: R.Tensor((14,), dtype="int64") = R.arange( + R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64" ) - gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + lv1: R.Tensor((14,), dtype="int64") = R.clip(lv, R.prim_value(0), R.prim_value(9)) + lv2: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv1, axis=2, mode="fast") + lv3: R.Tensor((12,), dtype="int64") = R.arange( + R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64" + ) + lv4: R.Tensor((12,), dtype="int64") = R.clip(lv3, R.prim_value(0), R.prim_value(9)) + lv5: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take( + lv2, lv4, axis=3, mode="fast" + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -2405,21 +3196,195 @@ def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( + lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.zeros( + R.shape([1, 3, 14, 12]), dtype="float32" + ) + + lv1: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(3),), + (R.prim_value(1),), + (R.prim_value(11),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( x, - pad_width=[0, 0, 0, 0, 2, 2, 1, 1], - pad_mode="circular", - pad_value=0.0, + (R.prim_value(3),), + (R.prim_value(0),), + (R.prim_value(10),), + (R.prim_value(1),), + assume_inbound=False, ) - gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) + + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + lv1, + (R.prim_value(2),), + (R.prim_value(2),), + (R.prim_value(12),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + lv2, + (R.prim_value(2),), + (R.prim_value(0),), + (R.prim_value(10),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to( + lv4, R.shape([1, 3, 10, 10]) + ) + + lv6: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(3),), + (R.prim_value(1),), + (R.prim_value(11),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv7: R.Tensor((1, 3, 14, 10), dtype="float32") = R.slice_scatter( + lv6, lv5, R.prim_value(2), R.prim_value(12), R.prim_value(1), axis=2 + ) + + lv8: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv, lv7, R.prim_value(1), R.prim_value(11), R.prim_value(1), axis=3 + ) + + lv9: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv8, + (R.prim_value(3),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv10: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv8, + (R.prim_value(3),), + (R.prim_value(10),), + (R.prim_value(11),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv11: R.Tensor((1, 3, 14, 1), dtype="float32") = R.broadcast_to( + lv10, R.shape([1, 3, 14, 1]) + ) + + lv12: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv8, lv11, R.prim_value(0), R.prim_value(1), R.prim_value(1), axis=3 + ) + + lv13: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv12, + (R.prim_value(3),), + (R.prim_value(11),), + (R.prim_value(12),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv14: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice( + lv12, + (R.prim_value(3),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv15: R.Tensor((1, 3, 14, 1), dtype="float32") = R.broadcast_to( + lv14, R.shape([1, 3, 14, 1]) + ) + lv16: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv12, lv15, R.prim_value(11), R.prim_value(12), R.prim_value(1), axis=3 + ) + + lv17: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv16, + (R.prim_value(2),), + (R.prim_value(0),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv18: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv16, + (R.prim_value(2),), + (R.prim_value(10),), + (R.prim_value(12),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv19: R.Tensor((1, 3, 2, 12), dtype="float32") = R.broadcast_to( + lv18, R.shape([1, 3, 2, 12]) + ) + + lv20: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv16, lv19, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=2 + ) + lv21: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv20, + (R.prim_value(2),), + (R.prim_value(12),), + (R.prim_value(14),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv22: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice( + lv20, + (R.prim_value(2),), + (R.prim_value(2),), + (R.prim_value(4),), + (R.prim_value(1),), + assume_inbound=False, + ) + + lv23: R.Tensor((1, 3, 2, 12), dtype="float32") = R.broadcast_to( + lv22, R.shape([1, 3, 2, 12]) + ) + + lv24: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter( + lv20, lv23, R.prim_value(12), R.prim_value(14), R.prim_value(1), axis=2 + ) + gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv24,) R.output(gv) return gv example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant) - verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, expected_reflect) - verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, {}, expected_replicate) - verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular) + verify_model( + PadModel(pad=[1, 1, 2, 2], mode="reflect"), + example_args, + {}, + expected_reflect, + run_ep_decomposition=True, + ) + verify_model( + PadModel(pad=[1, 1, 2, 2], mode="replicate"), + example_args, + {}, + expected_replicate, + run_ep_decomposition=True, + ) + verify_model( + PadModel(pad=[1, 1, 2, 2], mode="circular"), + example_args, + {}, + expected_circular, + run_ep_decomposition=True, + ) def test_pixel_shuffle(): @@ -2446,10 +3411,16 @@ def main( x: R.Tensor((1, 8, 10, 15), dtype="float32") ) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle( - x, upscale_factor=2 + lv: R.Tensor((1, 2, 2, 2, 10, 15), dtype="float32") = R.reshape( + x, R.shape([1, 2, 2, 2, 10, 15]) + ) + lv1: R.Tensor((1, 2, 10, 2, 15, 2), dtype="float32") = R.permute_dims( + lv, axes=[0, 1, 4, 2, 5, 3] + ) + lv2: R.Tensor((1, 2, 20, 30), dtype="float32") = R.reshape( + lv1, R.shape([1, 2, 20, 30]) ) - gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -2500,10 +3471,10 @@ def main( return gv example_args = (torch.randn(4, 4, dtype=torch.float32),) - verify_model(Einsum1(), example_args, {}, Expected1) + verify_model(Einsum1(), example_args, {}, Expected1, run_ep_decomposition=False) example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, dtype=torch.float32)) - verify_model(Einsum2(), example_args, {}, Expected2) + verify_model(Einsum2(), example_args, {}, Expected2, run_ep_decomposition=False) def test_outer(): @@ -2515,11 +3486,12 @@ def forward(self, x, y): class expected: @R.function def main( - a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32") + x: R.Tensor((3,), dtype="float32"), y: R.Tensor((4,), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): - lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b) - gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) + lv: R.Tensor((3, 1), dtype="float32") = R.reshape(x, R.shape([3, 1])) + lv1: R.Tensor((3, 4), dtype="float32") = R.multiply(lv, y) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -2724,12 +3696,14 @@ def main( ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) - lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( - input_1, lv, out_dtype="float32" + lv: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, R.shape([30, 10])) + lv1: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=[1, 0]) + lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv, lv1, out_dtype="float32") + lv3: R.Tensor((30, 7), dtype="float32") = R.add(w2, lv2) + lv4: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape( + lv3, R.shape([1, 3, 10, 7]) ) - lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2) - gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,) + gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv4,) R.output(gv) return gv @@ -2750,11 +3724,13 @@ def main( ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) - lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( - input_1, lv, out_dtype="float32" + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=[1, 0]) + lv1: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, R.shape([30, 10])) + lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv1, lv, out_dtype="float32") + lv3: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape( + lv2, R.shape([1, 3, 10, 7]) ) - gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,) + gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -2804,16 +3780,24 @@ def main( input_1: R.Tensor((1, 3, 8), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool1d( - input_1, - pool_size=[2], - strides=[2], - dilation=[1], - padding=[0, 0], - layout="NCW", - out_layout="NCW", + lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d( + lv, + pool_size=[1, 2], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1) + lv3: R.Tuple( + R.Tensor((1, 3, 1, 4), dtype="float32"), + R.Tensor((1, 3, 1, 4), dtype="float32"), + ) = (lv1, lv2) + lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0] + lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -2824,16 +3808,24 @@ def main( input_1: R.Tensor((1, 3, 8), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool1d( - input_1, - pool_size=[2], - strides=[2], - dilation=[1], - padding=[0, 0], - layout="NCW", - out_layout="NCW", + lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d( + lv, + pool_size=[1, 2], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1) + lv3: R.Tuple( + R.Tensor((1, 3, 1, 4), dtype="float32"), + R.Tensor((1, 3, 1, 4), dtype="float32"), + ) = (lv1, lv2) + lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0] + lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -2844,16 +3836,24 @@ def main( input_1: R.Tensor((1, 3, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool1d( - input_1, - pool_size=[3], - strides=[2], - dilation=[1], - padding=[0, 0], - layout="NCW", - out_layout="NCW", + lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input_1, axis=[-2]) + lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d( + lv, + pool_size=[1, 3], + strides=[1, 2], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", ) - gv = (lv,) + lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1) + lv3: R.Tuple( + R.Tensor((1, 3, 1, 4), dtype="float32"), + R.Tensor((1, 3, 1, 4), dtype="float32"), + ) = (lv1, lv2) + lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0] + lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2]) + gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,) R.output(gv) return gv @@ -2901,7 +3901,13 @@ def main( layout="NCHW", out_layout="NCHW", ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -2930,7 +3936,12 @@ def main( layout="NCHW", out_layout="NCHW", ) - gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 4, 4), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 4, 4), dtype="float32"), R.Tensor((1, 3, 4, 4), dtype="float32") + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 4, 4), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -2959,7 +3970,12 @@ def main( layout="NCHW", out_layout="NCHW", ) - gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,) + lv1: R.Tensor((1, 3, 6, 6), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 6, 6), dtype="float32"), R.Tensor((1, 3, 6, 6), dtype="float32") + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 6, 6), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -2993,7 +4009,7 @@ def main( input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool3d( + lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.max_pool3d( input_1, pool_size=[1, 1, 1], strides=[1, 1, 1], @@ -3002,7 +4018,13 @@ def main( layout="NCDHW", out_layout="NCDHW", ) - gv = (lv,) + lv1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 4, 4, 4), dtype="float32"), + R.Tensor((1, 3, 4, 4, 4), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3021,7 +4043,7 @@ def main( input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool3d( + lv: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.nn.max_pool3d( input_1, pool_size=[2, 2, 2], strides=[2, 2, 2], @@ -3030,7 +4052,13 @@ def main( layout="NCDHW", out_layout="NCDHW", ) - gv = (lv,) + lv1: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 3, 3, 3), dtype="float32"), + R.Tensor((1, 3, 3, 3, 3), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3049,7 +4077,7 @@ def main( input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")): with R.dataflow(): - lv = R.nn.max_pool3d( + lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.nn.max_pool3d( input_1, pool_size=[3, 3, 3], strides=[2, 2, 2], @@ -3058,7 +4086,13 @@ def main( layout="NCDHW", out_layout="NCDHW", ) - gv = (lv,) + lv1: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.zeros_like(lv) + lv2: R.Tuple( + R.Tensor((1, 3, 5, 5, 5), dtype="float32"), + R.Tensor((1, 3, 5, 5, 5), dtype="float32"), + ) = (lv, lv1) + lv3: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv2[0] + gv: R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -3083,22 +4117,22 @@ def forward(self, q, k, v): class Expected1: @R.function def main( - inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), - inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), - inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), + q: R.Tensor((32, 8, 128, 64), dtype="float32"), + k: R.Tensor((32, 8, 128, 64), dtype="float32"), + v: R.Tensor((32, 8, 128, 64), dtype="float32"), ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( - inp_0, axes=[0, 2, 1, 3] + q, axes=[0, 2, 1, 3] ) lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( - inp_1, axes=[0, 2, 1, 3] + k, axes=[0, 2, 1, 3] ) lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( - inp_2, axes=[0, 2, 1, 3] + v, axes=[0, 2, 1, 3] ) lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( - lv, lv1, lv2, scale=None + lv, lv1, lv2, scale=None, causal_mask=None, window_size=None ) lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( lv3, axes=[0, 2, 1, 3] @@ -3115,23 +4149,23 @@ def forward(self, q, k, v, mask): class Expected2: @R.function def main( - inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), - inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), - inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), - inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"), + q: R.Tensor((32, 8, 128, 64), dtype="float32"), + k: R.Tensor((32, 8, 128, 64), dtype="float32"), + v: R.Tensor((32, 8, 128, 64), dtype="float32"), + mask: R.Tensor((32, 8, 128, 128), dtype="float32"), ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( - inp_0, axes=[0, 2, 1, 3] + q, axes=[0, 2, 1, 3] ) lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( - inp_1, axes=[0, 2, 1, 3] + k, axes=[0, 2, 1, 3] ) lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( - inp_2, axes=[0, 2, 1, 3] + v, axes=[0, 2, 1, 3] ) - lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( - lv, lv1, lv2, inp_3, scale=None + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention_bias( + lv, lv1, lv2, mask, scale=None, causal_mask=None, window_size=None ) lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( lv3, axes=[0, 2, 1, 3] @@ -3149,6 +4183,7 @@ def main( ), {}, Expected1, + run_ep_decomposition=False, ) verify_model( @@ -3161,6 +4196,46 @@ def main( ), {}, Expected2, + run_ep_decomposition=False, + ) + + # Test 2D input (seq_len, head_dim) - bug fix for #18441 + class Attention2D(Module): + def forward(self, x): + return torch.nn.functional.scaled_dot_product_attention(x, x, x, is_causal=False) + + @I.ir_module + class Expected2D: + @R.function + def main( + x: R.Tensor((8, 32), dtype="float32"), + ) -> R.Tuple(R.Tensor((8, 32), dtype="float32")): + with R.dataflow(): + # Expand to add batch dimension for query, key, value separately + # (8, 32) -> (1, 8, 32) + lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + lv1: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + lv2: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + # Expand to add num_heads dimension: (1, 8, 32) -> (1, 1, 8, 32) + lv3: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=[1]) + lv4: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv1, axis=[1]) + lv5: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv2, axis=[1]) + # Attention operation: (1, 1, 8, 32) -> (1, 1, 8, 32) + lv6: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention( + lv3, lv4, lv5, scale=None, causal_mask=None, window_size=None + ) + # Squeeze batch and num_heads dimensions: (1, 1, 8, 32) -> (8, 32) + lv7: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv6, axis=[0, 1]) + gv: R.Tuple(R.Tensor((8, 32), dtype="float32")) = (lv7,) + R.output(gv) + return gv + + verify_model( + Attention2D(), + (torch.randn(8, 32, dtype=torch.float32),), + {}, + Expected2D, + run_ep_decomposition=False, ) @@ -3173,7 +4248,7 @@ def forward(self, data): class expected1: @R.function def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + data: R.Tensor((3, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -3181,30 +4256,38 @@ def main( ): # block 0 with R.dataflow(): - lv: R.Tuple( - R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((1, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=0) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] - lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) - lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] - lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) - lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] - lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) - lv7: R.Tuple( + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(2),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[0]) + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[0]) + gv: R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv2, lv4, lv6) - lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] - lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] - lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] - gv: R.Tuple( - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv8, lv9, lv10) + ) = (lv3, lv4, lv5) R.output(gv) return gv @@ -3216,7 +4299,7 @@ def forward(self, data): class expected2: @R.function def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + data: R.Tensor((3, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -3224,36 +4307,66 @@ def main( ): # block 0 with R.dataflow(): - lv: R.Tuple( - R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 1, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=1) - lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] - lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) - lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] - lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) - lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] - lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) - lv7: R.Tuple( - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv2, lv4, lv6) - lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] - lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] - lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + lv: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv2: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[1]) + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[1]) gv: R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv8, lv9, lv10) + ) = (lv3, lv4, lv5) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected3: + @R.function + def main( + data: R.Tensor((3, 1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 1, 3), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((3, 3), dtype="float32") = R.squeeze(lv, axis=[1]) + gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv1,) R.output(gv) return gv example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) verify_model(Unbind1(), example_args, {}, expected1) verify_model(Unbind2(), example_args, {}, expected2) + single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),) + verify_model(Unbind2(), single_dim_args, {}, expected3) def test_interpolate(): @@ -3325,14 +4438,455 @@ class expected_bicubic: def main( input: R.Tensor((1, 3, 112, 112), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + lv: R.Tensor((1, 3, 112, 112), dtype="float32") = R.astype(input, dtype="float32") + lv1: R.Tensor((1, 3, 112, 112), dtype="float32") = R.astype(lv, dtype="float32") + lv2: R.Tensor((224,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(224), R.prim_value(1), dtype="int64" + ) + lv3: R.Tensor((224,), dtype="float32") = R.astype(lv2, dtype="float32") + lv4: R.Tensor((224,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(224), R.prim_value(1), dtype="int64" + ) + lv5: R.Tensor((224,), dtype="float32") = R.astype(lv4, dtype="float32") + lv6: R.Tensor((224,), dtype="float32") = R.add(lv5, R.const(0.5, "float32")) + lv7: R.Tensor((224,), dtype="float32") = R.multiply(lv6, R.const(0.5, "float32")) + lv8: R.Tensor((224,), dtype="float32") = R.subtract(lv7, R.const(0.5, "float32")) + lv9: R.Tensor((224,), dtype="float32") = R.add(lv3, R.const(0.5, "float32")) + lv10: R.Tensor((224,), dtype="float32") = R.multiply(lv9, R.const(0.5, "float32")) + lv11: R.Tensor((224,), dtype="float32") = R.subtract(lv10, R.const(0.5, "float32")) + lv12: R.Tensor((224, 1), dtype="float32") = R.expand_dims(lv11, axis=[-1]) + lv13: R.Tensor((224,), dtype="float32") = R.floor(lv8) + lv14: R.Tensor((224, 1), dtype="float32") = R.floor(lv12) + lv15: R.Tensor((224, 1), dtype="float32") = R.subtract(lv12, lv14) + lv16: R.Tensor((224, 1), dtype="float32") = R.clip( + lv15, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(1.0)) + ) + lv17: R.Tensor((224,), dtype="float32") = R.subtract(lv8, lv13) + lv18: R.Tensor((224,), dtype="float32") = R.clip( + lv17, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(1.0)) + ) + lv19: R.Tensor((224,), dtype="int64") = R.astype(lv13, dtype="int64") + lv20: R.Tensor((224, 1), dtype="int64") = R.astype(lv14, dtype="int64") + lv21: R.Tensor((224, 1), dtype="int64") = R.subtract(lv20, R.const(1, "int64")) + lv22: R.Tensor((224, 1), dtype="int64") = R.add(lv20, R.const(1, "int64")) + lv23: R.Tensor((224, 1), dtype="int64") = R.add(lv20, R.const(2, "int64")) + lv24: R.Tensor((224,), dtype="int64") = R.subtract(lv19, R.const(1, "int64")) + lv25: R.Tensor((224,), dtype="int64") = R.add(lv19, R.const(1, "int64")) + lv26: R.Tensor((224,), dtype="int64") = R.add(lv19, R.const(2, "int64")) + lv27: R.Tensor((224,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv18) + lv28: R.Tensor((448,), dtype="float32") = R.concat((lv18, lv27), axis=0) + lv29: R.Tensor((2, 224), dtype="float32") = R.reshape(lv28, R.shape([2, 224])) + lv30: R.Tensor((224,), dtype="float32") = R.add(lv18, R.const(1.0, "float32")) + lv31: R.Tensor((224,), dtype="float32") = R.subtract(R.const(2.0, "float32"), lv18) + lv32: R.Tensor((448,), dtype="float32") = R.concat((lv30, lv31), axis=0) + lv33: R.Tensor((2, 224), dtype="float32") = R.reshape(lv32, R.shape([2, 224])) + lv34: R.Tensor((2, 224), dtype="float32") = R.multiply( + lv33, R.const(-0.75, "float32") + ) + lv35: R.Tensor((2, 224), dtype="float32") = R.subtract( + lv34, R.const(-3.75, "float32") + ) + lv36: R.Tensor((2, 224), dtype="float32") = R.multiply(lv35, lv33) + lv37: R.Tensor((2, 224), dtype="float32") = R.add(lv36, R.const(-6.0, "float32")) + lv38: R.Tensor((2, 224), dtype="float32") = R.multiply(lv37, lv33) + lv39: R.Tensor((2, 224), dtype="float32") = R.subtract( + lv38, R.const(-3.0, "float32") + ) + lv40: R.Tensor((2, 224), dtype="float32") = R.multiply( + lv29, R.const(1.25, "float32") + ) + lv41: R.Tensor((2, 224), dtype="float32") = R.subtract( + lv40, R.const(2.25, "float32") + ) + lv42: R.Tensor((2, 224), dtype="float32") = R.multiply(lv41, lv29) + lv43: R.Tensor((2, 224), dtype="float32") = R.multiply(lv42, lv29) + lv44: R.Tensor((2, 224), dtype="float32") = R.add(lv43, R.const(1.0, "float32")) + lv45: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv39, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv46: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv39, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv47: R.Tensor((224,), dtype="float32") = R.squeeze(lv45, axis=[0]) + lv48: R.Tensor((224,), dtype="float32") = R.squeeze(lv46, axis=[0]) + lv49: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv44, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv50: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv44, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv51: R.Tensor((224,), dtype="float32") = R.squeeze(lv49, axis=[0]) + lv52: R.Tensor((224,), dtype="float32") = R.squeeze(lv50, axis=[0]) + lv53: R.Tensor((224, 1), dtype="float32") = R.subtract( + R.const(1.0, "float32"), lv16 + ) + lv54: R.Tensor((448, 1), dtype="float32") = R.concat((lv16, lv53), axis=0) + lv55: R.Tensor((2, 224, 1), dtype="float32") = R.reshape(lv54, R.shape([2, 224, 1])) + lv56: R.Tensor((224, 1), dtype="float32") = R.add(lv16, R.const(1.0, "float32")) + lv57: R.Tensor((224, 1), dtype="float32") = R.subtract( + R.const(2.0, "float32"), lv16 + ) + lv58: R.Tensor((448, 1), dtype="float32") = R.concat((lv56, lv57), axis=0) + lv59: R.Tensor((2, 224, 1), dtype="float32") = R.reshape(lv58, R.shape([2, 224, 1])) + lv60: R.Tensor((2, 224, 1), dtype="float32") = R.multiply( + lv59, R.const(-0.75, "float32") + ) + lv61: R.Tensor((2, 224, 1), dtype="float32") = R.subtract( + lv60, R.const(-3.75, "float32") + ) + lv62: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv61, lv59) + lv63: R.Tensor((2, 224, 1), dtype="float32") = R.add(lv62, R.const(-6.0, "float32")) + lv64: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv63, lv59) + lv65: R.Tensor((2, 224, 1), dtype="float32") = R.subtract( + lv64, R.const(-3.0, "float32") + ) + lv66: R.Tensor((2, 224, 1), dtype="float32") = R.multiply( + lv55, R.const(1.25, "float32") + ) + lv67: R.Tensor((2, 224, 1), dtype="float32") = R.subtract( + lv66, R.const(2.25, "float32") + ) + lv68: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv67, lv55) + lv69: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv68, lv55) + lv70: R.Tensor((2, 224, 1), dtype="float32") = R.add(lv69, R.const(1.0, "float32")) + lv71: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv65, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv72: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv65, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv73: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv71, axis=[0]) + lv74: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv72, axis=[0]) + lv75: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv70, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv76: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv70, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv77: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv75, axis=[0]) + lv78: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv76, axis=[0]) + lv79: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv80: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv81: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv80, axis=3, mode="fast" + ) + lv82: R.Tensor((224,), dtype="int64") = R.squeeze(lv79, axis=None) + lv83: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv81, lv82, axis=2, mode="fast" + ) + lv84: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv85: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv86: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv85, axis=3, mode="fast" + ) + lv87: R.Tensor((224,), dtype="int64") = R.squeeze(lv84, axis=None) + lv88: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv86, lv87, axis=2, mode="fast" + ) + lv89: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv90: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv91: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv90, axis=3, mode="fast" + ) + lv92: R.Tensor((224,), dtype="int64") = R.squeeze(lv89, axis=None) + lv93: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv91, lv92, axis=2, mode="fast" + ) + lv94: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv95: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv96: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv95, axis=3, mode="fast" + ) + lv97: R.Tensor((224,), dtype="int64") = R.squeeze(lv94, axis=None) + lv98: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv96, lv97, axis=2, mode="fast" + ) + lv99: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv83, lv47) + lv100: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv88, lv51) + lv101: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv99, lv100) + lv102: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv93, lv52) + lv103: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv101, lv102) + lv104: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv98, lv48) + lv105: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv103, lv104) + lv106: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv107: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv108: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv107, axis=3, mode="fast" + ) + lv109: R.Tensor((224,), dtype="int64") = R.squeeze(lv106, axis=None) + lv110: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv108, lv109, axis=2, mode="fast" + ) + lv111: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv112: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv113: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv112, axis=3, mode="fast" + ) + lv114: R.Tensor((224,), dtype="int64") = R.squeeze(lv111, axis=None) + lv115: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv113, lv114, axis=2, mode="fast" + ) + lv116: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv117: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv118: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv117, axis=3, mode="fast" + ) + lv119: R.Tensor((224,), dtype="int64") = R.squeeze(lv116, axis=None) + lv120: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv118, lv119, axis=2, mode="fast" + ) + lv121: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv122: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv123: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv122, axis=3, mode="fast" + ) + lv124: R.Tensor((224,), dtype="int64") = R.squeeze(lv121, axis=None) + lv125: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv123, lv124, axis=2, mode="fast" + ) + lv126: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv110, lv47) + lv127: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv115, lv51) + lv128: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv126, lv127) + lv129: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv120, lv52) + lv130: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv128, lv129) + lv131: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv125, lv48) + lv132: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv130, lv131) + lv133: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv134: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv135: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv134, axis=3, mode="fast" + ) + lv136: R.Tensor((224,), dtype="int64") = R.squeeze(lv133, axis=None) + lv137: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv135, lv136, axis=2, mode="fast" + ) + lv138: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv139: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv140: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv139, axis=3, mode="fast" + ) + lv141: R.Tensor((224,), dtype="int64") = R.squeeze(lv138, axis=None) + lv142: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv140, lv141, axis=2, mode="fast" + ) + lv143: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv144: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv145: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv144, axis=3, mode="fast" + ) + lv146: R.Tensor((224,), dtype="int64") = R.squeeze(lv143, axis=None) + lv147: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv145, lv146, axis=2, mode="fast" + ) + lv148: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv149: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv150: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv149, axis=3, mode="fast" + ) + lv151: R.Tensor((224,), dtype="int64") = R.squeeze(lv148, axis=None) + lv152: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv150, lv151, axis=2, mode="fast" + ) + lv153: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv137, lv47) + lv154: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv142, lv51) + lv155: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv153, lv154) + lv156: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv147, lv52) + lv157: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv155, lv156) + lv158: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv152, lv48) + lv159: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv157, lv158) + lv160: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv161: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv162: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv161, axis=3, mode="fast" + ) + lv163: R.Tensor((224,), dtype="int64") = R.squeeze(lv160, axis=None) + lv164: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv162, lv163, axis=2, mode="fast" + ) + lv165: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv166: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv167: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv166, axis=3, mode="fast" + ) + lv168: R.Tensor((224,), dtype="int64") = R.squeeze(lv165, axis=None) + lv169: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv167, lv168, axis=2, mode="fast" + ) + lv170: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv171: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv172: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv171, axis=3, mode="fast" + ) + lv173: R.Tensor((224,), dtype="int64") = R.squeeze(lv170, axis=None) + lv174: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv172, lv173, axis=2, mode="fast" + ) + lv175: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv176: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv177: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv176, axis=3, mode="fast" + ) + lv178: R.Tensor((224,), dtype="int64") = R.squeeze(lv175, axis=None) + lv179: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv177, lv178, axis=2, mode="fast" + ) + lv180: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv164, lv47) + lv181: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv169, lv51) + lv182: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv180, lv181) + lv183: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv174, lv52) + lv184: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv182, lv183) + lv185: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv179, lv48) + lv186: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv184, lv185) + lv187: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv105, lv73) + lv188: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv132, lv77) + lv189: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv187, lv188) + lv190: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv159, lv78) + lv191: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv189, lv190) + lv192: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv186, lv74) + lv193: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv191, lv192) + lv194: R.Tensor((1, 3, 224, 224), dtype="float32") = R.astype( + lv193, dtype="float32" + ) + lv195: R.Tensor((1, 3, 224, 224), dtype="float32") = R.astype( + lv194, dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv195,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) + verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) + verify_model(InterpolateNearest(), example_args, {}, expected_nearest) + verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic) + + +def test_interpolate_antialiased(): + """Test bilinear interpolation with antialiasing enabled.""" + + class InterpolateBilinearAA(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, size=(64, 64), mode="bilinear", align_corners=False, antialias=True + ) + + @tvm.script.ir_module + class expected_bilinear_aa: + @R.function + def main( + input: R.Tensor((1, 3, 32, 32), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 64, 64), dtype="float32") = R.image.resize2d( input, - R.shape([224, 224]), + R.shape([64, 64]), roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], layout="NCHW", - method="cubic", + method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", cubic_alpha=-0.75, @@ -3340,14 +4894,12 @@ def main( extrapolation_value=0.0, out_dtype="void", ) - gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")) = (lv,) R.output(gv) return gv - example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) - verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) - verify_model(InterpolateNearest(), example_args, {}, expected_nearest) - verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic) + example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),) + verify_model(InterpolateBilinearAA(), example_args, {}, expected_bilinear_aa) def test_mean(): @@ -3637,14 +5189,13 @@ def main( input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32") ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")): with R.dataflow(): - lv: R.Tuple( - R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") - ) = R.meshgrid((input1, input2), indexing="ij") - lv1: R.Tensor((3, 3), dtype="float32") = lv[0] - lv2: R.Tensor((3, 3), dtype="float32") = lv[1] + lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input1, R.shape([3, 1])) + lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv, R.shape([3, 3])) + lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input2, R.shape([1, 3])) + lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2, R.shape([3, 3])) gv: R.Tuple( R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") - ) = (lv1, lv2) + ) = (lv1, lv3) R.output(gv) return gv @@ -3655,14 +5206,13 @@ def main( input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32") ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")): with R.dataflow(): - lv: R.Tuple( - R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") - ) = R.meshgrid((input1, input2), indexing="xy") - lv1: R.Tensor((3, 3), dtype="float32") = lv[0] - lv2: R.Tensor((3, 3), dtype="float32") = lv[1] + lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input2, R.shape([3, 1])) + lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv, R.shape([3, 3])) + lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input1, R.shape([1, 3])) + lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2, R.shape([3, 3])) gv: R.Tuple( R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") - ) = (lv1, lv2) + ) = (lv3, lv1) R.output(gv) return gv @@ -3812,25 +5362,14 @@ class Expected1: def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): with R.dataflow(): lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8])) - lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( - lv, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(7)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( - lv, - axes=[0], - begin=[R.prim_value(7)], - end=[R.prim_value(8)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv1: R.Tensor((8,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(8), R.prim_value(1), dtype="int64" ) - lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) - lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,) + lv2: R.Tensor((8,), dtype="int64") = R.add(lv1, R.const(7, "int64")) + lv3: R.Tensor((8,), dtype="int64") = R.mod(lv2, R.const(8, "int64")) + lv4: R.Tensor((8,), dtype="int64") = R.take(lv, lv3, axis=0, mode="fast") + lv5: R.Tensor((4, 2), dtype="int64") = R.reshape(lv4, R.shape([4, 2])) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) R.output(gv) return gv @@ -3840,24 +5379,13 @@ class Expected2: @R.function def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): with R.dataflow(): - lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(1)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(1)], - end=[R.prim_value(4)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv: R.Tensor((4,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(4), R.prim_value(1), dtype="int64" ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) + lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(1, "int64")) + lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, "int64")) + lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, mode="fast") + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv3,) R.output(gv) return gv @@ -3868,43 +5396,20 @@ class Expected3: def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): with R.dataflow(): # First roll along dim=0 with shift=2 - lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(0)], - end=[R.prim_value(2)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( - x, - axes=[0], - begin=[R.prim_value(2)], - end=[R.prim_value(4)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv: R.Tensor((4,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(4), R.prim_value(1), dtype="int64" ) - lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) - + lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(2, "int64")) + lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, "int64")) + lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, mode="fast") # Second roll along dim=1 with shift=1 - lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, - axes=[1], - begin=[R.prim_value(0)], - end=[R.prim_value(1)], - strides=[R.prim_value(1)], - assume_inbound=False, + lv4: R.Tensor((2,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(2), R.prim_value(1), dtype="int64" ) - lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( - lv2, - axes=[1], - begin=[R.prim_value(1)], - end=[R.prim_value(2)], - strides=[R.prim_value(1)], - assume_inbound=False, - ) - lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) - gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) + lv5: R.Tensor((2,), dtype="int64") = R.add(lv4, R.const(1, "int64")) + lv6: R.Tensor((2,), dtype="int64") = R.mod(lv5, R.const(2, "int64")) + lv7: R.Tensor((4, 2), dtype="int64") = R.take(lv3, lv6, axis=1, mode="fast") + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv7,) R.output(gv) return gv @@ -4038,12 +5543,33 @@ def main( R.output(gv) return gv + class SliceScatterNegative(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=1, start=0, end=-2, step=1) + + @tvm.script.ir_module + class expected_slice_scatter: + @R.function + def main( + a: R.Tensor((2, 5), dtype="float32"), b: R.Tensor((2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 5), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(0), R.prim_value(3), R.prim_value(1), axis=1 + ) + gv: R.Tuple(R.Tensor((2, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), torch.randn(8, 3, 10, 10)) verify_model(SliceScatter1(), example_args, {}, expected1) example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 16)) verify_model(SliceScatter2(), example_args, {}, expected2) + example_args = (torch.randn(2, 5, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32)) + verify_model(SliceScatterNegative(), example_args, {}, expected_slice_scatter) + def test_split(): class Chunk(Module): @@ -4054,7 +5580,7 @@ def forward(self, input): class Expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -4066,7 +5592,7 @@ def main( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=1) + ) = R.split(input, indices_or_sections=[1, 2], axis=1) lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1] lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2] @@ -4086,7 +5612,7 @@ def forward(self, data): class expected1: @R.function def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + data: R.Tensor((3, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -4094,33 +5620,41 @@ def main( ): # block 0 with R.dataflow(): - lv: R.Tuple( - R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((1, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=0) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] - lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) - lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] - lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) - lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] - lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) - lv7: R.Tuple( - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv2, lv4, lv6) - lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] - lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] - lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] - gv: R.Tuple( - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv8, lv9, lv10) - R.output(gv) - return gv - + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(0),), + (R.prim_value(2),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[0]) + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[0]) + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv3, lv4, lv5) + R.output(gv) + return gv + class Unbind2(Module): def forward(self, data): return torch.unbind(data, dim=1) @@ -4129,7 +5663,7 @@ def forward(self, data): class expected2: @R.function def main( - input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + data: R.Tensor((3, 3, 10, 10), dtype="float32") ) -> R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), @@ -4137,30 +5671,38 @@ def main( ): # block 0 with R.dataflow(): - lv: R.Tuple( - R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 1, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=3, axis=1) - lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] - lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) - lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] - lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) - lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] - lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) - lv7: R.Tuple( - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv2, lv4, lv6) - lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] - lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] - lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + lv: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv2: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice( + data, + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(3),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[1]) + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[1]) gv: R.Tuple( R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), R.Tensor((3, 10, 10), dtype="float32"), - ) = (lv8, lv9, lv10) + ) = (lv3, lv4, lv5) R.output(gv) return gv @@ -4197,18 +5739,35 @@ def forward(self, input): class Expected2: @R.function def main( - inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + input: R.Tensor((3, 1, 4, 1), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): - lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[0, 1, 2, 3]) gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) R.output(gv) return gv + class Squeeze3(Module): + def forward(self, input): + return input.squeeze(2) + + @I.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 1, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[2]) + gv: R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),) verify_model(Squeeze1(), example_args, {}, Expected1) verify_model(Squeeze2(), example_args, {}, Expected2) + verify_model(Squeeze3(), example_args, {}, Expected3) def test_stack(): @@ -4232,12 +5791,13 @@ def forward(self, x, y): class Expected0: @R.function def main( - inp_0: R.Tensor((2, 3), dtype="float32"), - inp_1: R.Tensor((2, 3), dtype="float32"), + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")): with R.dataflow(): - lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=0) - gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,) + lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y), axis=0) + lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3])) + gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -4245,12 +5805,13 @@ def main( class Expected1: @R.function def main( - inp_0: R.Tensor((2, 3), dtype="float32"), - inp_1: R.Tensor((2, 3), dtype="float32"), + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")): with R.dataflow(): - lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=1) - gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,) + lv: R.Tensor((2, 6), dtype="float32") = R.concat((x, y), axis=1) + lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3])) + gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -4258,12 +5819,14 @@ def main( class Expected3: @R.function def main( - inp_0: R.Tensor((2, 3), dtype="float32"), - inp_1: R.Tensor((2, 3), dtype="float32"), + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")): with R.dataflow(): - lv: R.Tensor((2, 3, 2), dtype="float32") = R.stack((inp_0, inp_1), axis=-1) - gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv,) + lv: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(x, axis=[2]) + lv1: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(y, axis=[2]) + lv2: R.Tensor((2, 3, 2), dtype="float32") = R.concat((lv, lv1), axis=-1) + gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -4296,7 +5859,7 @@ def main( ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) + lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, repeats=[1, 2]) gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4309,7 +5872,7 @@ def main( ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, repeats=[4, 2]) gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4405,6 +5968,44 @@ def main( verify_model(View(), example_args, {}, expected1) +def test_as_strided(): + class AsStrided(Module): + def forward(self, x): + return torch.ops.aten.as_strided.default(x, (3, 2, 2), (4, 2, 1)) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 2, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 2, 2), dtype="float32") = R.reshape(x, (3, 2, 2)) + gv: R.Tuple(R.Tensor((3, 2, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AsStridedNonContiguous(Module): + def forward(self, x): + return torch.ops.aten.as_strided.default(x, (2, 2, 2), (6, 3, 1)) + + class AsStridedWithStorageOffset(Module): + def forward(self, x): + return torch.ops.aten.as_strided.default(x, (2, 2), (2, 1), 1) + + example_args = (torch.randn(2, 2, 3, dtype=torch.float32),) + verify_model(AsStrided(), example_args, {}, Expected) + + exported = export(AsStridedNonContiguous(), args=example_args) + with pytest.raises(AssertionError, match="non-contiguous stride"): + from_exported_program(exported) + + example_args = (torch.randn(2, 2, dtype=torch.float32),) + exported = export(AsStridedWithStorageOffset(), args=example_args) + with pytest.raises(AssertionError, match="storage_offset"): + from_exported_program(exported) + + def test_arange(): class Arange(Module): def forward(self, input): @@ -4502,7 +6103,7 @@ def forward(self, input): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.zeros( @@ -4516,6 +6117,27 @@ def main( verify_model(Empty(), example_args, {}, Expected) +def test_empty_without_dtype(): + class EmptyWithoutDtype(Module): + def forward(self, input): + return torch.empty((5, 5)) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 5), dtype="float32") = R.zeros(R.shape([5, 5]), dtype="float32") + gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(EmptyWithoutDtype(), example_args, {}, Expected) + + def test_fill(): class Fill(Module): def forward(self, input: torch.Tensor): @@ -4525,11 +6147,11 @@ def forward(self, input: torch.Tensor): class Expected: @R.function def main( - inp_0: R.Tensor((10, 10), dtype="float32") + input: R.Tensor((10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.full( - R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" + lv: R.Tensor((10, 10), dtype="float32") = R.full_like( + input, R.const(1.5, "float32"), dtype="void" ) gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) R.output(gv) @@ -4549,13 +6171,15 @@ def forward(self, input: torch.Tensor): class Expected: @R.function def main( - x: R.Tensor((2, 3), dtype="float32") - ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): + input: R.Tensor((2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")): with R.dataflow(): - lv: R.Tensor((2, 3), dtype="float32") = R.full( - R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32" + lv: R.Tensor((2, 3), dtype="float32") = R.full_like( + input, R.const(42.0, "float32"), dtype="void" ) - gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) + gv: R.Tuple( + R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32") + ) = (lv, lv) R.output(gv) return gv @@ -4575,9 +6199,7 @@ def main( input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool") ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.full_like( - input, R.const(0, "int32"), dtype="void" - ) + lv: R.Tensor((), dtype="float32") = R.const(0.0, "float32") lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input) gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,) R.output(gv) @@ -4597,13 +6219,13 @@ class Expected: @R.function def main( input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool") - ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.full_like( - input, R.const(1.5, "float32"), dtype="void" - ) + lv: R.Tensor((), dtype="float32") = R.const(1.5, "float32") lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input) - gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,) + gv: R.Tuple( + R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32") + ) = (lv1, lv1) R.output(gv) return gv @@ -4611,6 +6233,43 @@ def main( verify_model(Masked_Fill_Inplace(), example_args, {}, Expected) +def test_masked_select(): + class MaskedSelect(Module): + def forward(self, data: torch.Tensor, mask: torch.Tensor): + return torch.masked_select(data, mask) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((2, 3), dtype="float32"), mask: R.Tensor((2, 3), dtype="bool") + ) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)): + R.func_attr( + { + "tir_var_lower_bound": {"u0": 0, "u1": 0}, + "tir_var_upper_bound": {"u0": 6, "u1": 6}, + } + ) + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.reshape(data, R.shape([6])) + lv1: R.Tensor((6,), dtype="bool") = R.reshape(mask, R.shape([6])) + lv2: R.Tensor(dtype="int64", ndim=2) = R.nonzero(lv1) + lv3: R.Tensor(dtype="int64", ndim=1) = R.squeeze(lv2, axis=[0]) + lv4: R.Tensor(dtype="float32", ndim=1) = R.take(lv, lv3, axis=0, mode="fast") + lv5: R.Tensor((), dtype="int64") = R.const(0, "int64") + lv6: R.Tensor((), dtype="bool") = R.const(True, "bool") + lv7: R.Tensor((), dtype="bool") = R.const(True, "bool") + gv: R.Tuple(R.Tensor(dtype="float32", ndim=1)) = (lv4,) + R.output(gv) + return gv + + example_args = ( + torch.randn(2, 3, dtype=torch.float32), + torch.tensor([[True, False, True], [False, True, False]]), + ) + verify_model(MaskedSelect(), example_args, {}, Expected) + + def test_new_ones(): class NewOnes(Module): def forward(self, x): @@ -4658,6 +6317,34 @@ def main( verify_model(NewZeros(), example_args, {}, expected1) +def test_copy(): + class CopyBroadcast(Module): + def forward(self, x, src): + x.copy_(src) + return x + + @tvm.script.ir_module + class expected_copy: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), src: R.Tensor((), dtype="int64") + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.astype(src, dtype="float32") + lv1: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(lv, (2, 3)) + gv: R.Tuple( + R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32") + ) = ( + lv1, + lv1, + ) + R.output(gv) + return gv + + example_args = (torch.zeros(2, 3, dtype=torch.float32), torch.tensor(1, dtype=torch.int64)) + verify_model(CopyBroadcast(), example_args, {}, expected_copy) + + def test_to_copy(): # float class ToFloat(Module): @@ -4794,6 +6481,7 @@ def main( example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) model = Conv2D1() + exported_program = torch.export.export(model, example_args) mod = from_exported_program(exported_program, keep_params_as_input=True) mod, params = detach_params(mod) @@ -4802,9 +6490,9 @@ def main( params = params["main"] assert len(params) == len(func.params) - 1 - for param_var, param_ndarray in zip(func.params[1:], params): - assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape - assert param_var.struct_info.dtype == param_ndarray.dtype + for param_var, param_tensor in zip(func.params[1:], params): + assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape + assert param_var.struct_info.dtype == param_tensor.dtype tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy()) tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy()) @@ -4830,9 +6518,7 @@ def main( return gv example_args = (torch.randn(256, 256, dtype=torch.float32),) - exported_program = export(Identity(), args=example_args) - mod = from_exported_program(exported_program, unwrap_unit_return_tuple=True) - tvm.ir.assert_structural_equal(mod, Expected) + verify_model(Identity(), example_args, {}, Expected, unwrap_unit_return_tuple=True) def test_no_bind_return_tuple(): @@ -4860,9 +6546,56 @@ def main( torch.randn(256, 256, dtype=torch.float32), torch.randn(256, 256, dtype=torch.float32), ) - exported_program = export(Identity(), args=example_args) - mod = from_exported_program(exported_program, no_bind_return_tuple=True) - tvm.ir.assert_structural_equal(mod, Expected) + verify_model(Identity(), example_args, {}, Expected, no_bind_return_tuple=True) + + +def test_register_buffer(): + class ModelWithBuffer(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("my_buffer", torch.randn(3, 4), persistent=False) + + def forward(self, x): + return x + self.my_buffer + + example_args = (torch.randn(2, 3, 4),) + ep = export(ModelWithBuffer(), args=example_args) + # Just verify that import works. + from_exported_program(ep) + + +def test_custom_op(): + class AddOp(Module): + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((5,), dtype="float32"), + y: R.Tensor((5,), dtype="float32"), + ) -> R.Tuple(R.Tensor((5,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5,), dtype="float32") = R.subtract(x, y) + gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + from tvm.relax.frontend.torch.exported_program_translator import ( + ExportedProgramImporter, + ) + + def custom_add_converter(node: torch.fx.Node, self: ExportedProgramImporter) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + + return self.block_builder.emit(R.subtract(x, y)) + + example_args = (torch.randn(5, dtype=torch.float32), torch.randn(5, dtype=torch.float32)) + verify_model( + AddOp(), example_args, {}, Expected, custom_convert_map={"add.Tensor": custom_add_converter} + ) def test_empty_like(): @@ -4874,10 +6607,10 @@ def forward(self, data): class Expected: @R.function def main( - inp_0: R.Tensor((5,), dtype="float32"), + data: R.Tensor((5,), dtype="float32"), ) -> R.Tuple(R.Tensor((5,), dtype="float32")): with R.dataflow(): - lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, dtype="void") + lv: R.Tensor((5,), dtype="float32") = R.zeros(R.shape([5]), dtype="float32") gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4896,13 +6629,16 @@ def forward(self, indices): class Expected: @R.function def main( - inp_0: R.Tensor((5,), dtype="int64"), + indices: R.Tensor((5,), dtype="int64"), ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")): with R.dataflow(): - lv: R.Tensor((5, 10), dtype="int64") = R.one_hot( - inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1 + lv: R.Tensor((10,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" ) - gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,) + lv1: R.Tensor((5, 1), dtype="int64") = R.expand_dims(indices, axis=[-1]) + lv2: R.Tensor((5, 10), dtype="bool") = R.equal(lv1, lv) + lv3: R.Tensor((5, 10), dtype="int64") = R.astype(lv2, dtype="int64") + gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv3,) R.output(gv) return gv @@ -4923,7 +6659,9 @@ def main( input: R.Tensor((128, 128), dtype="float32") ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, dtype="void") + lv: R.Tensor((128, 128), dtype="float32") = R.full_like( + input, R.const(1, "int32"), dtype="void" + ) gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4943,10 +6681,17 @@ class Expected: @R.function def main( input: R.Tensor((128, 128), dtype="float32") - ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") - gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) + lv: R.Tensor((128, 128), dtype="float32") = R.full_like( + input, R.const(0, "int32"), dtype="void" + ) + gv: R.Tuple( + R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32") + ) = ( + lv, + lv, + ) R.output(gv) return gv @@ -4967,7 +6712,9 @@ def main( input: R.Tensor((128, 128), dtype="float32") ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): with R.dataflow(): - lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 2]), dtype="float32") + lv: R.Tensor((5, 2), dtype="float32") = R.full( + R.shape([5, 2]), R.const(0.0, "float32"), dtype="float32" + ) gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) R.output(gv) return gv @@ -4989,7 +6736,9 @@ def main( input: R.Tensor((128, 128), dtype="float32") ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): with R.dataflow(): - lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") + lv: R.Tensor((128, 128), dtype="float32") = R.full_like( + input, R.const(0, "int32"), dtype="void" + ) gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) R.output(gv) return gv @@ -5173,12 +6922,15 @@ def main( data: R.Tensor((64,), dtype="float32"), indices_0: R.Tensor((128,), dtype="int64"), values: R.Tensor((128,), dtype="float32"), - ) -> R.Tuple(R.Tensor((64,), dtype="float32")): + ) -> R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")): with R.dataflow(): lv: R.Tensor((64,), dtype="float32") = R.index_put( data, R.tuple(indices_0), values, accumulate=False ) - gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = ( + lv, + lv, + ) R.output(gv) return gv @@ -5203,12 +6955,14 @@ def main( indices_0: R.Tensor((128,), dtype="int64"), indices_1: R.Tensor((128,), dtype="int64"), values: R.Tensor((128,), dtype="float32"), - ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")): + ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")): with R.dataflow(): lv: R.Tensor((32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1), values, accumulate=False ) - gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,) + gv: R.Tuple( + R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32") + ) = (lv, lv) R.output(gv) return gv @@ -5235,12 +6989,16 @@ def main( indices_1: R.Tensor((128,), dtype="int64"), indices_2: R.Tensor((128,), dtype="int64"), values: R.Tensor((128,), dtype="float32"), - ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")): + ) -> R.Tuple( + R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32") + ): with R.dataflow(): lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False ) - gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,) + gv: R.Tuple( + R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32") + ) = (lv, lv) R.output(gv) return gv @@ -5269,7 +7027,10 @@ def main( indices_2: R.Tensor((128,), dtype="int64"), indices_3: R.Tensor((128,), dtype="int64"), values: R.Tensor((128,), dtype="float32"), - ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")): + ) -> R.Tuple( + R.Tensor((8, 16, 32, 64), dtype="float32"), + R.Tensor((8, 16, 32, 64), dtype="float32"), + ): with R.dataflow(): lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( data, @@ -5277,7 +7038,10 @@ def main( values, accumulate=False, ) - gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,) + gv: R.Tuple( + R.Tensor((8, 16, 32, 64), dtype="float32"), + R.Tensor((8, 16, 32, 64), dtype="float32"), + ) = (lv, lv) R.output(gv) return gv @@ -5308,7 +7072,10 @@ def main( indices_3: R.Tensor((128,), dtype="int64"), indices_4: R.Tensor((128,), dtype="int64"), values: R.Tensor((128,), dtype="float32"), - ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")): + ) -> R.Tuple( + R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + ): with R.dataflow(): lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( data, @@ -5316,55 +7083,226 @@ def main( values, accumulate=False, ) - gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) = (lv,) + gv: R.Tuple( + R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + ) = (lv, lv) R.output(gv) return gv - # Run verification for each case - verify_model(IndexPut1D(), example_args_1d, {}, Expected1D) - verify_model(IndexPut2D(), example_args_2d, {}, Expected2D) - verify_model(IndexPut3D(), example_args_3d, {}, Expected3D) - verify_model(IndexPut4D(), example_args_4d, {}, Expected4D) - verify_model(IndexPut5D(), example_args_5d, {}, Expected5D) - - -def test_flip(): - class Flip0(Module): - def forward(self, data): - return torch.flip(data, [0]) + # Test case 6: 2D input with multi-dimensional index (broadcasting) + # This tests the multi-dimensional index support with broadcasting + class IndexPutBroadcast1D(Module): + def forward(self, data, indices_1): + indices_0 = torch.arange(data.shape[0]).unsqueeze(1) + values = torch.ones(data.shape[0], len(indices_1), dtype=data.dtype) + return data.index_put_((indices_0, indices_1), values, accumulate=False) - class Flip1(Module): - def forward(self, data): - return torch.flip(data, [1]) + example_args_broadcast1 = ( + torch.randn(32, 64, dtype=torch.float32), + torch.randint(0, 64, (10,), dtype=torch.int64), + ) - @tvm.script.ir_module - class Expected0: + @I.ir_module + class ExpectedBroadcast1D: @R.function def main( - inp_0: R.Tensor((2, 2), dtype="float32"), - ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")): + data: R.Tensor((32, 64), dtype="float32"), + indices_1: R.Tensor((10,), dtype="int64"), + ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")): with R.dataflow(): - lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0) - gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,) + lv: R.Tensor((32,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(32), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((32, 1), dtype="int64") = R.expand_dims(lv, axis=[1]) + lv2: R.Tensor((32, 10), dtype="float32") = R.full( + R.shape([32, 10]), R.const(1.0, "float32"), dtype="float32" + ) + lv3: R.Tensor((32, 64), dtype="float32") = R.index_put( + data, R.tuple(lv1, indices_1), lv2, accumulate=False + ) + gv: R.Tuple( + R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32") + ) = (lv3, lv3) R.output(gv) return gv - @tvm.script.ir_module - class Expected1: + # Test case 7: 2D input with multi-dimensional index (second position) + class IndexPutBroadcast2D(Module): + def forward(self, data, indices_0): + indices_1 = torch.arange(data.shape[1]).unsqueeze(1) + values = torch.ones(len(indices_0), data.shape[1], dtype=data.dtype) + return data.index_put_((indices_0, indices_1), values, accumulate=False) + + example_args_broadcast2 = ( + torch.randn(32, 64, dtype=torch.float32), + torch.randint(0, 32, (10,), dtype=torch.int64), + ) + + @I.ir_module + class ExpectedBroadcast2D: @R.function def main( - inp_0: R.Tensor((2, 2), dtype="float32"), - ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")): + data: R.Tensor((32, 64), dtype="float32"), + indices_0: R.Tensor((10,), dtype="int64"), + ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")): with R.dataflow(): - lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1) - gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,) + lv: R.Tensor((64,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(64), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv, axis=[1]) + lv2: R.Tensor((10, 64), dtype="float32") = R.full( + R.shape([10, 64]), R.const(1.0, "float32"), dtype="float32" + ) + lv3: R.Tensor((32, 64), dtype="float32") = R.index_put( + data, R.tuple(indices_0, lv1), lv2, accumulate=False + ) + gv: R.Tuple( + R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32") + ) = (lv3, lv3) R.output(gv) return gv - example_args = (torch.randn(2, 2, dtype=torch.float32),) + # Test case 8: 3D input with mixed 1D and 2D indices + class IndexPutBroadcast3D(Module): + def forward(self, data, indices_1): + indices_0 = torch.arange(data.shape[0]).unsqueeze(1) + indices_2 = torch.arange(data.shape[2]).unsqueeze(1) + values = torch.ones(data.shape[0], len(indices_1), data.shape[2], dtype=data.dtype) + return data.index_put_((indices_0, indices_1, indices_2), values, accumulate=False) - verify_model(Flip0(), example_args, {}, Expected0) - verify_model(Flip1(), example_args, {}, Expected1) + example_args_broadcast3d = ( + torch.randn(16, 32, 64, dtype=torch.float32), + torch.randint(0, 32, (10,), dtype=torch.int64), + ) + + @I.ir_module + class ExpectedBroadcast3D: + @R.function + def main( + data: R.Tensor((16, 32, 64), dtype="float32"), + indices_1: R.Tensor((10,), dtype="int64"), + ) -> R.Tuple( + R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((16,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(16), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((16, 1), dtype="int64") = R.expand_dims(lv, axis=[1]) + lv2: R.Tensor((64,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(64), R.prim_value(1), dtype="int64" + ) + lv3: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv2, axis=[1]) + lv4: R.Tensor((16, 10, 64), dtype="float32") = R.full( + R.shape([16, 10, 64]), R.const(1.0, "float32"), dtype="float32" + ) + lv5: R.Tensor((16, 32, 64), dtype="float32") = R.index_put( + data, R.tuple(lv1, indices_1, lv3), lv4, accumulate=False + ) + gv: R.Tuple( + R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32") + ) = (lv5, lv5) + R.output(gv) + return gv + + # Test case 9: batched indexing with slice (e.g., M[:, rows, cols] = x) + class IndexPutBatchedWithNone(Module): + def forward(self, x): + B = x.size(0) + M = torch.zeros(B, 11, 11) + rows = torch.arange(10) + cols = rows + 1 + M[:, rows, cols] = x # Batched index assignment + return M + + example_args_batched_none = (torch.randn(2, 10, dtype=torch.float32),) + + @I.ir_module + class ExpectedBatchedWithNone: + @R.function + def main( + x: R.Tensor((2, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 11, 11), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 11, 11), dtype="float32") = R.full( + R.shape([2, 11, 11]), R.const(0.0, "float32"), dtype="float32" + ) + lv1: R.Tensor((10,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" + ) + lv2: R.Tensor((10,), dtype="int64") = R.add(lv1, R.const(1, "int64")) + lv3: R.Tensor((2, 11, 11), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv4: R.Tensor((2,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(2), R.prim_value(1), dtype="int64" + ) + lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4, R.shape([2, 1])) + lv6: R.Tensor((2, 11, 11), dtype="float32") = R.index_put( + lv3, (lv5, lv1, lv2), x, accumulate=False + ) + lv7: R.Tensor((2, 11, 11), dtype="float32") = R.slice_scatter( + lv, lv6, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=0 + ) + gv: R.Tuple(R.Tensor((2, 11, 11), dtype="float32")) = (lv7,) + R.output(gv) + return gv + + # Run verification for each case + verify_model(IndexPut1D(), example_args_1d, {}, Expected1D) + verify_model(IndexPut2D(), example_args_2d, {}, Expected2D) + verify_model(IndexPut3D(), example_args_3d, {}, Expected3D) + verify_model(IndexPut4D(), example_args_4d, {}, Expected4D) + verify_model(IndexPut5D(), example_args_5d, {}, Expected5D) + verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {}, ExpectedBroadcast1D) + verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {}, ExpectedBroadcast2D) + verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {}, ExpectedBroadcast3D) + verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, ExpectedBatchedWithNone) + + +def test_flip(): + class Flip0(Module): + def forward(self, data): + return torch.flip(data, [0]) + + class Flip1(Module): + def forward(self, data): + return torch.flip(data, [1]) + + @tvm.script.ir_module + class Expected0: + @R.function + def main( + inp_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0) + gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1) + gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 2, dtype=torch.float32),) + + verify_model(Flip0(), example_args, {}, Expected0) + verify_model(Flip1(), example_args, {}, Expected1) def test_take(): @@ -5376,12 +7314,12 @@ def forward(self, data, indices): class Expected: @R.function def main( - inp_0: R.Tensor((5,), dtype="float32"), - inp_1: R.Tensor((3,), dtype="int64"), + data: R.Tensor((5,), dtype="float32"), + indices: R.Tensor((3,), dtype="int64"), ) -> R.Tuple(R.Tensor((3,), dtype="float32")): with R.dataflow(): - lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, dtype="int32") - lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv, axis=None) + lv: R.Tensor((5,), dtype="float32") = R.reshape(data, R.shape([5])) + lv1: R.Tensor((3,), dtype="float32") = R.take(lv, indices, axis=0, mode="fast") gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -5394,6 +7332,29 @@ def main( verify_model(Take(), example_args, {}, Expected) +def test_any(): + class AnyAten(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.any(x, dim=1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3), dtype="bool"), + ) -> R.Tuple(R.Tensor((2,), dtype="bool")): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="int8") = relax.op.astype(x, dtype="int8") + lv2: R.Tensor((2,), dtype="int8") = relax.op.max(lv, axis=1, keepdims=False) + lv3: R.Tensor((2,), dtype="bool") = relax.op.astype(lv2, dtype="bool") + gv: R.Tuple(R.Tensor((2,), dtype="bool")) = (lv3,) + R.output(gv) + return gv + + example_args = (torch.tensor([[0, 0, 0], [0, 1, 0]], dtype=torch.bool),) + verify_model(AnyAten(), example_args, {}, Expected) + + def test_std(): class Std(Module): def forward(self, x): @@ -5403,11 +7364,12 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((5, 3), dtype="float32"), + x: R.Tensor((5, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): - lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, keepdims=False) - gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False) + lv1: R.Tensor((), dtype="float32") = R.sqrt(lv) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -5424,10 +7386,10 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((5, 3), dtype="float32"), + x: R.Tensor((5, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): - lv: R.Tensor((), dtype="float32") = R.variance(inp_0, axis=None, keepdims=False) + lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False) gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) R.output(gv) return gv @@ -5445,10 +7407,10 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((5, 3), dtype="float32"), + x: R.Tensor((5, 3), dtype="float32"), ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): - lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, keepdims=False) + lv: R.Tensor((), dtype="float32") = R.prod(x, axis=None, keepdims=False) gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) R.output(gv) return gv @@ -5542,7 +7504,13 @@ def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype lv: R.Tensor((5, 3), dtype="int32") = R.argsort( x, axis=1, descending=True, dtype="int32" ) - gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,) + lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, lv, axis=1) + lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = ( + lv1, + lv, + ) + lv3: R.Tensor((5, 3), dtype="int32") = lv2[1] + gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,) R.output(gv) return gv @@ -5592,6 +7560,7 @@ def main( lhs: R.Tensor((B, 4), dtype="float32"), rhs: R.Tensor((B, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")): + R.func_attr({"tir_var_lower_bound": {"s0": 0}}) with R.dataflow(): lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs) gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,) @@ -5602,7 +7571,14 @@ def main( batch = torch.export.Dim("batch") dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} - verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes) + verify_model( + DynamicModel(), + example_args, + {}, + Expected, + dynamic_shapes=dynamic_shapes, + run_ep_decomposition=True, + ) def test_broadcast_to(): @@ -5644,6 +7620,7 @@ def main( (R.prim_value(1),), (R.prim_value(0),), (R.prim_value(2),), + (R.prim_value(1),), assume_inbound=False, ) gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) @@ -5798,8 +7775,20 @@ def main( input: R.Tensor((3, 5), dtype="float32") ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")): with R.dataflow(): - lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32") - gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,) + lv: R.Tensor((3,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((5,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" + ) + lv2: R.Tensor((3, 1), dtype="int64") = R.expand_dims(lv, axis=[-1]) + lv3: R.Tensor((3, 5), dtype="bool") = R.equal(lv2, lv1) + lv4: R.Tensor((1,), dtype="float32") = R.full( + R.shape([1]), R.const(1.0, "float32"), dtype="float32" + ) + lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv6: R.Tensor((3, 5), dtype="float32") = R.where(lv3, lv4, lv5) + gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv6,) R.output(gv) return gv @@ -5814,8 +7803,20 @@ def main( input: R.Tensor((5,), dtype="float32") ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")): with R.dataflow(): - lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32") - gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,) + lv: R.Tensor((5,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((5,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" + ) + lv2: R.Tensor((5, 1), dtype="int64") = R.expand_dims(lv, axis=[-1]) + lv3: R.Tensor((5, 5), dtype="bool") = R.equal(lv2, lv1) + lv4: R.Tensor((1,), dtype="float32") = R.full( + R.shape([1]), R.const(1.0, "float32"), dtype="float32" + ) + lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv6: R.Tensor((5, 5), dtype="float32") = R.where(lv3, lv4, lv5) + gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv6,) R.output(gv) return gv @@ -5839,16 +7840,34 @@ def forward(self, x): @tvm.script.ir_module class Expected1: @R.function - def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): + def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1) - lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss( - lv, - targets=R.const([0, 1, 2, 1], dtype="int64"), - reduction="mean", - ignore_index=-100, + lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32") + lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1) + lv2: R.Tensor((4,), dtype="bool") = R.not_equal( + R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") ) - gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) + lv3: R.Tensor((), dtype="int64") = R.const(0, "int64") + lv4: R.Tensor((4,), dtype="int64") = R.where( + lv2, R.const([0, 1, 2, 1], dtype="int64"), lv3 + ) + lv5: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv4, axis=[1]) + lv6: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv1, lv5, axis=1) + lv7: R.Tensor((4,), dtype="float32") = R.squeeze(lv6, axis=[1]) + lv8: R.Tensor((4,), dtype="float32") = R.negative(lv7) + lv9: R.Tensor((4,), dtype="bool") = R.not_equal( + R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") + ) + lv10: R.Tensor((), dtype="float32") = R.const(0.0, "float32") + lv11: R.Tensor((4,), dtype="float32") = R.where(lv9, lv8, lv10) + lv12: R.Tensor((4,), dtype="bool") = R.not_equal( + R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") + ) + lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False) + lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32") + lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False) + lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14) + gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,) R.output(gv) return gv @@ -5868,8 +7887,19 @@ def main( input: R.Tensor((9, 9), dtype="float32") ) -> R.Tuple(R.Tensor((9,), dtype="float32")): with R.dataflow(): - lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32") - gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,) + lv: R.Tensor((9,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(9), R.prim_value(1), dtype="int64" + ) + lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, "int64")) + lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, dtype="float32") + lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, R.const(0.125, "float32")) + lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, "float32")) + lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, "int64"), lv) + lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, dtype="float32") + lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, R.const(0.125, "float32")) + lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv7) + lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8) + gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,) R.output(gv) return gv @@ -5914,6 +7944,727 @@ def main( verify_model(Model(), example_args, {}, Expected) +def test_mm(): + class MatrixMultiply(Module): + def forward(self, a, b): + return torch.mm(a, b) + + example_args = ( + torch.randn(2, 3, dtype=torch.float32), + torch.randn(3, 4, dtype=torch.float32), + ) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + a: R.Tensor((2, 3), dtype="float32"), + b: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, out_dtype="float32") + gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(MatrixMultiply(), example_args, {}, Expected) + + +def test_sparse_mm(): + class SparseMatrixMultiply(Module): + def forward(self, sparse_input, dense_input): + return torch.sparse.mm(sparse_input, dense_input) + + indices = torch.tensor([[0, 1, 2], [2, 0, 1]]) + values = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + sparse_input = torch.sparse_coo_tensor(indices, values, size=(3, 100)) + dense_input = torch.randn(100, 50, dtype=torch.float32) + + example_args = (sparse_input, dense_input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + sparse_input: R.Tensor((3, 100), dtype="float32"), + dense_input: R.Tensor((100, 50), dtype="float32"), + ) -> R.Tuple(R.Tensor((3, 50), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 50), dtype="float32") = R.full( + R.shape([3, 50]), R.const(0.0, "float32"), dtype="float32" + ) + lv1: R.Tensor((3, 50), dtype="float32") = R.matmul( + sparse_input, dense_input, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((3, 50), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(SparseMatrixMultiply(), example_args, {}, Expected) + + +@tvm.testing.requires_llvm +def test_lstm(): + class LSTM(nn.Module): + def __init__(self, input_size, hidden_size, batch_first, bidirectional): + super().__init__() + self.lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=bidirectional, + ) + + def forward(self, x): + y, _ = self.lstm(x) + return y + + # Unidirectional LSTM with batch_first=True + torch.manual_seed(42) + x = torch.randn(2, 3, 4, dtype=torch.float32) + verify_model_numerically(LSTM(4, 8, batch_first=True, bidirectional=False), (x,)) + + # Unidirectional LSTM with batch_first=False + torch.manual_seed(43) + x2 = torch.randn(4, 2, 3, dtype=torch.float32) + verify_model_numerically(LSTM(3, 6, batch_first=False, bidirectional=False), (x2,)) + + # Bidirectional LSTM with batch_first=True + torch.manual_seed(44) + x3 = torch.randn(2, 3, 4, dtype=torch.float32) + verify_model_numerically(LSTM(4, 8, batch_first=True, bidirectional=True), (x3,)) + + # Bidirectional LSTM with batch_first=False + torch.manual_seed(45) + x4 = torch.randn(4, 2, 3, dtype=torch.float32) + verify_model_numerically(LSTM(3, 6, batch_first=False, bidirectional=True), (x4,)) + + +def test_tensor_none_tuple(): + example_args = (torch.tensor([1.0, 2.0, 3.0]),) + + class TensorNoneModel(Module): + def forward(self, x): + return x + 1, None + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((3,), dtype="float32") + ) -> R.Tuple(R.Tensor((3,), dtype="float32"), R.Object): + with R.dataflow(): + lv: R.Tensor((3,), dtype="float32") = R.add(x, R.const(1.0, "float32")) + gv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Object) = (lv, R.null_value()) + R.output(gv) + return gv + + verify_model(TensorNoneModel(), example_args, {}, Expected) + + +def test_gru(): + class BasicGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=4, + hidden_size=8, + num_layers=1, + batch_first=True, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(42) + x = torch.randn(2, 3, 4, dtype=torch.float32) + model = BasicGRU() + with torch.no_grad(): + pytorch_output = model(x) + exported_program = export(model, args=(x,)) + mod = from_exported_program(exported_program) + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_tvm = tvm.runtime.tensor(x.numpy()) + tvm_output = vm["main"](x_tvm) + if hasattr(tvm_output, "numpy"): + tvm_output_np = tvm_output.numpy() + else: + tvm_output_np = tvm_output[0].numpy() + assert ( + pytorch_output.shape == tvm_output_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" + tvm.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) + + class SeqFirstGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=3, + hidden_size=6, + num_layers=1, + batch_first=False, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(43) + x2 = torch.randn(4, 2, 3, dtype=torch.float32) + model2 = SeqFirstGRU() + with torch.no_grad(): + pytorch_output2 = model2(x2) + exported_program2 = export(model2, args=(x2,)) + mod2 = from_exported_program(exported_program2) + ex2 = relax.build(mod2, target) + vm2 = relax.VirtualMachine(ex2, tvm.cpu()) + x2_tvm = tvm.runtime.tensor(x2.numpy()) + tvm_output2 = vm2["main"](x2_tvm) + if hasattr(tvm_output2, "numpy"): + tvm_output2_np = tvm_output2.numpy() + else: + tvm_output2_np = tvm_output2[0].numpy() + assert pytorch_output2.shape == tvm_output2_np.shape + tvm.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) + + # Test bidirectional GRU with batch_first=True + class BidirectionalGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=4, + hidden_size=5, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(44) + x3 = torch.randn(2, 3, 4, dtype=torch.float32) + model3 = BidirectionalGRU() + with torch.no_grad(): + pytorch_output3 = model3(x3) + + # Verify output shape is correct (hidden_size * 2 due to bidirectional) + assert pytorch_output3.shape == ( + 2, + 3, + 10, + ), f"Expected shape (2, 3, 10), got {pytorch_output3.shape}" + + exported_program3 = export(model3, args=(x3,)) + mod3 = from_exported_program(exported_program3) + ex3 = relax.build(mod3, target) + vm3 = relax.VirtualMachine(ex3, tvm.cpu()) + x3_tvm = tvm.runtime.tensor(x3.numpy()) + tvm_output3 = vm3["main"](x3_tvm) + if hasattr(tvm_output3, "numpy"): + tvm_output3_np = tvm_output3.numpy() + else: + tvm_output3_np = tvm_output3[0].numpy() + assert ( + pytorch_output3.shape == tvm_output3_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output3.shape} vs TVM {tvm_output3_np.shape}" + tvm.testing.assert_allclose(pytorch_output3.numpy(), tvm_output3_np, rtol=1e-4, atol=1e-5) + + # Test bidirectional GRU with batch_first=False + class SeqFirstBidirectionalGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=3, + hidden_size=4, + num_layers=1, + batch_first=False, + bidirectional=True, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(45) + x4 = torch.randn(4, 2, 3, dtype=torch.float32) # (seq_len, batch, input_size) + model4 = SeqFirstBidirectionalGRU() + with torch.no_grad(): + pytorch_output4 = model4(x4) + + # Verify output shape (seq_len, batch, hidden_size * 2) + assert pytorch_output4.shape == ( + 4, + 2, + 8, + ), f"Expected shape (4, 2, 8), got {pytorch_output4.shape}" + + exported_program4 = export(model4, args=(x4,)) + mod4 = from_exported_program(exported_program4) + ex4 = relax.build(mod4, target) + vm4 = relax.VirtualMachine(ex4, tvm.cpu()) + x4_tvm = tvm.runtime.tensor(x4.numpy()) + tvm_output4 = vm4["main"](x4_tvm) + if hasattr(tvm_output4, "numpy"): + tvm_output4_np = tvm_output4.numpy() + else: + tvm_output4_np = tvm_output4[0].numpy() + assert pytorch_output4.shape == tvm_output4_np.shape + tvm.testing.assert_allclose(pytorch_output4.numpy(), tvm_output4_np, rtol=1e-4, atol=1e-5) + + +def test_dynamic_shape_with_range_constraints(): + class DynamicModel(torch.nn.Module): + def forward(self, x1, x2): + return torch.ops.aten.add.Tensor(x1, x2) + + @I.ir_module + class Expected: + @R.function + def main( + x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_lower_bound": {"s0": 1}, "tir_var_upper_bound": {"s0": 64}}) + with R.dataflow(): + lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2) + gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(8, 4), torch.randn(8, 4)) + batch = torch.export.Dim("batch", min=1, max=64) + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + verify_model( + DynamicModel(), + example_args, + {}, + Expected, + dynamic_shapes=dynamic_shapes, + map_free_vars=True, + ) + + +def test_dynamic_shape_with_addition_constraints(): + class ConcatModel(torch.nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=0) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + s0___1 = T.int64(is_size_var=True) + R.func_attr( + { + "tir_var_lower_bound": {"s0": 1, "s0___1": 2}, + "tir_var_upper_bound": {"s0": 64, "s0___1": 65}, + } + ) + with R.dataflow(): + lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, y), axis=0) + gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + batch = torch.export.Dim("batch", min=1, max=64) + example_args = (torch.randn(8, 4), torch.randn(9, 4)) + dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}} + + verify_model( + ConcatModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes, map_free_vars=True + ) + + +def test_dynamic_shape_with_subtraction_constraints(): + class ConcatModel(torch.nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=0) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s1___1", 4), dtype="float32"), y: R.Tensor(("s1", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s1___1 + s1", 4), dtype="float32")): + s1___1 = T.int64(is_size_var=True) + s1 = T.int64(is_size_var=True) + R.func_attr( + { + "tir_var_lower_bound": {"s1": 0, "s1___1": 1}, + "tir_var_upper_bound": {"s1": 63, "s1___1": 64}, + } + ) + with R.dataflow(): + lv: R.Tensor((s1___1 + s1, 4), dtype="float32") = R.concat((x, y), axis=0) + gv: R.Tuple(R.Tensor((s1___1 + s1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + batch = torch.export.Dim("batch", min=1, max=64) + example_args = (torch.randn(8, 4), torch.randn(7, 4)) + dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}} + + verify_model( + ConcatModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes, map_free_vars=True + ) + + +def test_dynamic_shape_with_multiplication_constraints(): + class ConcatModel(torch.nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=0) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0_2", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0 + s0_2", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + s0_2 = T.int64(is_size_var=True) + R.func_attr( + { + "tir_var_lower_bound": {"s0": 1, "s0_2": 2}, + "tir_var_upper_bound": {"s0": 64, "s0_2": 128}, + } + ) + with R.dataflow(): + lv: R.Tensor((s0 + s0_2, 4), dtype="float32") = R.concat((x, y), axis=0) + gv: R.Tuple(R.Tensor((s0 + s0_2, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + batch = torch.export.Dim("batch", min=1, max=64) + example_args = (torch.randn(8, 4), torch.randn(16, 4)) + dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}} + + verify_model( + ConcatModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes, map_free_vars=True + ) + + +def test_dynamic_shape_with_unbounded_constraints(): + class DynamicModel(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.add.Tensor(x, x) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s0", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_lower_bound": {"s0": 2}}) + with R.dataflow(): + lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x) + gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(8, 4),) + batch = torch.export.Dim("batch", min=2) + dynamic_shapes = {"x": {0: batch}} + + verify_model( + DynamicModel(), + example_args, + {}, + Expected, + dynamic_shapes=dynamic_shapes, + map_free_vars=True, + ) + + +def test_sym_size_int(): + class SymSizeInt(Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + # TODO(@mshr-h): `torch.ops.aten.sym_size.int(x, self.dim)` would be ideal, but currently + # the ep frontend is not able to handle it. + return torch.add(x[0], torch.ops.aten.sym_size.int(x, self.dim)) + + @I.ir_module + class Expected1: + @R.function + def main( + x: R.Tensor((1, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.take( + x, R.const(0, "int64"), axis=0, mode="fast" + ) + lv1: R.Tensor((3, 4), dtype="float32") = R.add(lv, R.const(3.0, "float32")) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args_1 = (torch.randn(1, 3, 4),) + verify_model(SymSizeInt(dim=1), example_args_1, {}, Expected1) + verify_model(SymSizeInt(dim=-2), example_args_1, {}, Expected1) + + class SymSizeIntDynamic(Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + shape_dim = torch.ops.aten.sym_size.int(x, self.dim) + return x.reshape(shape_dim, -1) + + @I.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor(("s0", 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")): + s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_lower_bound": {"s0": 0}}) + with R.dataflow(): + lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x, R.shape([s0, 12])) + gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args_2 = (torch.randn(2, 3, 4),) + dynamic_shapes = {"x": {0: torch.export.Dim("dim")}} + verify_model( + SymSizeIntDynamic(dim=0), example_args_2, {}, Expected2, dynamic_shapes=dynamic_shapes + ) + + +def test_exponential(): + class Exponential(Module): + def forward(self, x): + return x.exponential_() + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((4, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 8), dtype="float32"), R.Tensor((4, 8), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 8), dtype="float32") = R.zeros_like(x, dtype="void") + gv: R.Tuple( + R.Tensor((4, 8), dtype="float32"), R.Tensor((4, 8), dtype="float32") + ) = (lv, lv) + R.output(gv) + return gv + + example_args = (torch.randn(4, 8, dtype=torch.float32),) + verify_model(Exponential(), example_args, {}, Expected) + + +def test_max_dim(): + class MaxDim1(Module): + def forward(self, x): + return torch.max(x, dim=1) + + class MaxDim2(Module): + def forward(self, x): + return torch.max(x, dim=1, keepdim=True) + + @I.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((4, 8, 16), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 16), dtype="float32"), R.Tensor((4, 16), dtype="int64")): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64") + ) = R.topk(x, k=1, axis=1, ret_type="both", largest=True, dtype="int64") + lv1: R.Tensor((4, 1, 16), dtype="float32") = lv[0] + lv2: R.Tensor((4, 16), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((4, 1, 16), dtype="int64") = lv[1] + lv4: R.Tensor((4, 16), dtype="int64") = R.squeeze(lv3, axis=[1]) + lv5: R.Tuple( + R.Tensor((4, 16), dtype="float32"), R.Tensor((4, 16), dtype="int64") + ) = (lv2, lv4) + lv6: R.Tensor((4, 16), dtype="float32") = lv5[0] + lv7: R.Tensor((4, 16), dtype="int64") = lv5[1] + gv: R.Tuple( + R.Tensor((4, 16), dtype="float32"), R.Tensor((4, 16), dtype="int64") + ) = (lv6, lv7) + R.output(gv) + return gv + + @I.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((4, 8, 16), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64")): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64") + ) = R.topk(x, k=1, axis=1, ret_type="both", largest=True, dtype="int64") + lv1: R.Tensor((4, 1, 16), dtype="float32") = lv[0] + lv2: R.Tensor((4, 1, 16), dtype="int64") = lv[1] + lv3: R.Tuple( + R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64") + ) = (lv1, lv2) + lv4: R.Tensor((4, 1, 16), dtype="float32") = lv3[0] + lv5: R.Tensor((4, 1, 16), dtype="int64") = lv3[1] + gv: R.Tuple( + R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64") + ) = (lv4, lv5) + R.output(gv) + return gv + + example_args = (torch.randn(4, 8, 16, dtype=torch.float32),) + verify_model(MaxDim1(), example_args, {}, expected1) + verify_model(MaxDim2(), example_args, {}, expected2) + + +def test_alias(): + class Alias(Module): + def forward(self, x): + return torch.ops.aten.alias(x) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((4, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 8), dtype="float32")): + with R.dataflow(): + gv: R.Tuple(R.Tensor((4, 8), dtype="float32")) = (x,) + R.output(gv) + return gv + + example_args = (torch.randn(4, 8, dtype=torch.float32),) + verify_model(Alias(), example_args, {}, Expected) + + +def test_scatter_value(): + class ScatterValue(Module): + def forward(self, x, index): + return x.scatter(1, index, 0.5) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((4, 8), dtype="float32"), + index: R.Tensor((4, 2), dtype="int64"), + ) -> R.Tuple(R.Tensor((4, 8), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 2), dtype="float32") = R.broadcast_to( + R.const(0.5, "float32"), R.shape([4, 2]) + ) + lv1: R.Tensor((4, 8), dtype="float32") = R.scatter_elements(x, index, lv, axis=1) + gv: R.Tuple(R.Tensor((4, 8), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = ( + torch.randn(4, 8, dtype=torch.float32), + torch.randint(0, 8, (4, 2), dtype=torch.int64), + ) + verify_model(ScatterValue(), example_args, {}, Expected) + + +def test_grid_sample(): + class GridSample(Module): + def forward(self, input, grid): + return torch.nn.functional.grid_sample( + input, grid, mode="bilinear", padding_mode="zeros", align_corners=True + ) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 4, 4), dtype="float32"), + grid: R.Tensor((1, 2, 2, 2), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 2, 2), dtype="float32") = R.image.grid_sample( + input_1, + grid, + method="bilinear", + layout="NCHW", + padding_mode="zeros", + align_corners=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 3, 4, 4, dtype=torch.float32), + torch.randn(1, 2, 2, 2, dtype=torch.float32), + ) + verify_model(GridSample(), example_args, {}, expected) + + +def test_upsample_nearest2d(): + class UpsampleNearest2dScale(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest") + + class UpsampleNearest2dSize(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(20, 20), mode="nearest") + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + @tvm.script.ir_module + class expected_scale: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 20, 20), dtype="float32") = R.image.resize2d( + input_1, + size=(20, 20), + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel", + ) + gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_size: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 20, 20), dtype="float32") = R.image.resize2d( + input_1, + size=(20, 20), + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel", + ) + gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(UpsampleNearest2dScale(), example_args, {}, expected_scale) + verify_model(UpsampleNearest2dSize(), example_args, {}, expected_size) + + if __name__ == "__main__": tvm.testing.main() -1 diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 47ca0819a9c8..b7aeea6687e8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -38,7 +38,7 @@ def verify_model(torch_model, input_info, binding, expected): graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): mod = from_fx(graph_model, input_info) - binding = {k: tvm.nd.array(v) for k, v in binding.items()} + binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) tvm.ir.assert_structural_equal(mod, expected) @@ -1434,6 +1434,7 @@ def main( strides=[1, 1], dilation=[1, 1], padding=[0, 0, 0, 0], + count_include_pad=True, layout="NCHW", out_layout="NCHW", ) @@ -1467,6 +1468,7 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=True, + count_include_pad=True, layout="NCHW", out_layout="NCHW", ) @@ -1490,6 +1492,7 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): dilation=[1, 1], padding=[0, 0, 0, 0], ceil_mode=False, + count_include_pad=True, layout="NCHW", out_layout="NCHW", ) @@ -2749,6 +2752,27 @@ def main( verify_model(Unary(), input_info, {}, expected_unary) +def test_sqrt_integer_input_fx(): + input_info = [([1, 4], "int64")] + + class SqrtIntModel(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input_1: R.Tensor((1, 4), dtype="int64")) -> R.Tensor((1, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, dtype="float32") + lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv) + gv: R.Tensor((1, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(SqrtIntModel(), input_info, {}, expected) + + operator_bool_unary = [ (torch.isnan, R.isnan), (torch.isinf, R.isinf), @@ -3646,6 +3670,121 @@ def main( verify_model(Interpolate4(), input_info, {}, expected4) +def test_interpolate_nhwc_layout(): + # First verify backward compatibility - default should still be NCHW + input_info_nchw = [([1, 3, 10, 10], "float32")] + + class InterpolateDefault(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (5, 5)) + + @tvm.script.ir_module + class expected_default_nchw: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + # Verify default behavior (no default_image_layout parameter) uses NCHW + graph_model_default = fx.symbolic_trace(InterpolateDefault()) + with torch.no_grad(): + mod_default = from_fx(graph_model_default, input_info_nchw) + tvm.ir.assert_structural_equal(mod_default, expected_default_nchw) + + # Now test NHWC layout + input_info = [([1, 10, 10, 3], "float32")] + + class InterpolateNHWC(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (5, 5)) + + @tvm.script.ir_module + class expected_nhwc: + @R.function + def main( + input_1: R.Tensor((1, 10, 10, 3), dtype="float32") + ) -> R.Tensor((1, 5, 5, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 5, 5, 3), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NHWC", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 5, 5, 3), dtype="float32") = lv + R.output(gv) + return gv + + # Test with NHWC layout + graph_model = fx.symbolic_trace(InterpolateNHWC()) + with torch.no_grad(): + mod = from_fx(graph_model, input_info, default_image_layout="NHWC") + tvm.ir.assert_structural_equal(mod, expected_nhwc) + + # Test with bilinear interpolation and NHWC layout + class InterpolateNHWC2(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, size=None, scale_factor=2.0, mode="bilinear", align_corners=False + ) + + @tvm.script.ir_module + class expected_nhwc2: + @R.function + def main( + input_1: R.Tensor((1, 10, 10, 3), dtype="float32") + ) -> R.Tensor((1, 20, 20, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 20, 20, 3), dtype="float32") = R.image.resize2d( + input_1, + (20, 20), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NHWC", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 20, 20, 3), dtype="float32") = lv + R.output(gv) + return gv + + graph_model2 = fx.symbolic_trace(InterpolateNHWC2()) + with torch.no_grad(): + mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC") + tvm.ir.assert_structural_equal(mod2, expected_nhwc2) + + def test_addmm(): input_info = [ ([10, 10], "float32"), @@ -4578,9 +4717,9 @@ def main( params = params["main"] assert len(params) == len(func.params) - 1 - for param_var, param_ndarray in zip(func.params[1:], params): - assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape - assert param_var.struct_info.dtype == param_ndarray.dtype + for param_var, param_tensor in zip(func.params[1:], params): + assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape + assert param_var.struct_info.dtype == param_tensor.dtype tvm.testing.assert_allclose(params[0].numpy(), model.conv.bias.detach().detach().numpy()) tvm.testing.assert_allclose(params[1].numpy(), model.conv.weight.detach().detach().numpy()) @@ -5118,11 +5257,34 @@ def main( R.output(gv) return gv + class SliceScatterNegative(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=1, start=0, end=-2, step=1) + + @tvm.script.ir_module + class expected_slice_scatter: + @R.function + def main( + a: R.Tensor((2, 5), dtype="float32"), b: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 5), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 5), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(0), R.prim_value(3), R.prim_value(1), axis=1 + ) + gv: R.Tensor((2, 5), dtype="float32") = lv + R.output(gv) + return gv + verify_model( SliceScatter1(), [((8, 8, 10, 10), "float32"), ((8, 3, 10, 10), "float32")], {}, expected1 ) - verify_model(SliceScatter2(), [((8, 16), "float32"), ((6, 16), "float32")], {}, expected2) + verify_model( + SliceScatterNegative(), + [((2, 5), "float32"), ((2, 3), "float32")], + {}, + expected_slice_scatter, + ) def test_masked_scatter(): @@ -5713,7 +5875,28 @@ def main( inp_1: R.Tensor((1, 2, 3, 4), dtype="float32"), ) -> R.Tensor((1, 2, 3, 4), dtype="float32"): with R.dataflow(): - gv: R.Tensor((1, 2, 3, 4), dtype="float32") = inp_1 + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.broadcast_to( + inp_1, R.shape([1, 2, 3, 4]) + ) + gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + class CopyBroadcast(Module): + def forward(self, x, src): + x.copy_(src) + return x + + @tvm.script.ir_module + class expected_copy: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), src: R.Tensor((), dtype="int64") + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.astype(src, dtype="float32") + lv1: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(lv, (2, 3)) + gv: R.Tensor((2, 3), dtype="float32") = lv1 R.output(gv) return gv @@ -5723,6 +5906,7 @@ def main( {}, Expected, ) + verify_model(CopyBroadcast(), [((2, 3), "float32"), ((), "int64")], {}, expected_copy) def test_clone(): @@ -6139,11 +6323,22 @@ def main( @pytest.mark.parametrize( "torch_dtype, relax_dtype", [ - (torch.float32, "float32"), + # Float types (torch.float16, "float16"), + (torch.float32, "float32"), + (torch.float64, "float64"), (torch.bfloat16, "bfloat16"), - (torch.int64, "int64"), + # Signed integer types + (torch.int8, "int8"), + (torch.int16, "int16"), (torch.int32, "int32"), + (torch.int64, "int64"), + # Unsigned integer types + (torch.uint8, "uint8"), + (torch.uint16, "uint16"), + (torch.uint32, "uint32"), + (torch.uint64, "uint64"), + # Boolean (torch.bool, "bool"), ], ) @@ -6204,5 +6399,80 @@ def forward(self, input): ) +def test_round(): + input_info = [([3, 4], "float32")] + + class Round(Module): + def __init__(self, decimals=0): + super().__init__() + self.decimals = decimals + + def forward(self, x): + if self.decimals == 0: + return torch.round(x) + else: + return torch.round(x, decimals=self.decimals) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tensor((3, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.round(inp_0) + gv: R.Tensor((3, 4), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tensor((3, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.multiply(inp_0, R.const(100.0, "float32")) + lv1: R.Tensor((3, 4), dtype="float32") = R.round(lv) + lv2: R.Tensor((3, 4), dtype="float32") = R.divide(lv1, R.const(100.0, "float32")) + gv: R.Tensor((3, 4), dtype="float32") = lv2 + R.output(gv) + return gv + + rounds = [ + (0, Expected1), + (2, Expected2), + ] + + for decimals, expected in rounds: + verify_model(Round(decimals), input_info, {}, expected) + + # Test numerical accuracy with decimals + test_data = torch.tensor( + [ + [1.2345, 2.3456, 3.4567, 4.5678], + [5.6789, 6.7890, 7.8901, 8.9012], + [9.1234, 10.2345, 11.3456, 12.4567], + ] + ) + + for decimals in [0, 1, 2, 3]: + torch_model = Round(decimals) + graph_model = fx.symbolic_trace(torch_model) + with torch.no_grad(): + mod = from_fx(graph_model, input_info) + + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + torch_result = torch_model(test_data).numpy() + tvm_input = tvm.runtime.tensor(test_data.numpy()) + tvm_result = vm["main"](tvm_input).numpy() + + # Use relaxed tolerance due to floating-point precision in decimal operations + tvm.testing.assert_allclose(tvm_result, torch_result, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_nn_debug.py b/tests/python/relax/test_frontend_nn_debug.py index a055631a4d51..f3ead2e9c011 100644 --- a/tests/python/relax/test_frontend_nn_debug.py +++ b/tests/python/relax/test_frontend_nn_debug.py @@ -22,7 +22,7 @@ from tvm import tir from tvm.relax.frontend import nn from tvm.relax.frontend.nn import op, spec -from tvm.runtime import NDArray +from tvm.runtime import Tensor def test_debug_print(): @@ -43,10 +43,10 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name def test_debug_func(): - @tvm.register_func("testing.relax.frontend.nn.test_debug_func") + @tvm.register_global_func("testing.relax.frontend.nn.test_debug_func") def _debug( # pylint: disable=too-many-arguments lineno: str, - tensor: NDArray, + tensor: Tensor, const_int: int, const_float: float, const_str: str, diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index cbc2e7f42922..d5b73bec4c7f 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -57,8 +57,8 @@ def _var_equal(a, b): # pylint: disable=invalid-name def _test_scalar_add(func): # pylint: disable=invalid-name - x = tvm.nd.array(np.array(1.0).astype("float32")) - y = tvm.nd.array(np.array(3.0).astype("float32")) + x = tvm.runtime.tensor(np.array(1.0).astype("float32")) + y = tvm.runtime.tensor(np.array(3.0).astype("float32")) z = func(x, y).numpy() # pylint: enable=invalid-name assert z.ndim == 0 @@ -68,8 +68,8 @@ def _test_scalar_add(func): def _test_infer_sym(func, x, y, z): # pylint: disable=invalid-name # pylint: disable=invalid-name - a = tvm.nd.array(np.random.uniform(size=(x, y, 1)).astype("float32")) - b = tvm.nd.array(np.random.uniform(size=(y, z, 5)).astype("float32")) + a = tvm.runtime.tensor(np.random.uniform(size=(x, y, 1)).astype("float32")) + b = tvm.runtime.tensor(np.random.uniform(size=(y, z, 5)).astype("float32")) c = func(a, b).numpy() # pylint: enable=invalid-name assert c.shape == (x, y, z, 9) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 23250f28aa9f..e9a4a6f62424 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -715,5 +715,22 @@ def forward(self, x: nn.Tensor): assert ["layers.0.0.weight", "layers.0.1.weight"] == sorted(list(named_params.keys())) +def test_module_dict(): + class Module(nn.Module): + def __init__(self): + self.layers = nn.ModuleDict( + {"linear0": nn.Linear(4, 4, bias=False), "linear1": nn.Linear(4, 4, bias=False)} + ) + + def forward(self, x: nn.Tensor): + x = self.layers["linear0"](x) + x = self.layers["linear1"](x) + return x + + mod = Module() + named_params = dict(mod.named_parameters()) + assert ["layers.linear0.weight", "layers.linear1.weight"] == sorted(list(named_params.keys())) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_nn_mutator.py b/tests/python/relax/test_frontend_nn_mutator.py index ffb6586159b5..253e24a4eddf 100644 --- a/tests/python/relax/test_frontend_nn_mutator.py +++ b/tests/python/relax/test_frontend_nn_mutator.py @@ -65,6 +65,37 @@ def visit_param(self, name: str, node: nn.Parameter) -> Any: mutator.visit("mod3", mod3) +def test_mutator_naming_moduledict(): + class Module(nn.Module): + def __init__(self, dtype) -> None: + super().__init__() + self.param = nn.Parameter((32, 128), dtype) + + class Mutator(nn.Mutator): + def visit_param(self, name: str, node: nn.Parameter) -> Any: + if node.dtype == "float64": + assert name == "mod_dict.k0.0.param" + return node + elif node.dtype == "float32": + assert name == "mod_dict.k0.1.param" + return node + elif node.dtype == "float16": + assert name == "mod_dict.k1.0.param" + return node + elif node.dtype == "float8": + assert name == "mod_dict.k1.1.param" + return node + + mod_dict = nn.ModuleDict( + { + "k0": nn.ModuleList([Module("float64"), Module("float32")]), + "k1": nn.ModuleList([Module("float16"), Module("float8")]), + } + ) + mutator = Mutator() + mutator.visit("mod_dict", mod_dict) + + def test_mutator_naming_modulelist(): class Module(nn.Module): def __init__(self, dtype) -> None: @@ -124,6 +155,37 @@ def visit_module(self, name: str, node: nn.Module) -> Any: assert isinstance(module.mod, SubModule2) +def test_mutator_moduledict(): + class Module1(nn.Module): + def __init__(self) -> None: + super().__init__() + + class Module2(nn.Module): + def __init__(self) -> None: + super().__init__() + + class Module3(nn.Module): + def __init__(self) -> None: + super().__init__() + + class Mutator(nn.Mutator): + def visit_module(self, name: str, node: nn.Module) -> Any: + if isinstance(node, Module3): + return Module1() + else: + return node + + mutator = Mutator() + module_dict = nn.ModuleDict({"k0": Module1(), "k1": Module2(), "k2": Module3()}) + assert isinstance(module_dict["k0"], Module1) + assert isinstance(module_dict["k1"], Module2) + assert isinstance(module_dict["k2"], Module3) + module_dict = mutator.visit("", module_dict) + assert isinstance(module_dict["k0"], Module1) + assert isinstance(module_dict["k1"], Module2) + assert isinstance(module_dict["k2"], Module1) + + def test_mutator_modulelist(): class Module1(nn.Module): def __init__(self) -> None: @@ -150,7 +212,6 @@ def visit_module(self, name: str, node: nn.Module) -> Any: assert isinstance(module_list[1], Module2) assert isinstance(module_list[2], Module3) module_list = mutator.visit("", module_list) - print(module_list[2]) assert isinstance(module_list[0], Module1) assert isinstance(module_list[1], Module2) assert isinstance(module_list[2], Module1) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 5c400ef8be28..28c11f6dfaf5 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -384,6 +384,8 @@ def test( def test_nn(): class Model(Module): def test(self, x: Tensor, weight: Tensor, bias: Tensor): + log_out = op.log(x) + floor_out = op.floor(x) relu_out = op.relu(x) relu6_out = op.relu6(x) silu_out = op.silu(x) @@ -409,6 +411,8 @@ def test( ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): + log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.log(x) + floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.floor(x) relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x) relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x) silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) @@ -463,6 +467,8 @@ def test(self, x: Tensor): ) zeros_out = op.zeros([10, 10]) zeros_fp16_out = op.zeros([10, 10], dtype="float16") + + arange_out = op.arange(0, 10, 1, "float32") return x # fmt: off @@ -476,6 +482,7 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 10]), R.const(10, "float32"), dtype="float32") zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10, 10]), dtype="float32") zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10, 10]), dtype="float16") + arange: R.Tensor((10,), dtype="float32") = R.arange(T.int64(0), T.int64(10), T.int64(1), dtype="float32") gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io,) R.output(gv1) return gv1 @@ -504,7 +511,10 @@ def test( lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32") lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, axis=[1]) lv3: R.Tensor((5,), dtype="float32") = R.arange( - R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="float32" + R.prim_value(T.int64(0)), + R.prim_value(T.int64(5)), + R.prim_value(T.int64(1)), + dtype="float32", ) lv4: R.Tensor((5,), dtype="float32") = R.multiply( R.const(-9.2103404998779297, "float32"), lv3 @@ -899,7 +909,7 @@ def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 16, 8), d def test_empty(): - @tvm.register_func("test_empty_assert", override=True) + @tvm.register_global_func("test_empty_assert", override=True) def test_empty_assert(_lineo, x): assert x.shape == (10, 10) assert x.dtype == "float32" @@ -976,10 +986,12 @@ def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1) np_rand = np.random.rand(*prob_shape).astype(np.float32) # normalize it to get the random prob np_prob = np_rand / np_rand.sum(axis=1, keepdims=True) - nd_prob = tvm.nd.array(np_prob, dev) + nd_prob = tvm.runtime.tensor(np_prob, dev) # special sample to get deterministic results - nd_sample = tvm.nd.array(np.array([[1], [0], [1], [1], [0], [1]]).astype(np.float32), dev) - nd_sample_indices = tvm.nd.array(np.array([[0], [1], [1], [2], [2], [2]]).astype(np.int64), dev) + nd_sample = tvm.runtime.tensor(np.array([[1], [0], [1], [1], [0], [1]]).astype(np.float32), dev) + nd_sample_indices = tvm.runtime.tensor( + np.array([[0], [1], [1], [2], [2], [2]]).astype(np.int64), dev + ) inputs = [nd_prob, nd_sample, nd_sample_indices, effects] res = vm["foo"](*inputs) tvm.testing.assert_allclose( @@ -1104,12 +1116,14 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), index: R.Tensor((2, 3), dtype=" vm = relax.VirtualMachine(ex, dev) effects = vm["_initialize_effect"]() - sorted_prob = tvm.nd.array(np.array([[0.5, 0.4, 0.1], [0.4, 0.3, 0.3]]).astype(np.float32), dev) - indices = tvm.nd.array(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64), dev) - top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev) - top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev) - usample = tvm.nd.array(np.array([[0.5], [0.6], [0.7]]).astype(np.float32), dev) - sample_indices = tvm.nd.array(np.array([[0], [1], [1]]).astype(np.int64), dev) + sorted_prob = tvm.runtime.tensor( + np.array([[0.5, 0.4, 0.1], [0.4, 0.3, 0.3]]).astype(np.float32), dev + ) + indices = tvm.runtime.tensor(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64), dev) + top_p = tvm.runtime.tensor(np.array([[0.6], [0.9]]).astype(np.float32), dev) + top_k = tvm.runtime.tensor(np.array([[3], [2]]).astype(np.int64), dev) + usample = tvm.runtime.tensor(np.array([[0.5], [0.6], [0.7]]).astype(np.float32), dev) + sample_indices = tvm.runtime.tensor(np.array([[0], [1], [1]]).astype(np.int64), dev) inputs = [sorted_prob, indices, top_p, top_k, usample, sample_indices, effects] @@ -1220,10 +1234,12 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), d vm = relax.VirtualMachine(ex, dev) effects = vm["_initialize_effect"]() - prob = tvm.nd.array(np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]]).astype(np.float32), dev) - sorted_prob = tvm.nd.array(np.array([[0.5, 0.3, 0.2], [0.4, 0.3, 0.3]]).astype(np.float32), dev) - top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev) - top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev) + prob = tvm.runtime.tensor(np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]]).astype(np.float32), dev) + sorted_prob = tvm.runtime.tensor( + np.array([[0.5, 0.3, 0.2], [0.4, 0.3, 0.3]]).astype(np.float32), dev + ) + top_p = tvm.runtime.tensor(np.array([[0.6], [0.9]]).astype(np.float32), dev) + top_k = tvm.runtime.tensor(np.array([[3], [2]]).astype(np.int64), dev) inputs = [prob, sorted_prob, top_p, top_k, effects] diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index b55489a623f0..23348cf84757 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -172,12 +172,12 @@ def _check_output(tvm_out, ort_out): assert len(tvm_out) == len(ort_out), "Unequal number of outputs" for tvm_out_i, ort_out_i in zip(tvm_out, ort_out): _check_output(tvm_out_i, ort_out_i) - elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray): + elif isinstance(tvm_out, tvm.runtime.Tensor) and isinstance(ort_out, np.ndarray): if check_dtypes: assert tvm_out.numpy().dtype == ort_out.dtype tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray): - shape_out = tvm.nd.array([int(i) for i in tvm_out]) + shape_out = tvm.runtime.tensor([int(i) for i in tvm_out]) if check_dtypes: assert _get_numpy_subdtype(shape_out.numpy()) == _get_numpy_subdtype(ort_out) tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) @@ -789,6 +789,74 @@ def test_transpose(): verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]}) +def test_transpose_scalar(): + """Test Transpose with scalar inputs - should return scalar unchanged.""" + # Test scalar with no perm attribute (default behavior) + scalar_node = helper.make_node("Transpose", ["x"], ["y"]) + graph = helper.make_graph( + [scalar_node], + "transpose_scalar_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])], + ) + model = helper.make_model(graph, producer_name="transpose_scalar_test") + check_correctness(model) + + # Test with scalar constant and transpose without perm + scalar_constant = helper.make_node( + "Constant", + [], + ["scalar"], + value=helper.make_tensor("value", TensorProto.FLOAT, [], [5.0]), + ) + + transpose_node = helper.make_node("Transpose", ["scalar"], ["y"]) + graph = helper.make_graph( + [scalar_constant, transpose_node], + "transpose_scalar_constant_test", + inputs=[], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])], + ) + model = helper.make_model(graph, producer_name="transpose_scalar_constant_test") + check_correctness(model) + + +def test_transpose_axes_validation(): + """Test Transpose validation - perm axes count must match tensor dimensions""" + # Test 1D tensor with correct perm + transpose_1d_valid = helper.make_node("Transpose", ["x"], ["y"], perm=[0]) + graph_1d_valid = helper.make_graph( + [transpose_1d_valid], + "transpose_1d_valid_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [10])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [10])], + ) + model_1d_valid = helper.make_model(graph_1d_valid, producer_name="transpose_1d_valid_test") + check_correctness(model_1d_valid) + + # Test 2D tensor with correct perm + transpose_2d_valid = helper.make_node("Transpose", ["x"], ["y"], perm=[1, 0]) + graph_2d_valid = helper.make_graph( + [transpose_2d_valid], + "transpose_2d_valid_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4, 3])], + ) + model_2d_valid = helper.make_model(graph_2d_valid, producer_name="transpose_2d_valid_test") + check_correctness(model_2d_valid) + + # Test 3D tensor with correct perm + transpose_3d_valid = helper.make_node("Transpose", ["x"], ["y"], perm=[2, 0, 1]) + graph_3d_valid = helper.make_graph( + [transpose_3d_valid], + "transpose_3d_valid_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4, 2, 3])], + ) + model_3d_valid = helper.make_model(graph_3d_valid, producer_name="transpose_3d_valid_test") + check_correctness(model_3d_valid) + + def test_unsqueeze(): unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) @@ -828,6 +896,36 @@ def test_bias_gelu(): verify_binary("BiasGelu", [32, 32], [32], [32, 32], domain="com.microsoft") +def test_fast_gelu(): + """Test FastGelu with and without bias""" + # Test FastGelu without bias + fast_gelu_node = helper.make_node("FastGelu", ["x"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [fast_gelu_node], + "fast_gelu_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [32, 32])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="fast_gelu_test") + check_correctness(model) + + # Test FastGelu with bias + fast_gelu_with_bias_node = helper.make_node( + "FastGelu", ["x", "bias"], ["y"], domain="com.microsoft" + ) + graph_with_bias = helper.make_graph( + [fast_gelu_with_bias_node], + "fast_gelu_with_bias_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [32]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + model_with_bias = helper.make_model(graph_with_bias, producer_name="fast_gelu_with_bias_test") + check_correctness(model_with_bias) + + def test_where(): where_node = helper.make_node("Where", ["a", "b", "c"], ["d"]) @@ -1909,6 +2007,106 @@ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data): _test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, shape_data, shape, ref_data) +def test_expand_incompatible_broadcasting(): + """ + This test case reproduces the error where input tensor shape at dim 1 is 25 + and target shape at dim 3 is 56, which violates ONNX broadcasting rules + """ + + def _test_expand_error_case(name, data_shape, target_shape_vals): + data = np.random.uniform(size=data_shape).astype(np.float32) + + shape_array = np.array(target_shape_vals, dtype=np.int64) + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT64, + dims=shape_array.shape, + vals=shape_array.flatten(), + ), + ) + + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + + graph = helper.make_graph( + [shape_node, expand_node], + "expand_error_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)], + ) + + model = helper.make_model(graph, producer_name=name) + + with pytest.raises(ValueError) as exc_info: + from_onnx(model, keep_params_in_input=True) + + error_msg = str(exc_info.value) + assert ( + "broadcast" in error_msg.lower() or "incompatible" in error_msg.lower() + ), f"Expected broadcasting error, but got: {error_msg}" + + # Test case 1: Reproduce the exact error from the issue-17769 + # Input shape: (25,), target shape: (1, 1, 1, 56) + # This should faill because input dim 1 (25) != target dim 3 (56) and neither is 1 + _test_expand_error_case( + "expand_incompatible_25_to_56", + data_shape=(25,), + target_shape_vals=(1, 1, 1, 56), + ) + + # Test case 2: Another incompatible case + # Input shape: (1, 25), target shape: (1, 1, 1, 56) + # After right-alignment, input (1, 1, 1, 25) vs. target (1, 1, 1, 56) + # This should fail because 25 != 56 and neither is 1 + _test_expand_error_case( + "expand_incompatible_aligned_25_to_56", + data_shape=(1, 25), + target_shape_vals=(1, 1, 1, 56), + ) + + # Test case 3: Valid case for comparison - should not raise error + def _test_expand_valid_case(): + """Test a valid expand case to ensure our fix doesn't break valid operations""" + data_shape = (1, 25) + target_shape_vals = [2, 25] # Valid: input (1, 25) can broadcast to (2, 25) + + data = np.random.uniform(size=data_shape).astype(np.float32) + shape_array = np.array(target_shape_vals, dtype=np.int64) + + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT64, + dims=shape_array.shape, + vals=shape_array.flatten(), + ), + ) + + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + + graph = helper.make_graph( + [shape_node, expand_node], + "expand_valid_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)], + ) + + model = helper.make_model(graph, producer_name="expand_valid_test_case") + + try: + tvm_model = from_onnx(model, keep_params_in_input=True) + except Exception as e: + pytest.fail(f"Valid expand case should not fail, but got error: {e}") + + _test_expand_valid_case() + + # TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. @pytest.mark.skip("Produces ill-formed IR") def test_constantofshape(): @@ -3130,6 +3328,7 @@ def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", " gv: R.Tensor((A, B, A // B), dtype="float32") = x R.output(gv) return gv + # fmt: on tvm.ir.assert_structural_equal(tvm_model, Expected) @@ -3169,5 +3368,430 @@ def main(x: R.Tensor(("A", "B", "A // B"), dtype="float32")) -> R.Tensor(("A", " tvm.ir.assert_structural_equal(tvm_model, Expected) +def test_nms(): + """Test NonMaxSuppression operator conversion using our AllClassNMS implementation.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + boxes_shape = [1, 5, 4] # batch_size, num_boxes, 4 + scores_shape = [1, 2, 5] # batch_size, num_classes, num_boxes + + graph = helper.make_graph( + [nms_node], + "nms_test", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor("max_output_boxes_per_class", TensorProto.INT64, [1], [3]), + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]), + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.1]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [0, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test") + model.opset_import[0].version = 11 + + # Use deterministic random inputs for consistent testing + bg = np.random.MT19937(0) + rg = np.random.Generator(bg) + boxes = rg.standard_normal(size=boxes_shape).astype(np.float32) + scores = rg.standard_normal(size=scores_shape).astype(np.float32) + inputs = {"boxes": boxes, "scores": scores} + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + +def test_nms_algorithm_correctness(): + """Test NMS algorithm correctness with fixed data to verify suppression logic.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create fixed test data with known expected results + # Boxes: [x1, y1, x2, y2] format + boxes_data = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], # Box 0: [0,0,1,1] - should be selected + [ + 0.5, + 0.5, + 1.5, + 1.5, + ], # Box 1: [0.5,0.5,1.5,1.5] - overlaps with box 0, should be suppressed + [2.0, 2.0, 3.0, 3.0], + ] + ], # Box 2: [2,2,3,3] - no overlap, should be selected + dtype=np.float32, + ) + + # Scores: higher score = better + scores_data = np.array( + [ + [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4]] # Class 0: [0.9, 0.8, 0.7] - box 0 has highest score + ], # Class 1: [0.6, 0.5, 0.4] - box 0 has highest score + dtype=np.float32, + ) + + boxes_shape = [1, 3, 4] # batch_size, num_boxes, 4 + scores_shape = [1, 2, 3] # batch_size, num_classes, num_boxes + + graph = helper.make_graph( + [nms_node], + "nms_test_correctness", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor( + "max_output_boxes_per_class", TensorProto.INT64, [1], [2] + ), # Only 2 boxes per class + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]), # IoU threshold 0.5 + helper.make_tensor( + "score_threshold", TensorProto.FLOAT, [1], [0.1] + ), # Score threshold 0.1 + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [4, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_correctness") + + # Use fixed inputs instead of random + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + check_correctness(model, inputs=inputs, opset=11) + + +def test_nms_iou_suppression(): + """Test that NMS correctly suppresses overlapping boxes based on IoU threshold.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create overlapping boxes where box 0 has higher score and should be kept + boxes_data = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], # Box 0: [0,0,1,1] - highest score + [ + 0.1, + 0.1, + 1.1, + 1.1, + ], # Box 1: [0.1,0.1,1.1,1.1] - high IoU with box 0, should be suppressed + [2.0, 2.0, 3.0, 3.0], + ] + ], # Box 2: [2,2,3,3] - no overlap, should be kept + dtype=np.float32, + ) + + # Box 0 has highest score, Box 1 should be suppressed due to IoU with box 0 + scores_data = np.array([[[0.9, 0.8, 0.7]]], dtype=np.float32) + + boxes_shape = [1, 3, 4] + scores_shape = [1, 1, 3] + + graph = helper.make_graph( + [nms_node], + "nms_test_iou_suppression", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor("max_output_boxes_per_class", TensorProto.INT64, [1], [2]), + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]), # IoU threshold 0.5 + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.1]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [2, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_iou_suppression") + model.opset_import[0].version = 11 + + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + # Custom NMS output comparison + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + # For NMS, compare only the valid rows + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + +def test_nms_max_boxes_limit(): + """Test that NMS correctly limits the number of boxes per class.""" + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create data with 4 boxes, but limit to 2 per class + boxes_data = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], # Box 0 + [2.0, 0.0, 3.0, 1.0], # Box 1 + [0.0, 2.0, 1.0, 3.0], # Box 2 + [2.0, 2.0, 3.0, 3.0], + ] + ], # Box 3 + dtype=np.float32, + ) + + # All boxes have different scores + scores_data = np.array([[[0.9, 0.8, 0.7, 0.6]]], dtype=np.float32) + + boxes_shape = [1, 4, 4] + scores_shape = [1, 1, 4] + + graph = helper.make_graph( + [nms_node], + "nms_test_max_boxes_limit", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor( + "max_output_boxes_per_class", TensorProto.INT64, [1], [2] + ), # Limit to 2 boxes + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.1]), # Low IoU threshold + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.1]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [2, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_max_boxes_limit") + model.opset_import[0].version = 11 + + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + # Custom NMS output comparison + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + # For NMS, compare only the valid rows + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + +def test_nms_score_threshold(): + """Test that NMS correctly filters boxes based on score threshold. + + Note: This test uses a low score threshold (0.05) to ensure both TVM and ONNX Runtime + output the same fixed shape [3,3], allowing use of the standard check_correctness function. + """ + nms_node = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold"], + ["selected_indices"], + center_point_box=0, + ) + + # Create data with varying scores - ensure we get exactly 3 boxes after NMS + boxes_data = np.array( + [ + [[0.0, 0.0, 1.0, 1.0], [2.0, 0.0, 3.0, 1.0], [0.0, 2.0, 1.0, 3.0]] # Box 0 # Box 1 + ], # Box 2 + dtype=np.float32, + ) + + # Scores: 0.9, 0.3, 0.1 - adjust score threshold to get exactly 3 boxes + scores_data = np.array([[[0.9, 0.3, 0.1]]], dtype=np.float32) + + boxes_shape = [1, 3, 4] + scores_shape = [1, 1, 3] + + graph = helper.make_graph( + [nms_node], + "nms_test_score_threshold", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes_shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores_shape), + ], + initializer=[ + helper.make_tensor("max_output_boxes_per_class", TensorProto.INT64, [1], [3]), + helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.1]), + helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], [0.05]), + ], + outputs=[helper.make_tensor_value_info("selected_indices", TensorProto.INT64, [3, 3])], + ) + + model = helper.make_model(graph, producer_name="nms_test_score_threshold") + model.opset_import[0].version = 11 + + inputs = { + "boxes": boxes_data, + "scores": scores_data, + } + + # Run ONNX Runtime + ort_session = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + ort_output = ort_session.run([], inputs) + + # Run TVM + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + # Custom NMS output comparison + if isinstance(tvm_output, (list, tuple)): + tvm_selected = tvm_output[0].numpy() + else: + tvm_selected = tvm_output.numpy() + ort_selected = ort_output[0] + + # For NMS, compare only the valid rows + min_rows = min(tvm_selected.shape[0], ort_selected.shape[0]) + if min_rows > 0: + tvm.testing.assert_allclose( + tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, atol=1e-5 + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py index 4f049555f148..dd918ab3a2ea 100644 --- a/tests/python/relax/test_frontend_stablehlo.py +++ b/tests/python/relax/test_frontend_stablehlo.py @@ -126,7 +126,7 @@ def check_correctness( tvm_output = vm.get_outputs("main") # Single ouput - if isinstance(tvm_output, tvm.nd.NDArray): + if isinstance(tvm_output, tvm.runtime.Tensor): tvm.testing.assert_allclose(tvm_output.numpy(), jax_output, rtol=1e-5, atol=1e-5) return @@ -138,7 +138,7 @@ def check_correctness( def get_vm_res( ir_mod: tvm.IRModule, weights: Union[np.ndarray, List[np.ndarray]] -) -> Union[tvm.nd.NDArray, List[tvm.nd.NDArray]]: +) -> Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]]: """Compile and run an ir_module on Relax VM Parameters @@ -151,7 +151,7 @@ def get_vm_res( Results ------- - out: Union[tvm.nd.NDArray, List[tvm.nd.NDArray]] + out: Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]] inference result """ target = tvm.target.Target("llvm", host="llvm") diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py new file mode 100644 index 000000000000..da6fdacebdbd --- /dev/null +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test for FlashInfer GroupedGemm TVM integration""" + +import math + +import numpy as np +import pytest +import torch + +import tvm +import tvm.testing +from tvm import relax + +DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 +fp8_dtype = "float8_e4m3fn" + + +########################################### +################# Helpers ################# +########################################### +def has_flashinfer(): + """Check if FlashInfer is available""" + try: + from tvm.relax.backend.cuda import ( # pylint: disable=import-outside-toplevel + flashinfer, + ) + + return True + except ImportError: + return False + + +def has_cutlass(): + """Check if CUTLASS is available for SM90+ operations""" + if not tvm.get_global_func("device_api.cuda", True): + return False + try: + import pynvml # pylint: disable=import-outside-toplevel + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + return major >= 9 # SM90+ + except: + return False + + +def calc_diff(x: np.ndarray, y: np.ndarray): + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): + from einops import rearrange, reduce, repeat + + """ + Quantizes a 2D or 3D tensor to FP8. + + Args: + x (torch.Tensor): The 2D or 3D input tensor. + scale_shape (tuple): The shape of the scale tensor. + tile_shape (tuple): The shape of the tiles. + scale_major_mode (str): The tiling order, "K" for row-major like, + or another value for column-major like. + + Returns: + tuple: A tuple containing the quantized FP8 tensor and the + calculated float32 scales. + """ + # 1. Assertions and Initial Setup + ndim = x.ndim + assert ndim == len(scale_shape) == len(tile_shape) + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32) + + # 2. Tiling and Scale Calculation + if ndim == 2: + s0, s1 = scale_shape + t0, t1 = tile_shape + if scale_major_mode == "K": + # Tile x and find the max absolute value in each tile + x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Broadcast scales back to the original tensor shape + scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", t0=t0, t1=t1) + else: + # Handle column-major tiling + x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Permute scale axes before repeating to match layout + scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0") + scales_repeated = repeat(scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)", t0=t0, t1=t1) + + elif ndim == 3: + s0, s1, s2 = scale_shape + t0, t1, t2 = tile_shape + if scale_major_mode == "K": + # Tile x and find the max absolute value in each tile + x_tiled = rearrange( + x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 + ) + abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Broadcast scales back to the original tensor shape + scales_repeated = repeat( + x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, t2=t2 + ) + else: + # Handle layout where the last two axes are swapped + x_tiled = rearrange( + x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 + ) + abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + # Permute scale axes before repeating to match layout + scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1") + scales_repeated = repeat( + scales_permuted, + "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)", + t0=t0, + t1=t1, + t2=t2, + ) + # 3. Final Quantization + # Divide the original tensor by the broadcasted scales + x_fp32 = x / (scales_repeated + 1e-8) + + # Convert the result to the target FP8 format + x_fp8 = x_fp32.to(torch.float8_e4m3fn) + + return x_fp8, x_scale + + +def dequantize_fp8(x, x_scale, scale_major_mode): + from einops import rearrange + + """ + Quantizes a 2D or 3D tensor to FP8. + + Args: + x (torch.Tensor): The 2D or 3D input tensor. + scale_shape (tuple): The shape of the scale tensor. + tile_shape (tuple): The shape of the tiles. + scale_major_mode (str): The tiling order, "K" for row-major like, + or another value for column-major like. + + Returns: + tuple: A tuple containing the quantized FP8 tensor and the + calculated float32 scales. + """ + # 1. Assertions and Initial Setup + ndim = x.ndim + assert ndim == len(x_scale.shape) + + # 2. Tiling and Scale Calculation + if ndim == 2: + if scale_major_mode == "K": + s0, s1 = x_scale.shape + else: + s1, s0 = x_scale.shape + x = rearrange(x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + if scale_major_mode == "K": + x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1") + else: + x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1") + out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)") + elif ndim == 3: + if scale_major_mode == "K": + s0, s1, s2 = x_scale.shape + else: + s0, s2, s1 = x_scale.shape + x = rearrange( + x.to(torch.float32), + "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2", + s0=s0, + s1=s1, + s2=s2, + ) + if scale_major_mode == "K": + x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1") + else: + x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") + out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") + + return out + + +########################################### +########### Refernce generation ########### +########################################### +def compute_reference_grouped_gemm( + a_fp32: torch.Tensor, # (total_m, k) + b_fp32: torch.Tensor, # (batch_size, n, k) + m_indptr: torch.Tensor, + dtype_out: str, # (total_m, n) +): + """Compute reference result using PyTorch operations""" + """Compute reference result using original FP32 tensors""" + + total_m, k = a_fp32.shape + batch_size, n, k2 = b_fp32.shape + assert k == k2 + + # Perform grouped GEMM computation directly on original FP32 data + results = [] + + for i in range(batch_size): + start_m = m_indptr[i].item() + end_m = m_indptr[i + 1].item() + + # Extract group's portion of A + a_group = a_fp32[start_m:end_m, :] # [m_sizes[i], k] + b_group = b_fp32[i] + + # Multiply with shared B matrix + result_group = torch.mm(a_group, b_group.T) # [m_sizes[i], n] + results.append(result_group) + + result_fp32 = torch.cat(results, dim=0) + + # Convert to output dtype + if dtype_out == "bfloat16": + result = result_fp32.to(torch.bfloat16) + elif dtype_out == "float16": + result = result_fp32.to(torch.float16) + else: + result = result_fp32 + + return result + + +########################################### +########### Test data generation ########## +########################################### +def generate_test_data( + m_sizes: list, + batch_size: int, + n: int, + k: int, + dtype_a: str, + dtype_b: str, + dtype_out: str, + scale_granularity_m: int, + scale_granularity_n: int, + scale_granularity_k: int, + scale_major_mode: str, + device: tvm.runtime.Device, +): + """Generate test data for grouped GEMM operations""" + assert batch_size == len( + m_sizes + ), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})" + + # print(f"Device object: {device}") + torch_device = torch.device(f"cuda:{device.index}") + + cum_m = [0] + list(np.cumsum(m_sizes)) + total_m = cum_m[-1] + + # Generate input matrices A and B (where we assert of form fp8) random data in fp32 first, then convert + assert dtype_a == "float8_e4m3fn" + a_fp32 = torch.randn(total_m, k, device=torch_device, dtype=torch.float32) + + assert dtype_b == "float8_e4m3fn" + b_fp32 = torch.randn(batch_size, n, k, device=torch_device, dtype=torch.float32) / math.sqrt(k) + + if scale_major_mode == "K": # K mode: + scale_a_shape = (total_m // scale_granularity_m, k // scale_granularity_k) + scale_b_shape = (batch_size, n // scale_granularity_n, k // scale_granularity_k) + + else: # MN mode + scale_a_shape = (k // scale_granularity_k, total_m // scale_granularity_m) + scale_b_shape = (batch_size, k // scale_granularity_k, n // scale_granularity_n) + + tile_a_shape = (scale_granularity_m, scale_granularity_k) + tile_b_shape = (1, scale_granularity_n, scale_granularity_k) + + # quantize A, B + a_quantized, scale_a = quantize_fp8(a_fp32, scale_a_shape, tile_a_shape, scale_major_mode) + b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape, scale_major_mode) + + if dtype_a == "float8_e4m3fn": + a_tvm = tvm.runtime.tensor( + a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device + ) + else: + a_tvm = tvm.runtime.from_dlpack(a_quantized) + + if dtype_b == "float8_e4m3fn": + b_tvm = tvm.runtime.tensor( + b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device + ) + else: + b_tvm = tvm.runtime.from_dlpack(b_quantized) + + scale_a_tvm = tvm.runtime.from_dlpack(scale_a) + scale_b_tvm = tvm.runtime.from_dlpack(scale_b) + + # Create m_indptr for grouped operation + m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32) + m_indptr_tvm = tvm.runtime.tensor(m_indptr.cpu().numpy(), device) + + return { + "a": a_tvm, + "b": b_tvm, + "torch_a": a_fp32, + "torch_b": b_fp32, + "scale_a": scale_a_tvm, + "scale_b": scale_b_tvm, + "m_indptr": m_indptr_tvm, + "m_sizes": m_sizes, + "n": n, + "k": k, + "total_m": total_m, + "torch_scale_a": scale_a, + "torch_scale_b": scale_b, + "torch_m_indptr": m_indptr, + } + + +########################################### +############### Test driver ############### +########################################### +@pytest.mark.skipif(not has_flashinfer(), reason="FlashInfer not available") +@pytest.mark.skipif(not has_cutlass(), reason="CUTLASS SM90+ not available") +@pytest.mark.parametrize( + "dtype_a,dtype_b,dtype_out", + [ + ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"), + ("float8_e4m3fn", "float8_e4m3fn", "float16"), + ], +) +@pytest.mark.parametrize( + "scale_granularity_m,scale_granularity_n,scale_granularity_k", + [ + (1, 128, 128), # Row-wise A, block-wise B + ], +) +@pytest.mark.parametrize("scale_major_mode", ["K", "MN"]) +@pytest.mark.parametrize("mma_sm", [1, 2]) +@pytest.mark.parametrize( + "test_case", + [ + {"batch_size": 4, "m_sizes": [128, 256, 192, 320], "n": 512, "k": 1024}, + {"batch_size": 2, "m_sizes": [64, 128], "n": 256, "k": 512}, + {"batch_size": 3, "m_sizes": [256, 256, 128], "n": 768, "k": 768}, + {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768}, + ], +) +def test_grouped_gemm_correctness( + dtype_a, + dtype_b, + dtype_out, + scale_granularity_m, + scale_granularity_n, + scale_granularity_k, + scale_major_mode, + mma_sm, + test_case, +): + """Test correctness of GroupedGemm operations""" + device = tvm.cuda(0) + target = tvm.target.Target.from_device(device) + + # Generate the module + mod = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(target=target)[0] + + # Load the module + grouped_gemm_fn = mod["group_gemm_fp8_nt_groupwise"] + + # Generate test data + test_data = generate_test_data( + batch_size=test_case["batch_size"], + m_sizes=test_case["m_sizes"], + n=test_case["n"], + k=test_case["k"], + dtype_a=dtype_a, + dtype_b=dtype_b, + dtype_out=dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + device=device, + ) + + # Prepare output buffer + output_shape = (test_data["total_m"], test_data["n"]) + if dtype_out == "bfloat16": + output = tvm.runtime.empty(output_shape, dtype="bfloat16", device=device) + elif dtype_out == "float16": + output = tvm.runtime.empty(output_shape, dtype="float16", device=device) + else: + output = tvm.runtime.empty(output_shape, dtype="float32", device=device) + + # Create workspace buffers (required by the interface) + int_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="int32", device=device) + float_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="float32", device=device) + + grouped_gemm_fn( + int_workspace, # int_workspace_buffer + float_workspace, # float_workspace_buffer + test_data["a"], # A + test_data["b"], # B + test_data["scale_a"], # SFA + test_data["scale_b"], # SFB + output, # D + test_data["m_indptr"], # m_indptr + test_data["n"], # n (scalar) + test_data["k"], # k (scalar) + scale_granularity_m, + scale_granularity_n, + scale_granularity_k, + scale_major_mode, + mma_sm, + ) + + # Compute reference result + reference = compute_reference_grouped_gemm( + test_data["torch_a"], + test_data["torch_b"], + test_data["torch_m_indptr"], + dtype_out, + ) + + # Convert TVM output to PyTorch for comparison + output_torch = torch.as_tensor(output, device=test_data["torch_a"].device) + output_torch + + # Compare results with appropriate tolerance + if dtype_out == "bfloat16": + rtol, atol = 1e-2, 1e-2 + elif dtype_out == "float16": + rtol, atol = 1e-3, 1e-3 + else: + rtol, atol = 1e-4, 1e-4 + + # Check shapes match + assert ( + output_torch.shape == reference.shape + ), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}" + + diff = calc_diff(output_torch.cpu().double().numpy(), reference.cpu().double().numpy()) + assert diff < 1e-3, f"diff too large {diff}" + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py b/tests/python/relax/test_meta_schedule_relax_integration.py index 00a342c46050..6f3cdfa9a0de 100644 --- a/tests/python/relax/test_meta_schedule_relax_integration.py +++ b/tests/python/relax/test_meta_schedule_relax_integration.py @@ -154,7 +154,7 @@ def test_extracting_tasks(): relax_expectation = { "structural": 2, # The relax constants do not reach the tir at the lowering. - "ignore-ndarray": 2, + "ignore-tensor": 2, "anchor-block": 1, } for module_equality, count in relax_expectation.items(): @@ -167,7 +167,7 @@ def test_extracting_tasks(): assert len(extracted_tasks) == count tir_relax_mod = Module - tir_relax_expectation = {"structural": 3, "ignore-ndarray": 2, "anchor-block": 1} + tir_relax_expectation = {"structural": 3, "ignore-tensor": 2, "anchor-block": 1} for module_equality, count in tir_relax_expectation.items(): extracted_tasks = ms.relax_integration.extract_tasks( tir_relax_mod, @@ -178,7 +178,7 @@ def test_extracting_tasks(): assert len(extracted_tasks) == count -@pytest.mark.parametrize("module_equality", ["structural", "ignore-ndarray", "anchor-block"]) +@pytest.mark.parametrize("module_equality", ["structural", "ignore-tensor", "anchor-block"]) def test_using_anchor_trace(module_equality): relax_mod = Module target = "llvm -mcpu=core-avx2 -num-cores=1" diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 20c111495d6a..3376569bf349 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -33,6 +33,8 @@ def test_op_correctness(): assert relax.op.multiply(x, y).op == Op.get("relax.multiply") assert relax.op.power(x, y).op == Op.get("relax.power") assert relax.op.subtract(x, y).op == Op.get("relax.subtract") + assert relax.op.mod(x, y).op == Op.get("relax.mod") + assert relax.op.floor_mod(x, y).op == Op.get("relax.floor_mod") assert relax.op.equal(x, y).op == Op.get("relax.equal") assert relax.op.greater(x, y).op == Op.get("relax.greater") @@ -70,6 +72,8 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r (relax.op.subtract, tir.Sub), (relax.op.maximum, tir.Max), (relax.op.minimum, tir.Min), + (relax.op.mod, tir.Mod), + (relax.op.floor_mod, tir.FloorMod), ) diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py index d6e0a5e239b5..7269dfdbcf47 100644 --- a/tests/python/relax/test_op_create.py +++ b/tests/python/relax/test_op_create.py @@ -661,7 +661,7 @@ def test_arange_infer_struct_info_shape_var(): _check_inference( bb, relax.op.arange(start, stop, 2), - relax.TensorStructInfo((T.cast(T.ceil((stop - start) * 0.5), "int64"),), "float32"), + relax.TensorStructInfo((T.cast(T.ceil((stop - start) / 2), "int64"),), "float32"), ) _check_inference( bb, diff --git a/tests/python/relax/test_op_datatype.py b/tests/python/relax/test_op_datatype.py index 48820b9e2e00..a5507f7efaa2 100644 --- a/tests/python/relax/test_op_datatype.py +++ b/tests/python/relax/test_op_datatype.py @@ -28,7 +28,7 @@ def test_op_correctness(): x = relax.Var("x", R.Tensor((2, 3), "float32")) - c = relax.Constant(tvm.nd.array(np.array([1, 2, 3], dtype="float16"))) + c = relax.Constant(tvm.runtime.tensor(np.array([1, 2, 3], dtype="float16"))) assert relax.op.astype(x, "float16").op == Op.get("relax.astype") assert relax.op.wrap_param(c, "float32").op == Op.get("relax.wrap_param") @@ -108,8 +108,8 @@ def test_astype_infer_struct_info_wrong_input_type(): def test_wrap_param_infer_struct_info(): bb = relax.BlockBuilder() - x0 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="float16"))) - x1 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="int8"))) + x0 = relax.Constant(tvm.runtime.tensor(np.zeros([1, 2, 3], dtype="float16"))) + x1 = relax.Constant(tvm.runtime.tensor(np.zeros([1, 2, 3], dtype="int8"))) _check_inference( bb, relax.op.wrap_param(x0, "float32"), relax.TensorStructInfo((1, 2, 3), "float32") ) diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py index 840f2985614a..c76c150f6a82 100644 --- a/tests/python/relax/test_op_gradient_numeric.py +++ b/tests/python/relax/test_op_gradient_numeric.py @@ -45,7 +45,7 @@ def relax_check_gradients( The forward operator function. Should be a function in package relax.op. inputs_numpy : List[np.array] - The np array inputs for op_func. inputs_numpy will be transformed into TVM NDArray inside + The np array inputs for op_func. inputs_numpy will be transformed into TVM Tensor inside this function. If op_func takes a tuple of tensors as input, you can set tuple_input as True, and pass the @@ -84,12 +84,12 @@ def _numpy_to_sinfo(data): def _numpy_to_tvm(data): if isinstance(data, list): return [_numpy_to_tvm(d) for d in data] - return tvm.nd.array(data) + return tvm.runtime.tensor(data) def _tvm_to_numpy(data, ignore_idx=[]): if isinstance(data, tvm.ir.Array): return [_tvm_to_numpy(d) for i, d in enumerate(data) if i not in ignore_idx] - if isinstance(data, tvm.runtime.ndarray.NDArray): + if isinstance(data, tvm.runtime.Tensor): return data.numpy() return data @@ -189,7 +189,7 @@ def forward(*inputs): grad_ex = tvm.compile(grad_mod, target) grad_vm = relax.VirtualMachine(grad_ex, dev) - # tvm.runtime.NDArray inputs + # tvm.runtime.Tensor inputs inputs_tvm = [_numpy_to_tvm(i) for i in inputs_numpy] weights_tvm = _numpy_to_tvm(weights) result_filtered = _tvm_to_numpy(grad_vm[func_name](*inputs_tvm, weights_tvm), ignore_grads) @@ -781,11 +781,8 @@ def test_nll_loss_no_batch(target, dev, nll_reduction1, nll_weighted1, nll_ignor @tvm.testing.parametrize_targets("llvm") def test_conv2d(target, dev, c2d_shape1, c2d_shape2, c2d_kwargs): - # TODO(mlc-team) Update to uniform - # We should use float32 to check the correctness of conv2d - # to avoid possible precision problems - data1_numpy = np.random.uniform(0, 16, c2d_shape1).astype(np.float64) - data2_numpy = np.random.uniform(0, 3, c2d_shape2).astype(np.float64) + data1_numpy = np.random.uniform(0, 3, c2d_shape1).astype(np.float32) + data2_numpy = np.random.uniform(0, 3, c2d_shape2).astype(np.float32) relax_check_gradients( relax.op.nn.conv2d, [data1_numpy, data2_numpy], @@ -819,7 +816,7 @@ def test_conv2d(target, dev, c2d_shape1, c2d_shape2, c2d_kwargs): @tvm.testing.parametrize_targets("llvm") def test_max_pool2d(target, dev, pool_size, pool_kwargs): - data_numpy = np.random.uniform(0, 16, size=(3, 2, 10, 10)).astype(np.float64) + data_numpy = np.random.uniform(0, 3, size=(3, 2, 10, 10)).astype(np.float32) relax_check_gradients( relax.op.nn.max_pool2d, [data_numpy], @@ -832,7 +829,7 @@ def test_max_pool2d(target, dev, pool_size, pool_kwargs): @tvm.testing.parametrize_targets("llvm") def test_avg_pool2d(target, dev, pool_size, pool_kwargs): - data_numpy = np.random.uniform(0, 16, size=(3, 2, 10, 10)).astype(np.float64) + data_numpy = np.random.uniform(0, 3, size=(3, 2, 10, 10)).astype(np.float32) relax_check_gradients( relax.op.nn.avg_pool2d, [data_numpy], diff --git a/tests/python/relax/test_op_inspect.py b/tests/python/relax/test_op_inspect.py index 2ba9f9a7094f..2e6d81c613d5 100644 --- a/tests/python/relax/test_op_inspect.py +++ b/tests/python/relax/test_op_inspect.py @@ -19,6 +19,7 @@ import numpy as np import pytest +import tvm_ffi import tvm.testing from tvm import relax @@ -56,7 +57,7 @@ def main(A: R.Tensor): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) - arg = tvm.nd.empty([16], dtype) + arg = tvm.runtime.empty([16], dtype) res = vm["main"](arg) expected_type_code = tvm.runtime.DataType(dtype).type_code @@ -73,7 +74,7 @@ def main(A: R.Tensor): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) - arg = tvm.nd.empty([16], dtype) + arg = tvm.runtime.empty([16], dtype) res = vm["main"](arg) expected_type_bits = tvm.runtime.DataType(dtype).bits @@ -90,7 +91,7 @@ def main(A: R.Tensor): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) - arg = tvm.nd.empty([16], dtype) + arg = tvm.runtime.empty([16], dtype) res = vm["main"](arg) expected_type_lanes = tvm.runtime.DataType(dtype).lanes @@ -107,7 +108,7 @@ def main(A: R.Tensor): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) - arg = tvm.nd.empty(shape, "int32") + arg = tvm.runtime.empty(shape, "int32") res = vm["main"](arg) assert res == len(shape) @@ -123,7 +124,7 @@ def main(A: R.Tensor, axis: R.Prim("int64")): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) - arg = tvm.nd.empty(shape, "int32") + arg = tvm.runtime.empty(shape, "int32") res = [vm["main"](arg, i) for i, _ in enumerate(shape)] @@ -149,7 +150,7 @@ def main(A: R.Tensor, axis: R.Prim("int64")): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) - arg = tvm.nd.empty(shape, "int32") + arg = tvm.runtime.empty(shape, "int32") res = [vm["main"](arg, i) for i, _ in enumerate(shape)] expected = _get_compact_striding(shape) @@ -170,7 +171,7 @@ def main(A: R.Tensor, axis: R.Prim("int64")): expected_strides = [1, 4] # use transpose to make strides non-compact x = np.zeros([4, 4], "int32").T - y = tvm.ffi.from_dlpack(x, required_alignment=4, required_contiguous=False) + y = tvm_ffi.from_dlpack(x, require_alignment=4, require_contiguous=False) res = [vm["main"](y, i) for i, _ in enumerate(view_shape)] tvm.ir.assert_structural_equal(res, expected_strides) @@ -189,8 +190,8 @@ def main(A: R.Tensor): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) dtype = "int32" - backing_ndarray = tvm.nd.empty(backing_shape, dtype) - view = backing_ndarray._create_view(view_shape, dtype, relative_byte_offset=byte_offset) + backing_tensor = tvm.runtime.empty(backing_shape, dtype) + view = backing_tensor._create_view(view_shape, dtype, relative_byte_offset=byte_offset) res = vm["main"](view) assert res == byte_offset @@ -212,8 +213,8 @@ def main(A: R.Tensor): built = tvm.compile(mod) vm = relax.VirtualMachine(built, tvm.cpu()) - backing_ndarray = tvm.nd.empty(backing_shape, dtype) - view = backing_ndarray._create_view(view_shape, dtype, relative_byte_offset=byte_offset) + backing_tensor = tvm.runtime.empty(backing_shape, dtype) + view = backing_tensor._create_view(view_shape, dtype, relative_byte_offset=byte_offset) res = vm["main"](view) assert res == elem_offset diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 004c4b9618a0..d39584e06ba8 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -994,11 +994,19 @@ def test_squeeze_infer_struct_info_axis_length_not_one(): x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - with pytest.raises(TVMError): - bb.normalize(relax.op.squeeze(x0, [0])) - _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32")) - with pytest.raises(TVMError): - bb.normalize(relax.op.squeeze(x2, [0])) + # Squeeze concrete shape (2,3,4) at axis=0, but axis length 2 != 1, squeeze is no-op. + _check_inference( + bb, relax.op.squeeze(x0, [0]), relax.TensorStructInfo(shape=(2, 3, 4), dtype="float32") + ) + # Squeeze symbolic shape (a,3,4) at axis=0, assuming a can achieve successful squeeze. + _check_inference( + bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo(shape=(3, 4), dtype="float32") + ) + # Squeeze shape variable s0 (corresponding to (2,3,4)) at axis=0. + _check_inference( + bb, relax.op.squeeze(x2, [0]), relax.TensorStructInfo(shape=s0, dtype="float32") + ) + # Squeeze shape variable s1 (a,3,4) at axis=0, assuming a can achieve successful squeeze. _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2)) diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index 366ea1b6883d..9d05690f38b1 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -21,9 +21,9 @@ from tvm.script import tir as T -@tvm.register_func("test.op.identity", override=True) +@tvm.register_global_func("test.op.identity", override=True) def identity_packed(a): - return tvm.nd.array(a.numpy()) + return tvm.runtime.tensor(a.numpy()) @T.prim_func diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index a0ff507ef880..b076827dc4a0 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -1721,7 +1721,6 @@ def test_nll_loss_infer_struct_info_targets_dtype(): w = relax.Var("w", R.Tensor((5,), "float32")) targets0 = relax.Var("targets", R.Tensor((3, 10, 10), "float32")) targets1 = relax.Var("targets", R.Tensor((3, 10, 10), "float64")) - targets2 = relax.Var("targets", R.Tensor((3, 10, 10), "bool")) targets3 = relax.Var("targets", R.Tensor((3, 10, 10), "int32")) targets4 = relax.Var("targets", R.Tensor((3, 10, 10), "int64")) targets5 = relax.Var("targets", R.Tensor((3, 10, 10), "uint32")) @@ -1733,7 +1732,6 @@ def test_nll_loss_infer_struct_info_targets_dtype(): bb.normalize(relax.op.nn.nll_loss(x, targets1, w)) # correct cases - bb.normalize(relax.op.nn.nll_loss(x, targets2, w)) # bool is uint1 bb.normalize(relax.op.nn.nll_loss(x, targets3, w)) bb.normalize(relax.op.nn.nll_loss(x, targets4, w)) bb.normalize(relax.op.nn.nll_loss(x, targets5, w)) diff --git a/tests/python/relax/test_op_take.py b/tests/python/relax/test_op_take.py index 704895d0e4f3..6bbf13ef36eb 100644 --- a/tests/python/relax/test_op_take.py +++ b/tests/python/relax/test_op_take.py @@ -44,7 +44,7 @@ def main(A: R.Tensor([16, 16], "float16")): vm = tvm.relax.VirtualMachine(built, dev) np_input = np.random.random(size=[16, 16]).astype("float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.take(1, axis=axis) @@ -70,7 +70,7 @@ def main(A: R.Tensor([16, 16], "float16")): vm = tvm.relax.VirtualMachine(built, dev) np_input = np.random.random(size=[16, 16]).astype("float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.take([1], axis=axis) @@ -92,7 +92,7 @@ def main(A: R.Tensor([16, 16], "float16")): vm = tvm.relax.VirtualMachine(built, dev) np_input = np.random.random(size=[16, 16]).astype("float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.take([[1, 3], [5, 7]], axis=axis) @@ -119,7 +119,7 @@ def main(A: R.Tensor([16, 16], "float16")): vm = tvm.relax.VirtualMachine(built, dev) np_input = np.random.random(size=[16, 16]).astype("float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.take(1, axis=axis) @@ -147,7 +147,7 @@ def main(A: R.Tensor(["n", "n"], "float16")): vm = tvm.relax.VirtualMachine(built, dev) np_input = np.random.random(size=[16, 16]).astype("float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.take(15, axis=axis) @@ -171,7 +171,7 @@ def main(A: R.Tensor([3, 3], "float16")): vm = tvm.relax.VirtualMachine(built, dev) np_input = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype="float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) if axis == 0: np_expected = np.array( @@ -204,7 +204,7 @@ def main(A: R.Tensor([3, 3], "float16")): vm = tvm.relax.VirtualMachine(built, dev) np_input = np.random.random(size=[3, 3]).astype("float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="wrap") @@ -227,7 +227,7 @@ def main(A: R.Tensor([3, 3], "float16")): built = tvm.compile(Module, target=target) vm = tvm.relax.VirtualMachine(built, dev) np_input = np.random.random(size=[3, 3]).astype("float16") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="clip") diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index fc9458827b26..171fe0a627bb 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -481,7 +481,7 @@ class Expected: @R.function def main(A: R.Tensor([4096], "float32")): B = R.ExternFunc( - "runtime.TVMArrayCreateView", + "runtime.TVMTensorCreateView", R.Callable( derive_func="tvm.relax.struct_info.infer_view_sinfo", purity=True, @@ -513,7 +513,7 @@ class Expected: @R.function def main(A: R.Tensor(dtype="float32")): B = R.ExternFunc( - "runtime.TVMArrayCreateView", + "runtime.TVMTensorCreateView", R.Callable( derive_func="tvm.relax.struct_info.infer_view_sinfo", purity=True, @@ -543,7 +543,7 @@ class Expected: @R.function def main(A: R.Tensor([4096], "float32")): B = R.ExternFunc( - "runtime.TVMArrayCreateView", + "runtime.TVMTensorCreateView", R.Callable( derive_func="tvm.relax.struct_info.infer_view_sinfo", purity=True, @@ -573,7 +573,7 @@ class Expected: @R.function def main(A: R.Tensor([4096], "float32")): B = R.ExternFunc( - "runtime.TVMArrayCreateView", + "runtime.TVMTensorCreateView", R.Callable( derive_func="tvm.relax.struct_info.infer_view_sinfo", purity=True, @@ -622,7 +622,7 @@ class Expected: @R.function def main(A: R.Tensor([4096], "uint8")): B = R.ExternFunc( - "runtime.TVMArrayCreateView", + "runtime.TVMTensorCreateView", R.Callable( derive_func="tvm.relax.struct_info.infer_view_sinfo", purity=True, @@ -634,7 +634,7 @@ def main(A: R.Tensor([4096], "uint8")): R.prim_value(0), ) C = R.ExternFunc( - "runtime.TVMArrayCreateView", + "runtime.TVMTensorCreateView", R.Callable( derive_func="tvm.relax.struct_info.infer_view_sinfo", purity=True, @@ -664,7 +664,7 @@ def main(A: R.Tensor([4096], "float32")): vm = tvm.relax.VirtualMachine(built, device=dev) np_input = np.random.random([4096]).astype("float32") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input @@ -684,7 +684,7 @@ def main(A: R.Tensor([4096], "float32")): vm = tvm.relax.VirtualMachine(built, device=dev) np_input = np.random.random([4096]).astype("float32") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.reshape(64, 64) @@ -708,7 +708,7 @@ def main(A: R.Tensor([4096], "float32")): vm = tvm.relax.VirtualMachine(built, device=dev) np_input = np.random.random([4096]).astype("float32") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.reshape(64, 64)[32:48, :] @@ -728,7 +728,7 @@ def main(A: R.Tensor([4096], "float32")): vm = tvm.relax.VirtualMachine(built, device=dev) np_input = np.random.random([4096]).astype("float32") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = np_input.view("uint32") @@ -758,7 +758,7 @@ def main(A: R.Tensor([4096], "uint8")): vm = tvm.relax.VirtualMachine(built, device=dev) np_input = np.random.randint(0, 255, size=[4096]).astype("uint8") - tvm_input = tvm.nd.array(np_input, dev) + tvm_input = tvm.runtime.tensor(np_input, dev) tvm_output = vm["main"](tvm_input) np_expected = [ np_input[:2048].view("int32"), diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py new file mode 100644 index 000000000000..97145a53ff3b --- /dev/null +++ b/tests/python/relax/test_op_vision.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op, VDevice +from tvm.script import relax as R + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_all_class_non_max_suppression_infer_struct_info(): + bb = relax.BlockBuilder() + batch_size, num_classes, num_boxes = 10, 8, 5 + boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32")) + scores = relax.Var("scores", R.Tensor((batch_size, num_classes, num_boxes), "float32")) + max_output_boxes_per_class = relax.const(10, "int64") + iou_threshold = relax.const(0.5, "float32") + score_threshold = relax.const(0.1, "float32") + + _check_inference( + bb, + relax.op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + + +def test_all_class_non_max_suppression_wrong_input_number(): + bb = relax.BlockBuilder() + boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32")) + scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32")) + + with pytest.raises(TVMError): + relax.op.vision.all_class_non_max_suppression(boxes, scores) + + +def test_all_class_non_max_suppression_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + batch_size = tir.Var("batch_size", "int64") + num_classes = tir.Var("num_classes", "int64") + num_boxes = tir.Var("num_boxes", "int64") + boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32")) + scores = relax.Var("scores", R.Tensor((batch_size, num_classes, num_boxes), "float32")) + max_output_boxes_per_class = relax.const(10, "int64") + iou_threshold = relax.const(0.5, "float32") + score_threshold = relax.const(0.1, "float32") + + _check_inference( + bb, + relax.op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index 34d0ca9e36d2..f9bce3539645 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -40,8 +40,8 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): ex = tvm.compile(mod, target) x_np = np.random.rand(3, 4).astype(np.float32) y_np = np.random.rand(3, 4).astype(np.float32) - x = tvm.nd.array(x_np) - y = tvm.nd.array(y_np) + x = tvm.runtime.tensor(x_np) + y = tvm.runtime.tensor(y_np) vm = relax.VirtualMachine(ex, tvm.cpu()) z = vm["main"](x, y) @@ -106,8 +106,8 @@ def main( for i in range(num_steps): x_np = np.random.rand(1, 4).astype(np.float32) y_np = np.random.rand(1, 4).astype(np.float32) - x = tvm.nd.array(x_np) - y = tvm.nd.array(y_np) + x = tvm.runtime.tensor(x_np) + y = tvm.runtime.tensor(y_np) np_shape = (i + 1, 4) kv, kv_cache = vm["main"](x, y, tvm.runtime.ShapeTuple(np_shape), kv_cache) diff --git a/tests/python/relax/test_pytorch_integration.py b/tests/python/relax/test_pytorch_integration.py new file mode 100644 index 000000000000..6839906e7a28 --- /dev/null +++ b/tests/python/relax/test_pytorch_integration.py @@ -0,0 +1,380 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Test PyTorch integration with TVM Relax. + +This test verifies: +1. Seamless PyTorch tensor I/O with TVM backend +2. Cross-function calls between Python, TIR, and Relax functions +3. Dynamic Python function addition and execution +4. End-to-end pipeline testing +5. Error handling and edge cases +""" + +import pytest +import torch +import torch.nn.functional as F +import tvm +from tvm import relax, tir +from tvm.script import ir as I, relax as R, tir as T +from tvm.relax import BasePyModule +import numpy as np + + +@I.ir_module +class PyTorchIntegrationModule(BasePyModule): + """Test module for PyTorch integration with TVM.""" + + @I.pyfunc + def main(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """Main function demonstrating cross-function calls.""" + n = x.shape[0] + + # Call TIR function + lv = self.call_tir(self.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32")) + + # Apply ReLU + lv1 = F.relu(lv) + + # Call packed function (will be added dynamically) + lv2 = self.call_dps_packed("my_softmax", [lv1, 1], out_sinfo=R.Tensor((n, 20), "float32")) + + # Call Python function + lv3 = self.my_identity_func(lv2) + + return lv3 + + @T.prim_func + def matmul( + var_A: T.handle, + var_B: T.handle, + var_C: T.handle, + ): + """TIR function for matrix multiplication.""" + n = T.int32() + A = T.match_buffer(var_A, (n, 16), "float32") + B = T.match_buffer(var_B, (16, 20), "float32") + C = T.match_buffer(var_C, (n, 20), "float32") + + for i, j, k in T.grid(n, 20, 16): + with T.block("block"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @I.pyfunc + def my_identity_func(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class TestPyTorchIntegration: + def test_module_creation_and_instantiation(self): + module = PyTorchIntegrationModule + + assert hasattr(module, "__call__"), "Module should be callable" + + device = tvm.cpu(0) + instance = module(device) + + assert isinstance(instance, BasePyModule), "Instance should be BasePyModule" + + required_methods = ["main", "call_tir", "call_dps_packed"] + for method in required_methods: + assert hasattr(instance, method), f"Instance should have method: {method}" + + def test_module_creation_and_instantiation_gpu(self): + module = PyTorchIntegrationModule + + if tvm.cuda().exist: + assert hasattr(module, "__call__"), "Module should be callable" + + device = tvm.cuda(0) + instance = module(device) + + assert isinstance(instance, BasePyModule), "Instance should be BasePyModule" + required_methods = ["main", "call_tir", "call_dps_packed"] + for method in required_methods: + assert hasattr(instance, method), f"Instance should have method: {method}" + assert "cuda" in str(instance.target) + else: + pytest.skip("CUDA not available") + + def test_python_function_execution(self): + """Test that Python functions execute correctly.""" + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + # Test my_identity_func + input_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = instance.my_identity_func(input_tensor) + + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, input_tensor, atol=1e-5) + + def test_tir_function_execution(self): + """Test that TIR functions execute correctly.""" + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + # Test matmul function + n = 3 + x = torch.randn(n, 16, dtype=torch.float32) + w = torch.randn(16, 20, dtype=torch.float32) + + result = instance.call_tir(instance.matmul, [x, w], R.Tensor((n, 20), "float32")) + + assert isinstance(result, torch.Tensor) + assert result.shape == (n, 20) + + # Verify result with PyTorch matmul + expected = torch.matmul(x, w) + assert torch.allclose(result, expected, atol=1e-3) + + def test_dynamic_python_function_addition(self): + """Test adding Python functions dynamically.""" + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + # Define a custom function + def custom_activation(x): + return torch.sigmoid(x) + + # Add the function + instance.add_python_function("custom_activation", custom_activation) + + # Verify function is added + assert hasattr(instance, "custom_activation") + assert "custom_activation" in instance.pyfuncs + + # Test function execution + input_tensor = torch.tensor([1.0, -1.0, 0.0], dtype=torch.float32) + result = instance.custom_activation(input_tensor) + + assert isinstance(result, torch.Tensor) + expected = torch.sigmoid(input_tensor) + assert torch.allclose(result, expected, atol=1e-5) + + def test_call_dps_packed_with_dynamic_function(self): + """Test call_dps_packed with dynamically added function.""" + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + # Define my_softmax function + def my_softmax(tensor, dim): + """Custom softmax function for testing call_dps_packed.""" + # Convert TVM Tensor to PyTorch tensor if needed + if hasattr(tensor, "numpy"): + tensor = torch.from_numpy(tensor.numpy()) + return F.softmax(tensor, dim=dim) + + # Add the function + instance.my_softmax = my_softmax + + # Test call_dps_packed + input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + + result = instance.call_dps_packed( + "my_softmax", [input_tensor, 1], R.Tensor((2, 2), "float32") + ) + + assert isinstance(result, torch.Tensor) + expected = F.softmax(input_tensor, dim=1) + assert torch.allclose(result, expected, atol=1e-5) + + def test_end_to_end_pipeline(self): + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + def my_softmax(tensor, dim): + if hasattr(tensor, "numpy"): + tensor = torch.from_numpy(tensor.numpy()) + return F.softmax(tensor, dim=dim) + + instance.my_softmax = my_softmax + + n = 5 + x = torch.randn(n, 16, dtype=torch.float32) + w = torch.randn(16, 20, dtype=torch.float32) + + result = instance.main(x, w) + + assert isinstance(result, torch.Tensor) + assert result.shape == (n, 20) + assert result.dtype == torch.float32 + + def test_end_to_end_pipeline_gpu(self): + module = PyTorchIntegrationModule + + if tvm.cuda().exist: + device = tvm.cuda(0) + instance = module(device) + + # Test basic GPU functionality without complex TIR operations + assert isinstance(instance, BasePyModule) + assert "cuda" in str(instance.target) + + # Test that we can create and work with GPU tensors + n = 5 + x = torch.randn(n, 16, dtype=torch.float32, device="cuda") + w = torch.randn(16, 20, dtype=torch.float32, device="cuda") + + assert x.device.type == "cuda" + assert w.device.type == "cuda" + assert x.shape == (n, 16) + assert w.shape == (16, 20) + + # Test basic PyTorch operations on GPU + result = torch.matmul(x, w) + assert isinstance(result, torch.Tensor) + assert result.shape == (n, 20) + assert result.dtype == torch.float32 + assert result.device.type == "cuda" + else: + pytest.skip("CUDA not available") + + def test_cross_function_data_flow(self): + """Test data flow between different function types.""" + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + # Add required functions + def my_softmax(tensor, dim): + if hasattr(tensor, "numpy"): + tensor = torch.from_numpy(tensor.numpy()) + return F.softmax(tensor, dim=dim) + + instance.my_softmax = my_softmax + + # Create test data + n = 4 + x = torch.randn(n, 16, dtype=torch.float32) + w = torch.randn(16, 20, dtype=torch.float32) + + # Execute step by step to verify data flow + # Step 1: TIR matmul + lv = instance.call_tir(instance.matmul, [x, w], R.Tensor((n, 20), "float32")) + assert isinstance(lv, torch.Tensor) + assert lv.shape == (n, 20) + + # Step 2: ReLU + lv1 = F.relu(lv) + assert isinstance(lv1, torch.Tensor) + assert lv1.shape == (n, 20) + + # Step 3: Softmax via call_dps_packed + lv2 = instance.call_dps_packed("my_softmax", [lv1, 1], R.Tensor((n, 20), "float32")) + assert isinstance(lv2, torch.Tensor) + assert lv2.shape == (n, 20) + + # Step 4: Identity function + lv3 = instance.my_identity_func(lv2) + assert isinstance(lv3, torch.Tensor) + assert lv3.shape == (n, 20) + + # Verify final result matches expected + expected = F.softmax(F.relu(torch.matmul(x, w)), dim=1) + assert torch.allclose(lv3, expected, atol=1e-3) + + def test_error_handling(self): + """Test error handling for various edge cases.""" + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + # Test with missing function + with pytest.raises(Exception): + instance.call_dps_packed( + "non_existent_function", [torch.tensor([1.0])], R.Tensor((1,), "float32") + ) + + # Test with wrong tensor shapes + x = torch.randn(3, 16, dtype=torch.float32) + w = torch.randn(15, 20, dtype=torch.float32) # Wrong shape + + with pytest.raises(Exception): + instance.call_tir(instance.matmul, [x, w], R.Tensor((3, 20), "float32")) + + def test_tensor_type_preservation(self): + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + def my_softmax(tensor, dim): + if hasattr(tensor, "numpy"): + tensor = torch.from_numpy(tensor.numpy()) + return F.softmax(tensor, dim=dim) + + instance.my_softmax = my_softmax + + # Test with float32 data type (TIR function is hardcoded for float32) + test_dtype = torch.float32 + n = 3 + x = torch.randn(n, 16, dtype=test_dtype) + w = torch.randn(16, 20, dtype=test_dtype) + + result = instance.main(x, w) + + # Verify type preservation + assert result.dtype == test_dtype + assert isinstance(result, torch.Tensor) + assert result.shape == (n, 20) + assert result.dtype == torch.float32 + + def test_batch_processing(self): + """Test processing multiple inputs in batch.""" + module = PyTorchIntegrationModule + device = tvm.cpu(0) + instance = module(device) + + # Add required functions + def my_softmax(tensor, dim): + if hasattr(tensor, "numpy"): + tensor = torch.from_numpy(tensor.numpy()) + return F.softmax(tensor, dim=dim) + + instance.my_softmax = my_softmax + + # Process multiple inputs + batch_size = 5 + results = [] + + for i in range(batch_size): + n = 3 + i # Varying batch sizes + x = torch.randn(n, 16, dtype=torch.float32) + w = torch.randn(16, 20, dtype=torch.float32) + + result = instance.main(x, w) + results.append(result) + + assert isinstance(result, torch.Tensor) + assert result.shape == (n, 20) + + # Verify all results are valid + assert len(results) == batch_size + for result in results: + assert isinstance(result, torch.Tensor) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index c94dd9f5789d..897082dd792f 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -56,7 +56,7 @@ def run_cpu(mod, func_name, *args, exec_mode): def test_unique(exec_mode): # TODO(prakalp): also add test for compiling and running on cuda device. data_numpy = np.random.randint(0, 16, (16, 16)) - data = tvm.nd.array(data_numpy) + data = tvm.runtime.tensor(data_numpy) result, result_sorted = run_cpu(InputModule, "foo", data, exec_mode=exec_mode) expected_output_sorted, indices = np.unique(data_numpy, return_index=True) @@ -91,7 +91,7 @@ def test_print(exec_mode): run_cpu( PrintTest, "foo", - tvm.nd.array(np.array(1).astype("int32")), + tvm.runtime.tensor(np.array(1).astype("int32")), exec_mode=exec_mode, ) test_out.seek(0) @@ -108,7 +108,7 @@ def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(True)) return x - run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) + run_cpu(func, tvm.runtime.tensor(np.array(1).astype("int32")), exec_mode=exec_mode) def test_assert_passes_with_format_args(exec_mode): @@ -117,7 +117,7 @@ def func(x: R.Tensor((), "int32")): _ = R.assert_op(relax.const(True), x, format="You won't see me") return x - run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) + run_cpu(func, tvm.runtime.tensor(np.array(1).astype("int32")), exec_mode=exec_mode) def test_assert_fails(exec_mode): @@ -127,7 +127,7 @@ def func(x: R.Tensor((), "int32")): return x with pytest.raises(AssertionError, match="Assertion Failed"): - run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) + run_cpu(func, tvm.runtime.tensor(np.array(1).astype("int32")), exec_mode=exec_mode) def test_assert_fails_with_message(exec_mode): @@ -137,7 +137,7 @@ def func(x: R.Tensor((), "int32")): return x with pytest.raises(AssertionError, match="I failed..."): - run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), exec_mode=exec_mode) + run_cpu(func, tvm.runtime.tensor(np.array(1).astype("int32")), exec_mode=exec_mode) def test_assert_fails_with_args(exec_mode): @@ -147,7 +147,7 @@ def func(x: R.Tensor((), "int32")): return x with pytest.raises(AssertionError, match="5, 5"): - run_cpu(func, tvm.nd.array(np.array(5).astype("int32")), exec_mode=exec_mode) + run_cpu(func, tvm.runtime.tensor(np.array(5).astype("int32")), exec_mode=exec_mode) def test_assert_fails_with_formatted_args(exec_mode): @@ -157,7 +157,7 @@ def func(x: R.Tensor((), "int32")): return x with pytest.raises(AssertionError, match="Number: 6"): - run_cpu(func, tvm.nd.array(np.array(6).astype("int32")), exec_mode=exec_mode) + run_cpu(func, tvm.runtime.tensor(np.array(6).astype("int32")), exec_mode=exec_mode) def test_assert_on_argument_passes(exec_mode): @@ -166,8 +166,8 @@ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): _ = R.assert_op(condition) return x - condition = tvm.nd.array(np.array(True)) - x = tvm.nd.array(np.array(5).astype("int32")) + condition = tvm.runtime.tensor(np.array(True)) + x = tvm.runtime.tensor(np.array(5).astype("int32")) run_cpu(func, condition, x, exec_mode=exec_mode) @@ -177,8 +177,8 @@ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): _ = R.assert_op(condition) return x - condition = tvm.nd.array(np.array(False)) - x = tvm.nd.array(np.array(5).astype("int32")) + condition = tvm.runtime.tensor(np.array(False)) + x = tvm.runtime.tensor(np.array(5).astype("int32")) with pytest.raises(AssertionError): run_cpu(func, condition, x, exec_mode=exec_mode) @@ -190,7 +190,7 @@ def func(x: R.Tensor(["N"], "int32")): _ = R.assert_op(R.prim_value(N % 8 == 0)) return x - x = tvm.nd.array(np.arange(8, dtype="int32")) + x = tvm.runtime.tensor(np.arange(8, dtype="int32")) run_cpu(func, x, exec_mode=exec_mode) @@ -201,7 +201,7 @@ def func(x: R.Tensor(["N"], "int32")): _ = R.assert_op(R.prim_value(N % 8 == 0)) return x - x = tvm.nd.array(np.arange(10, dtype="int32")) + x = tvm.runtime.tensor(np.arange(10, dtype="int32")) with pytest.raises(AssertionError): run_cpu(func, x, exec_mode=exec_mode) @@ -238,14 +238,17 @@ def test_op_shape_of(exec_mode): assert const_shape == tvm.runtime.ShapeTuple([2, 2]) scalar_shape = run_cpu( - ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")), exec_mode=exec_mode + ShapeOfTest, + "get_shape", + tvm.runtime.tensor(np.array(1, dtype="int32")), + exec_mode=exec_mode, ) assert scalar_shape == tvm.runtime.ShapeTuple([]) tensor_shape = run_cpu( ShapeOfTest, "get_shape", - tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")), + tvm.runtime.tensor(np.zeros((1, 2, 3)).astype("int32")), exec_mode=exec_mode, ) assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3]) @@ -253,7 +256,7 @@ def test_op_shape_of(exec_mode): constrained_shape = run_cpu( ShapeOfTest, "get_constrained_shape", - tvm.nd.array(np.zeros((1,)).astype("int32")), + tvm.runtime.tensor(np.zeros((1,)).astype("int32")), exec_mode=exec_mode, ) assert constrained_shape == tvm.runtime.ShapeTuple([1]) @@ -283,25 +286,25 @@ def test_op_shape_to_tensor(exec_mode): out2d = run_cpu( ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2]), exec_mode=exec_mode ) - assert isinstance(out2d, tvm.runtime.ndarray.NDArray) + assert isinstance(out2d, tvm.runtime.Tensor) assert np.array_equal(out2d.numpy(), np.array([3, 2])) out3d = run_cpu( ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2]), exec_mode=exec_mode ) - assert isinstance(out3d, tvm.runtime.ndarray.NDArray) + assert isinstance(out3d, tvm.runtime.Tensor) assert np.array_equal(out3d.numpy(), np.array([3, 3, 2])) out4d = run_cpu( ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 2]), exec_mode=exec_mode ) - assert isinstance(out4d, tvm.runtime.ndarray.NDArray) + assert isinstance(out4d, tvm.runtime.Tensor) assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2])) outs = run_cpu( ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2]), exec_mode=exec_mode ) - assert isinstance(outs, tvm.runtime.ndarray.NDArray) + assert isinstance(outs, tvm.runtime.Tensor) assert np.array_equal(outs.numpy(), np.array([3, 2])) @@ -317,7 +320,7 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): np.random.seed(0) # to avoid flakiness arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr), exec_mode=exec_mode) + copy_found = run_cpu(CallPureTest, "pure_copy", tvm.runtime.tensor(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() @@ -335,7 +338,7 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): ) return z - @tvm.register_func("test.inplace.add", override=True) + @tvm.register_global_func("test.inplace.add", override=True) def inplace_add(a, b): arr_a = a.numpy() arr_b = b.numpy() @@ -362,18 +365,18 @@ def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): arr_a = np.random.rand(3, 4).astype("float32") arr_b = np.random.rand(3, 4).astype("float32") sum = arr_a + arr_b - tvm_arr_a = tvm.nd.array(arr_a) + tvm_arr_a = tvm.runtime.tensor(arr_a) result = run_cpu( - CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b), exec_mode=exec_mode + CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.runtime.tensor(arr_b), exec_mode=exec_mode ) assert result == tvm_arr_a assert (result.numpy() == sum).all() - @tvm.register_func("test.inplace.tuple_add", override=True) + @tvm.register_global_func("test.inplace.tuple_add", override=True) def inplace_tuple_add(a, b): arr_a = a.numpy() arr_b = b.numpy() - c = tvm.nd.array(arr_a + arr_b) + c = tvm.runtime.tensor(arr_a + arr_b) for i in range(len(arr_a)): for j in range(len(arr_a[i])): arr_a[i][j] = arr_a[i][j] + arr_b[i][j] @@ -397,8 +400,8 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") arr_a = np.random.rand(3, 4).astype("float32") arr_b = np.random.rand(3, 4).astype("float32") sum = arr_a + arr_b - tvm_arr_a = tvm.nd.array(arr_a) - tvm_arr_b = tvm.nd.array(arr_b) + tvm_arr_a = tvm.runtime.tensor(arr_a) + tvm_arr_b = tvm.runtime.tensor(arr_b) result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b, exec_mode=exec_mode) assert result[0] == tvm_arr_a assert (result[0].numpy() == sum).all() @@ -406,6 +409,82 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") assert (result[1].numpy() == sum).all() +def test_op_call_py_func(exec_mode): + """Test R.call_py_func operator functionality.""" + import torch + + def torch_relu(x): + if isinstance(x, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x.numpy()) + elif hasattr(x, "asnumpy"): + x_torch = torch.from_numpy(x.asnumpy()) + else: + x_np = np.array(x) + if isinstance(x_np, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x_np.numpy()) + elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): + x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np])) + if x_torch.ndim > 1: + x_torch = x_torch.flatten() + else: + x_torch = torch.from_numpy(x_np) + result = torch.relu(x_torch) + return tvm.runtime.tensor(result.numpy()) + + def torch_sigmoid(x): + if isinstance(x, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x.numpy()) + elif hasattr(x, "asnumpy"): + x_torch = torch.from_numpy(x.asnumpy()) + else: + x_np = np.array(x) + if isinstance(x_np, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x_np.numpy()) + elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): + x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np])) + if x_torch.ndim > 1: + x_torch = x_torch.flatten() + else: + x_torch = torch.from_numpy(x_np) + result = torch.sigmoid(x_torch) + return tvm.runtime.tensor(result.numpy()) + + register_func = tvm.get_global_func("vm.builtin.register_py_func") + register_func("torch_relu", torch_relu) + register_func("torch_sigmoid", torch_sigmoid) + + @tvm.script.ir_module + class CallPyFuncTest: + @R.function + def simple_call(x: R.Tensor((3,), "float32")): + result = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((3,), "float32")) + return result + + @R.function + def multiple_calls(x: R.Tensor((2,), "float32")): + y = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((2,), "float32")) + z = R.call_py_func(R.str("torch_sigmoid"), (y,), out_sinfo=R.Tensor((2,), "float32")) + return z + + np.random.seed(0) + x_data = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + x_tvm = tvm.runtime.tensor(x_data) + + result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode) + expected = np.maximum(x_data, 0.0) + assert (result.numpy() == expected).all() + + y_data = np.array([-0.5, 0.5], dtype=np.float32) + y_tvm = tvm.runtime.tensor(y_data) + + result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm, exec_mode=exec_mode) + expected2 = 1.0 / (1.0 + np.exp(-np.maximum(y_data, 0.0))) + assert (result2.numpy() == expected2).all() + + clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") + clear_func() + + def test_op_to_device(exec_mode): @tvm.script.ir_module class CallToDevice: @@ -422,7 +501,7 @@ def to_dev(x: R.Tensor((3, 4), "float32")): np.random.seed(0) # to avoid flakiness arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr), exec_mode=exec_mode) + copy_found = run_cpu(CallToDevice, "to_dev", tvm.runtime.tensor(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() @@ -439,7 +518,7 @@ def to_vdev(x: R.Tensor((3, 4), "float32")): np.random.seed(0) arr = np.random.rand(3, 4).astype("float32") - copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr), exec_mode=exec_mode) + copy_found = run_cpu(ToVDevice, "to_vdev", tvm.runtime.tensor(arr), exec_mode=exec_mode) assert (copy_found.numpy() == arr).all() @@ -454,10 +533,10 @@ def func(condition: R.Tensor((), "bool")): out = R.prim_value(10) return out - res = run_cpu(func, tvm.nd.array(np.array(True)), exec_mode=exec_mode) + res = run_cpu(func, tvm.runtime.tensor(np.array(True)), exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, tvm.nd.array(np.array(False)), exec_mode=exec_mode) + res = run_cpu(func, tvm.runtime.tensor(np.array(False)), exec_mode=exec_mode) assert res == 10 @@ -491,10 +570,10 @@ def func(x: R.Tensor(["N"], "int64")): out = R.prim_value(10) return out - res = run_cpu(func, tvm.nd.array(np.arange(16)), exec_mode=exec_mode) + res = run_cpu(func, tvm.runtime.tensor(np.arange(16)), exec_mode=exec_mode) assert res == 5 - res = run_cpu(func, tvm.nd.array(np.arange(20)), exec_mode=exec_mode) + res = run_cpu(func, tvm.runtime.tensor(np.arange(20)), exec_mode=exec_mode) assert res == 10 diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py b/tests/python/relax/test_relax_to_pyfunc_converter.py new file mode 100644 index 000000000000..a2f189297ae0 --- /dev/null +++ b/tests/python/relax/test_relax_to_pyfunc_converter.py @@ -0,0 +1,1042 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Comprehensive test cases for Relax to PyFunc converter. +Tests all major features including basic operations, call_tir, call_dps_packed, and symbolic shapes. +""" + + +import pytest +import torch +import torch.nn.functional as F +import numpy as np + + +import tvm +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R +from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter + + +@I.ir_module +class ComprehensiveTestModule: + """Test module covering all converter features.""" + + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + """TIR function for addition.""" + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + for i in range(5): + out[i] = x[i] + y[i] + + @T.prim_func + def mul_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + """TIR function for multiplication.""" + x = T.match_buffer(var_x, (3, 4), "float32") + y = T.match_buffer(var_y, (3, 4), "float32") + out = T.match_buffer(var_out, (3, 4), "float32") + for i in range(3): + for j in range(4): + out[i, j] = x[i, j] * y[i, j] + + @R.function + def simple_add( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.add(x, y) + + @R.function + def with_relu(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.nn.relu(x) + + @R.function + def with_call_tir( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + cls = ComprehensiveTestModule + return R.call_tir(cls.add_tir, (x, y), out_sinfo=R.Tensor((5,), "float32")) + + @R.function + def with_call_dps_packed(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.call_dps_packed( + "my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), "float32") + ) + + @R.function + def complex_function( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + added = R.add(x, y) + relued = R.nn.relu(added) + cls = ComprehensiveTestModule + tir_result = R.call_tir(cls.add_tir, (relued, y), out_sinfo=R.Tensor((5,), "float32")) + return R.nn.relu(tir_result) + + @R.function + def symbolic_add( + x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32") + ) -> R.Tensor(("n",), "float32"): + return R.add(x, y) + + @R.function + def symbolic_matmul( + x: R.Tensor(("batch", "m", "k"), "float32"), y: R.Tensor(("batch", "k", "n"), "float32") + ) -> R.Tensor(("batch", "m", "n"), "float32"): + return R.matmul(x, y) + + @R.function + def symbolic_expand_dims( + x: R.Tensor(("batch", "seq_len"), "float32") + ) -> R.Tensor(("batch", "seq_len", 1), "float32"): + return R.expand_dims(x, axis=2) + + @R.function + def multi_ops( + x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") + ) -> R.Tensor((3, 4), "float32"): + added = R.add(x, y) + multiplied = R.multiply(added, y) + powered = R.power(multiplied, R.const(2.0)) + maxed = R.maximum(powered, x) + return maxed + + @R.function + def reduction_ops(x: R.Tensor((5,), "float32")) -> R.Tensor((), "float32"): + sum_val = R.sum(x) + mean_val = R.mean(x) + max_val = R.max(x) + return R.add(R.add(sum_val, mean_val), max_val) + + @R.function + def comparison_ops( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): + eq_val = R.equal(x, y) + gt_val = R.greater(x, y) + return R.logical_and(eq_val, gt_val) + + @R.function + def test_reshape(x: R.Tensor((2, 3), "float32")) -> R.Tensor((6,), "float32"): + return R.reshape(x, (6,)) + + @R.function + def test_permute_dims(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((4, 2, 3), "float32"): + return R.permute_dims(x, axes=[2, 0, 1]) + + @R.function + def test_concat( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((4, 3), "float32"): + return R.concat((x, y), axis=0) + + @R.function + def test_split(x: R.Tensor((4, 3), "float32")) -> R.Tuple: + return R.split(x, indices_or_sections=2, axis=0) + + @R.function + def test_stack( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 2, 3), "float32"): + return R.stack((x, y), axis=1) + + @R.function + def test_take( + x: R.Tensor((3, 4), "float32"), indices: R.Tensor((2,), "int64") + ) -> R.Tensor((2,), "float32"): + return R.take(x, indices, axis=0) + + @R.function + def test_flip(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + return R.flip(x, axis=1) + + @R.function + def test_tile(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 6), "float32"): + return R.tile(x, (2, 2)) + + @R.function + def test_repeat(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 3), "float32"): + return R.repeat(x, repeats=2, axis=0) + + @R.function + def test_expand_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3, 1), "float32"): + return R.expand_dims(x, axis=2) + + @R.function + def test_squeeze(x: R.Tensor((2, 3, 1), "float32")) -> R.Tensor((2, 3), "float32"): + return R.squeeze(x, axis=2) + + @R.function + def test_sum_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.sum(x, axis=0) + + @R.function + def test_max_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.max(x, axis=0) + + +def create_mock_packed_function(): + """Create a mock packed function for testing.""" + + def mock_softmax(x, axis): + """Mock softmax function that just returns the input.""" + return x + + # Register the function globally + tvm.register_global_func("my_softmax", mock_softmax) + + +class TestRelaxToPyFuncConverter: + """Comprehensive test class for Relax to PyFunc converter.""" + + @classmethod + def setup_class(cls): + """Set up test fixtures.""" + cls.ir_mod = ComprehensiveTestModule + cls.converter = RelaxToPyFuncConverter(cls.ir_mod) + create_mock_packed_function() + + def test_basic_operations(self): + """Test basic arithmetic operations.""" + converted_ir_mod = self.converter.convert(["simple_add", "with_relu"]) + + # Test simple_add + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["simple_add"](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + + # Test with_relu + x_neg = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["with_relu"](x_neg) + expected = torch.nn.functional.relu(x_neg) + assert torch.allclose(result, expected) + + def test_call_tir(self): + """Test call_tir functionality with DLPack conversion.""" + converted_ir_mod = self.converter.convert(["with_call_tir"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["with_call_tir"](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + assert result.shape == expected.shape + + def test_call_dps_packed(self): + """Test call_dps_packed functionality.""" + converted_ir_mod = self.converter.convert(["with_call_dps_packed"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["with_call_dps_packed"](x) + expected = x + assert torch.allclose(result, expected) + + def test_complex_function(self): + """Test complex function with multiple operations.""" + converted_ir_mod = self.converter.convert(["complex_function"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["complex_function"](x, y) + + # Expected: relu(add(relu(add(x, y)), y)) + step1 = torch.add(x, y) + step2 = torch.nn.functional.relu(step1) + step3 = torch.add(step2, y) # TIR call + expected = torch.nn.functional.relu(step3) + + assert torch.allclose(result, expected) + + def test_symbolic_shapes(self): + """Test symbolic shape handling.""" + converted_ir_mod = self.converter.convert( + ["symbolic_add", "symbolic_matmul", "symbolic_expand_dims"] + ) + + # Test symbolic_add + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["symbolic_add"](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + + # Test symbolic_matmul + x = torch.randn(2, 3, 4, dtype=torch.float32) # (batch=2, m=3, k=4) + y = torch.randn(2, 4, 5, dtype=torch.float32) # (batch=2, k=4, n=5) + result = converted_ir_mod.pyfuncs["symbolic_matmul"](x, y) + expected = torch.matmul(x, y) + assert torch.allclose(result, expected) + assert result.shape == (2, 3, 5) + + # Test symbolic_expand_dims + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["symbolic_expand_dims"](x) + expected = torch.unsqueeze(x, dim=2) + assert torch.allclose(result, expected) + assert result.shape == (2, 2, 1) + + def test_multi_operations(self): + """Test multiple operations in sequence.""" + converted_ir_mod = self.converter.convert(["multi_ops"]) + + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + dtype=torch.float32, + ) + y = torch.tensor( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=torch.float32 + ) + + result = converted_ir_mod.pyfuncs["multi_ops"](x, y) + + # Expected: maximum(power(multiply(add(x, y), y), 2), x) + step1 = torch.add(x, y) + step2 = torch.mul(step1, y) + step3 = torch.pow(step2, 2.0) + expected = torch.maximum(step3, x) + + assert torch.allclose(result, expected) + + def test_reduction_operations(self): + """Test reduction operations.""" + converted_ir_mod = self.converter.convert(["reduction_ops"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["reduction_ops"](x) + + # Expected: sum(x) + mean(x) + max(x) + expected = torch.sum(x) + torch.mean(x) + torch.max(x) + + assert torch.allclose(result, expected) + assert result.shape == () + + def test_comparison_operations(self): + """Test comparison operations.""" + converted_ir_mod = self.converter.convert(["comparison_ops"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([1.0, 2.5, 3.0, 4.5, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["comparison_ops"](x, y) + + # Expected: logical_and(equal(x, y), greater(x, y)) + eq_val = torch.eq(x, y) + gt_val = torch.gt(x, y) + expected = torch.logical_and(eq_val, gt_val) + + assert torch.allclose(result, expected) + assert result.dtype == torch.bool + + def test_operator_mapping_completeness(self): + """Test that operator mapping is comprehensive.""" + operator_map = RelaxToPyFuncConverter._get_op_map() + + # Check that we have a good number of operators + assert len(operator_map) > 100, f"Expected >100 operators, got {len(operator_map)}" + + # Check key operator categories + binary_ops = [ + op + for op in operator_map.keys() + if op.startswith("relax.") and not op.startswith("relax.nn.") + ] + nn_ops = [op for op in operator_map.keys() if op.startswith("relax.nn.")] + + assert len(binary_ops) > 20, f"Expected >20 binary ops, got {len(binary_ops)}" + assert len(nn_ops) > 30, f"Expected >30 nn ops, got {len(nn_ops)}" + + # Check specific important operators + important_ops = [ + "relax.add", + "relax.multiply", + "relax.nn.relu", + "relax.nn.softmax", + "relax.matmul", + "relax.reshape", + "relax.sum", + "relax.mean", + ] + + for op in important_ops: + assert op in operator_map, f"Missing important operator: {op}" + + def test_error_handling(self): + """Test error handling for invalid inputs.""" + converted_ir_mod = self.converter.convert(["simple_add"]) + + # Test with wrong number of arguments + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + + with pytest.raises(ValueError, match="Expected 2 arguments"): + converted_ir_mod.pyfuncs["simple_add"](x) # Missing second argument + + # Test with incompatible shapes - this should raise a RuntimeError + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = torch.tensor([1.0, 2.0], dtype=torch.float32) # Different shape + + # This should raise a RuntimeError because shapes don't match + with pytest.raises(RuntimeError, match="The size of tensor a"): + converted_ir_mod.pyfuncs["simple_add"](x, y) + + def test_conversion_metadata(self): + """Test that conversion preserves metadata correctly.""" + converted_ir_mod = self.converter.convert(["simple_add"]) + + # Check that pyfuncs attribute exists + assert hasattr(converted_ir_mod, "pyfuncs") + assert "simple_add" in converted_ir_mod.pyfuncs + + # Check function metadata + pyfunc = converted_ir_mod.pyfuncs["simple_add"] + assert hasattr(pyfunc, "__name__") + assert hasattr(pyfunc, "__doc__") + assert pyfunc.__name__ == "simple_add" + + def test_tensor_operations(self): + """Test tensor manipulation operations.""" + converted_ir_mod = self.converter.convert( + [ + "test_reshape", + "test_permute_dims", + "test_concat", + "test_split", + "test_stack", + "test_take", + "test_flip", + "test_tile", + "test_repeat", + "test_expand_dims", + "test_squeeze", + "test_sum_with_axis", + "test_max_with_axis", + ] + ) + + # Test reshape + x1 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result1 = converted_ir_mod.pyfuncs["test_reshape"](x1) + expected1 = torch.reshape(x1, (6,)) + assert torch.allclose(result1, expected1), "Reshape operation failed" + + # Test permute_dims + x2 = torch.randn(2, 3, 4) + result2 = converted_ir_mod.pyfuncs["test_permute_dims"](x2) + expected2 = torch.permute(x2, (2, 0, 1)) + assert torch.allclose(result2, expected2), "Permute_dims operation failed" + + # Test concat + x3 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + y3 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + result3 = converted_ir_mod.pyfuncs["test_concat"](x3, y3) + expected3 = torch.cat([x3, y3], dim=0) + assert torch.allclose(result3, expected3), "Concat operation failed" + + # Test split + x4 = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], + dtype=torch.float32, + ) + result4 = converted_ir_mod.pyfuncs["test_split"](x4) + expected4 = torch.split(x4, 2, dim=0) + assert len(result4) == len(expected4), "Split operation failed - wrong number of tensors" + for r, e in zip(result4, expected4): + assert torch.allclose(r, e), "Split operation failed - tensor mismatch" + + # Test stack + x5 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + y5 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + result5 = converted_ir_mod.pyfuncs["test_stack"](x5, y5) + expected5 = torch.stack([x5, y5], dim=1) + assert torch.allclose(result5, expected5), "Stack operation failed" + + # Test take + x6 = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + dtype=torch.float32, + ) + indices = torch.tensor([0, 2], dtype=torch.int64) + result6 = converted_ir_mod.pyfuncs["test_take"](x6, indices) + expected6 = x6[indices] + assert torch.allclose(result6, expected6), "Take operation failed" + + # Test flip + x7 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result7 = converted_ir_mod.pyfuncs["test_flip"](x7) + expected7 = torch.flip(x7, dims=[1]) + assert torch.allclose(result7, expected7), "Flip operation failed" + + # Test tile + x8 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result8 = converted_ir_mod.pyfuncs["test_tile"](x8) + expected8 = torch.tile(x8, (2, 2)) + assert torch.allclose(result8, expected8), "Tile operation failed" + + # Test repeat + x9 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result9 = converted_ir_mod.pyfuncs["test_repeat"](x9) + expected9 = torch.repeat_interleave(x9, repeats=2, dim=0) + assert torch.allclose(result9, expected9), "Repeat operation failed" + + # Test expand_dims + x10 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result10 = converted_ir_mod.pyfuncs["test_expand_dims"](x10) + expected10 = torch.unsqueeze(x10, dim=2) + assert torch.allclose(result10, expected10), "Expand_dims operation failed" + + # Test squeeze + x11 = torch.tensor([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], dtype=torch.float32) + result11 = converted_ir_mod.pyfuncs["test_squeeze"](x11) + expected11 = torch.squeeze(x11, dim=2) + assert torch.allclose(result11, expected11), "Squeeze operation failed" + + # Test sum with axis + x12 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result12 = converted_ir_mod.pyfuncs["test_sum_with_axis"](x12) + expected12 = torch.sum(x12, dim=0) + assert torch.allclose(result12, expected12), "Sum with axis operation failed" + + # Test max with axis + x13 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result13 = converted_ir_mod.pyfuncs["test_max_with_axis"](x13) + expected13 = torch.max(x13, dim=0)[0] # torch.max returns (values, indices) + assert torch.allclose(result13, expected13), "Max with axis operation failed" + + +@I.ir_module +class ExtendedOperatorsModule: + """Extended test module with additional operators not covered in ComprehensiveTestModule.""" + + # Unary operations not covered in ComprehensiveTestModule + @R.function + def test_abs(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.abs(x) + + @R.function + def test_neg(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.negative(x) + + @R.function + def test_exp(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.exp(x) + + @R.function + def test_log(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.log(x) + + @R.function + def test_sqrt(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sqrt(x) + + @R.function + def test_sin(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sin(x) + + @R.function + def test_cos(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.cos(x) + + @R.function + def test_tanh(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.tanh(x) + + @R.function + def test_sigmoid(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sigmoid(x) + + # Comparison operations not covered in ComprehensiveTestModule + @R.function + def test_less( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): + return R.less(x, y) + + @R.function + def test_not_equal( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): + return R.not_equal(x, y) + + # Binary operations not covered in ComprehensiveTestModule + @R.function + def test_multiply( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.multiply(x, y) + + @R.function + def test_divide( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.divide(x, y) + + @R.function + def test_power( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.power(x, y) + + @R.function + def test_maximum( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.maximum(x, y) + + @R.function + def test_minimum( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.minimum(x, y) + + @R.function + def test_subtract( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.subtract(x, y) + + # Additional tensor operations with different parameters + @R.function + def test_transpose_2d(x: R.Tensor((2, 4), "float32")) -> R.Tensor((4, 2), "float32"): + return R.permute_dims(x, axes=[1, 0]) + + @R.function + def test_mean_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.mean(x, axis=0) + + @R.function + def test_min_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.min(x, axis=0) + + # Neural network operations not covered in ComprehensiveTestModule + @R.function + def test_gelu_nn(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.nn.gelu(x) + + @R.function + def test_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"): + return R.nn.softmax(x, axis=1) + + @R.function + def test_log_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"): + return R.nn.log_softmax(x, axis=1) + + # Advanced tensor operations with different parameters + @R.function + def test_tile_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 9), "float32"): + return R.tile(x, (2, 3)) + + @R.function + def test_repeat_axis(x: R.Tensor((3,), "float32")) -> R.Tensor((6,), "float32"): + return R.repeat(x, repeats=2, axis=0) + + +class TestExtendedOperators: + """Test class for extended operator coverage.""" + + @classmethod + def setup_class(cls): + """Set up test fixtures.""" + cls.ir_mod = ExtendedOperatorsModule + cls.converter = RelaxToPyFuncConverter(cls.ir_mod) + + def test_unary_operations(self): + """Test unary operations.""" + converted_ir_mod = self.converter.convert( + ["test_abs", "test_neg", "test_exp", "test_log", "test_sqrt"] + ) + + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) + + # Test abs + result = converted_ir_mod.pyfuncs["test_abs"](x) + expected = torch.abs(x) + assert torch.allclose(result, expected) + + # Test negative + result = converted_ir_mod.pyfuncs["test_neg"](x) + expected = torch.neg(x) + assert torch.allclose(result, expected) + + # Test exp + result = converted_ir_mod.pyfuncs["test_exp"](x) + expected = torch.exp(x) + assert torch.allclose(result, expected) + + # Test log (with positive values) + x_pos = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["test_log"](x_pos) + expected = torch.log(x_pos) + assert torch.allclose(result, expected) + + # Test sqrt + result = converted_ir_mod.pyfuncs["test_sqrt"](x_pos) + expected = torch.sqrt(x_pos) + assert torch.allclose(result, expected) + + def test_trigonometric_operations(self): + """Test trigonometric operations.""" + converted_ir_mod = self.converter.convert( + ["test_sin", "test_cos", "test_tanh", "test_sigmoid"] + ) + + x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0], dtype=torch.float32) + + # Test sin + result = converted_ir_mod.pyfuncs["test_sin"](x) + expected = torch.sin(x) + assert torch.allclose(result, expected) + + # Test cos + result = converted_ir_mod.pyfuncs["test_cos"](x) + expected = torch.cos(x) + assert torch.allclose(result, expected) + + # Test tanh + result = converted_ir_mod.pyfuncs["test_tanh"](x) + expected = torch.tanh(x) + assert torch.allclose(result, expected) + + # Test sigmoid + result = converted_ir_mod.pyfuncs["test_sigmoid"](x) + expected = torch.sigmoid(x) + assert torch.allclose(result, expected) + + def test_comparison_operations(self): + """Test comparison operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert(["test_less", "test_not_equal"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) + + # Test less + result = converted_ir_mod.pyfuncs["test_less"](x, y) + expected = torch.lt(x, y) + assert torch.equal(result, expected) + + # Test not equal + result = converted_ir_mod.pyfuncs["test_not_equal"](x, y) + expected = torch.ne(x, y) + assert torch.equal(result, expected) + + def test_binary_operations(self): + """Test binary operations.""" + converted_ir_mod = self.converter.convert( + [ + "test_multiply", + "test_divide", + "test_power", + "test_maximum", + "test_minimum", + "test_subtract", + ] + ) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) + + # Test multiply + result = converted_ir_mod.pyfuncs["test_multiply"](x, y) + expected = torch.mul(x, y) + assert torch.allclose(result, expected) + + # Test divide + result = converted_ir_mod.pyfuncs["test_divide"](x, y) + expected = torch.div(x, y) + assert torch.allclose(result, expected) + + # Test power + result = converted_ir_mod.pyfuncs["test_power"](x, y) + expected = torch.pow(x, y) + assert torch.allclose(result, expected) + + # Test maximum + result = converted_ir_mod.pyfuncs["test_maximum"](x, y) + expected = torch.maximum(x, y) + assert torch.allclose(result, expected) + + # Test minimum + result = converted_ir_mod.pyfuncs["test_minimum"](x, y) + expected = torch.minimum(x, y) + assert torch.allclose(result, expected) + + # Test subtract + result = converted_ir_mod.pyfuncs["test_subtract"](x, y) + expected = torch.sub(x, y) + assert torch.allclose(result, expected) + + def test_tensor_operations(self): + """Test tensor operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert(["test_transpose_2d"]) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) + + # Test transpose + result = converted_ir_mod.pyfuncs["test_transpose_2d"](x) + expected = torch.transpose(x, 0, 1) + assert torch.allclose(result, expected) + assert result.shape == (4, 2) + + def test_reduction_operations(self): + """Test reduction operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert(["test_mean_axis", "test_min_axis"]) + + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + + # Test mean + result = converted_ir_mod.pyfuncs["test_mean_axis"](x) + expected = torch.mean(x, dim=0) + assert torch.allclose(result, expected) + assert result.shape == (3,) + + # Test min + result = converted_ir_mod.pyfuncs["test_min_axis"](x) + expected = torch.min(x, dim=0)[0] + assert torch.allclose(result, expected) + assert result.shape == (3,) + + def test_neural_network_operations(self): + """Test neural network operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert( + ["test_gelu_nn", "test_softmax_nn", "test_log_softmax_nn"] + ) + + x = torch.tensor( + [[-2.0, -1.0, 0.0, 1.0, 2.0], [0.5, 1.5, 2.5, 3.5, 4.5]], dtype=torch.float32 + ) + + # Test gelu + result = converted_ir_mod.pyfuncs["test_gelu_nn"](x[0]) + expected = F.gelu(x[0]) + assert torch.allclose(result, expected) + + # Test softmax + result = converted_ir_mod.pyfuncs["test_softmax_nn"](x) + expected = F.softmax(x, dim=1) + assert torch.allclose(result, expected) + + # Test log_softmax + result = converted_ir_mod.pyfuncs["test_log_softmax_nn"](x) + expected = F.log_softmax(x, dim=1) + assert torch.allclose(result, expected) + + def test_advanced_tensor_operations(self): + """Test advanced tensor operations with different parameters.""" + converted_ir_mod = self.converter.convert(["test_tile_dims", "test_repeat_axis"]) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) + + # Test tile with different dimensions + result = converted_ir_mod.pyfuncs["test_tile_dims"](x) + expected = torch.tile(x, (2, 3)) + assert torch.allclose(result, expected) + assert result.shape == (4, 12) + + # Test repeat with different parameters + x_1d = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["test_repeat_axis"](x_1d) + expected = torch.repeat_interleave(x_1d, repeats=2, dim=0) + assert torch.allclose(result, expected) + assert result.shape == (6,) + + +class TestDLPackAndTupleSupport: + """Test DLPack conversion, tuple handling, and API compatibility features.""" + + def test_dlpack_conversion_fallback(self): + """Test DLPack conversion with numpy fallback.""" + + @I.ir_module + class DLPackTestModule: + @T.prim_func + def test_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (4,), "float32") + y = T.match_buffer(var_y, (4,), "float32") + out = T.match_buffer(var_out, (4,), "float32") + for i in range(4): + out[i] = x[i] + y[i] + + @R.function + def test_func( + x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32") + ) -> R.Tensor((4,), "float32"): + return R.call_tir( + DLPackTestModule.test_tir, (x, y), out_sinfo=R.Tensor((4,), "float32") + ) + + converter = RelaxToPyFuncConverter(DLPackTestModule) + converted_ir_mod = converter.convert(["test_func"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["test_func"](x, y) + expected = torch.add(x, y) + + assert torch.allclose(result, expected), "DLPack conversion with numpy fallback failed" + + def test_tuple_return_handling(self): + """Test proper handling of tuple returns (e.g., split operation).""" + + @I.ir_module + class TupleTestModule: + @R.function + def test_split(x: R.Tensor((6,), "float32")) -> R.Tuple: + return R.split(x, indices_or_sections=3, axis=0) + + converter = RelaxToPyFuncConverter(TupleTestModule) + converted_ir_mod = converter.convert(["test_split"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["test_split"](x) + expected = torch.split(x, 2, dim=0) + + assert isinstance(result, tuple), "Split should return tuple" + assert len(result) == len(expected), "Split should return correct number of tensors" + for r, e in zip(result, expected): + assert torch.allclose(r, e), "Split tensor values should match" + + def test_tvm_runtime_api_compatibility(self): + """Test compatibility with tvm.runtime API instead of deprecated tvm.nd.""" + + @I.ir_module + class RuntimeAPITestModule: + @T.prim_func + def test_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (3,), "float32") + y = T.match_buffer(var_y, (3,), "float32") + out = T.match_buffer(var_out, (3,), "float32") + for i in range(3): + out[i] = x[i] * y[i] + + @R.function + def test_func( + x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32") + ) -> R.Tensor((3,), "float32"): + return R.call_tir( + RuntimeAPITestModule.test_tir, (x, y), out_sinfo=R.Tensor((3,), "float32") + ) + + converter = RelaxToPyFuncConverter(RuntimeAPITestModule) + converted_ir_mod = converter.convert(["test_func"]) + + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = torch.tensor([2.0, 3.0, 4.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["test_func"](x, y) + expected = torch.mul(x, y) + + assert torch.allclose(result, expected) + + def test_packed_function_with_primvalue_args(self): + """Test packed function calls with PrimValue arguments.""" + # Register a test packed function + def test_packed_func(x, axis): + return x # Simple identity function + + tvm.register_global_func("test_packed_func", test_packed_func) + + @I.ir_module + class PackedFuncTestModule: + @R.function + def test_dps(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"): + return R.call_dps_packed( + "test_packed_func", (x, R.const(0)), out_sinfo=R.Tensor((4,), "float32") + ) + + converter = RelaxToPyFuncConverter(PackedFuncTestModule) + converted_ir_mod = converter.convert(["test_dps"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["test_dps"](x) + expected = x # Identity function + + assert torch.allclose(result, expected), "Packed function with PrimValue args failed" + + def test_mixed_tir_and_relax_operations(self): + """Test mixed TIR and Relax operations in a single function.""" + + @I.ir_module + class MixedOpsTestModule: + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (4,), "float32") + y = T.match_buffer(var_y, (4,), "float32") + out = T.match_buffer(var_out, (4,), "float32") + for i in range(4): + out[i] = x[i] + y[i] + + @R.function + def test_mixed( + x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32") + ) -> R.Tensor((4,), "float32"): + # TIR operation + tir_result = R.call_tir( + MixedOpsTestModule.add_tir, (x, y), out_sinfo=R.Tensor((4,), "float32") + ) + # Relax operations + relued = R.nn.relu(tir_result) + powered = R.power(relued, R.const(2.0)) + return R.nn.gelu(powered) + + converter = RelaxToPyFuncConverter(MixedOpsTestModule) + converted_ir_mod = converter.convert(["test_mixed"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["test_mixed"](x, y) + + # Manual computation for expected result + added = torch.add(x, y) + relued = F.relu(added) + powered = torch.pow(relued, 2.0) + expected = F.gelu(powered) + + assert torch.allclose(result, expected) + + def test_error_handling_improvements(self): + """Test improved error handling with tensor fallbacks.""" + + @I.ir_module + class ErrorHandlingTestModule: + @R.function + def test_error_handling(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"): + # This should trigger fallback mechanisms + return R.nn.relu(x) + + converter = RelaxToPyFuncConverter(ErrorHandlingTestModule) + converted_ir_mod = converter.convert(["test_error_handling"]) + + x = torch.tensor([-2.0, -1.0, 0.0, 1.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["test_error_handling"](x) + expected = F.relu(x) + + assert torch.allclose(result, expected), "Error handling with tensor fallbacks failed" + assert isinstance(result, torch.Tensor), "Result should be a tensor, not a string" + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py index fb4c8abdf9e6..8abdcda15267 100644 --- a/tests/python/relax/test_runtime_builtin.py +++ b/tests/python/relax/test_runtime_builtin.py @@ -28,7 +28,7 @@ def test_make_shape(): MK = MakeShapeCode make_shape = tvm.get_global_func("vm.builtin.make_shape") - heap = tvm.nd.array(np.arange(10).astype("int64")) + heap = tvm.runtime.tensor(np.arange(10).astype("int64")) s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2) assert s == tvm.runtime.container.ShapeTuple([10, 0, 2]) @@ -37,12 +37,12 @@ def test_make_shape(): def test_match_shape(): MS = MatchShapeCode match_shape = tvm.get_global_func("vm.builtin.match_shape") - heap = tvm.nd.array(np.zeros(10).astype("int64")) + heap = tvm.runtime.tensor(np.zeros(10).astype("int64")) assert heap.numpy()[2] == 0 s = tvm.runtime.container.ShapeTuple([1, 2, 3]) - x = tvm.nd.array(np.zeros([1, 2, 3])) + x = tvm.runtime.tensor(np.zeros([1, 2, 3])) match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") @@ -86,7 +86,7 @@ def test_check_shape_info(): def test_check_tensor_info(): check_tensor_info = tvm.get_global_func("vm.builtin.check_tensor_info") - x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) check_tensor_info(x, 2, "int32", "") check_tensor_info(x, -1, "int32", "") @@ -116,7 +116,7 @@ def test_check_tensor_info(): def test_check_tuple_info(): check_tuple_info = tvm.get_global_func("vm.builtin.check_tuple_info") - x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) t = tvm.runtime.convert([x, x, x]) check_tuple_info(t, 3, "") @@ -133,7 +133,7 @@ def test_check_tuple_info(): def test_check_func_info(): check_func_info = tvm.get_global_func("vm.builtin.check_func_info") f = tvm.runtime.convert(lambda x: x) - x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) check_func_info(f, "") @@ -144,8 +144,8 @@ def test_check_func_info(): def test_tuple_getitem(): tuple_getitem = tvm.get_global_func("vm.builtin.tuple_getitem") - x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) - y = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) + y = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) t = tvm.runtime.convert([x, y]) assert tuple_getitem(t, 0) == x @@ -157,10 +157,10 @@ def test_attention_kv_cache(): fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append") fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view") - cache = fcreate(tvm.nd.empty((1, 2), dtype="int32"), tvm.runtime.ShapeTuple([2, 2]), 0) + cache = fcreate(tvm.runtime.empty((1, 2), dtype="int32"), tvm.runtime.ShapeTuple([2, 2]), 0) num_steps = 2 for i in range(num_steps): - cache = fappend(cache, tvm.nd.array(i * np.ones((1, 2)).astype("int32"))) + cache = fappend(cache, tvm.runtime.tensor(i * np.ones((1, 2)).astype("int32"))) res = fview(cache, tvm.runtime.ShapeTuple((num_steps, 2))).numpy() for i in range(num_steps): @@ -168,8 +168,8 @@ def test_attention_kv_cache(): assert res[i][1] == i -def test_ndarray_cache(): - fload = tvm.get_global_func("vm.builtin.ndarray_cache.load") +def test_tensor_cache(): + fload = tvm.get_global_func("vm.builtin.tensor_cache.load") fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache") param_dict = { @@ -178,18 +178,18 @@ def test_ndarray_cache(): } temp = utils.tempdir() - tvmjs.dump_ndarray_cache(param_dict, temp.path, encode_format="f32-to-bf16") - fload(str(temp.path), tvm.cpu().device_type, 0) + tvmjs.dump_tensor_cache(param_dict, temp.path, encode_format="f32-to-bf16") + fload(str(temp.path), tvm.cpu().dlpack_device_type(), 0) res = fget_params("x", -1) for i, v in enumerate(res): v_np = param_dict[f"x_{i}"] if v_np.dtype == "float32": v_np = tvmjs._convert_bf16_to_f32(tvmjs._convert_f32_to_bf16(v_np)) - np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) + tvm.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) -def test_ndarray_cache_update(): - fload = tvm.get_global_func("vm.builtin.ndarray_cache.load") +def test_tensor_cache_update(): + fload = tvm.get_global_func("vm.builtin.tensor_cache.load") fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache") param_dict = { @@ -198,19 +198,19 @@ def test_ndarray_cache_update(): } temp = utils.tempdir() - tvmjs.dump_ndarray_cache(param_dict, temp.path, encode_format="f32-to-bf16") + tvmjs.dump_tensor_cache(param_dict, temp.path, encode_format="f32-to-bf16") param_dict["x_1"] = np.random.uniform(size=[10, 20]).astype("float32") param_dict["x_2"] = np.random.uniform(size=[10]).astype("float32") - tvmjs.dump_ndarray_cache( + tvmjs.dump_tensor_cache( param_dict, temp.path, encode_format="f32-to-bf16", update_if_exists=True ) - fload(str(temp.path), tvm.cpu().device_type, 0) + fload(str(temp.path), tvm.cpu().dlpack_device_type(), 0) res = fget_params("x", -1) for i, v in enumerate(res): v_np = param_dict[f"x_{i}"] if v_np.dtype == "float32": v_np = tvmjs._convert_bf16_to_f32(tvmjs._convert_f32_to_bf16(v_np)) - np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) + tvm.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) def test_attention_kv_cache_window_override(): @@ -220,7 +220,7 @@ def test_attention_kv_cache_window_override(): current_pos = 4 cache = fcreate( - tvm.nd.array(np.full((16, 2), -1).astype("int32")), + tvm.runtime.tensor(np.full((16, 2), -1).astype("int32")), tvm.runtime.ShapeTuple([16, 2]), current_pos, ) @@ -230,7 +230,7 @@ def test_attention_kv_cache_window_override(): for i in range(1, num_steps): np_array = i * np.ones((i, 2)).astype("int32") np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0) - cache = foverride(cache, tvm.nd.array(np_array), 16) + cache = foverride(cache, tvm.runtime.tensor(np_array), 16) current_pos = (current_pos + i) % 16 res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy() @@ -252,7 +252,7 @@ def test_attention_kv_cache_window_override_with_sinks(): current_pos = 0 cache = fcreate( - tvm.nd.array(np.full((16, 2), -1).astype("int32")), + tvm.runtime.tensor(np.full((16, 2), -1).astype("int32")), tvm.runtime.ShapeTuple([16, 2]), current_pos, ) @@ -262,7 +262,7 @@ def test_attention_kv_cache_window_override_with_sinks(): for i in range(num_steps): np_array = i * np.ones((1, 2)).astype("int32") np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0) - cache = foverride(cache, tvm.nd.array(np_array), 16, num_attention_sinks) + cache = foverride(cache, tvm.runtime.tensor(np_array), 16, num_attention_sinks) if has_sink: current_pos = max((current_pos + 1) % 16, num_attention_sinks) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py index 1941edeaa715..970cf3826055 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py @@ -140,7 +140,7 @@ def set_global_func(head_dim, dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f.main) ( ftranspose_append, @@ -182,7 +182,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): rope_scale, rope_theta, None, # rope_ext_factors - tvm.nd.empty((), dtype, device=device), + tvm.runtime.empty((), dtype, device=device), ftranspose_append, None, # f_transpose_append_mla ["tir", fattn_prefill_ragged], @@ -244,8 +244,8 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): values_expected = expected_v[seq_id] assert keys_expected.shape == values_expected.shape seq_length = expected_k[seq_id].shape[1] - keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) - values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) + keys = tvm.runtime.empty(keys_expected.shape, dtype=dtype, device=device) + values = tvm.runtime.empty(values_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) tvm.testing.assert_allclose(keys.numpy(), keys_expected, rtol=1e-3, atol=1e-3) tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3) @@ -395,8 +395,8 @@ def apply_attention( queries_np = global_new_q[layer_id] keys_np = global_new_k[layer_id] values_np = global_new_v[layer_id] - qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], axis=1), device) - outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) + qkv = tvm.runtime.tensor(np.concatenate([queries_np, keys_np, values_np], axis=1), device) + outputs = tvm.runtime.empty(queries_np.shape, dtype, device=device) fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) # Compute attention expected results. diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index ffd345229200..4aae9dec5995 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -23,7 +23,6 @@ import tvm.testing from tvm import dlight as dl from tvm import relax -from tvm.contrib import utils from tvm.relax.frontend.nn.llm.kv_cache import ( AttnKind, RopeMode, @@ -78,7 +77,7 @@ fcompact_copy = None -def set_global_func(): +def set_global_func(rope_mode: RopeMode): global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fpopn global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, fdebug_get_kv global fattention_prefill, fattention_decode, fattention_prefill_ragged @@ -98,48 +97,30 @@ def set_global_func(): ) fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") - def load_module(name: str, static_modules: List[tvm.runtime.Module]): - assert len(static_modules) > 0 - if len(static_modules) == 1: - return static_modules[0] - static_mod = static_modules[0] - for mod in static_modules[1:]: - static_mod.import_module(mod) - temp = utils.tempdir() - mod_path = temp.relpath(f"{name}.so") - static_mod.export_library(mod_path) - return tvm.runtime.load_module(mod_path) - target = tvm.target.Target.from_device(device) - flashinfer_prefill_mod = load_module( - "flashinfer_prefill", - relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - qk_head_dim=head_dim, - v_head_dim=head_dim, - target=target, - ), - ) - flashinfer_decode_mod = load_module( - "flashinfer_decode", - relax.backend.cuda.flashinfer.gen_flashinfer_decode_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - qk_head_dim=head_dim, - v_head_dim=head_dim, - target=target, - ), - ) - - fattention_prefill = flashinfer_prefill_mod["batch_prefill_with_paged_kv_cache_run"] - fattention_prefill_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] - fattention_prefill_ragged = flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"] - fattention_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] - fattention_decode = flashinfer_decode_mod["batch_decode_with_paged_kv_cache_run"] - fattention_decode_plan = flashinfer_decode_mod["batch_decode_with_paged_kv_cache_plan"] + flashinfer_prefill_mod = relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=head_dim, + v_head_dim=head_dim, + enable_inline_rope=rope_mode == RopeMode.INLINE, + )[0] + flashinfer_decode_mod = relax.backend.cuda.flashinfer.gen_flashinfer_decode_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=head_dim, + v_head_dim=head_dim, + enable_inline_rope=rope_mode == RopeMode.INLINE, + )[0] + + fattention_prefill = flashinfer_prefill_mod["batch_prefill_paged_run"] + fattention_prefill_plan = flashinfer_prefill_mod["batch_prefill_plan"] + fattention_prefill_ragged = flashinfer_prefill_mod["batch_prefill_ragged_run"] + fattention_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_plan"] + fattention_decode = flashinfer_decode_mod["batch_decode_run"] + fattention_decode_plan = flashinfer_decode_mod["batch_decode_plan"] builts = [] for tir_func in [ @@ -156,7 +137,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f.main) ( ftranspose_append, @@ -192,7 +173,7 @@ def create_kv_cache(rope_mode): rope_scale, rope_theta, None, # rope_ext_factors - tvm.nd.empty((), dtype, device=device), + tvm.runtime.empty((), dtype, device=device), ftranspose_append, None, # f_transpose_append_mla ["flashinfer", fattention_prefill_ragged, fattention_prefill_ragged_plan], @@ -224,8 +205,8 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): values_expected = expected_v[seq_id] assert keys_expected.shape == values_expected.shape seq_length = expected_k[seq_id].shape[1] - keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) - values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) + keys = tvm.runtime.empty(keys_expected.shape, dtype=dtype, device=device) + values = tvm.runtime.empty(values_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) torch.testing.assert_close( torch.from_numpy(keys.numpy()).to(device_torch), keys_expected, rtol=1e-3, atol=1e-3 @@ -365,8 +346,10 @@ def apply_attention( queries_np = global_new_q[layer_id] keys_np = global_new_k[layer_id] values_np = global_new_v[layer_id] - qkv = tvm.nd.array(torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device) - outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) + qkv = tvm.runtime.tensor( + torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device + ) + outputs = tvm.runtime.empty(queries_np.shape, dtype, device=device) fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) # Compute attention expected results. @@ -558,8 +541,8 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode): if __name__ == "__main__": - set_global_func() - for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]: + for rope_mode in [RopeMode.NONE, RopeMode.NORMAL]: + set_global_func(rope_mode) cache = create_kv_cache(rope_mode) test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode)) test_paged_attention_kv_cache_remove_sequence((cache, rope_mode)) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py index 2f726064a71b..cd76f9ce20a7 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py @@ -25,7 +25,6 @@ import tvm.testing from tvm import dlight as dl from tvm import relax -from tvm.contrib import utils from tvm.relax.frontend.nn.llm.kv_cache import ( AttnKind, RopeMode, @@ -84,7 +83,7 @@ # Register a dumb function for testing purpose. -@tvm.register_func("test.dumb_function", override=True) +@tvm.register_global_func("test.dumb_function", override=True) def _dumb_function(): raise RuntimeError("Dumb function isn't supposed to be accessed.") @@ -115,47 +114,27 @@ def set_global_func(dtype): fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv_mla") - def load_module(name: str, static_modules: List[tvm.runtime.Module]): - assert len(static_modules) > 0 - if len(static_modules) == 1: - return static_modules[0] - static_mod = static_modules[0] - for mod in static_modules[1:]: - static_mod.import_module(mod) - temp = utils.tempdir() - mod_path = temp.relpath(f"{name}.so") - static_mod.export_library(mod_path) - return tvm.runtime.load_module(mod_path) - target = tvm.target.Target.from_device(device) - flashinfer_prefill_mod = load_module( - "flashinfer_prefill", - relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - target=target, - enable_inline_rope=False, - ), - ) - flashinfer_mla_mod = load_module( - "flashinfer_mla", - relax.backend.cuda.flashinfer.gen_flashinfer_mla_module( - dtype_q=dtype, - dtype_kv=dtype, - dtype_o=dtype, - head_dim_ckv=kv_lora_rank, - head_dim_kpe=qk_rope_head_dim, - target=target, - ), - ) - - fattn_prefill_ragged = flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"] - fattn_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] - fmla_prefill = flashinfer_mla_mod["batch_mla_paged_attention_run"] - fmla_prefill_plan = flashinfer_mla_mod["batch_mla_paged_attention_plan"] + flashinfer_prefill_mod = relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + enable_inline_rope=False, + )[0] + flashinfer_mla_mod = relax.backend.cuda.flashinfer.gen_flashinfer_mla_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + head_dim_ckv=kv_lora_rank, + head_dim_kpe=qk_rope_head_dim, + )[0] + + fattn_prefill_ragged = flashinfer_prefill_mod["batch_prefill_ragged_run"] + fattn_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_plan"] + fmla_prefill = flashinfer_mla_mod["batch_mla_run"] + fmla_prefill_plan = flashinfer_mla_mod["batch_mla_plan"] builts = [] for tir_func in [ @@ -169,7 +148,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f.main) ( ftranspose_append, @@ -218,10 +197,16 @@ def create_kv_cache(dtype): 1, 10000, None, # rope_ext_factors - tvm.nd.empty((), dtype, device=device), + tvm.runtime.empty((), dtype, device=device), None, # f_transpose_append_mha ftranspose_append, - ["flashinfer", fattn_prefill_ragged, fattn_prefill_ragged_plan], # fattn_prefill_ragged + [ + "flashinfer", + fattn_prefill_ragged, + fattn_prefill_ragged_plan, + qk_nope_head_dim + qk_rope_head_dim, + v_head_dim, + ], # fattn_prefill_ragged [], # fattn_prefill [], # fattn_decode [], # fattn_prefill_sliding_window @@ -251,7 +236,7 @@ def verify_cached_kv(kv_cache, seq_ids, expected_kv): for seq_id in seq_ids: kv_expected = expected_kv[seq_id] seq_length = expected_kv[seq_id].shape[1] - kv_actual = tvm.nd.empty(kv_expected.shape, dtype=dtype, device=device) + kv_actual = tvm.runtime.empty(kv_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, kv_actual) torch.testing.assert_close( torch.from_numpy(kv_actual.numpy()).to(device_torch), kv_expected, rtol=1e-3, atol=1e-3 @@ -334,17 +319,17 @@ def apply_attention( is_decode_request = False for layer_id in range(num_layers): - queries = tvm.nd.array(global_new_q[layer_id].cpu().numpy(), device) - key_value = tvm.nd.array(global_new_kv[layer_id].cpu().numpy(), device) + queries = tvm.runtime.tensor(global_new_q[layer_id].cpu().numpy(), device) + key_value = tvm.runtime.tensor(global_new_kv[layer_id].cpu().numpy(), device) total_seq_length = global_new_q[layer_id].shape[0] - outputs1 = tvm.nd.empty( + outputs1 = tvm.runtime.empty( (total_seq_length, num_attention_heads, v_head_dim), dtype, device=device ) - lse1 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) - outputs2 = tvm.nd.empty( + lse1 = tvm.runtime.empty((total_seq_length, num_attention_heads), "float32", device=device) + outputs2 = tvm.runtime.empty( (total_seq_length, num_attention_heads, kv_lora_rank), dtype, device=device ) - lse2 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) + lse2 = tvm.runtime.empty((total_seq_length, num_attention_heads), "float32", device=device) fappend_mla_kv(kv_cache, layer_id, key_value) if not is_decode_request: @@ -361,8 +346,8 @@ def apply_attention( total_seq_length, num_attention_heads, qk_rope_head_dim ) keys = torch.cat([keys, k_pe_expanded], dim=2) - keys_tvm = tvm.nd.array(keys.cpu().numpy(), device) - values_tvm = tvm.nd.array(values.cpu().numpy(), device) + keys_tvm = tvm.runtime.tensor(keys.cpu().numpy(), device) + values_tvm = tvm.runtime.tensor(values.cpu().numpy(), device) fself_attn(kv_cache, layer_id, sm_scale, queries, keys_tvm, values_tvm, outputs1, lse1) if not all_new_sequences or is_decode_request: @@ -373,9 +358,9 @@ def apply_attention( queries_lora_np = torch.cat( [torch.bmm(queries_lora_np.permute(1, 0, 2), w_uk).permute(1, 0, 2), q_pe], dim=2 ) - queries_lora = tvm.nd.array(queries_lora_np.cpu().numpy(), device) + queries_lora = tvm.runtime.tensor(queries_lora_np.cpu().numpy(), device) fcross_attn(kv_cache, layer_id, sm_scale, queries_lora, outputs2, lse2) - cross_attn_output = tvm.nd.array( + cross_attn_output = tvm.runtime.tensor( torch.bmm( torch.from_numpy(outputs2.numpy()).to(device_torch).permute(1, 0, 2), w_uv ) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py index b2982abdb0a5..efc0a5694ca6 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py @@ -79,7 +79,7 @@ # Register a dumb function for testing purpose. -@tvm.register_func("test.dumb_function", override=True) +@tvm.register_global_func("test.dumb_function", override=True) def _dumb_function(): raise RuntimeError("Dumb function isn't supposed to be accessed.") @@ -134,7 +134,7 @@ def set_global_func(dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f.main) ( ftranspose_append, @@ -185,7 +185,7 @@ def create_kv_cache(dtype): 1, 10000, None, # rope_ext_factors - tvm.nd.empty((), dtype, device=device), + tvm.runtime.empty((), dtype, device=device), None, # f_transpose_append_mha ftranspose_append, ["tir", fmla_prefill_ragged], # fattn_prefill_ragged @@ -218,7 +218,7 @@ def verify_cached_kv(kv_cache, seq_ids, expected_kv): for seq_id in seq_ids: kv_expected = expected_kv[seq_id] seq_length = expected_kv[seq_id].shape[1] - kv_actual = tvm.nd.empty(kv_expected.shape, dtype=dtype, device=device) + kv_actual = tvm.runtime.empty(kv_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, kv_actual) torch.testing.assert_close( torch.from_numpy(kv_actual.numpy()).to(device_torch), kv_expected, rtol=1e-3, atol=1e-3 @@ -301,17 +301,17 @@ def apply_attention( is_decode_request = False for layer_id in range(num_layers): - queries = tvm.nd.array(global_new_q[layer_id].cpu().numpy(), device) - key_value = tvm.nd.array(global_new_kv[layer_id].cpu().numpy(), device) + queries = tvm.runtime.tensor(global_new_q[layer_id].cpu().numpy(), device) + key_value = tvm.runtime.tensor(global_new_kv[layer_id].cpu().numpy(), device) total_seq_length = global_new_q[layer_id].shape[0] - outputs1 = tvm.nd.empty( + outputs1 = tvm.runtime.empty( (total_seq_length, num_attention_heads, v_head_dim), dtype, device=device ) - lse1 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) - outputs2 = tvm.nd.empty( + lse1 = tvm.runtime.empty((total_seq_length, num_attention_heads), "float32", device=device) + outputs2 = tvm.runtime.empty( (total_seq_length, num_attention_heads, kv_lora_rank), dtype, device=device ) - lse2 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) + lse2 = tvm.runtime.empty((total_seq_length, num_attention_heads), "float32", device=device) fappend_mla_kv(kv_cache, layer_id, key_value) if not is_decode_request: @@ -328,8 +328,8 @@ def apply_attention( total_seq_length, num_attention_heads, qk_rope_head_dim ) keys = torch.cat([keys, k_pe_expanded], dim=2) - keys_tvm = tvm.nd.array(keys.cpu().numpy(), device) - values_tvm = tvm.nd.array(values.cpu().numpy(), device) + keys_tvm = tvm.runtime.tensor(keys.cpu().numpy(), device) + values_tvm = tvm.runtime.tensor(values.cpu().numpy(), device) fself_attn(kv_cache, layer_id, sm_scale, queries, keys_tvm, values_tvm, outputs1, lse1) if not all_new_sequences or is_decode_request: @@ -340,9 +340,9 @@ def apply_attention( queries_lora_np = torch.cat( [torch.bmm(queries_lora_np.permute(1, 0, 2), w_uk).permute(1, 0, 2), q_pe], dim=2 ) - queries_lora = tvm.nd.array(queries_lora_np.cpu().numpy(), device) + queries_lora = tvm.runtime.tensor(queries_lora_np.cpu().numpy(), device) fcross_attn(kv_cache, layer_id, sm_scale, queries_lora, outputs2, lse2) - cross_attn_output = tvm.nd.array( + cross_attn_output = tvm.runtime.tensor( torch.bmm( torch.from_numpy(outputs2.numpy()).to(device_torch).permute(1, 0, 2), w_uv ) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 8cd3a737402e..b80bd1acb7b7 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -142,7 +142,7 @@ def set_global_func(head_dim, dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f.main) ( ftranspose_append, @@ -184,7 +184,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): rope_scale, rope_theta, None, # rope_ext_factors - tvm.nd.empty((), dtype, device=device), + tvm.runtime.empty((), dtype, device=device), ftranspose_append, None, # f_transpose_append_mla ["tir", fattn_prefill_ragged], @@ -235,8 +235,8 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): values_expected = expected_v[seq_id] assert keys_expected.shape == values_expected.shape seq_length = expected_k[seq_id].shape[1] - keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) - values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) + keys = tvm.runtime.empty(keys_expected.shape, dtype=dtype, device=device) + values = tvm.runtime.empty(values_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) torch.testing.assert_close( torch.from_numpy(keys.numpy()).to(device_torch), keys_expected, rtol=1e-3, atol=1e-3 @@ -428,8 +428,10 @@ def apply_attention( queries_np = global_new_q[layer_id] keys_np = global_new_k[layer_id] values_np = global_new_v[layer_id] - qkv = tvm.nd.array(torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device) - outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) + qkv = tvm.runtime.tensor( + torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device + ) + outputs = tvm.runtime.empty(queries_np.shape, dtype, device=device) fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) # Compute attention expected results. diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py b/tests/python/relax/test_runtime_builtin_rnn_state.py index 095aba8b83e5..515c6ee648ff 100644 --- a/tests/python/relax/test_runtime_builtin_rnn_state.py +++ b/tests/python/relax/test_runtime_builtin_rnn_state.py @@ -81,7 +81,7 @@ def _build(tir_func): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) # pylint: disable=not-callable f = tvm.tir.build(mod["main"], target=target) - return f.entry_func + return f.main _f_tir_gets, _f_tir_sets = [], [] for state in states: @@ -95,7 +95,10 @@ def _build(tir_func): def create_rnn_state(): f_create = tvm.get_global_func("vm.builtin.rnn_state_create") - init_values = [tvm.nd.array(np_zero, device=device), tvm.nd.array(np_one, device=device)] + init_values = [ + tvm.runtime.tensor(np_zero, device=device), + tvm.runtime.tensor(np_one, device=device), + ] return f_create(num_layers, reserved_nseq, max_history, f_tir_gets, f_tir_sets, init_values) @@ -119,8 +122,8 @@ def test_rnn_state_get(rnn_state): # pylint: disable=redefined-outer-name f_clear(state) f_add_sequence(state, 0) f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1])) - tvm_nd_0 = tvm.nd.array(np.empty((1, 16, 16), "float16"), device=device) - tvm_nd_1 = tvm.nd.array(np.empty((1, 32, 32), "float32"), device=device) + tvm_nd_0 = tvm.runtime.tensor(np.empty((1, 16, 16), "float16"), device=device) + tvm_nd_1 = tvm.runtime.tensor(np.empty((1, 32, 32), "float32"), device=device) f_get(state, 0, 0, tvm_nd_0) f_get(state, 0, 1, tvm_nd_1) f_end_forward(state) @@ -136,8 +139,8 @@ def test_rnn_state_set(rnn_state): # pylint: disable=redefined-outer-name f_add_sequence(state, seq_id) f_begin_forward(state, ShapeTuple([0, 2]), ShapeTuple([1, 1])) - f_set(state, 0, 0, tvm.nd.array(np.full((2, 16, 16), 2.0, "float16"), device=device)) - f_set(state, 0, 1, tvm.nd.array(np.full((2, 32, 32), 3.0, "float32"), device=device)) + f_set(state, 0, 0, tvm.runtime.tensor(np.full((2, 16, 16), 2.0, "float16"), device=device)) + f_set(state, 0, 1, tvm.runtime.tensor(np.full((2, 32, 32), 3.0, "float32"), device=device)) f_end_forward(state) expected_values = [[np_two, np_three], [np_zero, np_one], [np_two, np_three]] @@ -151,8 +154,8 @@ def test_rnn_state_popn(rnn_state): # pylint: disable=redefined-outer-name f_add_sequence(state, 0) f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1])) - f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device)) - f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32), device=device)) + f_set(state, 0, 0, tvm.runtime.tensor(np_two.reshape(1, 16, 16), device=device)) + f_set(state, 0, 1, tvm.runtime.tensor(np_three.reshape(1, 32, 32), device=device)) f_end_forward(state) verify_state(state, [0], [[np_two, np_three]]) @@ -169,8 +172,8 @@ def test_rnn_state_fork_sequence(rnn_state): # pylint: disable=redefined-outer- f_add_sequence(state, 0) f_begin_forward(state, ShapeTuple([0]), ShapeTuple([1])) - f_set(state, 0, 0, tvm.nd.array(np_two.reshape(1, 16, 16), device=device)) - f_set(state, 0, 1, tvm.nd.array(np_three.reshape(1, 32, 32), device=device)) + f_set(state, 0, 0, tvm.runtime.tensor(np_two.reshape(1, 16, 16), device=device)) + f_set(state, 0, 1, tvm.runtime.tensor(np_three.reshape(1, 32, 32), device=device)) f_end_forward(state) f_fork_sequence(state, 0, 1, -1) verify_state(state, [0, 1], [[np_two, np_three], [np_two, np_three]]) diff --git a/tests/python/relax/test_runtime_sampling_flashinfer.py b/tests/python/relax/test_runtime_sampling_flashinfer.py index dc3a3c86e69a..8dcd7bf61289 100644 --- a/tests/python/relax/test_runtime_sampling_flashinfer.py +++ b/tests/python/relax/test_runtime_sampling_flashinfer.py @@ -51,8 +51,8 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): probs_np = np.array([[0.1, 0.2, 0.3, 0.2, 0.2] for _ in range(batch_size)], dtype="float32") dev = tvm.cuda(0) - prob_tvm = tvm.nd.array(probs_np, device=dev) - output_tvm = tvm.nd.empty((batch_size,), "int32", device=dev) + prob_tvm = tvm.runtime.tensor(probs_np, device=dev) + output_tvm = tvm.runtime.empty((batch_size,), "int32", device=dev) device = tvm.cuda() target = tvm.target.Target.from_device(device) diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py index d7ca2a672b55..4061da3a9c2e 100644 --- a/tests/python/relax/test_tir_call_source_kernel.py +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -92,8 +92,8 @@ def add(x_handle: T.handle, y_handle: T.handle, output_handle: T.handle): assert len(Module.get_attr("external_mods")) == 1 device = tvm.cuda(0) - x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) - y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device) + x_nd = tvm.runtime.tensor(np.random.rand(256).astype(np.float32), device) + y_nd = tvm.runtime.tensor(np.random.rand(256).astype(np.float32), device) output_np = x_nd.numpy() + y_nd.numpy() with tvm.target.Target("cuda"): diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index 6a9c34a5fb94..f2106ea2c2e7 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -37,7 +37,7 @@ def _legalize_and_build(mod: IRModule, target, dev): def _numpy_to_tvm(data): if isinstance(data, (list, tuple)): return [_numpy_to_tvm(_data) for _data in data] - return tvm.nd.array(data) + return tvm.runtime.tensor(data) def _tvm_to_numpy(data): diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e3274aea886a..b0bec5e858af 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -17,6 +17,7 @@ import pytest import tvm +import tvm.testing from tvm import relax import tvm.script diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index 2e9845f73f40..c46701d33a85 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -53,8 +53,8 @@ def main( x_np = np.random.rand(16, 16).astype(np.float32) w_np = np.random.rand(16, 16).astype(np.float32) - x_tvm = tvm.nd.array(x_np) - w_tvm = tvm.nd.array(w_np) + x_tvm = tvm.runtime.tensor(x_np) + w_tvm = tvm.runtime.tensor(w_np) params_dict = {"w": w_np if use_np_array else w_tvm} mod = relax.transform.BindParams("main", params_dict)(InputModule) assert len(mod["main"].params) == 1 @@ -97,10 +97,10 @@ def main( return out m, n, k = 4, 6, 8 - w0_tvm = tvm.nd.array(np.random.rand(n, m).astype(np.float32)) - b0_tvm = tvm.nd.array(np.random.rand(n).astype(np.float32)) - w1_tvm = tvm.nd.array(np.random.rand(k, n).astype(np.float32)) - b1_tvm = tvm.nd.array(np.random.rand(k).astype(np.float32)) + w0_tvm = tvm.runtime.tensor(np.random.rand(n, m).astype(np.float32)) + b0_tvm = tvm.runtime.tensor(np.random.rand(n).astype(np.float32)) + w1_tvm = tvm.runtime.tensor(np.random.rand(k, n).astype(np.float32)) + b1_tvm = tvm.runtime.tensor(np.random.rand(k).astype(np.float32)) params_dict = {"w0": w0_tvm, "b0": b0_tvm, "w1": w1_tvm, "b1": b1_tvm} mod = relax.transform.BindParams("main", params_dict)(Before) diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index b997eb9c6bc0..dbddc60f8cd9 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -106,8 +106,8 @@ def setup_test(): np0 = np.random.rand(16, 16).astype(np.float32) np1 = np.random.rand(16, 16).astype(np.float32) - data0 = tvm.nd.array(np0, dev) - data1 = tvm.nd.array(np1, dev) + data0 = tvm.runtime.tensor(np0, dev) + data1 = tvm.runtime.tensor(np1, dev) inputs = [data0, data1] # Ground truth should be generated before annotation diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 262e37b91b1b..83b81a6898a7 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -206,10 +206,9 @@ def main( lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast( lv0, R.Tensor((N, H, W, C), dtype="float32") ) - lv3: R.Tensor((N, C, H, W), dtype="float32") = R.permute_dims( - lv2, axes=[0, 3, 1, 2] - ) - gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv3, w) + lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3) + gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, axes=[0, 3, 1, 2]) R.output(gv) return gv @@ -4585,5 +4584,413 @@ def main( verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((8, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 40, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 2, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv5, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv6: R.Tensor((2, 10, 10, 10, 4), dtype="float32") = R.concat((gv3, gv6), axis=1) + gv7: R.Tensor((2, 40, 10, 10), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_conv2d_callback_to_buffer_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((5, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 37, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 20, 20), dtype="float32") = R.layout_transform( + gv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 5, 10, 10), dtype="float32") = R.nn.conv2d( + lv5, + w3, + strides=[2, 2], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv6: R.Tensor((2, 32, 10, 10), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv7: R.Tensor((2, 37, 10, 10), dtype="float32") = R.concat((lv6, gv6), axis=1) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((32, 32, 1, 1), dtype="float32"), + w2: R.Tensor((32, 32, 2, 2), dtype="float32"), + w3: R.Tensor((32, 32, 1, 1), dtype="float32"), + w4: R.Tensor((32, 32, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 32, 20, 20), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv1: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.max_pool2d( + gv, pool_size=[2, 2], strides=[2, 2], layout="NCHW4c", out_layout="NCHW4c" + ) + lv2: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv2, + padding=[0, 0, 1, 1], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv3: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, lv3) + gv4: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv3) + lv4: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv5: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv4, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv5: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w4, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv5, + strides=[1, 1], + padding=[0, 1, 1, 0], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv7: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv6) + gv8: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, gv5) + lv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv8, gv6) + gv9: R.Tensor((2, 32, 20, 20), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv9) + return gv9 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index bb10704acbb7..5b12480e253c 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -63,8 +63,8 @@ def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32" lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) # we expect to bind the repeated large constants lv1 = R.add( - R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), - R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + R.const(tvm.runtime.tensor(np.zeros((2, 2), dtype="int32"))), + R.const(tvm.runtime.tensor(np.zeros((2, 2), dtype="int32"))), ) gv = (lv0, lv1) R.output(gv) @@ -77,8 +77,8 @@ def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32" with R.dataflow(): lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) lv1 = R.add( - R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), - R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + R.const(tvm.runtime.tensor(np.zeros((2, 2), dtype="int32"))), + R.const(tvm.runtime.tensor(np.zeros((2, 2), dtype="int32"))), ) gv = (lv0, lv1) R.output(gv) diff --git a/tests/python/relax/test_transform_few_shot_tuning.py b/tests/python/relax/test_transform_few_shot_tuning.py index c640deee5496..e769c911a3f0 100644 --- a/tests/python/relax/test_transform_few_shot_tuning.py +++ b/tests/python/relax/test_transform_few_shot_tuning.py @@ -343,7 +343,7 @@ def _expected_results( func = func.with_attr("global_symbol", "main") rt_mod = tvm.compile(func, target="llvm") data = [ - tvm.nd.array(x) + tvm.runtime.tensor(x) for x in [ *inputs, np.zeros(output_shape, dtype=output_dtype), @@ -359,7 +359,7 @@ def _actual_results( target = _target() actual_rt_mod = tvm.compile(actual, target=target) actual_data = [ - tvm.nd.array(x, device=tvm.cuda() if target.kind.name == "cuda" else tvm.cpu()) + tvm.runtime.tensor(x, device=tvm.cuda() if target.kind.name == "cuda" else tvm.cpu()) for x in [ *inputs, np.zeros(output_shape, dtype=output_dtype), diff --git a/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py b/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py index 4b17829fa0d7..d47fa1166510 100644 --- a/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py +++ b/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py @@ -70,13 +70,13 @@ def test_fold_batchnorm_info_conv2d(): mod_fold = get_conv2d_batchnorm_sample() target = tvm.target.Target("llvm", host="llvm") - data_in = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype(np.float32)) + data_in = tvm.runtime.tensor(np.random.rand(1, 3, 224, 224).astype(np.float32)) - weight_data = tvm.nd.array(np.random.rand(32, 3, 3, 3).astype(np.float32)) - gamma_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) - beta_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) - mean_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) - variance_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + weight_data = tvm.runtime.tensor(np.random.rand(32, 3, 3, 3).astype(np.float32)) + gamma_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) + beta_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) + mean_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) + variance_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) params_np = { "weight": weight_data, "gamma": gamma_data, @@ -121,11 +121,11 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re def test_fold_batchnorm_info_conv2d_transform(): mod = get_conv2d_batchnorm_sample() mod = relax.transform.FoldBatchnormToConv2D()(mod) - weight_data = tvm.nd.array(np.random.rand(32, 3, 3, 3).astype(np.float32)) - gamma_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) - beta_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) - mean_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) - variance_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + weight_data = tvm.runtime.tensor(np.random.rand(32, 3, 3, 3).astype(np.float32)) + gamma_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) + beta_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) + mean_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) + variance_data = tvm.runtime.tensor(np.random.rand(32).astype(np.float32)) params_np = { "weight": weight_data, "gamma": gamma_data, diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 9f2e3a4a092d..c62a01768eec 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -38,7 +38,7 @@ def gen_mod(mod, name, binding): The const parameter bindings """ funcs = {} - binding = {k: tvm.nd.array(v) for k, v in binding.items()} + binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} for k, v in mod.functions.items(): if isinstance(v, tvm.relax.Function): @@ -431,12 +431,14 @@ def expected( ) -> R.Tensor((1, 1), dtype="int64"): return new_shape - before = gen_mod(Module, "before", {"indices": tvm.nd.array(np.array([0]).astype("int64"))}) + before = gen_mod( + Module, "before", {"indices": tvm.runtime.tensor(np.array([0]).astype("int64"))} + ) after = relax.transform.FoldConstant()(before) np_take = np.take([5, 4, 3, 2], [0], axis=0) np_expand = np.expand_dims(np_take, axis=[0]) np_concat = np.concatenate([np_expand], axis=0) - expected = gen_mod(Module, "expected", {"new_shape": tvm.nd.array(np_concat)}) + expected = gen_mod(Module, "expected", {"new_shape": tvm.runtime.tensor(np_concat)}) tvm.ir.assert_structural_equal(after, expected) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 8e583b3dd4cc..a67bc63f9bf2 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -2444,5 +2444,77 @@ def main( relax.transform.FuseTIR()(Before) +def test_block_name_numeric_suffix_deduplication(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def add1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in range(10): + with T.block("compute1"): + vi = T.axis.spatial(10, i) + y[vi] = x[vi] + T.float32(1.0) + + @T.prim_func(private=True) + def mul1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in range(10): + with T.block("compute1"): + vi = T.axis.spatial(10, i) + y[vi] = x[vi] * T.float32(2.0) + + @R.function(private=True) + def fused_add_mul(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": True}) + cls = Before + with R.dataflow(): + lv1 = R.call_tir(cls.add1, (x,), out_sinfo=R.Tensor((10,), dtype="float32")) + lv2 = R.call_tir(cls.mul1, (lv1,), out_sinfo=R.Tensor((10,), dtype="float32")) + R.output(lv2) + return lv2 + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Before + with R.dataflow(): + gv = cls.fused_add_mul(x) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def fused_add_mul(p_x: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": True}) + x = T.match_buffer(p_x, (T.int64(10),)) + y_intermediate_1 = T.match_buffer(p_output0, (T.int64(10),), elem_offset=T.int32(0)) + with T.block("root"): + T.reads() + T.writes() + y_intermediate = T.alloc_buffer((T.int64(10),), elem_offset=T.int32(0)) + for i in range(10): + with T.block("compute1"): + vi = T.axis.spatial(10, i) + T.reads(x[vi]) + T.writes(y_intermediate[vi]) + y_intermediate[vi] = x[vi] + T.float32(1.0) + for i in range(10): + with T.block("compute2"): + vi = T.axis.spatial(10, i) + T.reads(y_intermediate[vi]) + T.writes(y_intermediate_1[vi]) + y_intermediate_1[vi] = y_intermediate[vi] * T.float32(2.0) + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Expected + with R.dataflow(): + gv = R.call_tir(cls.fused_add_mul, (x,), out_sinfo=R.Tensor((10,), dtype="float32")) + R.output(gv) + return gv + + _check(Before, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_gradient_numeric.py b/tests/python/relax/test_transform_gradient_numeric.py index 70d6da8d7109..3b1d1dcefee4 100644 --- a/tests/python/relax/test_transform_gradient_numeric.py +++ b/tests/python/relax/test_transform_gradient_numeric.py @@ -24,7 +24,7 @@ def rand(dtype, *shape): - return tvm.nd.array(np.random.rand(*shape).astype(dtype)) + return tvm.runtime.tensor(np.random.rand(*shape).astype(dtype)) def _legalize_and_build(mod, target, dev): @@ -118,7 +118,9 @@ def test_mlp_blockbuilder(target, dev): for arg in After["MLP_adjoint"].params: shape = [int(l) for l in arg.struct_info.shape] if arg.struct_info.dtype == "int64": - args.append(tvm.nd.array(np.random.randint(0, out_size, size=shape).astype(np.int64))) + args.append( + tvm.runtime.tensor(np.random.randint(0, out_size, size=shape).astype(np.int64)) + ) else: # float32 args.append(rand("float32", *shape)) @@ -127,7 +129,7 @@ def test_mlp_blockbuilder(target, dev): _, grad = vm_after["MLP_adjoint"](*args) def func(*inputs): - loss = vm_before["MLP"](args[0], *[tvm.nd.array(i) for i in inputs], args[-1]) + loss = vm_before["MLP"](args[0], *[tvm.runtime.tensor(i) for i in inputs], args[-1]) return loss.numpy() check_numerical_grads(func, [i.numpy() for i in args[1:-1]], [i.numpy() for i in grad]) @@ -183,7 +185,7 @@ def main(x: R.Tensor((6,), "float32"), y: R.Tensor((6, 3, 4), "float32")): _, grad = vm_after["main_adjoint"](*args) def func(*inputs): - loss = vm_before["main"](*[tvm.nd.array(i) for i in inputs]) + loss = vm_before["main"](*[tvm.runtime.tensor(i) for i in inputs]) return loss.numpy() check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in grad]) @@ -220,7 +222,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")): _, grad = vm_after["main_adjoint"](*args) def func(*inputs): - loss = vm_before["main"](*[tvm.nd.array(i) for i in inputs]) + loss = vm_before["main"](*[tvm.runtime.tensor(i) for i in inputs]) return loss.numpy() check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in grad]) diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 25d483fc449c..ae0521a0e2f8 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -660,11 +660,11 @@ def transform_params( transformed = {} expected = [params[0].transpose(1, 0, 2, 3), params[1]] - @tvm.register_func("get_item", override=True) + @tvm.register_global_func("get_item", override=True) def get_item(i): - return tvm.nd.array(params[i], dev) + return tvm.runtime.tensor(params[i], dev) - @tvm.register_func("set_item", override=True) + @tvm.register_global_func("set_item", override=True) def set_item(i, value): assert i not in transformed, f"Set item called multiple times for index {i}" transformed[i] = value.numpy() diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index ff03ab4152c9..de2f183a102e 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -949,7 +949,7 @@ def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64 T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4]) T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4]) T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) - adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] * T.float32(0.020408163265306121) + adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] / T.float32(49.0) # fmt: on mod = LegalizeOps()(AdaptiveAvgPool2D) @@ -1104,15 +1104,14 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): return gv @T.prim_func(private=True) - def leaky_relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + def leaky_relu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.Select(T.float32(0) < rxplaceholder[i0_1, i1_1], rxplaceholder[i0_1, i1_1], \ - rxplaceholder[i0_1, i1_1] * T.float32(0.02)) + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(x[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * T.float32(0.02)) # fmt: on mod = LegalizeOps()(LeakyRelu) @@ -1140,19 +1139,17 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): return gv @T.prim_func(private=True) - def leaky_relu(var_rxplaceholder: T.handle, var_compute: T.handle): + def leaky_relu(var_x: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - compute = T.match_buffer(var_compute, [m, n], dtype="float32") + m, n = T.int64(), T.int64() + x = T.match_buffer(var_x, (m, n)) + compute = T.match_buffer(var_compute, (m, n)) for i0, i1 in T.grid(m, n): with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.Select(T.float32(0) < rxplaceholder[i0_1, i1_1], rxplaceholder[i0_1, i1_1], \ - rxplaceholder[i0_1, i1_1] * T.float32(0.03)) + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(x[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * T.float32(0.029999999999999999)) # fmt: on mod = LegalizeOps()(LeakyRelu) @@ -1259,42 +1256,42 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): return gv @T.prim_func(private=True) - def gelu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): + def gelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) - T_multiply_1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") - compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") - T_multiply_2 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") - T_divide = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") - for i0, i1 in T.grid(T.int64(2), T.int64(3)): + T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3))) + compute = T.alloc_buffer((T.int64(2), T.int64(3))) + T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3))) + T_add = T.alloc_buffer((T.int64(2), T.int64(3))) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_multiply"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1]) - T.writes(T_multiply_1[ax0, ax1]) - T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757) + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1]) + T.writes(T_multiply_1[v_ax0, v_ax1]) + T_multiply_1[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(T_multiply_1[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32") - for i0, i1 in T.grid(T.int64(2), T.int64(3)): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_1[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_multiply_1"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(compute[ax0, ax1]) - T.writes(T_multiply_2[ax0, ax1]) - T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5) - for i0, i1 in T.grid(T.int64(2), T.int64(3)): - with T.block("T_divide"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(T_multiply_2[ax0, ax1]) - T.writes(T_divide[ax0, ax1]) - T_divide[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1] - for i0, i1 in T.grid(T.int64(2), T.int64(3)): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(compute[v_ax0, v_ax1]) + T.writes(T_multiply_2[v_ax0, v_ax1]) + T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_2[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_multiply_2"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1], T_divide[ax0, ax1]) - T.writes(T_multiply[ax0, ax1]) - T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_divide[ax0, ax1] + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T_add[v_ax0, v_ax1] # fmt: on mod = LegalizeOps()(Gelu) @@ -1322,46 +1319,45 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): return gv @T.prim_func(private=True) - def gelu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): + def gelu(var_x: T.handle, var_T_multiply: T.handle): T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") - T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") - T_multiply_1 = T.alloc_buffer([m, n], dtype="float32") - compute = T.alloc_buffer([m, n], dtype="float32") - T_multiply_2 = T.alloc_buffer([m, n], dtype="float32") - T_add = T.alloc_buffer([m, n], dtype="float32") - for i0, i1 in T.grid(m, n): + m, n = T.int64(), T.int64() + x = T.match_buffer(var_x, (m, n)) + T_multiply = T.match_buffer(var_T_multiply, (m, n)) + T_multiply_1 = T.alloc_buffer((m, n)) + compute = T.alloc_buffer((m, n)) + T_multiply_2 = T.alloc_buffer((m, n)) + T_add = T.alloc_buffer((m, n)) + for ax0, ax1 in T.grid(m, n): with T.block("T_multiply"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1]) - T.writes(T_multiply_1[ax0, ax1]) - T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757) + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1]) + T.writes(T_multiply_1[v_ax0, v_ax1]) + T_multiply_1[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T.float32(0.70710678118654757) for i0, i1 in T.grid(m, n): with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(T_multiply_1[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32") - for i0, i1 in T.grid(m, n): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_1[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.erf(T_multiply_1[v_i0, v_i1]) + for ax0, ax1 in T.grid(m, n): with T.block("T_multiply_1"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(compute[ax0, ax1]) - T.writes(T_multiply_2[ax0, ax1]) - T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5) - for i0, i1 in T.grid(m, n): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(compute[v_ax0, v_ax1]) + T.writes(T_multiply_2[v_ax0, v_ax1]) + T_multiply_2[v_ax0, v_ax1] = compute[v_ax0, v_ax1] * T.float32(0.5) + for ax0, ax1 in T.grid(m, n): with T.block("T_add"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(T_multiply_2[ax0, ax1]) - T.writes(T_add[ax0, ax1]) - T_add[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1] - for i0, i1 in T.grid(m, n): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_multiply_2[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = T.float32(0.5) + T_multiply_2[v_ax0, v_ax1] + for ax0, ax1 in T.grid(m, n): with T.block("T_multiply_2"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1], T_add[ax0, ax1]) - T.writes(T_multiply[ax0, ax1]) - T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_add[ax0, ax1] + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], T_add[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T_add[v_ax0, v_ax1] # fmt: on mod = LegalizeOps()(Gelu) @@ -1887,29 +1883,29 @@ def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,), dtype="float32")) return gv @T.prim_func(private=True) - def cross_entropy_with_logits(rxplaceholder: T.Buffer(T.int64(3), "float32"), rxplaceholder_1: T.Buffer(T.int64(3), "float32"), T_multiply: T.Buffer((), "float32")): + def cross_entropy_with_logits(x: T.Buffer((T.int64(3),), "float32"), y: T.Buffer((T.int64(3),), "float32"), T_multiply: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) - T_multiply_1 = T.alloc_buffer([T.int64(3)], dtype="float32") - T_multiply_red = T.alloc_buffer([], dtype="float32") - for i0 in T.serial(T.int64(3)): + T_multiply_1 = T.alloc_buffer((T.int64(3),)) + T_multiply_red = T.alloc_buffer(()) + for ax0 in range(T.int64(3)): with T.block("T_multiply"): - ax0 = T.axis.spatial(T.int64(3), i0) - T.reads(rxplaceholder[ax0], rxplaceholder_1[ax0]) - T.writes(T_multiply_1[ax0]) - T_multiply_1[ax0] = rxplaceholder[ax0] * rxplaceholder_1[ax0] - for i0 in T.serial(T.int64(3)): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(x[v_ax0], y[v_ax0]) + T.writes(T_multiply_1[v_ax0]) + T_multiply_1[v_ax0] = x[v_ax0] * y[v_ax0] + for k0 in range(T.int64(3)): with T.block("T_multiply_red"): - k0 = T.axis.reduce(T.int64(3), i0) - T.reads(T_multiply_1[k0]) + v_k0 = T.axis.reduce(T.int64(3), k0) + T.reads(T_multiply_1[v_k0]) T.writes(T_multiply_red[()]) with T.init(): - T_multiply_red[()] = T.float32(0) - T_multiply_red[()] = T_multiply_red[()] + T_multiply_1[k0] + T_multiply_red[()] = T.float32(0.0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply_1[v_k0] with T.block("T_multiply_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_multiply[()]) - T_multiply[()] = T_multiply_red[()] * T.float32(-1) + T_multiply[()] = T_multiply_red[()] * T.float32(-1.0) # fmt: on mod = LegalizeOps()(CrossEntropyWithLogits) @@ -1933,35 +1929,35 @@ def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float3 return gv @T.prim_func(private=True) - def cross_entropy_with_logits(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((), "float32")): + def cross_entropy_with_logits(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) - T_multiply = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") - T_multiply_red = T.alloc_buffer([], dtype="float32") - T_multiply_1 = T.alloc_buffer([], dtype="float32") - for i0, i1 in T.grid(T.int64(2), T.int64(3)): + T_multiply = T.alloc_buffer((T.int64(2), T.int64(3))) + T_multiply_red = T.alloc_buffer(()) + T_multiply_1 = T.alloc_buffer(()) + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_multiply"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1]) - T.writes(T_multiply[ax0, ax1]) - T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * rxplaceholder_1[ax0, ax1] - for i0, i1 in T.grid(T.int64(2), T.int64(3)): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * y[v_ax0, v_ax1] + for k0, k1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_multiply_red"): - k0, k1 = T.axis.remap("RR", [i0, i1]) - T.reads(T_multiply[k0, k1]) + v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) + T.reads(T_multiply[v_k0, v_k1]) T.writes(T_multiply_red[()]) with T.init(): - T_multiply_red[()] = T.float32(0) - T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1] + T_multiply_red[()] = T.float32(0.0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1] with T.block("T_multiply_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_multiply_1[()]) - T_multiply_1[()] = T_multiply_red[()] * T.float32(-1) + T_multiply_1[()] = T_multiply_red[()] * T.float32(-1.0) with T.block("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_1[()]) T.writes(T_divide[()]) - T_divide[()] = T_multiply_1[()] * T.float32(0.5) + T_divide[()] = T_multiply_1[()] / T.float32(2) # fmt: on mod = LegalizeOps()(CrossEntropyWithLogits) @@ -1987,34 +1983,33 @@ def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n", "m"), dtype return gv @T.prim_func(private=True) - def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, T_divide: T.Buffer((), "float32")): + def cross_entropy_with_logits(var_x: T.handle, var_y: T.handle, T_divide: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) - m = T.int64() - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") - T_multiply = T.alloc_buffer([n, m], dtype="float32") - T_multiply_red = T.alloc_buffer([], dtype="float32") - T_multiply_1 = T.alloc_buffer([], dtype="float32") + m, n = T.int64(), T.int64() + x = T.match_buffer(var_x, (n, m)) + y = T.match_buffer(var_y, (n, m)) + T_multiply = T.alloc_buffer((n, m)) + T_multiply_red = T.alloc_buffer(()) + T_multiply_1 = T.alloc_buffer(()) for ax0, ax1 in T.grid(n, m): with T.block("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax0, v_ax1]) + T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1]) T.writes(T_multiply[v_ax0, v_ax1]) - T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * rxplaceholder_1[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * y[v_ax0, v_ax1] for k0, k1 in T.grid(n, m): with T.block("T_multiply_red"): v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) T.reads(T_multiply[v_k0, v_k1]) T.writes(T_multiply_red[()]) with T.init(): - T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T.float32(0.0) T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1] with T.block("T_multiply_1"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_multiply_1[()]) - T_multiply_1[()] = T_multiply_red[()] * T.float32(-1) + T_multiply_1[()] = T_multiply_red[()] * T.float32(-1.0) with T.block("T_divide"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_1[()]) @@ -2217,7 +2212,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(x_red[v_ax0]) T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = x_red[v_ax0] * T.float32(0.00063775510204081628) + T_divide_1[v_ax0] = x_red[v_ax0] / T.float32(1568) for ax0 in range(T.int64(3)): with T.block("T_multiply_2"): v_ax0 = T.axis.spatial(T.int64(3), ax0) @@ -2303,7 +2298,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(T_multiply_red[v_ax0]) T.writes(T_divide_2[v_ax0]) - T_divide_2[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) + T_divide_2[v_ax0] = T_multiply_red[v_ax0] / T.float32(1568) for ax0 in range(T.int64(3)): with T.block("T_multiply_5"): v_ax0 = T.axis.spatial(T.int64(3), ax0) @@ -2676,7 +2671,7 @@ def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.in ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3]) T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) - T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.05) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3] + T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] / T.float32(20) - rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20) * (rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3] # fmt: on mod = LegalizeOps()(LayerNorm) tvm.ir.assert_structural_equal(mod, Expected) @@ -2720,7 +2715,7 @@ def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffe v_ax0 = T.axis.spatial(T.int64(3), ax0) T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], layer_norm_weight[v_ax0], layer_norm_bias[v_ax0]) T.writes(T_layer_norm[v_ax0]) - T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] * T.float32(0.33333333333333331)) * T.rsqrt(x_red_temp_v1[()] * T.float32(0.33333333333333331) - x_red_temp_v0[()] * T.float32(0.33333333333333331) * (x_red_temp_v0[()] * T.float32(0.33333333333333331)) + T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0] + T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] / T.float32(3)) * T.rsqrt(x_red_temp_v1[()] / T.float32(3) - x_red_temp_v0[()] / T.float32(3) * (x_red_temp_v0[()] / T.float32(3)) + T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0] @R.function def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): @@ -2911,7 +2906,7 @@ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.in v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) - T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.float32(40) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): with T.block("T_reshape_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -2996,7 +2991,7 @@ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.in v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) T.reads(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) - T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05))) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / T.float32(40) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.float32(40)) + T.float32(1.0000000000000001e-05))) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): with T.block("T_reshape_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -3143,7 +3138,7 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) - rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] / T.float32(20) + T.float32(1.0000000000000001e-05)) for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): with T.block("T_cast_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -3219,7 +3214,7 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) - rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] / T.float32(20) + T.float32(1.0000000000000001e-05)) for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): with T.block("T_cast_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -3381,7 +3376,7 @@ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "floa v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(rsqrt[v_ax0, v_ax1]) - rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05)) + rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] / T.float32(20) + T.float32(1.0000000000000001e-05)) for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): with T.block("T_cast_1"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -3424,7 +3419,7 @@ def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8, 32, 8), "flo @tvm.script.ir_module class Expected: @T.prim_func(private=True) - def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8)), "float32"), B: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(8)), "float32"), C: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), D: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(16)), "float32")): + def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8)), "float32"), k: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(8)), "float32"), v: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), bias: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(16)), "float32")): T.func_attr({"tir.noalias": True}) # with T.block("root"): T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) @@ -3450,9 +3445,9 @@ def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): with T.block("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) + T.reads(q[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] + T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = q[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) @@ -3462,23 +3457,23 @@ def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(8)): with T.block("T_transpose_1"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(B[v_ax0, v_ax2, v_ax1, v_ax3]) + T.reads(k[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = B[v_ax0, v_ax2, v_ax1, v_ax3] + T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = k[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)): with T.block("T_reshape_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]) T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2]) T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)] - for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(8), T.int64(8)): + for b, i, j, k_1 in T.grid(T.int64(128), T.int64(16), T.int64(8), T.int64(8)): with T.block("T_batch_matmul_NT"): - v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) + v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k_1]) T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k]) T.writes(T_batch_matmul_NT[v_b, v_i, v_j]) T.block_attr({"layout_free_placeholders": [T_reshape_1]}) with T.init(): - T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0) + T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0.0) T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): with T.block("T_multiply"): @@ -3495,9 +3490,9 @@ def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(8)): with T.block("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], D[v_ax0, v_ax1, v_ax2, v_ax3]) + T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], bias[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] + D[v_ax0, v_ax1, v_ax2, v_ax3] + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] + bias[v_ax0, v_ax1, v_ax2, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): with T.block("T_reshape_3"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) @@ -3509,14 +3504,14 @@ def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_reshape_3[v_i0, v_i1, v_i2]) T.writes(trilu[v_i0, v_i1, v_i2]) - trilu[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, T_reshape_3[v_i0, v_i1, v_i2], T.float32(0)) + trilu[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, T_reshape_3[v_i0, v_i1, v_i2], T.float32(0.0)) for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), T.int64(1), T.int64(8)): with T.block("trilu_red"): v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2]) T.reads(trilu[v_ax0, v_ax1, v_k2]) T.writes(trilu_red[v_ax0, v_ax1, v_ax2]) with T.init(): - trilu_red[v_ax0, v_ax1, v_ax2] = T.float32(-3.4028234663852886e+38) + trilu_red[v_ax0, v_ax1, v_ax2] = T.float32(-340282346638528859811704183484516925440.0) trilu_red[v_ax0, v_ax1, v_ax2] = T.max(trilu_red[v_ax0, v_ax1, v_ax2], trilu[v_ax0, v_ax1, v_k2]) for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): with T.block("T_subtract"): @@ -3535,14 +3530,14 @@ def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(compute[v_i0, v_i1, v_i2]) T.writes(trilu_1[v_i0, v_i1, v_i2]) - trilu_1[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, compute[v_i0, v_i1, v_i2], T.float32(0)) + trilu_1[v_i0, v_i1, v_i2] = T.Select(v_i2 <= v_i1, compute[v_i0, v_i1, v_i2], T.float32(0.0)) for ax0, ax1, ax2, k2 in T.grid(T.int64(128), T.int64(16), T.int64(1), T.int64(8)): with T.block("trilu_red_1"): v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, ax2, k2]) T.reads(trilu_1[v_ax0, v_ax1, v_k2]) T.writes(trilu_red_1[v_ax0, v_ax1, v_ax2]) with T.init(): - trilu_red_1[v_ax0, v_ax1, v_ax2] = T.float32(0) + trilu_red_1[v_ax0, v_ax1, v_ax2] = T.float32(0.0) trilu_red_1[v_ax0, v_ax1, v_ax2] = trilu_red_1[v_ax0, v_ax1, v_ax2] + trilu_1[v_ax0, v_ax1, v_k2] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)): with T.block("T_divide"): @@ -3553,23 +3548,23 @@ def attention_bias(A: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8) for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(8), T.int64(16)): with T.block("T_transpose_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(C[v_ax0, v_ax2, v_ax1, v_ax3]) + T.reads(v[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = C[v_ax0, v_ax2, v_ax1, v_ax3] + T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = v[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)): with T.block("T_reshape_4"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]) T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2]) T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)] - for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(16), T.int64(8)): + for b, i, j, k_1 in T.grid(T.int64(128), T.int64(16), T.int64(16), T.int64(8)): with T.block("T_batch_matmul_NN"): - v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) + v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k_1]) T.reads(T_divide[v_b, v_i, v_k], T_reshape_4[v_b, v_k, v_j]) T.writes(T_batch_matmul_NN[v_b, v_i, v_j]) T.block_attr({"layout_free_placeholders": [T_reshape_4]}) with T.init(): - T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0) + T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0.0) T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_divide[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), T.int64(16), T.int64(16)): with T.block("T_reshape_5"): @@ -3589,7 +3584,6 @@ def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8) cls = Expected gv = R.call_tir(cls.attention_bias, (q, k, v, bias), out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32")) return gv - # fmt: on mod = LegalizeOps()(Attention) tvm.ir.assert_structural_equal(mod, Expected) diff --git a/tests/python/relax/test_transform_legalize_ops_qdq.py b/tests/python/relax/test_transform_legalize_ops_qdq.py index 55f1acadb134..09706c637ef7 100644 --- a/tests/python/relax/test_transform_legalize_ops_qdq.py +++ b/tests/python/relax/test_transform_legalize_ops_qdq.py @@ -212,7 +212,7 @@ def quantize( "int8", T.max( T.min( - T.round(A[v_i0, v_i1] * T.float32(0.5)) + T.float32(1), + T.round(A[v_i0, v_i1] / T.float32(2)) + T.float32(1), T.float32(127), ), T.float32(-128), @@ -311,7 +311,7 @@ def quantize( "int8", T.max( T.min( - T.round(A[v_i0, v_i1] * T.float16(0.5)) + T.float16(1), + T.round(A[v_i0, v_i1] / T.float16(2)) + T.float16(1), T.float16(127), ), T.float16(-128), diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index f8dab8981552..7edfff3dfc43 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -627,7 +627,7 @@ def mean(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5) ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder_red[ax0, ax1]) T.writes(T_divide[ax0, ax1]) - T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] * T.float32(0.1) + T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] / T.float32(10) # fmt: on mod = LegalizeOps()(Mean) @@ -718,7 +718,7 @@ def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)) v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.0083333333333333332) + T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(120.0) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_subtract"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -743,7 +743,7 @@ def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)) vi = T.axis.spatial(1, T.int64(0)) T.reads(T_multiply_red[()]) T.writes(T_divide_1[()]) - T_divide_1[()] = T_multiply_red[()] * T.float32(0.0083333333333333332) + T_divide_1[()] = T_multiply_red[()] / T.float32(120.0) with T.block("compute"): vi = T.axis.spatial(1, T.int64(0)) T.reads(T_divide_1[()]) @@ -881,7 +881,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) T.writes(T_divide_1[ax0, ax1, ax2, ax3]) - T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001) + T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] / T.float32(10.0) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -907,7 +907,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) T.writes(T_divide[ax0, ax1, ax2, ax3]) - T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001) + T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] / T.float32(10) # fmt: on mod = LegalizeOps()(Variance) @@ -1027,7 +1027,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.10000000000000001) + T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(10) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_subtract"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -1053,7 +1053,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(T_multiply_red[v_ax0, v_ax1]) T.writes(T_divide[v_ax0, v_ax1]) - T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] * T.float32(0.10000000000000001) + T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] / T.float32(10) @R.function def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype="float32"): diff --git a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py new file mode 100644 index 000000000000..d92570025fce --- /dev/null +++ b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py @@ -0,0 +1,344 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + + +@visitor +class ValidateBufferScopes(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, is_matched: bool) -> None: + self.is_matched = is_matched + + def visit(self, mod: IRModule) -> None: + """Entry point""" + self.mod = mod + for key, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op.name == "relax.call_tir": + pfunc = self.mod[call.args[0]] + if not self.is_matched: + # All scopes should be global in before pass + for _, buf in pfunc.buffer_map.items(): + assert ( + "global" == buf.data.type_annotation.storage_scope + ), f"expected to be global scoped, but got {val.data.type_annotation.storage_scope}" + else: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + buf = pfunc.buffer_map[pfunc.params[idx]] + assert ( + arg_sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {arg_sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + buf = pfunc.buffer_map[pfunc.params[-1]] + assert ( + call.sinfo_args[0].vdevice.memory_scope + == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {call.sinfo_args[0].vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + buf = pfunc.buffer_map[pfunc.params[len(call.args[1]) + idx]] + assert ( + sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + + +def verify(input): + ValidateBufferScopes(False).visit(input) + mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(input) + ValidateBufferScopes(True).visit(mod) + + +def test_single_arg_return(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv2 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5: R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = R.to_vdevice(lv2, dst_vdevice="opencl:1:global") + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + verify(Input) + + +def test_multi_arg_return(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def conv2d_NCHWc_OIHWo_opencl( + lv: T.Buffer((T.int64(2), T.int64(4), T.int64(28), T.int64(28), T.int64(4)), "float32"), + lv1: T.Buffer((T.int64(1), T.int64(16), T.int64(3), T.int64(3), T.int64(4)), "float32"), + conv2d_NCHWc_OIHWo: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + conv2d_NCHWc_OIHWo[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def fused_relu_concatenate_split( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + T_split_sections_intermediate: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + T_split_sections_intermediate_1: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + T_split_sections_intermediate[0, 0, 0, 0, 0] = T.float32(0.0) + T_split_sections_intermediate_1[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(16), T.int64(28), T.int64(28)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(28), T.int64(28), T.int64(4)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform1( + w: T.Buffer((T.int64(4), T.int64(16), T.int64(3), T.int64(3)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(1), T.int64(16), T.int64(3), T.int64(3), T.int64(4)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform2( + lv3: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0] = T.float32(0.0) + + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + w: R.Tensor((4, 16, 3, 3), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ): + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 4, 28, 28, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv1 = R.call_tir( + cls.te_layout_transform1, + (w,), + out_sinfo=R.Tensor( + (1, 16, 3, 3, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + gv = R.call_tir( + cls.conv2d_NCHWc_OIHWo_opencl, + (lv, lv1), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv_1 = R.call_tir( + cls.fused_relu_concatenate_split, + (gv,), + out_sinfo=[ + R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), + R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), + ], + ) + lv3: R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = lv_1[0] + lv4 = R.call_tir( + cls.te_layout_transform2, + (lv3,), + out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + ) + lv5: R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = lv_1[1] + lv6 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + ) + gv4: R.Tuple( + R.Tensor( + (2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ), + R.Tensor( + (2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + verify(Input) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 83e4d264c6a3..06e4ea142e95 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1347,6 +1347,18 @@ def main(x: R.Tensor((2, "n"), dtype="float32")): relax.transform.StaticPlanBlockMemory()(Module) +def test_invalid_tir_var_lower_bound(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")): + R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True}) + return x + + with pytest.raises((TVMError, TypeError)): + relax.transform.StaticPlanBlockMemory()(Module) + + def test_add(): @I.ir_module class Module: diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 658f80a06ec5..4e90216f9bc0 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -836,7 +836,7 @@ def main( "w2": np.random.uniform(size=(4, 4, 1, 1)).astype("float16"), "w3": np.random.uniform(size=(4,)).astype("float16"), } - binding = {k: tvm.nd.array(v) for k, v in binding.items()} + binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} Input = relax.transform.BindParams("main", binding)(Input) Expected = relax.transform.BindParams("main", binding)(Expected) Expected2 = relax.transform.BindParams("main", binding)(Expected2) @@ -975,7 +975,7 @@ def main( "w": np.random.uniform(size=(512, 4, 3, 3)).astype("float32"), "bias": np.random.uniform(size=(512,)).astype("float32"), } - binding = {k: tvm.nd.array(v) for k, v in binding_np.items()} + binding = {k: tvm.runtime.tensor(v) for k, v in binding_np.items()} Input_bound = relax.transform.BindParams("main", binding)(Input) Expected = relax.transform.BindParams("main", binding)(Expected) @@ -983,7 +983,7 @@ def main( _assert_test(Input_bound, expected2=Expected) binding_np["bias"][0] = 70000 # Out of fp16 range - binding = {k: tvm.nd.array(v) for k, v in binding_np.items()} + binding = {k: tvm.runtime.tensor(v) for k, v in binding_np.items()} Input_bound = relax.transform.BindParams("main", binding)(Input) Expected_no_bias_cast = relax.transform.BindParams("main", binding)(Expected_no_bias_cast) diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 694e7a688cf7..c0ff78ca4c6b 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -439,5 +439,20 @@ def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): _check(foo, bb.get()["foo"]) +def test_hint_on_device_scoped(): + @R.function + def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + r = R.hint_on_device(x, R.device(4, 2), "global.texture") + return r + + x = relax.Var("x", R.Tensor((), "int32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + tensor = bb.emit(relax.op.hint_on_device(x, R.opencl(2), "global.texture")) + bb.emit_func_output(tensor) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py new file mode 100644 index 000000000000..66e0adac3d22 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_all_class_non_max_suppression(): + @R.function + def foo( + boxes: R.Tensor((10, 5, 4), "float32"), + scores: R.Tensor((10, 8, 5), "float32"), + max_output_boxes_per_class: R.Tensor((), "int64"), + iou_threshold: R.Tensor((), "float32"), + score_threshold: R.Tensor((), "float32"), + ) -> R.Tuple(R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64")): + gv: R.Tuple( + R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64") + ) = R.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + "onnx", + ) + return gv + + boxes = relax.Var("boxes", R.Tensor((10, 5, 4), "float32")) + scores = relax.Var("scores", R.Tensor((10, 8, 5), "float32")) + max_output_boxes_per_class = relax.Var("max_output_boxes_per_class", R.Tensor((), "int64")) + iou_threshold = relax.Var("iou_threshold", R.Tensor((), "float32")) + score_threshold = relax.Var("score_threshold", R.Tensor((), "float32")) + + bb = relax.BlockBuilder() + with bb.function( + "foo", [boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold] + ): + gv = bb.emit( + relax.op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_pyfunc.py b/tests/python/relax/test_tvmscript_pyfunc.py new file mode 100644 index 000000000000..7b3c4052fa93 --- /dev/null +++ b/tests/python/relax/test_tvmscript_pyfunc.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Test TVMScript @I.pyfunc decorator functionality. + +This test verifies: +1. @I.pyfunc decorator works correctly +2. Python functions are properly integrated into IRModule +3. BasePyModule inheritance is handled correctly +4. ExternFunc nodes are created for Python functions +""" + +import pytest +import torch +import tvm +from tvm import relax +from tvm.script import ir as I, relax as R, tir as T +from tvm.relax import BasePyModule +import numpy as np + + +@I.ir_module +class TestPyFuncModule(BasePyModule): + """Test module with Python functions using @I.pyfunc decorator.""" + + @I.pyfunc + def pytorch_processor(x: torch.Tensor) -> torch.Tensor: + """Python function that processes PyTorch tensors.""" + return torch.nn.functional.relu(x) * 2.0 + + @I.pyfunc + def pytorch_adder(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Python function that adds two PyTorch tensors.""" + return x + y + + @I.pyfunc + def pytorch_complex_ops(x: torch.Tensor) -> torch.Tensor: + """Complex PyTorch operations.""" + result = torch.nn.functional.softmax(x, dim=0) + result = torch.nn.functional.dropout(result, p=0.1, training=False) + return result * 10.0 + + @T.prim_func + def simple_tir_func( + var_A: T.handle, + var_B: T.handle, + ): + T.func_attr({"tir.noalias": True}) + n = T.int32() + A = T.match_buffer(var_A, (n,), "float32") + B = T.match_buffer(var_B, (n,), "float32") + + for i in T.grid(n): + with T.block("copy"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + +class TestTVMScriptPyFunc: + def test_pyfunc_decorator_creates_pyfuncs_attribute(self): + module = TestPyFuncModule + + assert hasattr(module, "pyfuncs"), "Module should have pyfuncs attribute" + + pyfuncs = module.pyfuncs + assert isinstance(pyfuncs, dict), "pyfuncs should be a dictionary" + + expected_functions = ["pytorch_processor", "pytorch_adder", "pytorch_complex_ops"] + for func_name in expected_functions: + assert func_name in pyfuncs, f"Function {func_name} should be in pyfuncs" + + def test_pyfunc_functions_are_callable(self): + """Test that Python functions in pyfuncs are callable.""" + module = TestPyFuncModule + pyfuncs = module.pyfuncs + + # Test pytorch_processor + processor_func = pyfuncs["pytorch_processor"] + assert callable(processor_func), "pytorch_processor should be callable" + + # Test pytorch_adder + adder_func = pyfuncs["pytorch_adder"] + assert callable(adder_func), "pytorch_adder should be callable" + + # Test pytorch_complex_ops + complex_func = pyfuncs["pytorch_complex_ops"] + assert callable(complex_func), "pytorch_complex_ops should be callable" + + def test_pyfunc_functions_execute_correctly(self): + """Test that Python functions execute correctly.""" + module = TestPyFuncModule + pyfuncs = module.pyfuncs + + # Create test data + x = torch.tensor([1.0, -2.0, 3.0, -4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + # Test pytorch_processor + processor_func = pyfuncs["pytorch_processor"] + processor_result = processor_func(x) + + assert isinstance(processor_result, torch.Tensor) + expected = torch.nn.functional.relu(x) * 2.0 + assert torch.allclose(processor_result, expected, atol=1e-5) + + # Test pytorch_adder + adder_func = pyfuncs["pytorch_adder"] + adder_result = adder_func(x, y) + + assert isinstance(adder_result, torch.Tensor) + expected = x + y + assert torch.allclose(adder_result, expected, atol=1e-5) + + # Test pytorch_complex_ops + complex_func = pyfuncs["pytorch_complex_ops"] + complex_result = complex_func(x) + + assert isinstance(complex_result, torch.Tensor) + # Note: dropout is non-deterministic, so we just check shape and type + assert complex_result.shape == x.shape + assert complex_result.dtype == x.dtype + + def test_pyfunc_module_has_functions_attribute(self): + """Test that the module has functions attribute for IRModule operations.""" + module = TestPyFuncModule + + # Check if functions attribute exists + assert hasattr(module, "functions"), "Module should have functions attribute" + + functions = module.functions + # TVM IRModule.functions is not a standard dict, but has dict-like behavior + assert hasattr(functions, "__getitem__"), "functions should support dict-like access" + assert hasattr(functions, "__iter__"), "functions should be iterable" + + def test_pyfunc_module_script_method(self): + """Test that the module has script() method for TVMScript output.""" + module = TestPyFuncModule + + # Check if script method exists + assert hasattr(module, "script"), "Module should have script method" + + # Test script method execution + script_output = module.script() + assert isinstance(script_output, str), "script() should return a string" + assert len(script_output) > 0, "script() should return non-empty string" + + def test_pyfunc_module_inheritance_flag(self): + """Test that the module has BasePyModule inheritance flag.""" + module = TestPyFuncModule + + # Check if inheritance flag exists (this might not be set in all implementations) + if hasattr(module, "_base_py_module_inherited"): + assert module._base_py_module_inherited, "Inheritance flag should be True" + else: + # Alternative: check if the module supports Python functions + assert hasattr(module, "pyfuncs"), "Module should support Python functions" + + # Check if original class is preserved (this might not be set in all implementations) + if hasattr(module, "_original_class"): + assert module._original_class is not None, "Original class should be preserved" + else: + # Alternative: check if module is callable (ModuleFactory) + assert hasattr(module, "__call__"), "Module should be callable (ModuleFactory)" + + def test_pyfunc_module_creation_and_execution(self): + module = TestPyFuncModule + + assert hasattr(module, "__call__"), "Module should be callable" + + device = tvm.cpu(0) + instance = module(device) + + assert isinstance(instance, BasePyModule), "Instance should be BasePyModule" + assert hasattr(instance, "pyfuncs"), "Instance should have pyfuncs" + + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = instance.pytorch_processor(x) + + assert isinstance(result, torch.Tensor) + expected = torch.nn.functional.relu(x) * 2.0 + assert torch.allclose(result, expected, atol=1e-5) + + def test_pyfunc_module_creation_and_execution_gpu(self): + module = TestPyFuncModule + + if tvm.cuda().exist: + device = tvm.cuda(0) + instance = module(device) + + assert isinstance(instance, BasePyModule), "Instance should be BasePyModule" + assert hasattr(instance, "pyfuncs"), "Instance should have pyfuncs" + + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda") + result = instance.pytorch_processor(x) + + assert isinstance(result, torch.Tensor) + assert result.device.type == "cuda" + expected = torch.nn.functional.relu(x) * 2.0 + assert torch.allclose(result, expected, atol=1e-5) + else: + pytest.skip("CUDA not available") + + def test_pyfunc_with_tir_integration(self): + """Test that Python functions can work with TIR functions.""" + module = TestPyFuncModule + + # Create instance + device = tvm.cpu(0) + instance = module(device) + + # Test TIR function execution + n = 5 + input_tensor = torch.randn(n, dtype=torch.float32) + + # Call TIR function - it needs 3 arguments: input, output, and size + # But call_tir handles the output buffer creation, so we only pass input and size + # Note: TIR functions expect TVM types, not Python types + result = instance.call_tir( + instance.simple_tir_func, + [input_tensor], # Only pass input tensor, let call_tir handle the rest + R.Tensor((n,), "float32"), + ) + + # Verify result + assert isinstance(result, torch.Tensor) + assert result.shape == (n,) + assert torch.allclose(result, input_tensor, atol=1e-5) + + def test_pyfunc_decorator_preserves_function_signatures(self): + """Test that @I.pyfunc decorator preserves function signatures.""" + module = TestPyFuncModule + pyfuncs = module.pyfuncs + + # Check function signatures + import inspect + + # pytorch_processor signature + processor_func = pyfuncs["pytorch_processor"] + sig = inspect.signature(processor_func) + params = list(sig.parameters.keys()) + assert len(params) == 1, "pytorch_processor should have 1 parameter" + assert params[0] == "x", "First parameter should be 'x'" + + # pytorch_adder signature + adder_func = pyfuncs["pytorch_adder"] + sig = inspect.signature(adder_func) + params = list(sig.parameters.keys()) + assert len(params) == 2, "pytorch_adder should have 2 parameters" + assert params[0] == "x", "First parameter should be 'x'" + assert params[1] == "y", "Second parameter should be 'y'" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_vm_alloc_storage_with_scope.py b/tests/python/relax/test_vm_alloc_storage_with_scope.py index ec6696000429..3839ae123406 100644 --- a/tests/python/relax/test_vm_alloc_storage_with_scope.py +++ b/tests/python/relax/test_vm_alloc_storage_with_scope.py @@ -67,7 +67,7 @@ def test_alloc_storage_with_scope_global(): dev = tvm.cpu() # This is the important line which tests nd allocator vm_rt = relax.VirtualMachine(lib, dev, memory_cfg="naive") - x = tvm.nd.array(arg0, dev) + x = tvm.runtime.tensor(arg0, dev) vm_rt.set_input("main", x) vm_rt.invoke_stateful("main") output = vm_rt.get_outputs("main").numpy() diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index da8b905193fc..efd2f7ecbf59 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -52,8 +52,8 @@ def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): mod = TestVMCompileStage0 target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target, exec_mode=exec_mode) - inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) - inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp1 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu()) vm["foo"](inp1, inp2) tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) @@ -72,8 +72,8 @@ def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): return y ex = relax.build(mod, exec_mode=exec_mode) - inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) - inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp1 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu()) vm["foo"](inp1, inp2) tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) @@ -90,10 +90,10 @@ def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) -> R.Tensor(["m", "n"], d target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x0 = tvm.nd.array(np.zeros((1, 2)).astype("int32")) - y0 = tvm.nd.array(np.zeros((2, 1)).astype("float32")) - y1 = tvm.nd.array(np.zeros((1, 2)).astype("float32")) - y2 = tvm.nd.array(np.zeros((2, 1, 1)).astype("float32")) + x0 = tvm.runtime.tensor(np.zeros((1, 2)).astype("int32")) + y0 = tvm.runtime.tensor(np.zeros((2, 1)).astype("float32")) + y1 = tvm.runtime.tensor(np.zeros((1, 2)).astype("float32")) + y2 = tvm.runtime.tensor(np.zeros((2, 1, 1)).astype("float32")) vm["foo"](x0, y0) @@ -119,18 +119,18 @@ def foo(x: R.Tensor(dtype="float32")) -> R.Shape: vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) - arr = tvm.nd.array(np.random.rand(*shape).astype("float32")) + arr = tvm.runtime.tensor(np.random.rand(*shape).astype("float32")) res = vm["foo"](arr) assert res[0] == shape[0] * 2 assert res[1] == shape[1] * 3 # dtype mismatch with pytest.raises(ValueError, match=".*dtype.*"): - vm["foo"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + vm["foo"](tvm.runtime.tensor(np.zeros((1, 2)).astype("int32"))) # ndim mismatch with pytest.raises(ValueError, match=".*match_cast.*ndim.*"): - vm["foo"](tvm.nd.array(np.zeros((1,)).astype("float32"))) + vm["foo"](tvm.runtime.tensor(np.zeros((1,)).astype("float32"))) # type mismach with pytest.raises(TypeError): @@ -153,7 +153,7 @@ def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) - inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) res = vm["foo"](inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) @@ -177,7 +177,7 @@ def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) - inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) res = check_saved_func(vm, "foo", inp) tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), rtol=1e-7, atol=1e-7) @@ -217,8 +217,8 @@ def func( ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) - weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + data = tvm.runtime.tensor(np.random.rand(32, 16).astype(np.float32)) + weight = tvm.runtime.tensor(np.random.rand(16, 32).astype(np.float32)) res = check_saved_func(vm, "func", data, weight) expected = np.dot(data.numpy(), weight.numpy()) tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) @@ -265,9 +265,9 @@ def main( ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x = tvm.nd.array(np.zeros((2, 3)).astype(np.int32)) - y = tvm.nd.array(np.zeros((2, 3)).astype(np.int32)) - z = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + x = tvm.runtime.tensor(np.zeros((2, 3)).astype(np.int32)) + y = tvm.runtime.tensor(np.zeros((2, 3)).astype(np.int32)) + z = tvm.runtime.tensor(np.ones((2, 3)).astype(np.int32)) vm.set_input("main", x, y, z) vm.invoke_stateful("main") outs = vm.get_outputs("main") @@ -312,12 +312,12 @@ def main( ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) - y = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + x = tvm.runtime.tensor(np.ones((2, 3)).astype(np.int32)) + y = tvm.runtime.tensor(np.ones((2, 3)).astype(np.int32)) vm.set_input("main", x, y) vm.invoke_stateful("main") out = vm.get_outputs("main") - expected = tvm.nd.array(np.full((2, 3), 2).astype(np.int32)) + expected = tvm.runtime.tensor(np.full((2, 3), 2).astype(np.int32)) assert x == out tvm.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) @@ -342,8 +342,8 @@ def test_vm_emit_te_extern(exec_mode): ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) - weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + data = tvm.runtime.tensor(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.runtime.tensor(np.random.rand(32, 16).astype(np.float32)) res = check_saved_func(vm, "rx_cblas_matmul", data, weight) expected = np.dot(data.numpy(), weight.numpy()) tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) @@ -370,12 +370,12 @@ def te_func(A, B): ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - inp = tvm.nd.array( + inp = tvm.runtime.tensor( np.random.rand( 1, ).astype(np.float32) ) - inp2 = tvm.nd.array( + inp2 = tvm.runtime.tensor( np.random.rand( 2, ).astype(np.float32) @@ -406,13 +406,13 @@ def te_func(A): ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - inp = tvm.nd.array( + inp = tvm.runtime.tensor( np.random.rand( 1, ).astype(np.float32) ) res = check_saved_func(vm, "rx_func", inp) - np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16")) + tvm.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16")) def test_vm_emit_te_floor_symbolic_shape(exec_mode): @@ -435,7 +435,7 @@ def te_func(A): vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (9,) - inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) res = check_saved_func(vm, "rx_func", inp) def expected_output(): @@ -463,7 +463,7 @@ def test_vm_emit_te_constant_param_cpu(exec_mode): dev = tvm.cpu() vm = relax.VirtualMachine(exec, dev) - add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + add_res = check_saved_func(vm, "main", tvm.runtime.tensor(x_np, dev)) tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) @@ -490,7 +490,7 @@ def test_vm_emit_te_constant_param_gpu(exec_mode): dev = tvm.cuda() vm = relax.VirtualMachine(exec, dev) - add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + add_res = check_saved_func(vm, "main", tvm.runtime.tensor(x_np, dev)) tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) @@ -516,8 +516,8 @@ def te_func(A, B): vm = relax.VirtualMachine(ex, tvm.cpu()) shape1 = (5,) shape2 = (3,) - inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32)) - inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(*shape1).astype(np.float32)) + inp2 = tvm.runtime.tensor(np.random.rand(*shape2).astype(np.float32)) res = check_saved_func(vm, "rx_func", inp, inp2) def expected_output(): @@ -667,8 +667,8 @@ def te_func(A): ex.export_library(temp.relpath("exec.so")) vm = relax.VirtualMachine(tvm.runtime.load_module(temp.relpath("exec.so")), tvm.cpu()) - inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) - inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(2).astype(np.float32)) + inp2 = tvm.runtime.tensor(np.random.rand(3).astype(np.float32)) res = check_saved_func(vm, "rx_func", inp, inp2) @@ -693,8 +693,8 @@ def test_vm_tuple(exec_mode): vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (5,) - inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) - inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) + inp2 = tvm.runtime.tensor(np.random.rand(*shape).astype(np.float32)) (res1, res2), res3 = vm["rx_func"](inp, inp2) tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) @@ -722,8 +722,8 @@ def tuple_get_item( target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) - y_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + x_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp) tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7) @@ -754,7 +754,7 @@ def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + x = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) y = vm["main"](x) tvm.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-7, atol=1e-7) @@ -808,8 +808,8 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> target = tvm.target.Target("llvm", host="llvm") ex = relax.build(TestVMSubFunction, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) - y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + x_inp = tvm.runtime.tensor(np.random.rand(32, 32).astype(np.float32)) + y_inp = tvm.runtime.tensor(np.random.rand(32, 32).astype(np.float32)) res = check_saved_func(vm, "main", x_inp, y_inp) product = np.dot(x_inp.numpy(), y_inp.numpy()) expected = product * product @@ -843,7 +843,7 @@ def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: inp = np.empty(1).astype("float32") recursion_runs = np.random.randint(1, 10) inp.fill(recursion_runs) - inp = tvm.nd.array(inp) + inp = tvm.runtime.tensor(inp) res = check_saved_func(vm, "recursion", inp) tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), rtol=1e-7, atol=1e-7) @@ -870,7 +870,7 @@ def foo2( target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + x_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) res_1 = check_saved_func(vm, "foo1", x_inp) res_2 = check_saved_func(vm, "foo2", x_inp) @@ -903,8 +903,8 @@ def main( target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) - y_inp = tvm.nd.array(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], dtype="float32")) + x_inp = tvm.runtime.tensor(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.runtime.tensor(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], dtype="float32")) res = check_saved_func(vm, "main", x_inp, y_inp) tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy()) @@ -921,8 +921,8 @@ def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): target = tvm.target.Target("llvm", host="llvm") ex = relax.build(TestTimeEvaluator, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x = tvm.nd.array(np.random.rand(1).astype("float32")) - y = tvm.nd.array(np.random.rand(1).astype("float32")) + x = tvm.runtime.tensor(np.random.rand(1).astype("float32")) + y = tvm.runtime.tensor(np.random.rand(1).astype("float32")) # ensure we can use time_evaluator with the stateful API vm.set_input("main", x, y) @@ -1054,8 +1054,8 @@ def popen_check(): def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: - a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) - b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) + b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) vm.set_input("main", a, b) vm.invoke_stateful("main") res0 = vm.get_outputs("main") @@ -1067,17 +1067,17 @@ def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> Non tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7, atol=1e-7) - # bug! If you don't bind the NDArray to a var, the memory will get corrupted. + # bug! If you don't bind the Tensor to a var, the memory will get corrupted. # Possibly due to object lifecycles and other FFI issues - a = tvm.nd.array(np.array(2).astype("int32"), device) + a = tvm.runtime.tensor(np.array(2).astype("int32"), device) vm.set_input("test_vm_tuple", a) vm.invoke_stateful("test_vm_tuple") res2 = vm.get_outputs("test_vm_tuple") - # the results are NDArrays wrapped around scalars, - # so we have to get the scalar out of the NDArray + # the results are Tensors wrapped around scalars, + # so we have to get the scalar out of the Tensor assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2) - b = tvm.nd.array(np.array(1).astype("int32"), device) + b = tvm.runtime.tensor(np.array(1).astype("int32"), device) vm.set_input("test_vm_nested_tuple", b) vm.invoke_stateful("test_vm_nested_tuple") res3 = vm.get_outputs("test_vm_nested_tuple") @@ -1088,8 +1088,8 @@ def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> Non def set_input_attempt_stateless(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: # this should fail: once you set inputs, you cannot run statelessly - a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) - b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) + b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) vm.set_input("main", a, b) # must use invoke stateful! vm["main"]() @@ -1102,8 +1102,8 @@ def set_input_attempt_invoke(vm: relax.VirtualMachine, device: tvm.runtime.Devic def set_input_attempt_get(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: # this should fail: you can't get outputs without invoking the function first - a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) - b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) + b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) vm.set_input("main", a, b) _ = vm.get_outputs("main") @@ -1169,16 +1169,16 @@ def main(x: R.Tuple([R.Tensor((32,), "float32"), R.Tensor((32,), "float32")])) - temp = utils.tempdir() vm, device = make_vm(MyMod, exec_mode, temp) device = tvm.cpu(0) - a = tvm.nd.empty((32,), "float32", device=device) - b = tvm.nd.empty((32,), "float32", device=device) + a = tvm.runtime.empty((32,), "float32", device=device) + b = tvm.runtime.empty((32,), "float32", device=device) vm.set_input("main", (a, b)) vm.invoke_stateful("main") def save_function_kwargs_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: # just checking that we can use kwargs for the args when saving a function - a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) - b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) + b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) vm.save_function("main", "saved_main", x=a, w=b) res0 = vm["saved_main"]() tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) @@ -1197,8 +1197,8 @@ def save_function_time_evaluator_trial( vm: relax.VirtualMachine, device: tvm.runtime.Device ) -> None: # just checking that the saved function can be called in the time evaluator - a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) - b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + a = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) + b = tvm.runtime.tensor(np.random.rand(32, 32).astype("float32"), device) vm.save_function("main", "saved_main", a, b) vm.time_evaluator("saved_main", device)() @@ -1292,16 +1292,16 @@ def func_llvm( dev_llvm = tvm.device("llvm") vm_llvm = tvm.relax.VirtualMachine(built, device=dev_llvm) llvm_output = vm_llvm["func_llvm"]( - tvm.nd.array(np_A, dev_llvm), - tvm.nd.array(np_B, dev_llvm), + tvm.runtime.tensor(np_A, dev_llvm), + tvm.runtime.tensor(np_B, dev_llvm), ) dev_cuda = tvm.device("cuda") vm_cuda = tvm.relax.VirtualMachine(built, device=dev_cuda) cuda_output = vm_cuda["func_cuda"]( - tvm.nd.array(np_A, dev_cuda), - tvm.nd.array(np_B, dev_cuda), + tvm.runtime.tensor(np_A, dev_cuda), + tvm.runtime.tensor(np_B, dev_cuda), ) np_C = np_A + np_B diff --git a/tests/python/relax/test_vm_builtin.py b/tests/python/relax/test_vm_builtin.py index 04e2ae1bf339..2bc5e9ea7030 100644 --- a/tests/python/relax/test_vm_builtin.py +++ b/tests/python/relax/test_vm_builtin.py @@ -44,9 +44,9 @@ def foo(x: R.Tensor((3, 5), "float32"), y: R.Tensor((3, 1), "float32")): np_rand = np.random.rand(3, 5).astype(np.float32) # normalize it to get the random prob np_prob = np_rand / np_rand.sum(axis=1, keepdims=True) - nd_prob = tvm.nd.array(np_prob) + nd_prob = tvm.runtime.tensor(np_prob) # special sample to get deterministic results - nd_sample = tvm.nd.array(np.array([[1.0], [0], [1]]).astype(np.float32)) + nd_sample = tvm.runtime.tensor(np.array([[1.0], [0], [1]]).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu()) res = vm["foo"](nd_prob, nd_sample) diff --git a/tests/python/relax/test_vm_callback_function.py b/tests/python/relax/test_vm_callback_function.py index c8f3f2945ede..1014ed98a558 100644 --- a/tests/python/relax/test_vm_callback_function.py +++ b/tests/python/relax/test_vm_callback_function.py @@ -51,7 +51,7 @@ def custom_callback(arr): from_callback = arr np_A = np.arange(16, dtype="int32") - tvm_A = tvm.nd.array(np_A) + tvm_A = tvm.runtime.tensor(np_A) vm["relax_func"](tvm_A, custom_callback) @@ -78,7 +78,7 @@ def relax_func( np_A = np.arange(16, dtype="int32") def custom_callback(): - return tvm.nd.array(np_A) + return tvm.runtime.tensor(np_A) output = vm["relax_func"](custom_callback) diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index dac0f867cefb..9633244c67fb 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -51,7 +51,7 @@ def foo(x: R.Tensor((3, 4), "float32")): mod = TestVMMove target = tvm.target.Target("llvm", host="llvm") ex = codegen(mod, target, exec_mode) - inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu()) res = check_saved_func(vm, "foo", inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) @@ -73,14 +73,14 @@ def foo(x: R.Tensor((3, 4), "float32")): mod = TestVMToDevice target = tvm.target.Target("llvm", host="llvm") ex = codegen(mod, target, exec_mode) - inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp = tvm.runtime.tensor(np.random.rand(3, 4).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu()) res = check_saved_func(vm, "foo", inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) # check the resulting tensor is on cpu:0 assert res.device == tvm.cpu(0) - assert res.device.device_type == 1 - assert res.device.device_id == 0 + assert res.device.dlpack_device_type() == 1 + assert res.device.index == 0 @pytest.mark.parametrize("exec_mode", EXEC_MODE) @@ -100,7 +100,7 @@ def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2, dtype="float3 target = tvm.target.Target("llvm", host="llvm") ex = codegen(mod, target, exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - inp = tvm.nd.array(np.random.rand(3, 4)) + inp = tvm.runtime.tensor(np.random.rand(3, 4)) res = vm["main"](inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy()) @@ -145,14 +145,14 @@ def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor: target = tvm.target.Target("llvm", host="llvm") ex = codegen(mod, target, exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - inp = tvm.nd.array(np.random.rand(3, 4)) - res = vm["ife"](tvm.nd.array(1), inp) + inp = tvm.runtime.tensor(np.random.rand(3, 4)) + res = vm["ife"](tvm.runtime.tensor(1), inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) - res = vm["ife"](tvm.nd.array(True), inp) + res = vm["ife"](tvm.runtime.tensor(True), inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) - res = vm["ife"](tvm.nd.array(0), inp) + res = vm["ife"](tvm.runtime.tensor(0), inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) - res = vm["ife"](tvm.nd.array(False), inp) + res = vm["ife"](tvm.runtime.tensor(False), inp) tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) @@ -171,7 +171,7 @@ def main(x: R.Tensor(ndim=2, dtype="float32")): target = tvm.target.Target("llvm", host="llvm") ex = codegen(mod, target, exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - inp = tvm.nd.array(np.random.rand(2, 3)) + inp = tvm.runtime.tensor(np.random.rand(2, 3)) res0, res1, res2 = vm["main"](inp) tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2])) tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4])) @@ -203,7 +203,7 @@ def main(x: R.Tensor(ndim=2, dtype="float32")): target = tvm.target.Target("llvm", host="llvm") ex = codegen(mod, target, exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - inp = tvm.nd.array(np.random.rand(1, 2)) + inp = tvm.runtime.tensor(np.random.rand(1, 2)) res = vm["main"](inp) tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy()) @@ -262,7 +262,7 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): target = tvm.target.Target("llvm", host="llvm") ex = codegen(mod, target, exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) - x = tvm.nd.array(np.zeros((1, 2)).astype("float32")) + x = tvm.runtime.tensor(np.zeros((1, 2)).astype("float32")) res = vm["main"](x) assert res == tvm.runtime.container.ShapeTuple([2, 1, 2]) @@ -272,11 +272,11 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): # wrong ndim with pytest.raises(ValueError, match=r".*ndim.*"): - vm["main"](tvm.nd.array(np.zeros(1).astype("float32"))) + vm["main"](tvm.runtime.tensor(np.zeros(1).astype("float32"))) # wrong dtype with pytest.raises(ValueError, match=r".*dtype.*"): - vm["main"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + vm["main"](tvm.runtime.tensor(np.zeros((1, 2)).astype("int32"))) @pytest.mark.parametrize("exec_mode", EXEC_MODE) @@ -352,7 +352,7 @@ def main(x: R.Tensor((3, 4), "float32")): vm = relax.VirtualMachine(ex, dev) input_np = np.random.rand(3, 4).astype("float32") - input = tvm.nd.array(input_np, dev) + input = tvm.runtime.tensor(input_np, dev) res = vm["main"](input) expected = input_np.reshape(6, 2) tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7) diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 1026864e4f9b..d04fd6bdab1b 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -101,7 +101,7 @@ def test_vm_run(): dev = tvm.cuda(0) vm = relax.VirtualMachine(ex, dev) x_np = np.random.uniform(size=(16, 16)).astype("float32") - x = tvm.nd.array(x_np, dev) + x = tvm.runtime.tensor(x_np, dev) y = vm["main"](x) y_np = x_np + 1.0 + 1.0 + 1.0 + 1.0 tvm.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) @@ -129,13 +129,13 @@ def test_capture_error_is_recoverable(): target = tvm.target.Target("cuda") dev = tvm.cuda() - @tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True) + @tvm.register_global_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True) def invalid_impl_for_cudagraph(arg_tensor): # Memory allocation/deallocation may not be performed while # capturing a cudaGraph. This passes the warm-up run # performed by "vm.builtin.cuda_graph.run_or_capture", but # throws an exception when the cudaGraph is being captured. - _dummy_workspace = tvm.nd.empty([16], "float16", dev) + _dummy_workspace = tvm.runtime.empty([16], "float16", dev) return arg_tensor @I.ir_module @@ -171,7 +171,7 @@ def main(A: R.Tensor([16], "float16")): built = tvm.compile(Module, target=target) vm = tvm.relax.VirtualMachine(built, dev) - arg = tvm.nd.array(np.arange(16).astype("float16"), dev) + arg = tvm.runtime.tensor(np.arange(16).astype("float16"), dev) with pytest.raises(tvm.TVMError): vm["main"](arg) diff --git a/tests/python/relax/test_vm_execbuilder.py b/tests/python/relax/test_vm_execbuilder.py index 861ec9f8b041..44ca5c20498c 100644 --- a/tests/python/relax/test_vm_execbuilder.py +++ b/tests/python/relax/test_vm_execbuilder.py @@ -31,12 +31,12 @@ def test_vm_execute(): ib.emit_ret(ib.r(2)) ex = ib.get() vm = relax.VirtualMachine(ex, tvm.cpu()) - a = tvm.nd.array( + a = tvm.runtime.tensor( np.random.rand( 4, ) ) - b = tvm.nd.array( + b = tvm.runtime.tensor( np.random.rand( 4, ) @@ -56,12 +56,12 @@ def test_vm_multiple_func(): ib.emit_ret(ib.r(2)) ex = ib.get() vm = relax.VirtualMachine(ex, tvm.cpu()) - a = tvm.nd.array( + a = tvm.runtime.tensor( np.random.rand( 4, ) ) - b = tvm.nd.array( + b = tvm.runtime.tensor( np.random.rand( 4, ) @@ -108,8 +108,8 @@ def test_emit_cache(): s2 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 3])) assert s0 == s1 assert s1 != s2 - y0 = ib.convert_constant(tvm.nd.array(np.array([1, 2, 3]).astype("int32"))) - y1 = ib.convert_constant(tvm.nd.array(np.array([1, 2, 3]).astype("int32"))) + y0 = ib.convert_constant(tvm.runtime.tensor(np.array([1, 2, 3]).astype("int32"))) + y1 = ib.convert_constant(tvm.runtime.tensor(np.array([1, 2, 3]).astype("int32"))) assert y0 == y1 ib.emit_ret(ib.r(0)) @@ -153,7 +153,7 @@ def test_vm_operand(): def test_vm_shapeof(): ib = relax.ExecBuilder() shape = (32, 16) - arr = tvm.nd.array(np.random.rand(*shape)) + arr = tvm.runtime.tensor(np.random.rand(*shape)) with ib.function("main", num_inputs=0): ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(0)) ib.emit_ret(ib.r(0)) @@ -200,12 +200,12 @@ def test_vm_goto(): ib.emit_ret(ib.r(2)) ex = ib.get() vm = relax.VirtualMachine(ex, tvm.cpu()) - a = tvm.nd.array( + a = tvm.runtime.tensor( np.random.rand( 4, ) ) - b = tvm.nd.array( + b = tvm.runtime.tensor( np.random.rand( 4, ) @@ -224,12 +224,12 @@ def test_vm_if(): ib.emit_ret(ib.r(3)) ex = ib.get() vm = relax.VirtualMachine(ex, tvm.cpu()) - a = tvm.nd.array( + a = tvm.runtime.tensor( np.random.rand( 4, ) ) - b = tvm.nd.array( + b = tvm.runtime.tensor( np.random.rand( 4, ) @@ -255,10 +255,10 @@ def test_vm_invoke_closure(): ex = ib.get() vm = relax.VirtualMachine(ex, tvm.cpu()) - w_inp = tvm.nd.array(np.random.rand(2, 3)) - x_inp = tvm.nd.array(np.random.rand(2, 3)) - y_inp = tvm.nd.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) - z_inp = tvm.nd.array(np.random.rand(2, 3)) + w_inp = tvm.runtime.tensor(np.random.rand(2, 3)) + x_inp = tvm.runtime.tensor(np.random.rand(2, 3)) + y_inp = tvm.runtime.tensor([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) + z_inp = tvm.runtime.tensor(np.random.rand(2, 3)) clo = vm["main"](w_inp, x_inp) res = vm.invoke_closure(clo, y_inp, z_inp) tvm.testing.assert_allclose( @@ -280,8 +280,8 @@ def main(inp: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype=" ex = tvm.compile(Module, "llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) - correct_input = tvm.nd.array(np.random.normal(size=(10, 10)).astype("float32")) - incorrect_input = tvm.nd.array(np.random.normal(size=(12, 10)).astype("float32")) + correct_input = tvm.runtime.tensor(np.random.normal(size=(10, 10)).astype("float32")) + incorrect_input = tvm.runtime.tensor(np.random.normal(size=(12, 10)).astype("float32")) try: vm["main"](incorrect_input) diff --git a/tests/python/relax/test_vm_instrument.py b/tests/python/relax/test_vm_instrument.py index 8c4d728da18b..c4d24481ec2d 100644 --- a/tests/python/relax/test_vm_instrument.py +++ b/tests/python/relax/test_vm_instrument.py @@ -81,7 +81,7 @@ def instrument(func, name, before_run, ret_val, *args): return relax.VMInstrumentReturnKind.SKIP_RUN vm.set_instrument(instrument) - vm["main"](tvm.nd.array(data_np)) + vm["main"](tvm.runtime.tensor(data_np)) assert hit_count[("matmul", True)] == 2 assert ("matmul", False) not in hit_count assert hit_count[("relu", True)] == 2 @@ -95,7 +95,7 @@ def test_lib_comparator(): # compare against library module cmp = LibCompareVMInstrument(vm.module.imports[0], tvm.cpu(), verbose=False) vm.set_instrument(cmp) - vm["main"](tvm.nd.array(data_np)) + vm["main"](tvm.runtime.tensor(data_np)) if __name__ == "__main__": diff --git a/tests/python/relax/test_vm_multi_device.py b/tests/python/relax/test_vm_multi_device.py index 91ae8bf79256..018eb7bc3cc6 100644 --- a/tests/python/relax/test_vm_multi_device.py +++ b/tests/python/relax/test_vm_multi_device.py @@ -79,9 +79,9 @@ def foo( np_ipt2 = np.random.rand(4, 5).astype(np.float32) np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2) - ipt0 = tvm.nd.array(np_ipt0, devices[0]) - ipt1 = tvm.nd.array(np_ipt1, devices[0]) - ipt2 = tvm.nd.array(np_ipt2, devices[1]) + ipt0 = tvm.runtime.tensor(np_ipt0, devices[0]) + ipt1 = tvm.runtime.tensor(np_ipt1, devices[0]) + ipt2 = tvm.runtime.tensor(np_ipt2, devices[1]) res = vm["foo"](ipt0, ipt1, ipt2) tvm.testing.assert_allclose(res.numpy(), np_res) @@ -134,10 +134,10 @@ def foo( np_ipt3 = np.random.rand(5, 6).astype(np.float32) np_res = np.matmul(np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2), np_ipt3) - ipt0 = tvm.nd.array(np_ipt0, devices[0]) - ipt1 = tvm.nd.array(np_ipt1, devices[0]) - ipt2 = tvm.nd.array(np_ipt2, devices[1]) - ipt3 = tvm.nd.array(np_ipt3, devices[2]) + ipt0 = tvm.runtime.tensor(np_ipt0, devices[0]) + ipt1 = tvm.runtime.tensor(np_ipt1, devices[0]) + ipt2 = tvm.runtime.tensor(np_ipt2, devices[1]) + ipt3 = tvm.runtime.tensor(np_ipt3, devices[2]) res = vm["foo"](ipt0, ipt1, ipt2, ipt3) tvm.testing.assert_allclose(res.numpy(), np_res) @@ -179,9 +179,9 @@ def foo( np_ipt2 = np.random.rand(4, 5).astype(np.float32) np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2) - ipt0 = tvm.nd.array(np_ipt0, devices[1]) - ipt1 = tvm.nd.array(np_ipt1, devices[1]) - ipt2 = tvm.nd.array(np_ipt2, devices[0]) + ipt0 = tvm.runtime.tensor(np_ipt0, devices[1]) + ipt1 = tvm.runtime.tensor(np_ipt1, devices[1]) + ipt2 = tvm.runtime.tensor(np_ipt2, devices[0]) res = vm["foo"](ipt0, ipt1, ipt2) tvm.testing.assert_allclose(res.numpy(), np_res, rtol=1e-4, atol=1e-4) diff --git a/tests/python/relax/test_vm_profiler.py b/tests/python/relax/test_vm_profiler.py index eaf914560530..cdb27377a587 100644 --- a/tests/python/relax/test_vm_profiler.py +++ b/tests/python/relax/test_vm_profiler.py @@ -55,7 +55,7 @@ def test_conv2d_cpu(): ex = get_exec(data_np.shape) vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True) - report = vm.profile("main", tvm.nd.array(data_np)) + report = vm.profile("main", tvm.runtime.tensor(data_np)) print(report) assert "Duration" in str(report) @@ -76,7 +76,7 @@ def with_rpc(ex, f, data_np): device = remote.cpu() vm = relax.VirtualMachine(rexec, device=device, profile=True) - data = tvm.nd.array(data_np, device) + data = tvm.runtime.tensor(data_np, device) f(vm, data) diff --git a/tests/python/runtime/test_evaluator_with_preproc.py b/tests/python/runtime/test_evaluator_with_preproc.py index fd8f8e95b0bf..208d584e99a5 100644 --- a/tests/python/runtime/test_evaluator_with_preproc.py +++ b/tests/python/runtime/test_evaluator_with_preproc.py @@ -49,9 +49,9 @@ def test_time_evalutor_with_preproc(f_preproc: str): dev = tvm.cuda(0) evaluator = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1, f_preproc=f_preproc) - a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) - b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) - c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev) + a = tvm.runtime.tensor(np.random.rand(128, 128).astype("float32"), device=dev) + b = tvm.runtime.tensor(np.random.rand(128, 128).astype("float32"), device=dev) + c = tvm.runtime.tensor(np.zeros((128, 128)).astype("float32"), device=dev) args = [a, b, c] print("Evaluator (f_preproc={}):\t{:.5f}ms".format(f_preproc, evaluator(*args).mean * 1000)) diff --git a/tests/python/runtime/test_executable.py b/tests/python/runtime/test_executable.py index 571ce7adb2bf..4d6830b8b6a4 100644 --- a/tests/python/runtime/test_executable.py +++ b/tests/python/runtime/test_executable.py @@ -60,9 +60,9 @@ def test_executable_getitem(): add_func = executable["add"] # Verify the function works - a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) - b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) - c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + a = tvm.runtime.tensor(np.array([1.0] * 10, dtype="float32")) + b = tvm.runtime.tensor(np.array([2.0] * 10, dtype="float32")) + c = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) add_func(a, b, c) @@ -87,10 +87,10 @@ def test_executable_jit_already_jitted(): # The module might be different after force recompilation # Verify both modules work correctly - a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) - b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) - c1 = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) - c2 = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + a = tvm.runtime.tensor(np.array([1.0] * 10, dtype="float32")) + b = tvm.runtime.tensor(np.array([2.0] * 10, dtype="float32")) + c1 = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) + c2 = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) jitted_mod1["add"](a, b, c1) jitted_mod3["add"](a, b, c2) @@ -118,9 +118,9 @@ def test_executable_export_library(): assert loaded_mod is not None # Test the loaded module - a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) - b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) - c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + a = tvm.runtime.tensor(np.array([1.0] * 10, dtype="float32")) + b = tvm.runtime.tensor(np.array([2.0] * 10, dtype="float32")) + c = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) loaded_mod["add"](a, b, c) @@ -155,9 +155,9 @@ def test_executable_export_library_with_workspace(): assert loaded_mod is not None # Test the loaded module - a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) - b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) - c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + a = tvm.runtime.tensor(np.array([1.0] * 10, dtype="float32")) + b = tvm.runtime.tensor(np.array([2.0] * 10, dtype="float32")) + c = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) loaded_mod["add"](a, b, c) @@ -190,9 +190,9 @@ def test_executable_integration(): assert add_func is not None # Test the function works - a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) - b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) - c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + a = tvm.runtime.tensor(np.array([1.0] * 10, dtype="float32")) + b = tvm.runtime.tensor(np.array([2.0] * 10, dtype="float32")) + c = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) add_func(a, b, c) @@ -214,7 +214,7 @@ def test_executable_integration(): # Test the loaded module loaded_add = loaded_mod["add"] - c_loaded = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + c_loaded = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) loaded_add(a, b, c_loaded) # Check results @@ -249,9 +249,9 @@ def test_executable_jit_force_recompile(): assert jitted_mod3 is not jitted_mod1 # Test the function works - a = tvm.nd.array(np.array([1.0] * 10, dtype="float32")) - b = tvm.nd.array(np.array([2.0] * 10, dtype="float32")) - c = tvm.nd.array(np.array([0.0] * 10, dtype="float32")) + a = tvm.runtime.tensor(np.array([1.0] * 10, dtype="float32")) + b = tvm.runtime.tensor(np.array([2.0] * 10, dtype="float32")) + c = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) jitted_mod3["add"](a, b, c) diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index 8ee483e5f148..49d1c36f50bc 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -22,7 +22,7 @@ import tvm import tvm.testing -from tvm import nd +import tvm.runtime from tvm.runtime import container as _container diff --git a/tests/python/runtime/test_runtime_dlpack.py b/tests/python/runtime/test_runtime_dlpack.py index 201037c6e469..a5d09ee465a1 100644 --- a/tests/python/runtime/test_runtime_dlpack.py +++ b/tests/python/runtime/test_runtime_dlpack.py @@ -29,7 +29,7 @@ def test_from_dlpack_shape_one(): tgt = tvm.target.Target(target="llvm", host="llvm") rows = 1 - a = tvm.runtime.ndarray.from_dlpack(to_dlpack(torch.randn(rows, 16))) + a = tvm.runtime.from_dlpack(to_dlpack(torch.randn(rows, 16))) A = te.placeholder((rows, 16), name="A") B = te.placeholder((rows, 16), name="B") @@ -39,8 +39,8 @@ def test_from_dlpack_shape_one(): dev = tvm.device(tgt.kind.name, 0) - b = tvm.nd.array(np.random.uniform(size=(rows, 16)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((rows, 16), dtype=C.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=(rows, 16)).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.zeros((rows, 16), dtype=C.dtype), dev) fadd(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) @@ -53,7 +53,7 @@ def test_from_dlpack_strided(): rows = 1 inp = torch.randn(rows, 16) - a = tvm.runtime.ndarray.from_dlpack(to_dlpack(inp)) + a = tvm.runtime.from_dlpack(to_dlpack(inp)) view = a._create_view((2, 8)) np.testing.assert_equal(inp.numpy().reshape(2, 8), view.numpy()) diff --git a/tests/python/runtime/test_runtime_extension.py b/tests/python/runtime/test_runtime_extension.py index 7c7dca51c728..44534a6b4703 100644 --- a/tests/python/runtime/test_runtime_extension.py +++ b/tests/python/runtime/test_runtime_extension.py @@ -32,7 +32,7 @@ def test_dltensor_compatible(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange")) f = tvm.compile(mod, target="llvm") - a = tvm.nd.array(np.zeros(10, dtype=dtype)) + a = tvm.runtime.tensor(np.zeros(10, dtype=dtype)) f(a) np.testing.assert_equal(a.numpy(), np.arange(a.shape[0])) diff --git a/tests/python/runtime/test_runtime_measure.py b/tests/python/runtime/test_runtime_measure.py index ef27feb26398..41271b1ba312 100644 --- a/tests/python/runtime/test_runtime_measure.py +++ b/tests/python/runtime/test_runtime_measure.py @@ -27,7 +27,7 @@ def test_min_repeat_ms(): tmp = tempdir() filename = tmp.relpath("log") - @tvm.register_func + @tvm.register_global_func def my_debug(filename): """one call lasts for 100 ms and writes one character to a file""" time.sleep(0.1) @@ -37,7 +37,7 @@ def my_debug(filename): X = te.compute((), lambda: tvm.tir.call_packed("my_debug", filename)) func = tvm.tir.build(te.create_prim_func([X])) - x = tvm.nd.empty((), dtype="int32") + x = tvm.runtime.empty((), dtype="int32") ftimer = func.time_evaluator(func.entry_name, tvm.cpu(), number=1, repeat=1) ftimer(x) diff --git a/tests/python/runtime/test_runtime_module_load.py b/tests/python/runtime/test_runtime_module_load.py index d22d40f6f2b1..edb7b4f79362 100644 --- a/tests/python/runtime/test_runtime_module_load.py +++ b/tests/python/runtime/test_runtime_module_load.py @@ -34,7 +34,7 @@ path_dso = sys.argv[1] dtype = sys.argv[2] ff = tvm.runtime.load_module(path_dso) -a = tvm.nd.array(np.zeros(10, dtype=dtype)) +a = tvm.runtime.tensor(np.zeros(10, dtype=dtype)) ff(a) np.testing.assert_equal(a.numpy(), np.arange(a.shape[0])) print("Finish runtime checking...") @@ -75,10 +75,10 @@ def save_object(names): f1 = tvm.runtime.load_module(path_dso) f2 = tvm.runtime.load_module(path_ll) - a = tvm.nd.array(np.zeros(10, dtype=dtype)) + a = tvm.runtime.tensor(np.zeros(10, dtype=dtype)) f1(a) np.testing.assert_equal(a.numpy(), np.arange(a.shape[0])) - a = tvm.nd.array(np.zeros(10, dtype=dtype)) + a = tvm.runtime.tensor(np.zeros(10, dtype=dtype)) f2(a) np.testing.assert_equal(a.numpy(), np.arange(a.shape[0])) @@ -124,8 +124,8 @@ def popen_check(): import tvm f1 = tvm.runtime.load_module(path_dso) - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=1024).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev) f1(a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) @@ -140,8 +140,8 @@ def check_c(device): print("Skip because %s is not enabled" % device) return f = tvm.compile(sch.mod, target=tvm.target.Target(device, host="c")) - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=1024).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev) f["main"](a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) @@ -176,8 +176,8 @@ def check_llvm(): m = tvm.runtime.load_module(path_dso) fadd1 = m["myadd1"] fadd2 = m["myadd2"] - a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=nn).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(nn, dtype=A.dtype), dev) fadd1(a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) fadd2(a, b) @@ -207,8 +207,8 @@ def popen_check(): ctypes.CDLL(path_dso) # Load the system wide library mm = tvm.runtime.system_lib() - a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=nn).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(nn, dtype=A.dtype), dev) mm["myadd1"](a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) mm["myadd2"](a, b) diff --git a/tests/python/runtime/test_runtime_nd_array.py b/tests/python/runtime/test_runtime_nd_array.py index 8b30b7bba05c..4ed81de55f0e 100644 --- a/tests/python/runtime/test_runtime_nd_array.py +++ b/tests/python/runtime/test_runtime_nd_array.py @@ -23,9 +23,9 @@ def test_1d_full_view_of_1d_arr(): - """NDArray::CreateView may return the same array""" + """Tensor::CreateView may return the same array""" np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_output = tvm_input._create_view([1024]) np_expected = np_input @@ -34,9 +34,9 @@ def test_1d_full_view_of_1d_arr(): def test_1d_view_of_first_half_of_1d_arr(): - """NDArray::CreateView may return a subset of an array""" + """Tensor::CreateView may return a subset of an array""" np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_output = tvm_input._create_view([512]) np_expected = np_input[0:512] @@ -45,9 +45,9 @@ def test_1d_view_of_first_half_of_1d_arr(): def test_1d_view_of_first_half_of_1d_arr(): - """Subset returned by NDArray::CreateView may have a byte offset""" + """Subset returned by Tensor::CreateView may have a byte offset""" np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_output = tvm_input._create_view([512], relative_byte_offset=512 * 4) np_expected = np_input[512:1024] @@ -58,16 +58,16 @@ def test_1d_view_of_first_half_of_1d_arr(): def test_view_larger_than_original_is_invalid(): """Subset may not be larger than the original array""" np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) - with pytest.raises(ValueError, match="the NDArray being viewed only contains 4096 bytes"): + with pytest.raises(ValueError, match="the Tensor being viewed only contains 4096 bytes"): tvm_input._create_view([2048]) def test_view_entirely_outside_bounds_of_original_is_invalid(): """The byte_offset may not place a view outside the original array""" np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) with pytest.raises(ValueError, match="would occupy bytes 8192 <= i_byte < 12288"): tvm_input._create_view([1024], relative_byte_offset=2048 * 4) @@ -76,14 +76,14 @@ def test_view_entirely_outside_bounds_of_original_is_invalid(): def test_view_partially_outside_bounds_of_original_is_invalid(): """The byte_offset may not place any elements of a view outside the original array""" np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) with pytest.raises(ValueError, match="would occupy bytes 2048 <= i_byte < 6144"): tvm_input._create_view([1024], relative_byte_offset=512 * 4) def test_subview_first_half_of_first_half(): - """NDArray::CreateView be applied to a view + """Tensor::CreateView be applied to a view The first view is at element offset 0 (byte offset 0). The second view is at element offset 0 (byte offset 0) relative to the first @@ -92,7 +92,7 @@ def test_subview_first_half_of_first_half(): """ np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_view = tvm_input._create_view( [512], @@ -108,7 +108,7 @@ def test_subview_first_half_of_first_half(): def test_subview_first_half_of_second_half(): - """NDArray::CreateView be applied to a view + """Tensor::CreateView be applied to a view The first view is at element offset 512 (byte offset 2048). The second view is at element offset 0 (byte offset 0) relative to the @@ -117,7 +117,7 @@ def test_subview_first_half_of_second_half(): """ np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_view = tvm_input._create_view( [512], @@ -133,7 +133,7 @@ def test_subview_first_half_of_second_half(): def test_subview_second_half_of_first_half(): - """NDArray::CreateView be applied to a view + """Tensor::CreateView be applied to a view The first view is at element offset 0 (byte offset 0). The second view is at element offset 256 (byte offset 1024) relative to the @@ -142,7 +142,7 @@ def test_subview_second_half_of_first_half(): """ np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_view = tvm_input._create_view( [512], @@ -158,7 +158,7 @@ def test_subview_second_half_of_first_half(): def test_subview_second_half_of_second_half(): - """NDArray::CreateView be applied to a view + """Tensor::CreateView be applied to a view The first view is at element offset 512 (byte offset 2048). The second view is at element offset 256 (byte offset 1024) relative @@ -167,7 +167,7 @@ def test_subview_second_half_of_second_half(): """ np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_view = tvm_input._create_view( [512], @@ -183,7 +183,7 @@ def test_subview_second_half_of_second_half(): def test_subview_must_be_in_range_of_immediate_parent(): - """Bounds-checking is applied relative to the NDArray + """Bounds-checking is applied relative to the Tensor The first view is at location and covers bytes [0,2048). The subview would occupy bytes [2048, 4096), and raises an error as @@ -191,7 +191,7 @@ def test_subview_must_be_in_range_of_immediate_parent(): """ np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_view = tvm_input._create_view( [512], @@ -206,9 +206,9 @@ def test_subview_must_be_in_range_of_immediate_parent(): def test_2d_view_into_1d_arr(): - """NDArray::CreateView may change the dimensionality of an array""" + """Tensor::CreateView may change the dimensionality of an array""" np_input = np.arange(1024, dtype="int32") - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_output = tvm_input._create_view([32, 32]) np_expected = np_input.reshape(32, 32) @@ -217,9 +217,9 @@ def test_2d_view_into_1d_arr(): def test_2d_full_view_into_2d_arr(): - """NDArray::CreateView may change the shape of an array""" + """Tensor::CreateView may change the shape of an array""" np_input = np.arange(1024, dtype="int32").reshape(32, 32) - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_output = tvm_input._create_view([16, 64]) np_expected = np_input.reshape(16, 64) @@ -228,9 +228,9 @@ def test_2d_full_view_into_2d_arr(): def test_2d_view_of_first_half_of_2d_arr(): - """NDArray::CreateView may return a multi-dimensional view""" + """Tensor::CreateView may return a multi-dimensional view""" np_input = np.arange(1024, dtype="int32").reshape(32, 32) - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_output = tvm_input._create_view([16, 32]) np_expected = np_input[0:16, :] @@ -239,9 +239,9 @@ def test_2d_view_of_first_half_of_2d_arr(): def test_2d_view_of_second_half_of_2d_arr(): - """NDArray::CreateView may return a multi-dimensional view with byte offset""" + """Tensor::CreateView may return a multi-dimensional view with byte offset""" np_input = np.arange(1024, dtype="int32").reshape(32, 32) - tvm_input = tvm.nd.array(np_input) + tvm_input = tvm.runtime.tensor(np_input) tvm_output = tvm_input._create_view([16, 32], relative_byte_offset=32 * 16 * 4) np_expected = np_input[16:32, :] diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index ac8653012ace..627ebbb7d62c 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -53,7 +53,7 @@ # Windows does not support fork so we can enable Windows for testing sys.platform.startswith("win") == False and multiprocessing.get_start_method() != "fork", reason=( - "pytest + multiprocessing spawn method causes tvm.register_func to " + "pytest + multiprocessing spawn method causes tvm.register_global_func to " "not work on the rpc.Server." ), ) @@ -76,8 +76,8 @@ def verify_rpc(remote, target, shape, dtype): f = tvm.compile(te.create_prim_func([A, B]), target=target) dev = remote.cpu(0) - a = tvm.nd.array(np.random.randint(0, 256, size=shape).astype(A.dtype), device=dev) - b = tvm.nd.array(np.zeros(shape).astype(A.dtype), device=dev) + a = tvm.runtime.tensor(np.random.randint(0, 256, size=shape).astype(A.dtype), device=dev) + b = tvm.runtime.tensor(np.zeros(shape).astype(A.dtype), device=dev) temp = utils.tempdir() path_dso = temp.relpath("dev_lib.o") f.write_to_file(path_dso) @@ -133,10 +133,10 @@ def test_rpc_array(): def check_remote(): x = np.ones((3, 4)) - r_cpu = tvm.nd.array(x, remote.cpu(0)) + r_cpu = tvm.runtime.tensor(x, remote.cpu(0)) assert str(r_cpu.device).startswith("remote") np.testing.assert_equal(r_cpu.numpy(), x) - fremote = remote.get_function("rpc.test.remote_array_func") + fremote = remote.get_function("rpc.test.remote_tensor_func") fremote(r_cpu) check_remote() @@ -152,8 +152,8 @@ def check_remote(): dev = remote.cpu(0) a_np = np.ones((5041, 720)).astype("float32") b_np = np.ones((720, 192)).astype("float32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) np.testing.assert_equal(a.numpy(), a_np) np.testing.assert_equal(b.numpy(), b_np) @@ -251,8 +251,8 @@ def check_remote(remote): f.export_library(path_dso) remote.upload(path_dso) f1 = remote.load_module("dev_lib.so") - a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(102, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), dev) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) cost = time_f(a, b).mean print("%g secs/op" % cost) @@ -266,8 +266,8 @@ def check_remote(remote): with open(local_download_path, "wb") as fo: fo.write(remote.download_linked_module("dev_lib.tar")) fupdated = tvm.runtime.load_module(local_download_path) - a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), tvm.cpu(0)) - b = tvm.nd.array(np.zeros(102, dtype=A.dtype), tvm.cpu(0)) + a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), tvm.cpu(0)) + b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), tvm.cpu(0)) fupdated(a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) @@ -289,8 +289,8 @@ def check_minrpc(): dev = remote.cpu(0) f1 = remote.system_lib() - a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(102, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), dev) time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1) cost = time_f(a, b).mean np.testing.assert_equal(b.numpy(), a.numpy() + 1) @@ -325,8 +325,8 @@ def check_remote_link_cl(remote): f.export_library(path_tar) remote.upload(path_tar) fhost = remote.load_module("myadd.tar") - a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), dev) - b = tvm.nd.array(np.zeros(102, dtype=A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=102).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.zeros(102, dtype=A.dtype), dev) fhost(a, b) np.testing.assert_equal(b.numpy(), a.numpy() + 1) @@ -369,7 +369,7 @@ def check_multi_hop(): assert fecho("xyz") == "xyz" assert bytes(fecho(bytearray(b"123"))) == b"123" - nd = tvm.nd.array([1, 2, 3], device=client.cpu(0)) + nd = tvm.runtime.tensor([1, 2, 3], device=client.cpu(0)) assert nd.numpy()[1] == 2 def check_error_handling(): @@ -386,7 +386,7 @@ def check_error_handling(): @tvm.testing.requires_rpc -def test_rpc_return_ndarray(): +def test_rpc_return_tensor(): # start server server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") diff --git a/tests/python/runtime/test_runtime_trace.py b/tests/python/runtime/test_runtime_trace.py index 5093ce930ec3..146db5a06535 100644 --- a/tests/python/runtime/test_runtime_trace.py +++ b/tests/python/runtime/test_runtime_trace.py @@ -24,13 +24,13 @@ def test_trace_default_action(): x = te.placeholder((n, n, n), name="X", dtype="float32") y = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([i, j, k, x[i][j][k]])) f = tvm.compile(te.create_prim_func([x, y]), target="llvm") - xnd = tvm.nd.array(np.ones((n, n, n), dtype=x.dtype)) - ynd = tvm.nd.array(np.zeros((n, n, n), dtype=y.dtype)) + xnd = tvm.runtime.tensor(np.ones((n, n, n), dtype=x.dtype)) + ynd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=y.dtype)) f(xnd, ynd) def test_trace_expr_assign(): - @tvm.register_func("tvm.tir.trace_callback2") + @tvm.register_global_func("tvm.tir.trace_callback2") def trace_buffer(x): return @@ -45,9 +45,9 @@ def check_assign(dtype): ) f = tvm.compile(te.create_prim_func([x, y, z]), "llvm") - xnd = tvm.nd.array(np.ones((n, n, n), dtype=x.dtype)) - ynd = tvm.nd.array(np.zeros((n, n, n), dtype=y.dtype)) - znd = tvm.nd.array(np.zeros((n, n, n), dtype=z.dtype)) + xnd = tvm.runtime.tensor(np.ones((n, n, n), dtype=x.dtype)) + ynd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=y.dtype)) + znd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=z.dtype)) f(xnd, ynd, znd) assert np.array_equal(xnd.numpy(), np.ones((n, n, n))) @@ -59,7 +59,7 @@ def check_assign(dtype): def test_trace_expr_sum_generated(): - @tvm.register_func("tvm.tir.trace_callback3") + @tvm.register_global_func("tvm.tir.trace_callback3") def trace_buffer(x): return @@ -73,9 +73,9 @@ def check_expr_sum(dtype): + tvm.tir.trace([b[i][j][k]], "tvm.tir.trace_callback3"), ) f = tvm.compile(te.create_prim_func([a, b, c])) - xnd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=a.dtype))) - ynd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=b.dtype))) - znd = tvm.nd.array(np.zeros((n, n, n), dtype=c.dtype)) + xnd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=a.dtype))) + ynd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=b.dtype))) + znd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=c.dtype)) f(xnd, ynd, znd) assert np.array_equal(znd.numpy(), xnd.numpy() + ynd.numpy()) @@ -84,7 +84,7 @@ def check_expr_sum(dtype): def test_trace_expr_sum_args(): - @tvm.register_func("tvm.tir.trace_silent") + @tvm.register_global_func("tvm.tir.trace_silent") def silent(*args): return @@ -103,11 +103,11 @@ def check_expr_sum(dtype): + tvm.tir.trace([i, j, k, e[i][j][k]], "tvm.tir.trace_silent"), ) f = tvm.compile(te.create_prim_func([a, b, d, e, c])) - a_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=a.dtype))) - b_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=b.dtype))) - d_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=d.dtype))) - e_nd = tvm.nd.array(np.array(np.ones((n, n, n), dtype=e.dtype))) - c_nd = tvm.nd.array(np.zeros((n, n, n), dtype=c.dtype)) + a_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=a.dtype))) + b_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=b.dtype))) + d_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=d.dtype))) + e_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=e.dtype))) + c_nd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=c.dtype)) f(a_nd, b_nd, d_nd, e_nd, c_nd) assert np.array_equal( c_nd.numpy(), a_nd.numpy() + b_nd.numpy() + d_nd.numpy() + e_nd.numpy() @@ -118,7 +118,7 @@ def check_expr_sum(dtype): def test_trace_expr_sum_custom(): - @tvm.register_func("tvm.tir.trace_callback4") + @tvm.register_global_func("tvm.tir.trace_callback4") def trace_buffer(x): return @@ -134,9 +134,9 @@ def check_expr_sum_custom(dtype): f = tvm.compile(te.create_prim_func([a, b, c])) npa = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=a.dtype) npb = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=a.dtype) - xnd = tvm.nd.array(npa) - ynd = tvm.nd.array(npb) - znd = tvm.nd.array(np.zeros((n, n), dtype=c.dtype)) + xnd = tvm.runtime.tensor(npa) + ynd = tvm.runtime.tensor(npb) + znd = tvm.runtime.tensor(np.zeros((n, n), dtype=c.dtype)) f(xnd, ynd, znd) assert np.array_equal(znd.numpy(), npa + npb) @@ -145,11 +145,11 @@ def check_expr_sum_custom(dtype): def test_trace_can_change_traced_value_int(): - @tvm.register_func("tvm.tir.trace_change_int_first") + @tvm.register_global_func("tvm.tir.trace_change_int_first") def trace_buffer(x): return 13 - @tvm.register_func("tvm.tir.trace_change_int_second") + @tvm.register_global_func("tvm.tir.trace_change_int_second") def trace_buffer(x): return 14 @@ -160,9 +160,9 @@ def check_assign(dtype): z = te.compute(x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_int_second")) f = tvm.compile(te.create_prim_func([x, y, z])) - xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype)) - ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype)) - znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype)) + xnd = tvm.runtime.tensor(np.ones((n,), dtype=x.dtype)) + ynd = tvm.runtime.tensor(np.zeros((n,), dtype=y.dtype)) + znd = tvm.runtime.tensor(np.zeros((n,), dtype=z.dtype)) f(xnd, ynd, znd) check_array_first = np.array([13, 13, 13, 13]) check_array_second = np.array([14, 14, 14, 14]) @@ -174,11 +174,11 @@ def check_assign(dtype): def test_trace_can_change_traced_value_float(): - @tvm.register_func("tvm.tir.trace_change_float_first") + @tvm.register_global_func("tvm.tir.trace_change_float_first") def trace_buffer(x): return 13.0 - @tvm.register_func("tvm.tir.trace_change_float_second") + @tvm.register_global_func("tvm.tir.trace_change_float_second") def trace_buffer(x): return 14.0 @@ -191,9 +191,9 @@ def check_assign(dtype): ) f = tvm.compile(te.create_prim_func([x, y, z]), target="llvm") - xnd = tvm.nd.array(np.ones((n,), dtype=x.dtype)) - ynd = tvm.nd.array(np.zeros((n,), dtype=y.dtype)) - znd = tvm.nd.array(np.zeros((n,), dtype=z.dtype)) + xnd = tvm.runtime.tensor(np.ones((n,), dtype=x.dtype)) + ynd = tvm.runtime.tensor(np.zeros((n,), dtype=y.dtype)) + znd = tvm.runtime.tensor(np.zeros((n,), dtype=z.dtype)) f(xnd, ynd, znd) check_array_first = np.array([13.0, 13.0, 13.0, 13.0]) check_array_second = np.array([14.0, 14.0, 14.0, 14.0]) diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index 686954baade1..d656031ad9cb 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -84,7 +84,7 @@ def my_func(a: T.handle): mod = tvm.compile(my_func, target=target) - A_nd = tvm.nd.array(np.empty((1,), dtype="int32"), device=dev) + A_nd = tvm.runtime.tensor(np.empty((1,), dtype="int32"), device=dev) mod(A_nd) ref = 10000 // (sve_device_vector_length // 32) @@ -109,8 +109,8 @@ def my_func(a: T.handle, b: T.handle): A_np = np.random.uniform(size=(num_elements,)).astype("float32") B_np = np.zeros((num_elements,)).astype("float32") - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_np) @@ -137,8 +137,8 @@ def my_func(a: T.handle, b: T.handle): A_np = np.random.uniform(size=(num_elements,)).astype(dtype) B_np = np.zeros((num_elements,)).astype(dtype) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_np) @@ -159,7 +159,7 @@ def my_func(a: T.handle): mod = tvm.compile(my_func, target=target) A_np = np.zeros((num_elements,)).astype("float32") - A_nd = tvm.nd.array(A_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) mod(A_nd) ref = np.ones((num_elements,)) diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index 4906b219c359..8aa314bd6293 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -24,15 +24,19 @@ def test_all_targets_device_type_verify(): """Consistency verification for all targets' device type""" - all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()] + target_kind_set = set(tvm.target.Target.list_kinds()) + target_kind_set.remove("composite") + all_targets = [tvm.target.Target(t) for t in target_kind_set] for tgt in all_targets: - if tgt.kind.name not in tvm.runtime.Device.DEVICE_NAME_TO_TYPE: + if tgt.kind.name not in tvm.runtime.Device._DEVICE_NAME_TO_TYPE: raise KeyError( - "Cannot find target kind: %s in Device.DEVICE_NAME_TO_TYPE" % tgt.kind.name + "Cannot find target kind: %s in Device._DEVICE_NAME_TO_TYPE" % tgt.kind.name ) - assert tgt.get_target_device_type() == tvm.runtime.Device.DEVICE_NAME_TO_TYPE[tgt.kind.name] + assert ( + tgt.get_target_device_type() == tvm.runtime.Device._DEVICE_NAME_TO_TYPE[tgt.kind.name] + ) def test_target_string_parse(): @@ -347,7 +351,7 @@ def test_canon_multi_target_and_host_5(): def test_canon_multi_target_and_host_6(): """Test `canon_target_and_host` by using TVM Objects""" - cuda_device_type = tvm.device("cuda").device_type + cuda_device_type = tvm.device("cuda").dlpack_device_type() target = {cuda_device_type: Target(target="cuda", host="llvm")} host = None raw_targets_1 = Target.canon_multi_target_and_host(target, host) diff --git a/tests/python/target/test_virtual_device.py b/tests/python/target/test_virtual_device.py index a6434480fa83..4441bab128b8 100644 --- a/tests/python/target/test_virtual_device.py +++ b/tests/python/target/test_virtual_device.py @@ -21,7 +21,7 @@ def test_make_virtual_device_for_device(): virtual_device = tvm.target.VirtualDevice(tvm.device("cuda")) - assert virtual_device.device_type == 2 + assert virtual_device.dlpack_device_type() == 2 # ie kDLCUDA assert virtual_device.virtual_device_id == 0 assert virtual_device.target is None @@ -31,7 +31,7 @@ def test_make_virtual_device_for_device(): def test_make_virtual_device_for_device_and_target(): target = tvm.target.Target("cuda") virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target) - assert virtual_device.device_type == 2 # ie kDLCUDA + assert virtual_device.dlpack_device_type() == 2 # ie kDLCUDA assert virtual_device.target == target assert virtual_device.memory_scope == "" @@ -40,7 +40,7 @@ def test_make_virtual_device_for_device_target_and_memory_scope(): target = tvm.target.Target("cuda") scope = "local" virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target, scope) - assert virtual_device.device_type == 2 # ie kDLCUDA + assert virtual_device.dlpack_device_type() == 2 # ie kDLCUDA assert virtual_device.target == target assert virtual_device.memory_scope == scope diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index b070371b8ac4..426272584bb5 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -352,8 +352,8 @@ def test_constant(): func = te.create_prim_func([C, A]) func = tvm.compile(func) a_np = np.random.uniform(size=(M,)).astype(A.dtype) - c = tvm.nd.array(np.zeros(M, dtype=C.dtype)) - x = func(c, tvm.nd.array(a_np)) + c = tvm.runtime.tensor(np.zeros(M, dtype=C.dtype)) + x = func(c, tvm.runtime.tensor(a_np)) tvm.testing.assert_allclose(a_np + 2, c.numpy()) @@ -367,8 +367,8 @@ def test_data_dependent_access(): a_np = np.random.uniform(size=(10,)).astype(A.dtype) b_np = np.arange(10, dtype=B.dtype) - c = tvm.nd.array(np.zeros(10, dtype=C.dtype)) - func(c, tvm.nd.array(a_np), tvm.nd.array(b_np)) + c = tvm.runtime.tensor(np.zeros(10, dtype=C.dtype)) + func(c, tvm.runtime.tensor(a_np), tvm.runtime.tensor(b_np)) tvm.testing.assert_allclose(a_np[b_np], c.numpy()) @@ -852,7 +852,7 @@ def tir_workload( v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) - for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30): + for rv0, rv1 in T.grid((v_ax2 % 3 * 4 + 16) // 12 + 1, (v_ax3 % 3 * 10 + 40) // 30 + 1): with T.block("adaptive_pool_sum"): v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0) v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1) @@ -870,7 +870,7 @@ def tir_workload( T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) - adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30)) + adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", (v_ax2 % 3 * 4 + 16) // 12 + 1) * T.Cast("float32", (v_ax3 % 3 * 10 + 40) // 30 + 1)) # fmt: on def te_workload(): diff --git a/tests/python/testing/test_tvm_testing_features.py b/tests/python/testing/test_tvm_testing_features.py index 6d394ebeb649..9618113ae3a9 100644 --- a/tests/python/testing/test_tvm_testing_features.py +++ b/tests/python/testing/test_tvm_testing_features.py @@ -49,7 +49,7 @@ def test_all_targets_used(self): assert sorted(self.targets_used) == sorted(self.enabled_targets) def test_all_devices_used(self): - sort_key = lambda dev: (dev.device_type, dev.device_id) + sort_key = lambda dev: (dev.dlpack_device_type(), dev.index) assert sorted(self.devices_used, key=sort_key) == sorted(self.enabled_devices, key=sort_key) targets_with_explicit_list = [] diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index cddc9131f30f..f6e1d2eade24 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -345,5 +345,88 @@ def func(): tvm.tir.analysis.verify_well_formed(mod) +def test_error_message_without_previous_definition_location(): + """Test case 1: Error message without 'It was first defined at' + + This tests the scenario where it == end(), so the error message should contain + 'TIR is ill-formed, due to multiple definitions of variable' but should NOT + contain 'It was first defined at' since the iterator is invalid. + """ + + @T.prim_func(check_well_formed=False) + def func(): + x = T.int32() + + with T.LetStmt(42, var=x): + T.evaluate(x) + + with T.LetStmt(99, var=x): # This should trigger the error + T.evaluate(x) + + with pytest.raises(ValueError) as exc_info: + tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + + error_msg = str(exc_info.value) + + assert "TIR is ill-formed" in error_msg + assert "multiple definitions of variable" in error_msg + + +def test_error_message_with_previous_definition_location(): + """Test case 2: Error message with 'It was first defined at' + + This tests the scenario where it != end(), so the error message should contain + both 'TIR is ill-formed, due to multiple definitions of variable' and should also + contain 'It was first defined at' with the location information. + """ + + @T.prim_func(check_well_formed=False) + def func(): + x = T.int32() + + with T.LetStmt(42, var=x): + with T.LetStmt(99, var=x): # This should trigger the error + T.evaluate(x) + + with pytest.raises(ValueError) as exc_info: + tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + + error_msg = str(exc_info.value) + + assert "TIR is ill-formed" in error_msg + assert "multiple nested definitions of variable" in error_msg + + # should contains location information since it != end() + assert "It was first defined at" in error_msg + assert "was re-defined at" in error_msg + + +def test_sequential_redefinition_with_location(): + """Test case 2b: Sequential redefinition that includes location info + + This tests the previously_defined_ path where it != end() + """ + + @T.prim_func(check_well_formed=False) + def func(): + x = T.int32() + + with T.LetStmt(1, var=x): + T.evaluate(x) + + with T.LetStmt(2, var=x): # This should trigger the error + T.evaluate(x) + + with pytest.raises(ValueError) as exc_info: + tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + + error_msg = str(exc_info.value) + + assert "TIR is ill-formed" in error_msg + assert "multiple definitions of variable" in error_msg + assert "It was first defined at" in error_msg + assert "later re-defined at" in error_msg + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_base.py b/tests/python/tir-base/test_tir_base.py index d204ebfb6084..b23c600b15b8 100644 --- a/tests/python/tir-base/test_tir_base.py +++ b/tests/python/tir-base/test_tir_base.py @@ -14,11 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import tvm import pytest from tvm import tir from tvm.base import TVMError from tvm.ir.transform import PassContext +from tvm.script import tir as T import itertools import pytest @@ -113,6 +115,61 @@ def test_control_flow_jump(): assert out == 1.0 +def test_break_loop(): + @T.prim_func + def func(In: T.Buffer[(2,), "int32"], Out: T.Buffer[(2,), "int32"]): + Out[0] = 0 + Out[1] = 1 + for i in range(10): + for j in range(10): + if i * 10 + j == In[0]: + Out[0] = i + j + break + if Out[0] > 0: + break + while Out[1] > 0: + Out[1] = Out[1] + 1 + if Out[1] > In[1]: + break + + func = build_tir_func(func) + a = np.asarray([49, 8], "int32") + b = np.zeros([2], "int32") + if not hasattr(b, "__dlpack__"): + return + func(a, b) + assert b[0] == 13 + assert b[1] == 9 + + +def test_continue_loop(): + @T.prim_func + def func(Out: T.Buffer[(2,), "int32"]): + T.func_attr({"global_symbol": "main"}) + Out[0] = 0 + Out[1] = 0 + for i in range(10): + for j in range(10): + if (i * 10 + j) % 3 != 0: + continue + Out[0] = Out[0] + 1 + k = T.decl_buffer([], "int32") + k[()] = 0 + while k[()] < Out[0]: + k[()] = k[()] + 1 + if k[()] % 6 == 0: + Out[1] = Out[1] + 1 + continue + + func = build_tir_func(func) + b = np.zeros([2], "int32") + if not hasattr(b, "__dlpack__"): + return + func(b) + assert b[0] == 34 + assert b[1] == 5 # 6, 12, 18, 24, 30 + + def test_exception(): with pytest.raises(TypeError): x = tir.Var(name=1, dtype="int") diff --git a/tests/python/tir-base/test_tir_constructor.py b/tests/python/tir-base/test_tir_constructor.py index 42c2998e27a8..407607055787 100644 --- a/tests/python/tir-base/test_tir_constructor.py +++ b/tests/python/tir-base/test_tir_constructor.py @@ -140,7 +140,7 @@ def test_stmt_constructor(): assert isinstance(x, tvm.tir.AttrStmt) assert x.value.value == 1 - x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), tvm.runtime.convert("hellow"), nop) + x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), tvm.runtime.convert("hellow"), nop) assert isinstance(x, tvm.tir.AssertStmt) assert x.body == nop @@ -150,8 +150,8 @@ def test_stmt_constructor(): assert x.extent.value == 10 assert x.body == nop - buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("uint1"))) - buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var) + buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("bool"))) + buffer = tvm.tir.decl_buffer([16], "bool", data=buffer_var) x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10]) assert isinstance(x, tvm.tir.BufferStore) assert x.buffer == buffer @@ -160,7 +160,7 @@ def test_stmt_constructor(): assert x.value.value == 1 buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) + x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "bool"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var @@ -168,7 +168,7 @@ def test_stmt_constructor(): storage_scope = "global.texture" buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope)) - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) + x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "bool"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var @@ -181,7 +181,7 @@ def test_stmt_constructor(): assert x.attr_key == "xyz" assert x.body == nop - x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), nop) + x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop) assert isinstance(x, tvm.tir.IfThenElse) assert x.then_case.value.value == 11 assert x.else_case == nop diff --git a/tests/python/tir-base/test_tir_imm_values.py b/tests/python/tir-base/test_tir_imm_values.py index 11213e35364c..4ec1674af203 100644 --- a/tests/python/tir-base/test_tir_imm_values.py +++ b/tests/python/tir-base/test_tir_imm_values.py @@ -271,7 +271,7 @@ def float_imm_div(x: T.float32, y: T.float32, z: T.Buffer((), "float32")): def __wrap_build(f): lib = tvm.compile(f, target="llvm") - z = tvm.nd.array(np.zeros([]).astype("float32")) + z = tvm.runtime.tensor(np.zeros([]).astype("float32")) def _func(x, y): lib(x, y, z) diff --git a/tests/python/tir-base/test_tir_index_map.py b/tests/python/tir-base/test_tir_index_map.py index 3ddbd2f69f59..8696a4062668 100644 --- a/tests/python/tir-base/test_tir_index_map.py +++ b/tests/python/tir-base/test_tir_index_map.py @@ -214,12 +214,12 @@ def expected_inverse(i0, i1, i2, i3): assert expected_map.is_equivalent_to(inverse_map) -def test_map_ndarray(): +def test_map_tensor(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) inp = np.arange(16).astype("int8") - out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() + out = index_map.map_tensor(tvm.runtime.tensor(inp)).numpy() ref = np.zeros(out.shape).astype("int8") @@ -232,7 +232,7 @@ def test_map_ndarray(): inp = np.random.randn(10, 10, 10, 10).astype("float16") - out = index_map.map_ndarray(tvm.nd.array(inp)).numpy() + out = index_map.map_tensor(tvm.runtime.tensor(inp)).numpy() ref = np.transpose(inp, (3, 0, 1, 2)) @@ -254,8 +254,8 @@ def test_map_ndarray(): I = 64 O = 64 inp = np.random.randn(kH, kW, I, O).astype("float32") - arr = tvm.nd.array(inp) - out = index_map.map_ndarray(arr).numpy() + arr = tvm.runtime.tensor(inp) + out = index_map.map_tensor(arr).numpy() ref = np.zeros(out.shape).astype("float32") @@ -269,7 +269,7 @@ def test_map_ndarray(): np.testing.assert_equal(ref, out) inverse_map = index_map.inverse(inp.shape) - np.testing.assert_equal(inverse_map.map_ndarray(index_map.map_ndarray(arr)).numpy(), inp) + np.testing.assert_equal(inverse_map.map_tensor(index_map.map_tensor(arr)).numpy(), inp) if __name__ == "__main__": diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index 55f8dbed6c3c..8dabdbb344f3 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -42,8 +42,8 @@ def test_nearbyint(): dev = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(high=100, size=n).astype(A.dtype), dev) - a_rounded = tvm.nd.array(np.random.uniform(size=n).astype(A_rounded.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(high=100, size=n).astype(A.dtype), dev) + a_rounded = tvm.runtime.tensor(np.random.uniform(size=n).astype(A_rounded.dtype), dev) func(a, a_rounded) # Note that numpys rint rounds to nearest integer with # ties to halfway is broken by rounding to even. @@ -66,6 +66,7 @@ def test_round_intrinsics_on_int(): def test_unary_intrin(): test_funcs = [ + (tvm.tir.exp, lambda x: np.exp(x)), (tvm.tir.exp10, lambda x: np.power(10, x)), (tvm.tir.log2, lambda x: np.log2(x)), (tvm.tir.log10, lambda x: np.log10(x)), @@ -97,8 +98,8 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): dev = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=atol, rtol=rtol) @@ -113,11 +114,21 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): np.random.uniform(-2.0, -1.1, size=n // 2), ] ).astype(A.dtype) - a2 = tvm.nd.array(out_np, dev) - b2 = tvm.nd.array(np.empty_like(out_np), dev) + a2 = tvm.runtime.tensor(out_np, dev) + b2 = tvm.runtime.tensor(np.empty_like(out_np), dev) func(a2, b2) # all outputs should be NaN assert np.all(np.isnan(b2.numpy())) + if name == "exp": + n = 8 + out_np = np.random.randint(-20, 20, size=n).astype(A.dtype) + a2 = tvm.runtime.tensor(out_np, dev) + b2 = tvm.runtime.tensor(np.empty_like(out_np), dev) + func(a2, b2) + assert b2.numpy().dtype == np.float32 + # Verify correctness against NumPy exp + expected = np.exp(out_np.astype(np.float32)) + tvm.testing.assert_allclose(b2.numpy(), expected, rtol=1e-5, atol=1e-5) for func in test_funcs: atol = rtol = 1e-3 if func[0].__name__ in ["asin", "acos", "atan"] else 1e-5 @@ -149,9 +160,9 @@ def run_test(tvm_intrin, np_func): dev = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) func(a, b, c) tvm.testing.assert_allclose(c.numpy(), np_func(a.numpy(), b.numpy()), atol=1e-5, rtol=1e-5) @@ -176,9 +187,9 @@ def test_ldexp(): dev = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.randint(0, 5, size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) + b = tvm.runtime.tensor(np.random.randint(0, 5, size=n).astype(B.dtype), dev) + c = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) func(a, b, c) tvm.testing.assert_allclose(c.numpy(), np.ldexp(a.numpy(), b.numpy()), atol=1e-5, rtol=1e-5) @@ -230,8 +241,8 @@ def clz_np(x, dtype): for high in highs: a_np = np.random.randint(1, high=high, size=(n,), dtype=dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(np.zeros((n,)).astype("int32"), dev) func(a, b) ref = clz_np(a_np, dtype) np.testing.assert_equal(b.numpy(), ref) diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 5e1d25e48b0d..85cd726dda7f 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -134,6 +134,7 @@ def test_basic(): def test_stmt(): x = tvm.tir.Evaluate(0) tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x) + tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.UNROLLED, x, step=2) def test_dir(): @@ -302,7 +303,7 @@ def test_isnan(): z = te.var("z", "int32") assert str(tvm.tir.isnan(z)) == "T.bool(False)" k = te.var("k", "int8x2") - assert str(tvm.tir.isnan(k).dtype) == "uint1x2" + assert str(tvm.tir.isnan(k).dtype) == "boolx2" def test_equality(): diff --git a/tests/python/tir-base/test_tir_ops.py b/tests/python/tir-base/test_tir_ops.py index dfa5cbab80c0..cb7d8c597ab9 100644 --- a/tests/python/tir-base/test_tir_ops.py +++ b/tests/python/tir-base/test_tir_ops.py @@ -69,8 +69,8 @@ def test_const_fold3(): x = te.var("x") for val in [0, 1]: for func in [tvm.tir.all, tvm.tir.any]: - check_throws(lambda: func(tvm.tir.const(val, "uint1"), x)) - check_throws(lambda: func(x, tvm.tir.const(val, "uint1"))) + check_throws(lambda: func(tvm.tir.const(val, "bool"), x)) + check_throws(lambda: func(x, tvm.tir.const(val, "bool"))) # Test const folding when both arguments are const for tvm_func, py_func in [ @@ -80,13 +80,13 @@ def test_const_fold3(): for v1 in [0, 1]: for v2 in [0, 1]: tvm.ir.assert_structural_equal( - tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, "uint1")), - tvm.tir.const(py_func(v1, v2), "uint1"), + tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2, "bool")), + tvm.tir.const(py_func(v1, v2), "bool"), ) - x = te.var("x", "uint1") - true = tvm.tir.const(1, "uint1") - false = tvm.tir.const(0, "uint1") + x = te.var("x", "bool") + true = tvm.tir.const(1, "bool") + false = tvm.tir.const(0, "bool") assert tvm.tir.all(x, true).same_as(x) assert tvm.tir.all(true, x).same_as(x) diff --git a/tests/python/tir-base/test_tir_ptx_cp_async.py b/tests/python/tir-base/test_tir_ptx_cp_async.py index d5c029c10138..9e0e18c30781 100644 --- a/tests/python/tir-base/test_tir_ptx_cp_async.py +++ b/tests/python/tir-base/test_tir_ptx_cp_async.py @@ -18,6 +18,7 @@ from tvm.script import tir as T import numpy as np import tvm.testing +import pytest @T.prim_func @@ -55,8 +56,8 @@ def test_ptx_cp_async(): A_np = np.random.rand(32, 128).astype("float16") B_np = np.zeros((32, 128)).astype("float16") dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_np) @@ -94,7 +95,7 @@ def ptx_cp_async_barrier( B[tx, i] = A_shared[tx, i] -@tvm.testing.requires_cuda_compute_version(8) +@tvm.testing.requires_cuda_compute_version(9) def test_ptx_cp_async_barrier(): f = ptx_cp_async_barrier @@ -102,8 +103,8 @@ def test_ptx_cp_async_barrier(): A_np = np.random.rand(32, 128).astype("float16") B_np = np.zeros((32, 128)).astype("float16") dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_np) @@ -143,8 +144,8 @@ def test_ptx_cp_async_bulk(): A_np = np.random.rand(32, 128).astype("float16") B_np = np.zeros((32, 128)).astype("float16") dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_np) diff --git a/tests/python/tir-base/test_tir_ptx_ldmatrix.py b/tests/python/tir-base/test_tir_ptx_ldmatrix.py index 346f9c393fcd..8d4ed399b2e8 100644 --- a/tests/python/tir-base/test_tir_ptx_ldmatrix.py +++ b/tests/python/tir-base/test_tir_ptx_ldmatrix.py @@ -87,8 +87,8 @@ def test_ptx_ldmatrix(): A_mask_np[:16, :16] = A_np[:16, :16] B_np = np.zeros((16, 16)).astype("float16") dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_mask_np) diff --git a/tests/python/tir-base/test_tir_ptx_mma.py b/tests/python/tir-base/test_tir_ptx_mma.py index 8f221d95da32..ad38348efdb4 100644 --- a/tests/python/tir-base/test_tir_ptx_mma.py +++ b/tests/python/tir-base/test_tir_ptx_mma.py @@ -74,9 +74,9 @@ def test_gemm_mma_m8n8k4_row_col_fp64pf64fp64(): C_np = np.zeros([8, 8]).astype("float64") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -150,9 +150,9 @@ def test_gemm_mma_m8n8k4_row_row_fp16fp16fp16(): C_np = np.zeros([16, 16]).astype("float16") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -233,9 +233,9 @@ def test_gemm_mma_m8n8k4_row_row_fp16fp16fp32(): C_np = np.zeros([16, 16]).astype("float32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -304,9 +304,9 @@ def test_gemm_mma_m8n8k16_row_col_s8s8s32(): C_np = np.zeros([8, 8]).astype("int32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -375,9 +375,9 @@ def test_gemm_mma_m8n8k16_row_col_s8u8s32(): C_np = np.zeros([8, 8]).astype("int32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -442,9 +442,9 @@ def test_gemm_mma_m8n8k32_row_col_s4s4s32(): cuda_mod = tvm.compile(sch.mod, target="cuda") ctx = tvm.cuda() - A_tvm = tvm.nd.empty([8, 32], "int4", ctx) - B_tvm = tvm.nd.empty([8, 32], "int4", ctx) - C_tvm = tvm.nd.empty([8, 8], "int32", ctx) + A_tvm = tvm.runtime.empty([8, 32], "int4", ctx) + B_tvm = tvm.runtime.empty([8, 32], "int4", ctx) + C_tvm = tvm.runtime.empty([8, 8], "int32", ctx) cuda_mod(A_tvm, B_tvm, C_tvm) # Currently the correctness is not checked. @@ -505,9 +505,9 @@ def test_gemm_mma_m8n8k32_row_col_s4u4s32(): cuda_mod = tvm.compile(sch.mod, target="cuda") ctx = tvm.cuda() - A_tvm = tvm.nd.empty([8, 32], "int4", ctx) - B_tvm = tvm.nd.empty([8, 32], "uint4", ctx) - C_tvm = tvm.nd.empty([8, 8], "int32", ctx) + A_tvm = tvm.runtime.empty([8, 32], "int4", ctx) + B_tvm = tvm.runtime.empty([8, 32], "uint4", ctx) + C_tvm = tvm.runtime.empty([8, 8], "int32", ctx) cuda_mod(A_tvm, B_tvm, C_tvm) # Currently the correctness is not checked. @@ -574,9 +574,9 @@ def test_gemm_mma_m16n8k8_row_col_fp16fp16fp32(): C_np = np.zeros([16, 8]).astype("float32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -650,9 +650,9 @@ def test_gemm_mma_m16n8k16_row_col_fp16fp16fp16(): C_np = np.zeros([16, 8]).astype("float16") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -726,9 +726,9 @@ def test_gemm_mma_m16n8k16_row_col_fp16fp16fp32(): C_np = np.zeros([16, 8]).astype("float32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -802,9 +802,9 @@ def test_gemm_mma_m16n8k16_row_col_s8s8s32(): C_np = np.zeros([16, 8]).astype("int32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -878,9 +878,9 @@ def test_gemm_mma_m16n8k16_row_col_s8u8s32(): C_np = np.zeros([16, 8]).astype("int32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -954,9 +954,9 @@ def test_gemm_mma_m16n8k32_row_col_s8s8s32(): C_np = np.zeros([16, 8]).astype("int32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -1030,9 +1030,9 @@ def test_gemm_mma_m16n8k32_row_col_s8u8s32(): C_np = np.zeros([16, 8]).astype("int32") ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(C_np, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(C_np, ctx) cuda_mod(A_tvm, B_tvm, C_tvm) @@ -1102,9 +1102,9 @@ def test_gemm_mma_m16n8k64_row_col_s4s4s32(): cuda_mod = tvm.compile(sch.mod, target="cuda") ctx = tvm.cuda() - A_tvm = tvm.nd.empty([16, 64], "int4", ctx) - B_tvm = tvm.nd.empty([8, 64], "int4", ctx) - C_tvm = tvm.nd.empty([16, 8], "int32", ctx) + A_tvm = tvm.runtime.empty([16, 64], "int4", ctx) + B_tvm = tvm.runtime.empty([8, 64], "int4", ctx) + C_tvm = tvm.runtime.empty([16, 8], "int32", ctx) cuda_mod(A_tvm, B_tvm, C_tvm) # Currently the correctness is not checked. @@ -1170,9 +1170,9 @@ def test_gemm_mma_m16n8k64_row_col_s4u4s32(): cuda_mod = tvm.compile(sch.mod, target="cuda") ctx = tvm.cuda() - A_tvm = tvm.nd.empty([16, 64], "int4", ctx) - B_tvm = tvm.nd.empty([8, 64], "uint4", ctx) - C_tvm = tvm.nd.empty([16, 8], "int32", ctx) + A_tvm = tvm.runtime.empty([16, 64], "int4", ctx) + B_tvm = tvm.runtime.empty([8, 64], "uint4", ctx) + C_tvm = tvm.runtime.empty([16, 8], "int32", ctx) cuda_mod(A_tvm, B_tvm, C_tvm) # Currently the correctness is not checked. @@ -1239,9 +1239,9 @@ def test_gemm_mma_m16n8k256_row_col_b1b1s32(): cuda_mod = tvm.compile(sch.mod, target="cuda") ctx = tvm.cuda() - A_tvm = tvm.nd.empty([16, 256], "int1", ctx) - B_tvm = tvm.nd.empty([8, 256], "int1", ctx) - C_tvm = tvm.nd.empty([16, 8], "int32", ctx) + A_tvm = tvm.runtime.empty([16, 256], "int1", ctx) + B_tvm = tvm.runtime.empty([8, 256], "int1", ctx) + C_tvm = tvm.runtime.empty([16, 8], "int32", ctx) cuda_mod(A_tvm, B_tvm, C_tvm) # Currently the correctness is not checked. diff --git a/tests/python/tir-base/test_tir_ptx_mma_sp.py b/tests/python/tir-base/test_tir_ptx_mma_sp.py index d5c6c9a03b45..fef373799b2b 100644 --- a/tests/python/tir-base/test_tir_ptx_mma_sp.py +++ b/tests/python/tir-base/test_tir_ptx_mma_sp.py @@ -283,10 +283,10 @@ def get_meta_m16n8k16_half(mask): meta = get_meta_m16n8k16_half(mask) ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) - meta_tvm = tvm.nd.array(meta, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(np.zeros_like(C_np), ctx) + meta_tvm = tvm.runtime.tensor(meta, ctx) cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) @@ -322,10 +322,10 @@ def get_meta_m16n8k32_half(mask): meta = get_meta_m16n8k32_half(mask) ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) - meta_tvm = tvm.nd.array(meta, ctx) + A_tvm = tvm.runtime.tensor(A_np, ctx) + B_tvm = tvm.runtime.tensor(B_np, ctx) + C_tvm = tvm.runtime.tensor(np.zeros_like(C_np), ctx) + meta_tvm = tvm.runtime.tensor(meta, ctx) cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index 601afd8f164f..01af60724cbb 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -18,7 +18,7 @@ import numpy as np import pytest from tvm import te -from tvm.ffi.access_path import AccessPath +from tvm_ffi.access_path import AccessPath from tvm.script import tir as T, ir as I @@ -120,8 +120,8 @@ def test_prim_func(): func1 = tvm.ir.load_json(tvm.ir.save_json(func0)) tvm.ir.assert_structural_equal(func0, func1) - data0 = tvm.nd.array([1, 2, 3]) - data1 = tvm.nd.array([1, 2, 3]) + data0 = tvm.runtime.tensor([1, 2, 3]) + data1 = tvm.runtime.tensor([1, 2, 3]) # attributes and ndarrays func0 = func0.with_attr("data", data0) func1 = func1.with_attr("data", data1) @@ -174,15 +174,15 @@ def test_prim_func_body_mismatch(): def test_array(): x = np.arange(10) - nx = tvm.nd.array(x) - ny = tvm.nd.array(x) - nz = tvm.nd.array(x.reshape(2, 5)) + nx = tvm.runtime.tensor(x) + ny = tvm.runtime.tensor(x) + nz = tvm.runtime.tensor(x.reshape(2, 5)) assert consistent_equal(nx, ny) assert not consistent_equal(nx, nz) def test_env_func(): - @tvm.register_func("test.sequal.env_func") + @tvm.register_global_func("test.sequal.env_func") def test(x): return x + 1 diff --git a/tests/python/tir-base/test_tir_te_extern_primfunc.py b/tests/python/tir-base/test_tir_te_extern_primfunc.py index 9c375481fe45..1408597fa22e 100644 --- a/tests/python/tir-base/test_tir_te_extern_primfunc.py +++ b/tests/python/tir-base/test_tir_te_extern_primfunc.py @@ -48,8 +48,8 @@ def func_1(A: T.Buffer((16,), "float32"), C: T.Buffer((1,), "float32")): def verify_func_1(module): a_np = np.random.randint(low=-128, high=127, size=(16,)).astype(np.float32) c_np = np.zeros((1,), dtype=np.float32) - a = tvm.nd.array(a_np, device=tvm.cpu(0)) - c = tvm.nd.array(c_np, device=tvm.cpu(0)) + a = tvm.runtime.tensor(a_np, device=tvm.cpu(0)) + c = tvm.runtime.tensor(c_np, device=tvm.cpu(0)) module(a, c) tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1), c.numpy(), rtol=1e-4) @@ -78,9 +78,9 @@ def verify_func_2(module): a_np = np.random.randint(low=-128, high=127, size=(16,)).astype(np.float32) d_np = np.random.randint(low=-128, high=127, size=(2,)).astype(np.float32) c_np = np.zeros((1,), dtype=np.float32) - a = tvm.nd.array(a_np, device=tvm.cpu(0)) - d = tvm.nd.array(d_np, device=tvm.cpu(0)) - c = tvm.nd.array(c_np, device=tvm.cpu(0)) + a = tvm.runtime.tensor(a_np, device=tvm.cpu(0)) + d = tvm.runtime.tensor(d_np, device=tvm.cpu(0)) + c = tvm.runtime.tensor(c_np, device=tvm.cpu(0)) module(c, a, d) tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1 + d_np[0]), c.numpy(), rtol=1e-4) @@ -116,11 +116,11 @@ def verify_func_3(module): c_np = np.zeros((1,), dtype=np.float32) e_np = np.zeros((16,), dtype=np.float32) f_np = np.zeros((16,), dtype=np.float32) - a = tvm.nd.array(a_np, device=tvm.cpu(0)) - d = tvm.nd.array(d_np, device=tvm.cpu(0)) - c = tvm.nd.array(c_np, device=tvm.cpu(0)) - e = tvm.nd.array(e_np, device=tvm.cpu(0)) - f = tvm.nd.array(f_np, device=tvm.cpu(0)) + a = tvm.runtime.tensor(a_np, device=tvm.cpu(0)) + d = tvm.runtime.tensor(d_np, device=tvm.cpu(0)) + c = tvm.runtime.tensor(c_np, device=tvm.cpu(0)) + e = tvm.runtime.tensor(e_np, device=tvm.cpu(0)) + f = tvm.runtime.tensor(f_np, device=tvm.cpu(0)) module(c, a, d, e, f) tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1 + d_np[0]), c.numpy(), rtol=1e-4) @@ -158,11 +158,11 @@ def verify_func_4(module): c_np = np.zeros((1,), dtype=np.float32) e_np = np.zeros((16,), dtype=np.float32) f_np = np.zeros((16,), dtype=np.float32) - a = tvm.nd.array(a_np, device=tvm.cpu(0)) - d = tvm.nd.array(d_np, device=tvm.cpu(0)) - c = tvm.nd.array(c_np, device=tvm.cpu(0)) - e = tvm.nd.array(e_np, device=tvm.cpu(0)) - f = tvm.nd.array(f_np, device=tvm.cpu(0)) + a = tvm.runtime.tensor(a_np, device=tvm.cpu(0)) + d = tvm.runtime.tensor(d_np, device=tvm.cpu(0)) + c = tvm.runtime.tensor(c_np, device=tvm.cpu(0)) + e = tvm.runtime.tensor(e_np, device=tvm.cpu(0)) + f = tvm.runtime.tensor(f_np, device=tvm.cpu(0)) module(c, a, f, d, e) tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1 + d_np[0]), c.numpy(), rtol=1e-4) diff --git a/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py b/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py index c8679843dda6..882a5b72cefa 100644 --- a/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py +++ b/tests/python/tir-schedule/test_tir_schedule_decompose_padding.py @@ -32,9 +32,9 @@ def check_decompose_padding(origin, scheduled, expected, check_run=False): out_buffer = origin.buffer_map[origin.params[1]] in_shape = [int(_) for _ in in_buffer.shape] out_shape = [int(_) for _ in out_buffer.shape] - x = tvm.nd.array(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) - y0 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) - y1 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + x = tvm.runtime.tensor(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) + y0 = tvm.runtime.tensor(np.zeros(out_shape).astype(out_buffer.dtype)) + y1 = tvm.runtime.tensor(np.zeros(out_shape).astype(out_buffer.dtype)) f_origin = tvm.compile(origin) f_scheduled = tvm.compile(scheduled) f_origin(x, y0) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py new file mode 100644 index 000000000000..dc89f9df56a7 --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_bias_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + + +@T.prim_func +def matmul_bias_fp32_before( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j in T.grid(32, 32): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_fp32_expected( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def matmul_bias_multiple_epilogue_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_multiple_epilogue_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj], C[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +def test_fuse_reduction_epilogue_basic(): + sch = tir.Schedule(matmul_bias_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_before) + + +def test_fuse_reduction_epilogue_fp32(): + sch = tir.Schedule(matmul_bias_fp32_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_fp32_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_fp32_before) + + +def test_fuse_reduction_epilogue_numerical_correctness(): + sch_original = tir.Schedule(matmul_bias_before, debug_mask="all") + mod_original = tvm.compile(sch_original.mod["main"], target="llvm") + + sch_fused = tir.Schedule(matmul_bias_before, debug_mask="all") + sch_fused.fuse_reduction_epilogue("multiply", "add") + mod_fused = tvm.compile(sch_fused.mod["main"], target="llvm") + + A_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + B_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + C_np = np.random.randint(-1000, 1000, size=(16, 16), dtype="int32") + + expected = (A_np.astype("int32") @ B_np.T.astype("int32")) + C_np + + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + + mod_original( + tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), tvm.runtime.tensor(C_np), D_original_tvm + ) + + mod_fused( + tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), tvm.runtime.tensor(C_np), D_fused_tvm + ) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + tvm.testing.assert_allclose(D_original, expected, rtol=1e-5) + tvm.testing.assert_allclose(D_fused, expected, rtol=1e-5) + tvm.testing.assert_allclose(D_fused, D_original, rtol=1e-5) + + +def test_fuse_reduction_epilogue_multiple_epilogue(): + sch = tir.Schedule(matmul_bias_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], matmul_bias_multiple_epilogue_expected + ) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py b/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py index 0ea51aaf83aa..6fdd830120ec 100644 --- a/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py +++ b/tests/python/tir-schedule/test_tir_schedule_rolling_buffer.py @@ -38,9 +38,9 @@ def check_rolling_buffer( out_buffer = origin.buffer_map[origin.params[1]] in_shape = [int(_) for _ in in_buffer.shape] out_shape = [int(_) for _ in out_buffer.shape] - x = tvm.nd.array(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) - y0 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) - y1 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + x = tvm.runtime.tensor(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) + y0 = tvm.runtime.tensor(np.zeros(out_shape).astype(out_buffer.dtype)) + y1 = tvm.runtime.tensor(np.zeros(out_shape).astype(out_buffer.dtype)) f_origin = tvm.compile(origin) f_scheduled = tvm.compile(scheduled) f_origin(x, y0) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index d5646f60fb7a..203bf0fea222 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -169,9 +169,9 @@ def run_test( b_np = np.random.randint(-128, 128, (K, N)).astype("int8") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) + c = tvm.runtime.tensor(np.zeros((M, N), dtype=out_dtype), dev) f(a, b, c) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py index c8edaf30fca9..f98c10c8b9e6 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py @@ -146,9 +146,9 @@ def run_test( b_np = np.random.randint(-128, 128, (K, N)).astype("int8") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) + c = tvm.runtime.tensor(np.zeros((M, N), dtype=out_dtype), dev) f(a, b, c) diff --git a/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py b/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py new file mode 100644 index 000000000000..fa46ef36403c --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_annotate_irregular_loop.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for AnnotateIrregularLoop""" + +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T + + +def test_handle_irrgular_unit_loop(): + """Dedicated testcase to check the unitloop with loop jump not simplified""" + + @T.prim_func + def before(A: T.Buffer((10,), "int32")): + for i in T.serial(1): + if A[i] > 5: + break + A[i] = A[i] + 1 + for j in T.serial(1): + if A[j] > 5: + continue + A[j] = A[j] + 1 + for k in T.serial(1): + A[k] = A[k] + 1 + + @T.prim_func + def expected(A: T.Buffer((10,), "int32")): + for i in T.serial(1, annotations={"irregular_loop_mark": 1}): + if A[i] > 5: + break + A[i] = A[i] + 1 + for j in T.serial(1, annotations={"irregular_loop_mark": 1}): + if A[j] > 5: + continue + A[j] = A[j] + 1 + A[0] = A[0] + 1 + + mod = tvm.IRModule.from_expr(before) + mod = tvm.tir.transform.AnnotateIrregularLoop()(mod) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + tvm.ir.assert_structural_equal(mod["before"].with_attr("global_symbol", "expected"), expected) + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tir.transform.AnnotateIrregularLoop() + + +class TestAnnotateLoopWithBreak(BaseCompare): + """Test that loops containing break statements are annotated as irregular.""" + + def before(A: T.Buffer((10,), "int32")): + for i in T.serial(10): + if A[i] > 5: + break + A[i] = A[i] + 1 + + def expected(A: T.Buffer((10,), "int32")): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if A[i] > 5: + break + A[i] = A[i] + 1 + + +class TestAnnotateLoopWithContinue(BaseCompare): + """Test that loops containing continue statements are annotated as irregular.""" + + def before(A: T.Buffer((10,), "int32")): + for i in T.serial(10): + if A[i] < 0: + continue + A[i] = A[i] * 2 + + def expected(A: T.Buffer((10,), "int32")): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if A[i] < 0: + continue + A[i] = A[i] * 2 + + +class TestNestedIrregularBothLoops(BaseCompare): + """Test nested loops where both loops have break/continue.""" + + def before(A: T.Buffer((10, 10), "int32")): + for i in T.serial(10): + if i > 7: + break + for j in T.serial(10): + if A[i, j] < 0: + continue + A[i, j] = A[i, j] + 1 + + def expected(A: T.Buffer((10, 10), "int32")): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if i > 7: + break + for j in T.serial(10, annotations={"irregular_loop_mark": 1}): + if A[i, j] < 0: + continue + A[i, j] = A[i, j] + 1 + + +class TestWhileLoopWithBreak(BaseCompare): + """Test that while loops with break/continue are not annotated (while loops don't have annotations).""" + + def before(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + def expected(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + +class TestBreakInNestedConditional(BaseCompare): + """Test break statement deeply nested in conditional blocks.""" + + def before(A: T.Buffer((10,), "int32"), flag1: T.int32, flag2: T.int32): + for i in T.serial(10): + if flag1 > 0: + if flag2 > 0: + if A[i] > 5: + break + A[i] = A[i] + 1 + + def expected(A: T.Buffer((10,), "int32"), flag1: T.int32, flag2: T.int32): + for i in T.serial(10, annotations={"irregular_loop_mark": 1}): + if flag1 > 0: + if flag2 > 0: + if A[i] > 5: + break + A[i] = A[i] + 1 + + +class TestWhileLoopWithBreakStandalone(BaseCompare): + """Test that while loops with break/continue are not annotated (while loops don't have annotations).""" + + def before(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + def expected(A: T.Buffer((10,), "int32")): + i = T.int32(0) + while i < 10: + if A[i] > 5: + break + A[i] = A[i] + 1 + i = i + 1 + + +class TestNestedIrregularLoopStandalone(BaseCompare): + """Test deeply nested loops with irregular control flow only in innermost loop.""" + + def before(A: T.Buffer((5, 5, 5), "int32")): + for i in T.serial(5): + for j in T.serial(5): + for k in T.serial(5): + if A[i, j, k] > 10: + break + if A[i, j, k] < 0: + continue + A[i, j, k] = A[i, j, k] + 1 + + def expected(A: T.Buffer((5, 5, 5), "int32")): + for i in T.serial(5): + for j in T.serial(5): + for k in T.serial(5, annotations={"irregular_loop_mark": 1}): + if A[i, j, k] > 10: + break + if A[i, j, k] < 0: + continue + A[i, j, k] = A[i, j, k] + 1 + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index fa1aa558b6d0..37e3d34f8c8f 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -44,6 +44,69 @@ def f32tobf16(v): return T.reinterpret("bfloat16", f32tou16(v)) +def test_bf16_simple_store_will_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Cptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16") + C = T.decl_buffer((100,), "bfloat16", data=Cptr) + for i in T.grid(100): + B[i] = A[i] + C[i] = T.exp(B[i]) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Cptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "float32") + C = T.decl_buffer((100,), "bfloat16", data=Cptr) + for i in T.grid(100): + B[i] = bf16tof32(A[i]) + C[i] = f32tobf16(T.exp(B[i])) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("uint16", storage_scope="shared"), + Cptr: T.handle("uint16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "uint16", data=Aptr) + B = T.decl_buffer((100,), "float32") + C = T.decl_buffer((100,), "uint16", data=Cptr) + for i in T.grid(100): + B[i] = u16tof32(A[i]) + C[i] = f32tou16(T.exp(B[i])) + + return After + + target = Target("nvidia/geforce-rtx-2080-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + + def test_bf16_storage_compute_scope_will_legalize(): def get_before(): @tvm.script.ir_module diff --git a/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py new file mode 100644 index 000000000000..6f6d88137c20 --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import tir +from tvm.script import tir as T + + +def test_canonicalize_loop(): + @T.prim_func + def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in range(1, 128, 5): + B[i] = A[i] + 1.0 + + @T.prim_func + def expected(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 26): + B[i * 5 + 1] = A[i * 5 + 1] + 1.0 + + mod = tvm.IRModule.from_expr(before) + mod = tir.transform.CanonicalizeLoop()(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_canonicalize_nested_loop(): + @T.prim_func + def before(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in range(1, 128, 5): + for j in range(2, 128, 3): + B[i, j] = A[i, j] + 1.0 + + @T.prim_func + def expected(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 26): + for j in T.serial(0, 42): + B[i * 5 + 1, j * 3 + 2] = A[i * 5 + 1, j * 3 + 2] + 1.0 + + mod = tvm.IRModule.from_expr(before) + mod = tir.transform.CanonicalizeLoop()(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_canonicalize_negative_step(): + @T.prim_func + def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 127, step=-3): + B[i] = A[i] + 1.0 + + mod = tvm.IRModule.from_expr(before) + with pytest.raises(tvm.error.InternalError): + mod = tir.transform.CanonicalizeLoop()(mod) + + +def test_canonicalize_dynamic_step(): + """Currently we report error for dynamic step since we could not prove it is positive""" + + @T.prim_func + def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"], step: T.int32): + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 128, step=step): + B[i] = A[i] + 1.0 + + mod = tvm.IRModule.from_expr(before) + with pytest.raises(tvm.error.InternalError): + mod = tir.transform.CanonicalizeLoop()(mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index d4c93bb24ae9..0855afcfd64a 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -16,6 +16,7 @@ # under the License. import tvm +import tvm_ffi import tvm.testing from tvm.script import tir as T @@ -147,8 +148,8 @@ def test_inject_async_copy(): A_np = np.random.rand(32, 128).astype(dtype) B_np = np.zeros((32, 128)).astype(dtype) dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_np) @@ -176,9 +177,9 @@ def test_inject_async_copy_shared_dyn(): B_np = np.random.rand(32, 128).astype("float16") C_np = np.zeros((32, 128)).astype("float16") dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) - C_nd = tvm.nd.array(C_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) + C_nd = tvm.runtime.tensor(C_np, device=dev) mod(A_nd, B_nd, C_nd) tvm.testing.assert_allclose(C_nd.numpy(), A_np + B_np) @@ -213,7 +214,7 @@ def ptx_global_to_shared_copy_fp32x1_barrier( B[tx, i] = A_shared[tx, i] -@tvm.testing.requires_cuda +@tvm.testing.requires_cuda_compute_version(9) def test_inject_async_copy_barrier(): dtype = "float32" vec_size = 1 @@ -233,8 +234,8 @@ def test_inject_async_copy_barrier(): A_np = np.random.rand(32, 128).astype(dtype) B_np = np.zeros((32, 128)).astype(dtype) dev = tvm.cuda(0) - A_nd = tvm.nd.array(A_np, device=dev) - B_nd = tvm.nd.array(B_np, device=dev) + A_nd = tvm.runtime.tensor(A_np, device=dev) + B_nd = tvm.runtime.tensor(B_np, device=dev) mod(A_nd, B_nd) tvm.testing.assert_allclose(B_nd.numpy(), A_np) @@ -263,8 +264,8 @@ def test_inject_async_copy_barrier(): extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64]; - A_shared[((int)threadIdx.x)] = 0.000000e+00f; - B_shared[((int)threadIdx.x)] = 0.000000e+00f; + A_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/; + B_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/; __asm__ __volatile__("cp.async.commit_group;"); @@ -401,7 +402,7 @@ def get_original_code(): nonlocal original_code return original_code - @tvm.register_func(func_name, override=True) + @tvm.register_global_func(func_name, override=True) def tvm_callback_cuda_postproc(code, _): nonlocal original_code original_code = code @@ -421,9 +422,9 @@ def tvm_callback_cuda_postproc(code, _): # Restore previous postproc func to avoid impacting other tests if prev_postproc is None: - tvm.ffi.registry.remove_global_func(func_name) + tvm_ffi.registry.remove_global_func(func_name) else: - tvm.register_func(func_name, prev_postproc, override=True) + tvm.register_global_func(func_name, prev_postproc, override=True) @tvm.testing.requires_cuda diff --git a/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py b/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py index c4f2756251c5..697887dc8cbb 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/tir-transform/test_tir_transform_inject_software_pipeline.py @@ -1538,9 +1538,9 @@ def build_and_run(sch): a_np = np.random.uniform(size=(N, K)).astype("float16") b_np = np.random.uniform(size=(K, M)).astype("float16") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + a = tvm.runtime.tensor(a_np, dev) + b = tvm.runtime.tensor(b_np, dev) + c = tvm.runtime.tensor(np.zeros((N, M), dtype="float32"), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) diff --git a/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py b/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py index 5006efba50b2..e8fee40ec173 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py +++ b/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py @@ -19,7 +19,7 @@ from tvm.script import tir as T -@tvm.register_func("tvm.info.mem.global.test_with_head_address") +@tvm.register_global_func("tvm.info.mem.global.test_with_head_address") def mem_info_with_head_address(): return tvm.ir.make_node( "target.MemoryInfo", @@ -30,7 +30,7 @@ def mem_info_with_head_address(): ) -@tvm.register_func("tvm.info.mem.global.test_without_head_address") +@tvm.register_global_func("tvm.info.mem.global.test_without_head_address") def mem_info_without_head_address(): return tvm.ir.make_node( "target.MemoryInfo", diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py b/tests/python/tir-transform/test_tir_transform_lower_intrin.py index f31cf559764d..864b24bc0f51 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py @@ -48,9 +48,9 @@ def make_binds(i): C = te.compute((n,), make_binds) f = tvm.compile(te.create_prim_func([A, B, C]), "llvm") - a = tvm.nd.array(np.array([x for x, y in data], dtype=expr.dtype)) - b = tvm.nd.array(np.array([y for x, y in data], dtype=expr.dtype)) - c = tvm.nd.array(np.zeros(len(data), dtype=expr.dtype)) + a = tvm.runtime.tensor(np.array([x for x, y in data], dtype=expr.dtype)) + b = tvm.runtime.tensor(np.array([y for x, y in data], dtype=expr.dtype)) + c = tvm.runtime.tensor(np.zeros(len(data), dtype=expr.dtype)) f(a, b, c) cref = np.array([fref(x, y) for x, y in data]) np.testing.assert_equal(c.numpy(), cref) diff --git a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py index 410269ffae5c..2ba658b73822 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_lower_match_buffer.py @@ -532,5 +532,36 @@ def test_fail_match_func_param(): _check_fail(fail_match_func_param) +@T.prim_func +def scalar_match_buffer_type_coercion(a: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 8): + with T.block(""): + vi = T.axis.spatial(8, i) + vj = T.axis.spatial(8, j) + T.reads() + T.writes(A[vi, vj]) + # Create scalar match buffer from single element - this triggers type coercion + scalar_buf = T.match_buffer(A[vi, vj], (), offset_factor=1) + scalar_buf[()] = T.float32(1.0) + + +@T.prim_func +def transformed_scalar_match_buffer_type_coercion(a: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 8): + with T.block(""): + vi = T.axis.spatial(8, i) + vj = T.axis.spatial(8, j) + T.reads() + T.writes(A[vi, vj]) + # Scalar match_buffer eliminated, direct assignment + A[vi, vj] = T.float32(1.0) + + +def test_scalar_match_buffer_type_coercion(): + _check(scalar_match_buffer_type_coercion, transformed_scalar_match_buffer_type_coercion) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index 08f377829f1e..180f76a67ecd 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -26,7 +26,7 @@ from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol -@tvm.register_func("tvm.test_matmul") +@tvm.register_global_func("tvm.test_matmul") def my_matmul(a, b, c): c.copyfrom(np.dot(a.numpy(), b.numpy())) @@ -143,7 +143,7 @@ def build_tir(): mod = build_tir() f = tvm.compile(mod, None) - a = tvm.nd.array(np.zeros(2, dtype="float32")) + a = tvm.runtime.tensor(np.zeros(2, dtype="float32")) f(a) tvm.testing.assert_allclose(a.numpy(), expected_value) diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index dd7bd3bf54a2..723584ff5576 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -214,8 +214,8 @@ def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")): built = tvm.compile(func, target="llvm") - A = tvm.nd.array(np.zeros([16], dtype="int32")) - B = tvm.nd.empty([16, 16], "int32", tvm.cpu()) + A = tvm.runtime.tensor(np.zeros([16], dtype="int32")) + B = tvm.runtime.empty([16, 16], "int32", tvm.cpu()) with pytest.raises(tvm.TVMError): built(A, B) @@ -231,8 +231,8 @@ def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")): built = tvm.compile(func, target="llvm") - A = tvm.nd.array(np.zeros([16], dtype="int32")) - B = tvm.nd.empty([16], "int32", tvm.cpu()) + A = tvm.runtime.tensor(np.zeros([16], dtype="int32")) + B = tvm.runtime.empty([16], "int32", tvm.cpu()) with pytest.raises(tvm.TVMError): built(A, B) @@ -261,6 +261,7 @@ def func_without_arg( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_func_without_arg", } ) assert num_args == 0, "func_without_arg: num_args should be 0" @@ -315,6 +316,7 @@ def main( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_main", } ) assert num_args == 1, "main: num_args should be 1" @@ -372,6 +374,7 @@ def main( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_main", } ) assert num_args == 1, "main: num_args should be 1" diff --git a/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py b/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py index 46fd4104544a..617d028c1332 100644 --- a/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_unpacked_api.py @@ -68,7 +68,7 @@ def test_device_setup(mod, target, dev): assert f.body.value == 0 assert f.body.body.node == "default" assert f.body.body.attr_key == "device_type" - assert f.body.body.value == dev.device_type + assert f.body.body.value == dev.dlpack_device_type() def test_no_buffers_no_device_setup(): diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index e8d21a8dc4f9..36500c4d9885 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -26,7 +26,7 @@ def register_mem(scope_tb, max_bits): # Register mem - @tvm.register_func("tvm.info.mem.%s" % scope_tb) + @tvm.register_global_func("tvm.info.mem.%s" % scope_tb) def mem_info_inp_buffer(): return tvm.ir.make_node( "target.MemoryInfo", diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 1dece07ed9dd..8352b116443a 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -20,9 +20,9 @@ import pytest import tvm import tvm.testing +import tvm.runtime from tvm import tir from tvm.ir.base import assert_structural_equal -from tvm.runtime import ndarray from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tir as T @@ -366,7 +366,7 @@ def test_ir_builder_tir_allocate(): # the expected allocate buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local")) ir_expected = tir.Allocate( - buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1) + buffer_var, "float32", [10], tvm.tir.const(1, "bool"), tir.Evaluate(1) ) # Check if the generated ir is expected @@ -388,7 +388,7 @@ def test_ir_builder_tir_allocate_const(): buffer_var, "int32", [10], - ndarray.array(np.asarray(data, "int32")), + tvm.runtime.tensor(np.asarray(data, "int32")), tir.Evaluate(1), annotations={}, ) diff --git a/tests/python/tvmscript/test_tvmscript_ops.py b/tests/python/tvmscript/test_tvmscript_ops.py index 7672b75ec126..0d6beabd7a40 100644 --- a/tests/python/tvmscript/test_tvmscript_ops.py +++ b/tests/python/tvmscript/test_tvmscript_ops.py @@ -81,10 +81,10 @@ def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, sco np_out2[i, j, k] = -1.0 np_out3[i, j] = -1 - in_data = tvm.nd.array(np_data, ctx) - out1 = tvm.nd.array(np_out1, ctx) - out2 = tvm.nd.array(np_out2, ctx) - out3 = tvm.nd.array(np_out3, ctx) + in_data = tvm.runtime.tensor(np_data, ctx) + out1 = tvm.runtime.tensor(np_out1, ctx) + out2 = tvm.runtime.tensor(np_out2, ctx) + out3 = tvm.runtime.tensor(np_out3, ctx) f(in_data, out1, out2, out3, score_threshold, id_index, score_index) tvm.testing.assert_allclose(out1.numpy(), np_out1, rtol=1e-5) tvm.testing.assert_allclose(out2.numpy(), np_out2, rtol=1e-5) @@ -134,8 +134,8 @@ def _check_alloc_zero_dim_buffer(f): np_data = np.zeros(shape=()).astype(dtype) np_out = np.zeros(shape=()).astype(dtype) - tvm_data = tvm.nd.array(np_data, ctx) - tvm_out = tvm.nd.array(np_out, ctx) + tvm_data = tvm.runtime.tensor(np_data, ctx) + tvm_out = tvm.runtime.tensor(np_out, ctx) # np func exection np_inter = np.array(1) @@ -175,7 +175,7 @@ def ceildiv_test(A: T.Buffer(16, "int32")): @tvm.testing.requires_llvm def test_ceildiv(): f = tvm.compile(ceildiv_test, "llvm") - a = tvm.nd.array(np.arange(16).astype("int32")) + a = tvm.runtime.tensor(np.arange(16).astype("int32")) f(a) ref = (np.arange(16) + 3) // 4 tvm.testing.assert_allclose(a.numpy(), ref) diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 68e9adeff267..cc285e9835de 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -18,6 +18,7 @@ import pytest import tvm.testing +import tvm_ffi from tvm.script.parser import tir as T from tvm import ir, tir @@ -326,6 +327,32 @@ def non_starred(a: T.handle, b: T.handle): tvm.ir.assert_structural_equal(starred, non_starred) +def test_tir_loop_steps(): + N = T.Var("N", "int32") + + @T.prim_func(private=True) + def loop_with_steps( + A: T.Buffer((N,)), B: T.Buffer((N,)), C: T.Buffer((N,)), tid: T.int32, v: T.int32 + ): + for i in T.serial(tid, N, step=2): + C[i] = A[i] + B[i] + for i in T.unroll(tid, N, step=3): + C[i] = A[i] + B[i] + for i in T.vectorized(tid, N, step=4): + C[i] = A[i] + B[i] + for i in T.parallel(tid, N, step=5): + C[i] = A[i] + B[i] + for i in T.serial(tid, N, step=v): + C[i] = A[i] + B[i] + + stmts = loop_with_steps.body.seq + assert stmts[0].step == 2 + assert stmts[1].step == 3 + assert stmts[2].step == 4 + assert stmts[3].step == 5 + assert stmts[4].step.name == "v" + + def test_tir_empty_tuple_index(): @T.macro def bar(val): @@ -545,10 +572,10 @@ def expected() -> None: def test_block_annotation_merge(): - def _to_dict(anno: tvm.ffi.container.Map): + def _to_dict(anno: tvm_ffi.container.Map): result = {} for k, v in anno.items(): - result[k] = _to_dict(v) if isinstance(v, tvm.ffi.container.Map) else v + result[k] = _to_dict(v) if isinstance(v, tvm_ffi.container.Map) else v return result @T.prim_func @@ -611,5 +638,71 @@ def expected() -> None: tvm.ir.assert_structural_equal(func, expected) +def test_tir_macro_block_name_suffix(): + @T.macro + def operation(A, idx): + with T.block("op"): + v = T.axis.remap("S", [idx]) + A[v] = A[v] * T.float32(2) + + @T.prim_func(private=True) + def func_w_macro(a: T.handle) -> None: + A = T.match_buffer(a, [10]) + for i in T.serial(0, 10): + operation(A, i) + operation(A, i) + operation(A, i) + + @T.prim_func(private=True) + def expected(a: T.handle) -> None: + A = T.match_buffer(a, [10]) + for i in T.serial(0, 10): + with T.block("op"): + v = T.axis.remap("S", [i]) + A[v] = A[v] * T.float32(2) + with T.block("op_1"): + v = T.axis.remap("S", [i]) + A[v] = A[v] * T.float32(2) + with T.block("op_2"): + v = T.axis.remap("S", [i]) + A[v] = A[v] * T.float32(2) + + tvm.ir.assert_structural_equal(func_w_macro, expected) + + +def test_ifexp(): + @T.prim_func(private=True) + def func(A: T.buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + A[i, j] = i if i < j else j + + @T.prim_func(private=True) + def expected(A: T.buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + A[i, j] = T.if_then_else(i < j, i, j) + + tvm.ir.assert_structural_equal(func, expected) + + +def test_sequence_compare(): + @T.prim_func(private=True) + def tir_func(A: T.Buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + if 0 < i < 128 and 0 < j < 128: + A[i, j] = 1 + else: + A[i, j] = 0 + + @T.prim_func(private=True) + def expected(A: T.buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + if (0 < i and i < 128) and (0 < j and j < 128): + A[i, j] = 1 + else: + A[i, j] = 0 + + tvm.ir.assert_structural_equal(tir_func, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_printer_annotation.py b/tests/python/tvmscript/test_tvmscript_printer_annotation.py index c45c0a91c5c5..74c66fb94cdb 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_annotation.py +++ b/tests/python/tvmscript/test_tvmscript_printer_annotation.py @@ -18,7 +18,7 @@ from typing import Optional import pytest -from tvm.ffi.access_path import AccessPath +from tvm_ffi.access_path import AccessPath from tvm.script import tir as T diff --git a/tests/python/tvmscript/test_tvmscript_printer_doc.py b/tests/python/tvmscript/test_tvmscript_printer_doc.py index 20a705f9ff83..f8e20915fad0 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_doc.py +++ b/tests/python/tvmscript/test_tvmscript_printer_doc.py @@ -21,7 +21,7 @@ import pytest import tvm -from tvm.ffi.access_path import AccessPath +from tvm_ffi.access_path import AccessPath from tvm.script.printer.doc import ( AssertDoc, AssignDoc, diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py index f3a385ca0911..70473954eb9c 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py +++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py @@ -19,7 +19,7 @@ import tvm from tvm.ir import assert_structural_equal -from tvm.ffi.access_path import AccessPath +from tvm_ffi.access_path import AccessPath from tvm.script import ir as I, tir as T diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index be8b03357dde..e4af15807426 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -961,13 +961,13 @@ def test_predicated_buffer_load_store(): buffer_load = tir.BufferLoad( buffer=buffer_map[b], indices=[0, tir.Ramp(0, 4, 4)], - predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4), + predicate=tir.Broadcast(tir.IntImm("bool", 0), 4), ) body = tir.BufferStore( buffer=buffer_map[a], value=buffer_load, indices=[0, tir.Ramp(0, 2, 4)], - predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4), + predicate=tir.Broadcast(tir.IntImm("bool", 0), 4), ) func = tir.PrimFunc( params=[a, b], @@ -1046,5 +1046,34 @@ def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): _assert_print(main, expected_output) +def test_func_with_loop_jumps(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,), "float32") + B = T.match_buffer(b, (4,), "float32") + for i in range(1000): + if i % 13 == 0: + A[1] = A[1] + 1 + continue + if A[0] >= B[0]: + break + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + for i in range(1000): + if i % 13 == 0: + A[1] = A[1] + T.float32(1.0) + T.continue_loop() + if A[0] >= B[0]: + T.break_loop() + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_printer_underlining.py b/tests/python/tvmscript/test_tvmscript_printer_underlining.py index e36e96c77d7f..130bb7f23724 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_underlining.py +++ b/tests/python/tvmscript/test_tvmscript_printer_underlining.py @@ -18,7 +18,7 @@ from typing import Optional import pytest -from tvm.ffi.access_path import AccessPath +from tvm_ffi.access_path import AccessPath from tvm.script.printer.doc import ( ExprStmtDoc, IdDoc, diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 2be2e2e98d81..b3d459b2e67f 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -4002,6 +4002,41 @@ def func( return func +def func_with_loop_jumps(): + @T.prim_func + def func(In: T.Buffer((1,), "int32"), Out: T.Buffer((2,), "int32")): + Out[0] = 0 + Out[1] = 0 + for i in range(1000): + if i % 13 == 0: + Out[1] = Out[1] + 1 + continue + Out[0] = Out[0] + 1 + if Out[0] >= In[0]: + break + + return func + + +def func_with_loop_steps(): + @T.prim_func + def func( + A: T.Buffer((1024,)), B: T.Buffer((1024,)), C: T.Buffer((1024,)), tid: T.int32, v: T.int32 + ): + for i in T.serial(tid, 1024, step=2): + C[i] = A[i] + B[i] + for i in T.unroll(tid, 1024, step=3): + C[i] = A[i] + B[i] + for i in T.vectorized(tid, 1024, step=4): + C[i] = A[i] + B[i] + for i in T.parallel(tid, 1024, step=5): + C[i] = A[i] + B[i] + for i in range(tid, 1024, 6): + C[i] = A[i] + B[i] + + return func + + def op_of_literal(): op_list = [ (T.exp, 0), @@ -4220,6 +4255,8 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return_zero_private, return_zero_private_with_attr, func_attr_with_list, + func_with_loop_jumps, + func_with_loop_steps, *op_of_literal(), *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index 33880539eb5f..df8675704b67 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -506,5 +506,27 @@ def implicit(): assert_structural_equal_ignore_global_symbol(implicit, explicit) +def test_loop_jump_statement(): + """`break` and `continue` evaluates to TIR intrinsics""" + + @T.prim_func + def explicit(): + for i in range(16): + if i % 2 == 0: + T.evaluate(T.continue_loop()) + if i < 15: + T.evaluate(T.break_loop()) + + @T.prim_func + def implicit(): + for i in range(16): + if i % 2 == 0: + continue + if i < 15: + break + + assert_structural_equal_ignore_global_symbol(implicit, explicit) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index e5775c10ec34..8b85a27277e0 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -39,7 +39,7 @@ echo set\(USE_OPENCL ON\) >> config.cmake fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_CPP_RPC ON\) >> config.cmake -echo set\(USE_CPP_RTVM ON\) >> config.cmake +#echo set\(USE_CPP_RTVM ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake @@ -51,8 +51,7 @@ echo set\(USE_OPENCL_GTEST ON\) >> config.cmake echo set\(USE_OPENCL_EXTN_QCOM ON\) >> config.cmake -cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake" \ - -DANDROID_ABI=arm64-v8a \ +cmake -DANDROID_ABI=arm64-v8a \ -DANDROID_PLATFORM=android-28 \ -DCMAKE_SYSTEM_VERSION=1 \ -DCMAKE_FIND_ROOT_PATH="${ADRENO_OPENCL}" \ @@ -62,4 +61,4 @@ cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain. -DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" \ -DMACHINE_NAME="aarch64-linux-gnu" .. -make -j$(nproc) tvm_rpc rtvm opencl-cpptest +make -j$(nproc) tvm_rpc opencl-cpptest diff --git a/tests/scripts/task_python_adreno.sh b/tests/scripts/task_python_adreno.sh index f019cd1eccb1..1714a3c06358 100755 --- a/tests/scripts/task_python_adreno.sh +++ b/tests/scripts/task_python_adreno.sh @@ -57,8 +57,8 @@ trap "{ kill ${TRACKER_PID}; kill ${DEVICE_PID}; cleanup; }" 0 # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install --target=python -v ./3rdparty/tvm-ffi/ exit 0 diff --git a/tests/scripts/task_python_arm_compute_library.sh b/tests/scripts/task_python_arm_compute_library.sh index 1423fb198543..b67724308fce 100755 --- a/tests/scripts/task_python_arm_compute_library.sh +++ b/tests/scripts/task_python_arm_compute_library.sh @@ -23,5 +23,5 @@ source tests/scripts/setup-pytest-env.sh find . -type f -path "*.pyc" | xargs rm -f -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 7b58658bd7c7..bb1fd2d95b8d 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -47,8 +47,8 @@ sphinx_precheck() { clean_files echo "PreCheck sphinx doc generation WARNINGS.." - # setup cython - cd python; python3 setup.py build_ext --inplace; cd .. + # setup tvm-ffi into python folder + python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ pushd docs make clean @@ -126,8 +126,8 @@ clean_files find . -type f -path "*.log" | xargs rm -f find . -type f -path "*.pyc" | xargs rm -f -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ cd docs diff --git a/tests/scripts/task_python_hexagon.sh b/tests/scripts/task_python_hexagon.sh index fd53007a37ce..6d91759805b7 100755 --- a/tests/scripts/task_python_hexagon.sh +++ b/tests/scripts/task_python_hexagon.sh @@ -27,8 +27,8 @@ fi source tests/scripts/setup-pytest-env.sh -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # disable hexagon tests for now exit 0 diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 326743394d2a..a1a0068ac972 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -33,5 +33,5 @@ fi # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ diff --git a/tests/scripts/task_python_nightly.sh b/tests/scripts/task_python_nightly.sh index 42cf343e71ad..af1b6ec3d212 100755 --- a/tests/scripts/task_python_nightly.sh +++ b/tests/scripts/task_python_nightly.sh @@ -20,8 +20,8 @@ set -euxo pipefail source tests/scripts/setup-pytest-env.sh -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index 54170133530d..569ad9b2de4b 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -23,8 +23,8 @@ source tests/scripts/setup-pytest-env.sh # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # NOTE: also set by task_python_unittest_gpuonly.sh. if [ -z "${TVM_UNITTEST_TESTSUITE_NAME:-}" ]; then diff --git a/tests/scripts/task_web_wasm.sh b/tests/scripts/task_web_wasm.sh index 8a08c1ecb58d..c43215549788 100755 --- a/tests/scripts/task_web_wasm.sh +++ b/tests/scripts/task_web_wasm.sh @@ -20,6 +20,9 @@ set -euxo pipefail export PYTHONPATH=`pwd`/python +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ + rm -rf .emscripten_cache cd web make clean diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh index 5a72254924e1..c25cc6ec6597 100755 --- a/tests/scripts/unity/task_python_relax.sh +++ b/tests/scripts/unity/task_python_relax.sh @@ -25,8 +25,8 @@ export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" export TVM_BIND_THREADS=0 export TVM_NUM_THREADS=2 -# setup cython -cd python; python3 setup.py build_ext --inplace; cd .. +# setup tvm-ffi into python folder +python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ # Run Relax tests TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax diff --git a/version.py b/version.py index cf37e645c4a2..a5bc19164c70 100644 --- a/version.py +++ b/version.py @@ -21,6 +21,7 @@ List of affected files: - tvm-root/python/tvm/libinfo.py +- tvm-root/pyproject.toml - tvm-root/include/tvm/runtime/base.h - tvm-root/conda/recipe/meta.yaml - tvm-root/web/package.json @@ -44,7 +45,7 @@ # Two tag formats are supported: # - vMAJ.MIN.PATCH (e.g. v0.8.0) or # - vMAJ.MIN.devN (e.g. v0.8.dev0) -__version__ = "0.22.dev0" +__version__ = "0.23.dev0" # --------------------------------------------------- @@ -175,6 +176,13 @@ def sync_version(pub_ver, local_ver, dry_run): local_ver, dry_run, ) + # pyproject.toml + update( + os.path.join(PROJ_ROOT, "pyproject.toml"), + r"(?<=version = \")[.0-9a-z\+]+", + pub_ver, + dry_run, + ) # Use public version for other parts for now # Note that full git hash is already available in libtvm # C++ header diff --git a/web/.gitignore b/web/.gitignore index 17d59ed10d4b..a746034d5aa4 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -4,5 +4,5 @@ out node_modules build debug -.ndarray_cache +.tensor_cache src/tvmjs_runtime_wasi.js diff --git a/web/Makefile b/web/Makefile index e9d1375fc76c..9f8a7e94b42f 100644 --- a/web/Makefile +++ b/web/Makefile @@ -18,8 +18,8 @@ TVM_ROOT=$(realpath $(shell dirname $(firstword $(MAKEFILE_LIST))))/../ INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ - -I$(TVM_ROOT)/ffi/include\ - -I$(TVM_ROOT)/ffi/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include\ + -I$(TVM_ROOT)/3rdparty/tvm-ffi/include\ + -I$(TVM_ROOT)/3rdparty/tvm-ffi/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include\ -I$(TVM_ROOT)/3rdparty/compiler-rt -I$(TVM_ROOT)/3rdparty/picojson .PHONY: clean all rmtypedep preparetest diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html index 07e6fe87fc95..6bcecfe8661c 100644 --- a/web/apps/browser/rpc_server.html +++ b/web/apps/browser/rpc_server.html @@ -51,12 +51,12 @@ function connectRPC() { const proxyUrl = document.getElementById("proxyUrl").value; const key = document.getElementById("proxyKey").value; - const ndarrayCacheName = document.getElementById("cache-select").value; - let ndarrayCacheUrl = new URL(ndarrayCacheName + "/", document.URL).href; - let ndarrayCacheDevice = document.getElementById("ndarrayCacheDevice").value; + const tensorCacheName = document.getElementById("cache-select").value; + let tensorCacheUrl = new URL(tensorCacheName + "/", document.URL).href; + let tensorCacheDevice = document.getElementById("tensorCacheDevice").value; - if (ndarrayCacheName == "none" || ndarrayCacheName === undefined) { - ndarrayCacheUrl = ""; + if (tensorCacheName == "none" || tensorCacheName === undefined) { + tensorCacheUrl = ""; } // only works for once. @@ -66,7 +66,7 @@ new tvmjs.RPCServer( proxyUrl, key, getImports, customLog, - ndarrayCacheUrl, ndarrayCacheDevice, initProgressCallback, + tensorCacheUrl, tensorCacheDevice, initProgressCallback, tvmjsGlobalEnv.asyncOnRPCServerLoad); } @@ -117,12 +117,12 @@

Options

type="text" value="wasm" />
- NDArrayCache - + TensorCache - CacheDevice - - diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index b33724c722d7..d6b94bda32fe 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -31,6 +31,7 @@ #define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY +#include #include #include #include @@ -252,10 +253,10 @@ class AsyncLocalSession : public LocalSession { std::optional async_wait_; // time evaluator - ffi::Function GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, - int device_id, int number, int repeat, int min_repeat_ms, - int limit_zero_time_iterations, int cooldown_interval_ms, - int repeats_to_cooldown) { + ffi::Function GetTimeEvaluator(ffi::Optional opt_mod, std::string name, + int device_type, int device_id, int number, int repeat, + int min_repeat_ms, int limit_zero_time_iterations, + int cooldown_interval_ms, int repeats_to_cooldown) { Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; @@ -302,12 +303,12 @@ class AsyncLocalSession : public LocalSession { } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("wasm.LocalSession", []() { return CreateRPCSessionModule(std::make_shared()); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 31f494322684..c1839947bacf 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -30,6 +30,7 @@ #define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY +#include #include #include @@ -38,7 +39,6 @@ #include "src/runtime/device_api.cc" #include "src/runtime/file_utils.cc" #include "src/runtime/logging.cc" -#include "src/runtime/ndarray.cc" #include "src/runtime/profiling.cc" #include "src/runtime/rpc/rpc_channel.cc" #include "src/runtime/rpc/rpc_endpoint.cc" @@ -46,19 +46,22 @@ #include "src/runtime/rpc/rpc_local_session.cc" #include "src/runtime/rpc/rpc_module.cc" #include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/tensor.cc" #include "src/runtime/workspace_pool.cc" // relax setup -#include "ffi/src/ffi/container.cc" -#include "ffi/src/ffi/dtype.cc" -#include "ffi/src/ffi/error.cc" -#include "ffi/src/ffi/extra/library_module.cc" -#include "ffi/src/ffi/extra/library_module_system_lib.cc" -#include "ffi/src/ffi/extra/module.cc" -#include "ffi/src/ffi/extra/testing.cc" -#include "ffi/src/ffi/function.cc" -#include "ffi/src/ffi/ndarray.cc" -#include "ffi/src/ffi/object.cc" -#include "ffi/src/ffi/traceback.cc" +#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc" +#include "3rdparty/tvm-ffi/src/ffi/container.cc" +#include "3rdparty/tvm-ffi/src/ffi/dtype.cc" +#include "3rdparty/tvm-ffi/src/ffi/error.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc" +#include "3rdparty/tvm-ffi/src/ffi/function.cc" +#include "3rdparty/tvm-ffi/src/ffi/object.cc" +#include "3rdparty/tvm-ffi/src/ffi/tensor.cc" +#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc" #include "src/runtime/memory/memory_manager.cc" #include "src/runtime/nvtx.cc" #include "src/runtime/vm/attn_backend.cc" @@ -67,9 +70,9 @@ #include "src/runtime/vm/executable.cc" #include "src/runtime/vm/kv_state.cc" #include "src/runtime/vm/lm_support.cc" -#include "src/runtime/vm/ndarray_cache_support.cc" #include "src/runtime/vm/paged_kv_cache.cc" #include "src/runtime/vm/rnn_state.cc" +#include "src/runtime/vm/tensor_cache_support.cc" #include "src/runtime/vm/vm.cc" // --- Implementations of backend and wasm runtime API. --- @@ -105,50 +108,59 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvmjs.testing.call", [](ffi::PackedArgs args, ffi::Any* ret) { (args[0].cast()).CallPacked(args.Slice(1), ret); }) - .def_packed("tvmjs.testing.log_info_str", - [](ffi::PackedArgs args, ffi::Any* ret) { LOG(INFO) << args[0].cast(); }) + .def_packed( + "tvmjs.testing.log_info_str", + [](ffi::PackedArgs args, ffi::Any* ret) { LOG(INFO) << args[0].cast(); }) .def("tvmjs.testing.add_one", [](int x) { return x + 1; }) .def_packed("tvmjs.testing.wrap_callback", [](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); *ret = ffi::TypedFunction([pf]() { pf(); }); }); -}); +} -void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) { +void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::string& format, + const std::string& dtype) { + ICHECK_NE(bytes, nullptr); + const char* byte_data = bytes->data; + const size_t byte_size = bytes->size; if (format == "f32-to-bf16" && dtype == "float32") { - std::vector buffer(bytes.length() / 2); - std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2); - // decode bf16 to f32 - const uint16_t* bf16 = reinterpret_cast(buffer.data()); + const uint16_t* bf16 = reinterpret_cast(byte_data); uint32_t* data = static_cast(cpu_arr->data); ICHECK(cpu_arr.IsContiguous()); size_t size = 1; for (int i = 0; i < cpu_arr->ndim; ++i) { size *= cpu_arr->shape[i]; } - ICHECK_EQ(size, bytes.length() / 2); + ICHECK_EQ(size, byte_size / 2); for (size_t i = 0; i < size; ++i) { data[i] = static_cast(bf16[i]) << 16; } } else { - cpu_arr.CopyFromBytes(bytes.data(), bytes.length()); + cpu_arr.CopyFromBytes(byte_data, byte_size); } } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tvmjs.array.decode_storage", ArrayDecodeStorage); -}); + refl::GlobalDef().def_packed( + "tvmjs.array.decode_storage", [](ffi::PackedArgs args, ffi::Any* ret) { + Tensor cpu_arr = args[0].cast(); + TVMFFIByteArray* bytes = args[1].cast(); + std::string format = args[2].cast().operator std::string(); + std::string dtype = args[3].cast().operator std::string(); + ArrayDecodeStorage(cpu_arr, bytes, format, dtype); + }); +} // Concatenate n TVMArrays -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvmjs.runtime.ArrayConcat", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -162,11 +174,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ data.push_back(arr_i->at(j)); } } - *ret = Array(data); + *ret = ffi::Array(data); }); -}); +} -NDArray ConcatEmbeddings(const std::vector& embeddings) { +Tensor ConcatEmbeddings(const std::vector& embeddings) { // Get output shape int64_t hidden_size = embeddings[0]->shape[1]; DLDataType dtype = embeddings[0]->dtype; @@ -182,7 +194,7 @@ NDArray ConcatEmbeddings(const std::vector& embeddings) { std::vector shape; shape.push_back(seqLen); shape.push_back(hidden_size); - NDArray result = NDArray::Empty(shape, dtype, device); + Tensor result = Tensor::Empty(shape, dtype, device); // Copy int offset = 0; @@ -193,36 +205,36 @@ NDArray ConcatEmbeddings(const std::vector& embeddings) { copy_dst.shape = embeddings[i]->shape; copy_dst.byte_offset = offset * hidden_size * ((embeddings[i]->dtype.bits * embeddings[i]->dtype.lanes + 7) / 8); - NDArray::CopyFromTo(©_src, ©_dst); + Tensor::CopyFromTo(©_src, ©_dst); offset += embeddings[i]->shape[0]; } return result; } -// Concatenate n NDArrays -TVM_FFI_STATIC_INIT_BLOCK({ +// Concatenate n Tensors +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tvmjs.runtime.ConcatEmbeddings", [](ffi::PackedArgs args, ffi::Any* ret) { - std::vector embeddings; + std::vector embeddings; for (int i = 0; i < args.size(); ++i) { - embeddings.push_back(args[i].cast()); + embeddings.push_back(args[i].cast()); } - NDArray result = ConcatEmbeddings(std::move(embeddings)); + Tensor result = ConcatEmbeddings(std::move(embeddings)); *ret = result; }) - .def("tvmjs.runtime.NDArrayCopyFromBytes", - [](NDArray nd, TVMFFIByteArray* bytes) { nd.CopyFromBytes(bytes->data, bytes->size); }) - .def("tvmjs.runtime.NDArrayCopyToBytes", [](NDArray nd) -> ffi::Bytes { - size_t size = GetDataSize(*(nd.operator->())); + .def("tvmjs.runtime.TensorCopyFromBytes", + [](Tensor nd, TVMFFIByteArray* bytes) { nd.CopyFromBytes(bytes->data, bytes->size); }) + .def("tvmjs.runtime.TensorCopyToBytes", [](Tensor nd) -> ffi::Bytes { + size_t size = ffi::GetDataSize(*(nd.operator->())); std::string bytes; bytes.resize(size); nd.CopyToBytes(bytes.data(), size); return ffi::Bytes(bytes); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index eb14a7b7d7ee..03d08f731b95 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -165,7 +165,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { const char* kind() const final { return "webgpu"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { // special function if (name == "webgpu.get_fmap") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { @@ -211,7 +211,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { ffi::Bytes SaveToBytes() const final { LOG(FATAL) << "Not implemented"; } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { // can only return source code. return source_; } @@ -237,11 +237,11 @@ ffi::Module WebGPUModuleLoadFromBytes(const ffi::Bytes& bytes) { stream->Read(&fmap); stream->Read(&smap); - return ffi::Module(make_object(smap, fmap)); + return ffi::Module(ffi::make_object(smap, fmap)); } // for now webgpu is hosted via a vulkan module. -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ffi.Module.load_from_bytes.webgpu", WebGPUModuleLoadFromBytes) @@ -249,7 +249,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ DeviceAPI* ptr = WebGPUDeviceAPI::Global(); *rv = static_cast(ptr); }); -}); +} } // namespace runtime } // namespace tvm diff --git a/web/package-lock.json b/web/package-lock.json index 79ea7dfecd62..5297ab6104a9 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,13 +1,17 @@ { "name": "tvmjs", - "version": "0.22.0-dev0", + "version": "0.23.0-dev1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.22.0-dev0", + "version": "0.23.0-dev1", "license": "Apache-2.0", + "dependencies": { + "audit": "^0.0.6", + "fix": "^0.0.6" + }, "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", "@rollup/plugin-node-resolve": "^13.0.4", @@ -26,61 +30,50 @@ "ws": "^7.2.5" } }, - "node_modules/@ampproject/remapping": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.2.0.tgz", - "integrity": "sha512-qRmjj8nj9qmLTQXXmaR1cck3UXSRMPrbsLJAasZpF+t3riI71BXed5ebIOYwQntykeZuhjsdweEc9BxH5Jc26w==", - "dev": true, - "dependencies": { - "@jridgewell/gen-mapping": "^0.1.0", - "@jridgewell/trace-mapping": "^0.3.9" - }, - "engines": { - "node": ">=6.0.0" - } - }, "node_modules/@babel/code-frame": { - "version": "7.18.6", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.18.6.tgz", - "integrity": "sha512-TDCmlK5eOvH+eH7cdAFlNXeVJqWIQ7gW9tY1GJIpUtFb6CmjVyq2VM3u71bOyR8CRihcCgMUYoDNyLXao3+70Q==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", "dev": true, "dependencies": { - "@babel/highlight": "^7.18.6" + "@babel/helper-validator-identifier": "^7.27.1", + "js-tokens": "^4.0.0", + "picocolors": "^1.1.1" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/compat-data": { - "version": "7.20.5", - "resolved": "https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.20.5.tgz", - "integrity": "sha512-KZXo2t10+/jxmkhNXc7pZTqRvSOIvVv/+lJwHS+B2rErwOyjuVRh60yVpb7liQ1U5t7lLJ1bz+t8tSypUZdm0g==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.28.5.tgz", + "integrity": "sha512-6uFXyCayocRbqhZOB+6XcuZbkMNimwfVGFji8CTZnCzOHVGvDqzvitu1re2AU5LROliz7eQPhB8CpAMvnx9EjA==", "dev": true, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/core": { - "version": "7.20.5", - "resolved": "https://registry.npmjs.org/@babel/core/-/core-7.20.5.tgz", - "integrity": "sha512-UdOWmk4pNWTm/4DlPUl/Pt4Gz4rcEMb7CY0Y3eJl5Yz1vI8ZJGmHWaVE55LoxRjdpx0z259GE9U5STA9atUinQ==", - "dev": true, - "dependencies": { - "@ampproject/remapping": "^2.1.0", - "@babel/code-frame": "^7.18.6", - "@babel/generator": "^7.20.5", - "@babel/helper-compilation-targets": "^7.20.0", - "@babel/helper-module-transforms": "^7.20.2", - "@babel/helpers": "^7.20.5", - "@babel/parser": "^7.20.5", - "@babel/template": "^7.18.10", - "@babel/traverse": "^7.20.5", - "@babel/types": "^7.20.5", - "convert-source-map": "^1.7.0", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/core/-/core-7.28.5.tgz", + "integrity": "sha512-e7jT4DxYvIDLk1ZHmU/m/mB19rex9sv0c2ftBtjSBv+kVM/902eh0fINUzD7UwLLNR+jU585GxUJ8/EBfAM5fw==", + "dev": true, + "dependencies": { + "@babel/code-frame": "^7.27.1", + "@babel/generator": "^7.28.5", + "@babel/helper-compilation-targets": "^7.27.2", + "@babel/helper-module-transforms": "^7.28.3", + "@babel/helpers": "^7.28.4", + "@babel/parser": "^7.28.5", + "@babel/template": "^7.27.2", + "@babel/traverse": "^7.28.5", + "@babel/types": "^7.28.5", + "@jridgewell/remapping": "^2.3.5", + "convert-source-map": "^2.0.0", "debug": "^4.1.0", "gensync": "^1.0.0-beta.2", - "json5": "^2.2.1", - "semver": "^6.3.0" + "json5": "^2.2.3", + "semver": "^6.3.1" }, "engines": { "node": ">=6.9.0" @@ -90,228 +83,158 @@ "url": "https://opencollective.com/babel" } }, + "node_modules/@babel/core/node_modules/convert-source-map": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz", + "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==", + "dev": true + }, "node_modules/@babel/core/node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, "bin": { "semver": "bin/semver.js" } }, "node_modules/@babel/generator": { - "version": "7.20.5", - "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.20.5.tgz", - "integrity": "sha512-jl7JY2Ykn9S0yj4DQP82sYvPU+T3g0HFcWTqDLqiuA9tGRNIj9VfbtXGAYTTkyNEnQk1jkMGOdYka8aG/lulCA==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.28.5.tgz", + "integrity": "sha512-3EwLFhZ38J4VyIP6WNtt2kUdW9dokXA9Cr4IVIFHuCpZ3H8/YFOl5JjZHisrn1fATPBmKKqXzDFvh9fUwHz6CQ==", "dev": true, "dependencies": { - "@babel/types": "^7.20.5", - "@jridgewell/gen-mapping": "^0.3.2", - "jsesc": "^2.5.1" + "@babel/parser": "^7.28.5", + "@babel/types": "^7.28.5", + "@jridgewell/gen-mapping": "^0.3.12", + "@jridgewell/trace-mapping": "^0.3.28", + "jsesc": "^3.0.2" }, "engines": { "node": ">=6.9.0" } }, - "node_modules/@babel/generator/node_modules/@jridgewell/gen-mapping": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz", - "integrity": "sha512-mh65xKQAzI6iBcFzwv28KVWSmCkdRBWoOh+bYQGW3+6OZvbbN3TqMGo5hqYxQniRcH9F2VZIoJCm4pa3BPDK/A==", - "dev": true, - "dependencies": { - "@jridgewell/set-array": "^1.0.1", - "@jridgewell/sourcemap-codec": "^1.4.10", - "@jridgewell/trace-mapping": "^0.3.9" - }, - "engines": { - "node": ">=6.0.0" - } - }, "node_modules/@babel/helper-compilation-targets": { - "version": "7.20.0", - "resolved": "https://registry.npmjs.org/@babel/helper-compilation-targets/-/helper-compilation-targets-7.20.0.tgz", - "integrity": "sha512-0jp//vDGp9e8hZzBc6N/KwA5ZK3Wsm/pfm4CrY7vzegkVxc65SgSn6wYOnwHe9Js9HRQ1YTCKLGPzDtaS3RoLQ==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/helper-compilation-targets/-/helper-compilation-targets-7.27.2.tgz", + "integrity": "sha512-2+1thGUUWWjLTYTHZWK1n8Yga0ijBz1XAhUXcKy81rd5g6yh7hGqMp45v7cadSbEHc9G3OTv45SyneRN3ps4DQ==", "dev": true, "dependencies": { - "@babel/compat-data": "^7.20.0", - "@babel/helper-validator-option": "^7.18.6", - "browserslist": "^4.21.3", - "semver": "^6.3.0" + "@babel/compat-data": "^7.27.2", + "@babel/helper-validator-option": "^7.27.1", + "browserslist": "^4.24.0", + "lru-cache": "^5.1.1", + "semver": "^6.3.1" }, "engines": { "node": ">=6.9.0" - }, - "peerDependencies": { - "@babel/core": "^7.0.0" } }, "node_modules/@babel/helper-compilation-targets/node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, "bin": { "semver": "bin/semver.js" } }, - "node_modules/@babel/helper-environment-visitor": { - "version": "7.18.9", - "resolved": "https://registry.npmjs.org/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.9.tgz", - "integrity": "sha512-3r/aACDJ3fhQ/EVgFy0hpj8oHyHpQc+LPtJoY9SzTThAsStm4Ptegq92vqKoE3vD706ZVFWITnMnxucw+S9Ipg==", - "dev": true, - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@babel/helper-function-name": { - "version": "7.19.0", - "resolved": "https://registry.npmjs.org/@babel/helper-function-name/-/helper-function-name-7.19.0.tgz", - "integrity": "sha512-WAwHBINyrpqywkUH0nTnNgI5ina5TFn85HKS0pbPDfxFfhyR/aNQEn4hGi1P1JyT//I0t4OgXUlofzWILRvS5w==", - "dev": true, - "dependencies": { - "@babel/template": "^7.18.10", - "@babel/types": "^7.19.0" - }, - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@babel/helper-hoist-variables": { - "version": "7.18.6", - "resolved": "https://registry.npmjs.org/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz", - "integrity": "sha512-UlJQPkFqFULIcyW5sbzgbkxn2FKRgwWiRexcuaR8RNJRy8+LLveqPjwZV/bwrLZCN0eUHD/x8D0heK1ozuoo6Q==", + "node_modules/@babel/helper-globals": { + "version": "7.28.0", + "resolved": "https://registry.npmjs.org/@babel/helper-globals/-/helper-globals-7.28.0.tgz", + "integrity": "sha512-+W6cISkXFa1jXsDEdYA8HeevQT/FULhxzR99pxphltZcVaugps53THCeiWA8SguxxpSp3gKPiuYfSWopkLQ4hw==", "dev": true, - "dependencies": { - "@babel/types": "^7.18.6" - }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helper-module-imports": { - "version": "7.18.6", - "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.18.6.tgz", - "integrity": "sha512-0NFvs3VkuSYbFi1x2Vd6tKrywq+z/cLeYC/RJNFrIX/30Bf5aiGYbtvGXolEktzJH8o5E5KJ3tT+nkxuuZFVlA==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.27.1.tgz", + "integrity": "sha512-0gSFWUPNXNopqtIPQvlD5WgXYI5GY2kP2cCvoT8kczjbfcfuIljTbcWrulD1CIPIX2gt1wghbDy08yE1p+/r3w==", "dev": true, "dependencies": { - "@babel/types": "^7.18.6" + "@babel/traverse": "^7.27.1", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helper-module-transforms": { - "version": "7.20.2", - "resolved": "https://registry.npmjs.org/@babel/helper-module-transforms/-/helper-module-transforms-7.20.2.tgz", - "integrity": "sha512-zvBKyJXRbmK07XhMuujYoJ48B5yvvmM6+wcpv6Ivj4Yg6qO7NOZOSnvZN9CRl1zz1Z4cKf8YejmCMh8clOoOeA==", + "version": "7.28.3", + "resolved": "https://registry.npmjs.org/@babel/helper-module-transforms/-/helper-module-transforms-7.28.3.tgz", + "integrity": "sha512-gytXUbs8k2sXS9PnQptz5o0QnpLL51SwASIORY6XaBKF88nsOT0Zw9szLqlSGQDP/4TljBAD5y98p2U1fqkdsw==", "dev": true, "dependencies": { - "@babel/helper-environment-visitor": "^7.18.9", - "@babel/helper-module-imports": "^7.18.6", - "@babel/helper-simple-access": "^7.20.2", - "@babel/helper-split-export-declaration": "^7.18.6", - "@babel/helper-validator-identifier": "^7.19.1", - "@babel/template": "^7.18.10", - "@babel/traverse": "^7.20.1", - "@babel/types": "^7.20.2" + "@babel/helper-module-imports": "^7.27.1", + "@babel/helper-validator-identifier": "^7.27.1", + "@babel/traverse": "^7.28.3" }, "engines": { "node": ">=6.9.0" - } - }, - "node_modules/@babel/helper-plugin-utils": { - "version": "7.20.2", - "resolved": "https://registry.npmjs.org/@babel/helper-plugin-utils/-/helper-plugin-utils-7.20.2.tgz", - "integrity": "sha512-8RvlJG2mj4huQ4pZ+rU9lqKi9ZKiRmuvGuM2HlWmkmgOhbs6zEAw6IEiJ5cQqGbDzGZOhwuOQNtZMi/ENLjZoQ==", - "dev": true, - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@babel/helper-simple-access": { - "version": "7.20.2", - "resolved": "https://registry.npmjs.org/@babel/helper-simple-access/-/helper-simple-access-7.20.2.tgz", - "integrity": "sha512-+0woI/WPq59IrqDYbVGfshjT5Dmk/nnbdpcF8SnMhhXObpTq2KNBdLFRFrkVdbDOyUmHBCxzm5FHV1rACIkIbA==", - "dev": true, - "dependencies": { - "@babel/types": "^7.20.2" }, - "engines": { - "node": ">=6.9.0" + "peerDependencies": { + "@babel/core": "^7.0.0" } }, - "node_modules/@babel/helper-split-export-declaration": { - "version": "7.18.6", - "resolved": "https://registry.npmjs.org/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.18.6.tgz", - "integrity": "sha512-bde1etTx6ZyTmobl9LLMMQsaizFVZrquTEHOqKeQESMKo4PlObf+8+JA25ZsIpZhT/WEd39+vOdLXAFG/nELpA==", + "node_modules/@babel/helper-plugin-utils": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-plugin-utils/-/helper-plugin-utils-7.27.1.tgz", + "integrity": "sha512-1gn1Up5YXka3YYAHGKpbideQ5Yjf1tDa9qYcgysz+cNCXukyLl6DjPXhD3VRwSb8c0J9tA4b2+rHEZtc6R0tlw==", "dev": true, - "dependencies": { - "@babel/types": "^7.18.6" - }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helper-string-parser": { - "version": "7.19.4", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.19.4.tgz", - "integrity": "sha512-nHtDoQcuqFmwYNYPz3Rah5ph2p8PFeFCsZk9A/48dPc/rGocJ5J3hAAZ7pb76VWX3fZKu+uEr/FhH5jLx7umrw==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "dev": true, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.19.1", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz", - "integrity": "sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", + "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", "dev": true, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helper-validator-option": { - "version": "7.18.6", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-option/-/helper-validator-option-7.18.6.tgz", - "integrity": "sha512-XO7gESt5ouv/LRJdrVjkShckw6STTaB7l9BrpBaAHDeF5YZT+01PCwmR0SJHnkW6i8OwW/EVWRShfi4j2x+KQw==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-option/-/helper-validator-option-7.27.1.tgz", + "integrity": "sha512-YvjJow9FxbhFFKDSuFnVCe2WxXk1zWc22fFePVNEaWJEu8IrZVlda6N0uHwzZrUM1il7NC9Mlp4MaJYbYd9JSg==", "dev": true, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helpers": { - "version": "7.20.6", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.20.6.tgz", - "integrity": "sha512-Pf/OjgfgFRW5bApskEz5pvidpim7tEDPlFtKcNRXWmfHGn9IEI2W2flqRQXTFb7gIPTyK++N6rVHuwKut4XK6w==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.4.tgz", + "integrity": "sha512-HFN59MmQXGHVyYadKLVumYsA9dBFun/ldYxipEjzA4196jpLZd8UjEEBLkbEkvfYreDqJhZxYAWFPtrfhNpj4w==", "dev": true, "dependencies": { - "@babel/template": "^7.18.10", - "@babel/traverse": "^7.20.5", - "@babel/types": "^7.20.5" + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.4" }, "engines": { "node": ">=6.9.0" } }, - "node_modules/@babel/highlight": { - "version": "7.18.6", - "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.18.6.tgz", - "integrity": "sha512-u7stbOuYjaPezCuLj29hNW1v64M2Md2qupEKP1fHc7WdOA3DgLh37suiSrZYY7haUB7iBeQZ9P1uiRF359do3g==", + "node_modules/@babel/parser": { + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", + "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", "dev": true, "dependencies": { - "@babel/helper-validator-identifier": "^7.18.6", - "chalk": "^2.0.0", - "js-tokens": "^4.0.0" + "@babel/types": "^7.28.5" }, - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@babel/parser": { - "version": "7.20.5", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.20.5.tgz", - "integrity": "sha512-r27t/cy/m9uKLXQNWWebeCUHgnAZq0CpG1OwKRxzJMP1vpSU4bSIK2hq+/cp0bQxetkXx38n09rNu8jVkcK/zA==", - "dev": true, "bin": { "parser": "bin/babel-parser.js" }, @@ -355,6 +278,36 @@ "@babel/core": "^7.0.0-0" } }, + "node_modules/@babel/plugin-syntax-class-static-block": { + "version": "7.14.5", + "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-class-static-block/-/plugin-syntax-class-static-block-7.14.5.tgz", + "integrity": "sha512-b+YyPmr6ldyNnM6sqYeMWE+bgJcJpO6yS4QD7ymxgH34GBPNDM/THBh8iunyvKIZztiwLH4CJZ0RxTk9emgpjw==", + "dev": true, + "dependencies": { + "@babel/helper-plugin-utils": "^7.14.5" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0-0" + } + }, + "node_modules/@babel/plugin-syntax-import-attributes": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-import-attributes/-/plugin-syntax-import-attributes-7.27.1.tgz", + "integrity": "sha512-oFT0FrKHgF53f4vOsZGi2Hh3I35PfSmVs4IBFLFj4dnafP+hIWDLg3VyKmUHfLoLHlyxY4C7DGtmHuJgn+IGww==", + "dev": true, + "dependencies": { + "@babel/helper-plugin-utils": "^7.27.1" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0-0" + } + }, "node_modules/@babel/plugin-syntax-import-meta": { "version": "7.10.4", "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-import-meta/-/plugin-syntax-import-meta-7.10.4.tgz", @@ -451,6 +404,21 @@ "@babel/core": "^7.0.0-0" } }, + "node_modules/@babel/plugin-syntax-private-property-in-object": { + "version": "7.14.5", + "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-private-property-in-object/-/plugin-syntax-private-property-in-object-7.14.5.tgz", + "integrity": "sha512-0wVnp9dxJ72ZUJDV27ZfbSj6iHLoytYZmh3rFcxNnvsJF3ktkzLDZPy/mA17HGsaQT3/DQsWYX1f1QGWkCoVUg==", + "dev": true, + "dependencies": { + "@babel/helper-plugin-utils": "^7.14.5" + }, + "engines": { + "node": ">=6.9.0" + }, + "peerDependencies": { + "@babel/core": "^7.0.0-0" + } + }, "node_modules/@babel/plugin-syntax-top-level-await": { "version": "7.14.5", "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-top-level-await/-/plugin-syntax-top-level-await-7.14.5.tgz", @@ -467,58 +435,45 @@ } }, "node_modules/@babel/template": { - "version": "7.18.10", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.18.10.tgz", - "integrity": "sha512-TI+rCtooWHr3QJ27kJxfjutghu44DLnasDMwpDqCXVTal9RLp3RSYNh4NdBrRP2cQAoG9A8juOQl6P6oZG4JxA==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", "dev": true, "dependencies": { - "@babel/code-frame": "^7.18.6", - "@babel/parser": "^7.18.10", - "@babel/types": "^7.18.10" + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/traverse": { - "version": "7.20.5", - "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.20.5.tgz", - "integrity": "sha512-WM5ZNN3JITQIq9tFZaw1ojLU3WgWdtkxnhM1AegMS+PvHjkM5IXjmYEGY7yukz5XS4sJyEf2VzWjI8uAavhxBQ==", - "dev": true, - "dependencies": { - "@babel/code-frame": "^7.18.6", - "@babel/generator": "^7.20.5", - "@babel/helper-environment-visitor": "^7.18.9", - "@babel/helper-function-name": "^7.19.0", - "@babel/helper-hoist-variables": "^7.18.6", - "@babel/helper-split-export-declaration": "^7.18.6", - "@babel/parser": "^7.20.5", - "@babel/types": "^7.20.5", - "debug": "^4.1.0", - "globals": "^11.1.0" + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.28.5.tgz", + "integrity": "sha512-TCCj4t55U90khlYkVV/0TfkJkAkUg3jZFA3Neb7unZT8CPok7iiRfaX0F+WnqWqt7OxhOn0uBKXCw4lbL8W0aQ==", + "dev": true, + "dependencies": { + "@babel/code-frame": "^7.27.1", + "@babel/generator": "^7.28.5", + "@babel/helper-globals": "^7.28.0", + "@babel/parser": "^7.28.5", + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.5", + "debug": "^4.3.1" }, "engines": { "node": ">=6.9.0" } }, - "node_modules/@babel/traverse/node_modules/globals": { - "version": "11.12.0", - "resolved": "https://registry.npmjs.org/globals/-/globals-11.12.0.tgz", - "integrity": "sha512-WOBp/EEGUiIsJSp7wcv/y6MO+lV9UoncWqxuFfm8eBwzWNgyfBd6Gz+IeKQ9jCmyhoH99g15M3T+QaVHFjizVA==", - "dev": true, - "engines": { - "node": ">=4" - } - }, "node_modules/@babel/types": { - "version": "7.20.5", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.20.5.tgz", - "integrity": "sha512-c9fst/h2/dcF7H+MJKZ2T0KjEQ8hY/BNnDk/H3XY8C4Aw/eWQXWn/lWntHF9ooUBnGmEvbfGrTgLWc+um0YDUg==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", + "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", "dev": true, "dependencies": { - "@babel/helper-string-parser": "^7.19.4", - "@babel/helper-validator-identifier": "^7.19.1", - "to-fast-properties": "^2.0.0" + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" }, "engines": { "node": ">=6.9.0" @@ -547,38 +502,41 @@ } }, "node_modules/@eslint-community/eslint-utils": { - "version": "4.4.0", - "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz", - "integrity": "sha512-1/sA4dwrzBAyeUoQ6oxahHKmrZvsnLCg4RfxW3ZFGGmQkSNQPFNLV9CUEFQP1x9EYXHTo5p6xdhZM1Ne9p/AfA==", + "version": "4.9.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.9.0.tgz", + "integrity": "sha512-ayVFHdtZ+hsq1t2Dy24wCmGXGe4q9Gu3smhLYALJrr473ZH27MsnSL+LKUlimp4BWJqMDMLmPpx/Q9R3OAlL4g==", "dev": true, "dependencies": { - "eslint-visitor-keys": "^3.3.0" + "eslint-visitor-keys": "^3.4.3" }, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" }, + "funding": { + "url": "https://opencollective.com/eslint" + }, "peerDependencies": { "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" } }, "node_modules/@eslint-community/regexpp": { - "version": "4.5.1", - "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.5.1.tgz", - "integrity": "sha512-Z5ba73P98O1KUYCCJTUeVpja9RcGoMdncZ6T49FCUl2lN38JtCJ+3WgIDBv0AuY4WChU5PmtJmOCTlN6FZTFKQ==", + "version": "4.12.2", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.12.2.tgz", + "integrity": "sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==", "dev": true, "engines": { "node": "^12.0.0 || ^14.0.0 || >=16.0.0" } }, "node_modules/@eslint/eslintrc": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.0.3.tgz", - "integrity": "sha512-+5gy6OQfk+xx3q0d6jGZZC3f3KzAkXc/IanVxd1is/VIIziRqqt3ongQz0FiTUXqTk0c7aDB3OaFuKnuSoJicQ==", + "version": "2.1.4", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.1.4.tgz", + "integrity": "sha512-269Z39MS6wVJtsoUl10L60WdkhJVdPG24Q4eZTH3nnF6lpvSShEK3wQjDX9JRWAUPvPh7COouPpU9IrqaZFvtQ==", "dev": true, "dependencies": { "ajv": "^6.12.4", "debug": "^4.3.2", - "espree": "^9.5.2", + "espree": "^9.6.0", "globals": "^13.19.0", "ignore": "^5.2.0", "import-fresh": "^3.2.1", @@ -593,41 +551,24 @@ "url": "https://opencollective.com/eslint" } }, - "node_modules/@eslint/eslintrc/node_modules/argparse": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", - "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", - "dev": true - }, - "node_modules/@eslint/eslintrc/node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", - "dev": true, - "dependencies": { - "argparse": "^2.0.1" - }, - "bin": { - "js-yaml": "bin/js-yaml.js" - } - }, "node_modules/@eslint/js": { - "version": "8.41.0", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.41.0.tgz", - "integrity": "sha512-LxcyMGxwmTh2lY9FwHPGWOHmYFCZvbrFCBZL4FzSSsxsRPuhrYUg/49/0KDfW8tnIEaEHtfmn6+NPN+1DqaNmA==", + "version": "8.57.1", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.57.1.tgz", + "integrity": "sha512-d9zaMRSTIKDLhctzH12MtXvJKSSUhaHcjV+2Z+GK+EEY7XKpP5yR4x+N3TAcHTcu963nIr+TMcCb4DBCYX1z6Q==", "dev": true, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" } }, "node_modules/@humanwhocodes/config-array": { - "version": "0.11.8", - "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz", - "integrity": "sha512-UybHIJzJnR5Qc/MsD9Kr+RpO2h+/P1GhOwdiLPXK5TWk5sgTdu88bTD9UP+CKbPPh5Rni1u0GjAdYQLemG8g+g==", + "version": "0.13.0", + "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.13.0.tgz", + "integrity": "sha512-DZLEEqFWQFiyK6h5YIeynKx7JlvCYWL0cImfSRXZ9l4Sg2efkFGTuFf6vzXjK1cq6IYkU+Eg/JizXw+TD2vRNw==", + "deprecated": "Use @eslint/config-array instead", "dev": true, "dependencies": { - "@humanwhocodes/object-schema": "^1.2.1", - "debug": "^4.1.1", + "@humanwhocodes/object-schema": "^2.0.3", + "debug": "^4.3.1", "minimatch": "^3.0.5" }, "engines": { @@ -648,9 +589,10 @@ } }, "node_modules/@humanwhocodes/object-schema": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/@humanwhocodes/object-schema/-/object-schema-1.2.1.tgz", - "integrity": "sha512-ZnQMnLV4e7hDlUvw8H+U8ASL02SS2Gn6+9Ac3wGGLIe7+je2AeAOxPY+izIPJDfFDb7eDjev0Us8MO1iFRN8hA==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/@humanwhocodes/object-schema/-/object-schema-2.0.3.tgz", + "integrity": "sha512-93zYdMES/c1D69yZiKDBj0V24vqNzB/koF26KPaagAfd3P/4gUlh3Dys5ogAK+Exi9QyzlD8x/08Zt7wIKcDcA==", + "deprecated": "Use @eslint/object-schema instead", "dev": true }, "node_modules/@istanbuljs/load-nyc-config": { @@ -669,109 +611,113 @@ "node": ">=8" } }, - "node_modules/@istanbuljs/load-nyc-config/node_modules/resolve-from": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", - "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", + "node_modules/@istanbuljs/load-nyc-config/node_modules/argparse": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", + "integrity": "sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==", "dev": true, - "engines": { - "node": ">=8" + "dependencies": { + "sprintf-js": "~1.0.2" } }, - "node_modules/@istanbuljs/schema": { - "version": "0.1.3", - "resolved": "https://registry.npmjs.org/@istanbuljs/schema/-/schema-0.1.3.tgz", - "integrity": "sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==", + "node_modules/@istanbuljs/load-nyc-config/node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", "dev": true, + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, "engines": { "node": ">=8" } }, - "node_modules/@jest/console": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/@jest/console/-/console-26.6.2.tgz", - "integrity": "sha512-IY1R2i2aLsLr7Id3S6p2BA82GNWryt4oSvEXLAKc+L2zdi89dSkE8xC1C+0kpATG4JhBJREnQOH7/zmccM2B0g==", + "node_modules/@istanbuljs/load-nyc-config/node_modules/js-yaml": { + "version": "3.14.2", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", + "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", "dev": true, "dependencies": { - "@jest/types": "^26.6.2", - "@types/node": "*", - "chalk": "^4.0.0", - "jest-message-util": "^26.6.2", - "jest-util": "^26.6.2", - "slash": "^3.0.0" + "argparse": "^1.0.7", + "esprima": "^4.0.0" }, - "engines": { - "node": ">= 10.14.2" + "bin": { + "js-yaml": "bin/js-yaml.js" } }, - "node_modules/@jest/console/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/@istanbuljs/load-nyc-config/node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "p-locate": "^4.1.0" }, "engines": { "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" } }, - "node_modules/@jest/console/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/@istanbuljs/load-nyc-config/node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" + "p-try": "^2.0.0" }, "engines": { - "node": ">=10" + "node": ">=6" }, "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/@jest/console/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/@istanbuljs/load-nyc-config/node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", "dev": true, "dependencies": { - "color-name": "~1.1.4" + "p-limit": "^2.2.0" }, "engines": { - "node": ">=7.0.0" + "node": ">=8" } }, - "node_modules/@jest/console/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true + "node_modules/@istanbuljs/load-nyc-config/node_modules/resolve-from": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", + "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", + "dev": true, + "engines": { + "node": ">=8" + } }, - "node_modules/@jest/console/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/@istanbuljs/schema": { + "version": "0.1.3", + "resolved": "https://registry.npmjs.org/@istanbuljs/schema/-/schema-0.1.3.tgz", + "integrity": "sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==", "dev": true, "engines": { "node": ">=8" } }, - "node_modules/@jest/console/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/@jest/console": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/@jest/console/-/console-26.6.2.tgz", + "integrity": "sha512-IY1R2i2aLsLr7Id3S6p2BA82GNWryt4oSvEXLAKc+L2zdi89dSkE8xC1C+0kpATG4JhBJREnQOH7/zmccM2B0g==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "@jest/types": "^26.6.2", + "@types/node": "*", + "chalk": "^4.0.0", + "jest-message-util": "^26.6.2", + "jest-util": "^26.6.2", + "slash": "^3.0.0" }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, "node_modules/@jest/core": { @@ -813,86 +759,16 @@ "node": ">= 10.14.2" } }, - "node_modules/@jest/core/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/@jest/environment": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/@jest/environment/-/environment-26.6.2.tgz", + "integrity": "sha512-nFy+fHl28zUrRsCeMB61VDThV1pVTtlEokBRgqPrcT1JNq4yRNIyTHfyht6PqtUvY9IsuLGTrbG8kPXjSZIZwA==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/@jest/core/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/@jest/core/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/@jest/core/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/@jest/core/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/@jest/core/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/@jest/environment": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/@jest/environment/-/environment-26.6.2.tgz", - "integrity": "sha512-nFy+fHl28zUrRsCeMB61VDThV1pVTtlEokBRgqPrcT1JNq4yRNIyTHfyht6PqtUvY9IsuLGTrbG8kPXjSZIZwA==", - "dev": true, - "dependencies": { - "@jest/fake-timers": "^26.6.2", - "@jest/types": "^26.6.2", - "@types/node": "*", - "jest-mock": "^26.6.2" + "@jest/fake-timers": "^26.6.2", + "@jest/types": "^26.6.2", + "@types/node": "*", + "jest-mock": "^26.6.2" }, "engines": { "node": ">= 10.14.2" @@ -967,76 +843,6 @@ "node-notifier": "^8.0.0" } }, - "node_modules/@jest/reporters/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/@jest/reporters/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/@jest/reporters/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/@jest/reporters/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/@jest/reporters/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/@jest/reporters/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/@jest/source-map": { "version": "26.6.2", "resolved": "https://registry.npmjs.org/@jest/source-map/-/source-map-26.6.2.tgz", @@ -1108,271 +914,123 @@ "node": ">= 10.14.2" } }, - "node_modules/@jest/transform/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/@jest/types": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/@jest/types/-/types-26.6.2.tgz", + "integrity": "sha512-fC6QCp7Sc5sX6g8Tvbmj4XUTbyrik0akgRy03yjXbQaBWWNWGE7SGtJk98m0N8nzegD/7SggrUlivxo5ax4KWQ==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "@types/istanbul-lib-coverage": "^2.0.0", + "@types/istanbul-reports": "^3.0.0", + "@types/node": "*", + "@types/yargs": "^15.0.0", + "chalk": "^4.0.0" }, "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "node": ">= 10.14.2" } }, - "node_modules/@jest/transform/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "@jridgewell/sourcemap-codec": "^1.5.0", + "@jridgewell/trace-mapping": "^0.3.24" } }, - "node_modules/@jest/transform/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/@jridgewell/remapping": { + "version": "2.3.5", + "resolved": "https://registry.npmjs.org/@jridgewell/remapping/-/remapping-2.3.5.tgz", + "integrity": "sha512-LI9u/+laYG4Ds1TDKSJW2YPrIlcVYOwi2fUC6xB43lueCjgxV4lffOCZCtYFiH6TNOX+tQKXx97T4IKHbhyHEQ==", "dev": true, "dependencies": { - "color-name": "~1.1.4" - }, + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.24" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "dev": true, "engines": { - "node": ">=7.0.0" + "node": ">=6.0.0" } }, - "node_modules/@jest/transform/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", "dev": true }, - "node_modules/@jest/transform/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", "dev": true, - "engines": { - "node": ">=8" + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" } }, - "node_modules/@jest/transform/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" }, "engines": { - "node": ">=8" + "node": ">= 8" } }, - "node_modules/@jest/types": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/@jest/types/-/types-26.6.2.tgz", - "integrity": "sha512-fC6QCp7Sc5sX6g8Tvbmj4XUTbyrik0akgRy03yjXbQaBWWNWGE7SGtJk98m0N8nzegD/7SggrUlivxo5ax4KWQ==", + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", "dev": true, - "dependencies": { - "@types/istanbul-lib-coverage": "^2.0.0", - "@types/istanbul-reports": "^3.0.0", - "@types/node": "*", - "@types/yargs": "^15.0.0", - "chalk": "^4.0.0" - }, "engines": { - "node": ">= 10.14.2" + "node": ">= 8" } }, - "node_modules/@jest/types/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" }, "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "node": ">= 8" } }, - "node_modules/@jest/types/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/@rollup/plugin-commonjs": { + "version": "20.0.0", + "resolved": "https://registry.npmjs.org/@rollup/plugin-commonjs/-/plugin-commonjs-20.0.0.tgz", + "integrity": "sha512-5K0g5W2Ol8hAcTHqcTBHiA7M58tfmYi1o9KxeJuuRNpGaTa5iLjcyemBitCBcKXaHamOBBEH2dGom6v6Unmqjg==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" + "@rollup/pluginutils": "^3.1.0", + "commondir": "^1.0.1", + "estree-walker": "^2.0.1", + "glob": "^7.1.6", + "is-reference": "^1.2.1", + "magic-string": "^0.25.7", + "resolve": "^1.17.0" }, "engines": { - "node": ">=10" + "node": ">= 8.0.0" }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "peerDependencies": { + "rollup": "^2.38.3" } }, - "node_modules/@jest/types/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/@jest/types/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/@jest/types/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/@jest/types/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/@jridgewell/gen-mapping": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.1.1.tgz", - "integrity": "sha512-sQXCasFk+U8lWYEe66WxRDOE9PjVz4vSM51fTu3Hw+ClTpUSQb718772vH3pyS5pShp6lvQM7SxgIDXXXmOX7w==", - "dev": true, - "dependencies": { - "@jridgewell/set-array": "^1.0.0", - "@jridgewell/sourcemap-codec": "^1.4.10" - }, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/resolve-uri": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz", - "integrity": "sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w==", - "dev": true, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/set-array": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", - "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", - "dev": true, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/sourcemap-codec": { - "version": "1.4.14", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz", - "integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==", - "dev": true - }, - "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.17", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.17.tgz", - "integrity": "sha512-MCNzAp77qzKca9+W/+I0+sEpaUnZoeasnghNeVc41VZCEKaCH73Vq3BZZ/SzWIgrqE4H4ceI+p+b6C0mHf9T4g==", - "dev": true, - "dependencies": { - "@jridgewell/resolve-uri": "3.1.0", - "@jridgewell/sourcemap-codec": "1.4.14" - } - }, - "node_modules/@nodelib/fs.scandir": { - "version": "2.1.5", - "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", - "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", - "dev": true, - "dependencies": { - "@nodelib/fs.stat": "2.0.5", - "run-parallel": "^1.1.9" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/@nodelib/fs.stat": { - "version": "2.0.5", - "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", - "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", - "dev": true, - "engines": { - "node": ">= 8" - } - }, - "node_modules/@nodelib/fs.walk": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", - "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", - "dev": true, - "dependencies": { - "@nodelib/fs.scandir": "2.1.5", - "fastq": "^1.6.0" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/@rollup/plugin-commonjs": { - "version": "20.0.0", - "resolved": "https://registry.npmjs.org/@rollup/plugin-commonjs/-/plugin-commonjs-20.0.0.tgz", - "integrity": "sha512-5K0g5W2Ol8hAcTHqcTBHiA7M58tfmYi1o9KxeJuuRNpGaTa5iLjcyemBitCBcKXaHamOBBEH2dGom6v6Unmqjg==", - "dev": true, - "dependencies": { - "@rollup/pluginutils": "^3.1.0", - "commondir": "^1.0.1", - "estree-walker": "^2.0.1", - "glob": "^7.1.6", - "is-reference": "^1.2.1", - "magic-string": "^0.25.7", - "resolve": "^1.17.0" - }, - "engines": { - "node": ">= 8.0.0" - }, - "peerDependencies": { - "rollup": "^2.38.3" - } - }, - "node_modules/@rollup/plugin-commonjs/node_modules/estree-walker": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", - "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", - "dev": true - }, "node_modules/@rollup/plugin-node-resolve": { "version": "13.3.0", "resolved": "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-13.3.0.tgz", @@ -1410,6 +1068,12 @@ "rollup": "^1.20.0||^2.0.0" } }, + "node_modules/@rollup/pluginutils/node_modules/estree-walker": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-1.0.1.tgz", + "integrity": "sha512-1fMXF3YP4pZZVozF8j/ZLfvnR8NSIljt56UhbZ5PeeDmmGHpgpdwQt7ITlGvYaQukCvuBRMLEiKiYC+oeIg4cg==", + "dev": true + }, "node_modules/@sinonjs/commons": { "version": "1.8.6", "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-1.8.6.tgz", @@ -1438,31 +1102,31 @@ } }, "node_modules/@types/babel__core": { - "version": "7.1.20", - "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.1.20.tgz", - "integrity": "sha512-PVb6Bg2QuscZ30FvOU7z4guG6c926D9YRvOxEaelzndpMsvP+YM74Q/dAFASpg2l6+XLalxSGxcq/lrgYWZtyQ==", + "version": "7.20.5", + "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", + "integrity": "sha512-qoQprZvz5wQFJwMDqeseRXWv3rqMvhgpbXFfVyWhbx9X47POIA6i/+dXefEmZKoAgOaTdaIgNSMqMIU61yRyzA==", "dev": true, "dependencies": { - "@babel/parser": "^7.1.0", - "@babel/types": "^7.0.0", + "@babel/parser": "^7.20.7", + "@babel/types": "^7.20.7", "@types/babel__generator": "*", "@types/babel__template": "*", "@types/babel__traverse": "*" } }, "node_modules/@types/babel__generator": { - "version": "7.6.4", - "resolved": "https://registry.npmjs.org/@types/babel__generator/-/babel__generator-7.6.4.tgz", - "integrity": "sha512-tFkciB9j2K755yrTALxD44McOrk+gfpIpvC3sxHjRawj6PfnQxrse4Clq5y/Rq+G3mrBurMax/lG8Qn2t9mSsg==", + "version": "7.27.0", + "resolved": "https://registry.npmjs.org/@types/babel__generator/-/babel__generator-7.27.0.tgz", + "integrity": "sha512-ufFd2Xi92OAVPYsy+P4n7/U7e68fex0+Ee8gSG9KX7eo084CWiQ4sdxktvdl0bOPupXtVJPY19zk6EwWqUQ8lg==", "dev": true, "dependencies": { "@babel/types": "^7.0.0" } }, "node_modules/@types/babel__template": { - "version": "7.4.1", - "resolved": "https://registry.npmjs.org/@types/babel__template/-/babel__template-7.4.1.tgz", - "integrity": "sha512-azBFKemX6kMg5Io+/rdGT0dkGreboUVR0Cdm3fz9QJWpaQGJRQXl7C+6hOTCZcMll7KFyEQpgbYI2lHdsS4U7g==", + "version": "7.4.4", + "resolved": "https://registry.npmjs.org/@types/babel__template/-/babel__template-7.4.4.tgz", + "integrity": "sha512-h/NUaSyG5EyxBIp8YRxo4RMe2/qQgvyowRwVMzhYhBCONbW8PUsg4lkFMrhgZhUe5z3L3MiLDuvyJ/CaPa2A8A==", "dev": true, "dependencies": { "@babel/parser": "^7.1.0", @@ -1470,12 +1134,12 @@ } }, "node_modules/@types/babel__traverse": { - "version": "7.18.2", - "resolved": "https://registry.npmjs.org/@types/babel__traverse/-/babel__traverse-7.18.2.tgz", - "integrity": "sha512-FcFaxOr2V5KZCviw1TnutEMVUVsGt4D2hP1TAfXZAMKuHYW3xQhe3jTxNPWutgCJ3/X1c5yX8ZoGVEItxKbwBg==", + "version": "7.28.0", + "resolved": "https://registry.npmjs.org/@types/babel__traverse/-/babel__traverse-7.28.0.tgz", + "integrity": "sha512-8PvcXf70gTDZBgt9ptxJ8elBeBjcLOAcOtoO/mPJjtji1+CdGbHgm77om1GrsPxsiE+uXIpNSK64UYaIwQXd4Q==", "dev": true, "dependencies": { - "@babel/types": "^7.3.0" + "@babel/types": "^7.28.2" } }, "node_modules/@types/estree": { @@ -1485,60 +1149,63 @@ "dev": true }, "node_modules/@types/graceful-fs": { - "version": "4.1.5", - "resolved": "https://registry.npmjs.org/@types/graceful-fs/-/graceful-fs-4.1.5.tgz", - "integrity": "sha512-anKkLmZZ+xm4p8JWBf4hElkM4XR+EZeA2M9BAkkTldmcyDY4mbdIJnRghDJH3Ov5ooY7/UAoENtmdMSkaAd7Cw==", + "version": "4.1.9", + "resolved": "https://registry.npmjs.org/@types/graceful-fs/-/graceful-fs-4.1.9.tgz", + "integrity": "sha512-olP3sd1qOEe5dXTSaFvQG+02VdRXcdytWLAZsAq1PecU8uqQAhkrnbli7DagjtXKW/Bl7YJbUsa8MPcuc8LHEQ==", "dev": true, "dependencies": { "@types/node": "*" } }, "node_modules/@types/istanbul-lib-coverage": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.4.tgz", - "integrity": "sha512-z/QT1XN4K4KYuslS23k62yDIDLwLFkzxOuMplDtObz0+y7VqJCaO2o+SPwHCvLFZh7xazvvoor2tA/hPz9ee7g==", + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.6.tgz", + "integrity": "sha512-2QF/t/auWm0lsy8XtKVPG19v3sSOQlJe/YHZgfjb/KBBHOGSV+J2q/S671rcq9uTBrLAXmZpqJiaQbMT+zNU1w==", "dev": true }, "node_modules/@types/istanbul-lib-report": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/@types/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz", - "integrity": "sha512-plGgXAPfVKFoYfa9NpYDAkseG+g6Jr294RqeqcqDixSbU34MZVJRi/P+7Y8GDpzkEwLaGZZOpKIEmeVZNtKsrg==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/@types/istanbul-lib-report/-/istanbul-lib-report-3.0.3.tgz", + "integrity": "sha512-NQn7AHQnk/RSLOxrBbGyJM/aVQ+pjj5HCgasFxc0K/KhoATfQ/47AyUl15I2yBUpihjmas+a+VJBOqecrFH+uA==", "dev": true, "dependencies": { "@types/istanbul-lib-coverage": "*" } }, "node_modules/@types/istanbul-reports": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/@types/istanbul-reports/-/istanbul-reports-3.0.1.tgz", - "integrity": "sha512-c3mAZEuK0lvBp8tmuL74XRKn1+y2dcwOUpH7x4WrF6gk1GIgiluDRgMYQtw2OFcBvAJWlt6ASU3tSqxp0Uu0Aw==", + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/istanbul-reports/-/istanbul-reports-3.0.4.tgz", + "integrity": "sha512-pk2B1NWalF9toCRu6gjBzR69syFjP4Od8WRAX+0mmf9lAjCRicLOWc+ZrxZHx/0XRjotgkF9t6iaMJ+aXcOdZQ==", "dev": true, "dependencies": { "@types/istanbul-lib-report": "*" } }, "node_modules/@types/json-schema": { - "version": "7.0.11", - "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.11.tgz", - "integrity": "sha512-wOuvG1SN4Us4rez+tylwwwCV1psiNVOkJeM3AUWUNWg/jDQY2+HE/444y5gc+jBmRqASOm2Oeh5c1axHobwRKQ==", + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, "node_modules/@types/node": { - "version": "20.4.5", - "resolved": "https://registry.npmmirror.com/@types/node/-/node-20.4.5.tgz", - "integrity": "sha512-rt40Nk13II9JwQBdeYqmbn2Q6IVTA5uPhvSO+JVqdXw/6/4glI6oR9ezty/A9Hg5u7JH4OmYmuQ+XvjKm0Datg==", - "dev": true + "version": "20.19.25", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.19.25.tgz", + "integrity": "sha512-ZsJzA5thDQMSQO788d7IocwwQbI8B5OPzmqNvpf3NY/+MHDAS759Wo0gd2WQeXYt5AAAQjzcrTVC6SKCuYgoCQ==", + "dev": true, + "dependencies": { + "undici-types": "~6.21.0" + } }, "node_modules/@types/normalize-package-data": { - "version": "2.4.1", - "resolved": "https://registry.npmjs.org/@types/normalize-package-data/-/normalize-package-data-2.4.1.tgz", - "integrity": "sha512-Gj7cI7z+98M282Tqmp2K5EIsoouUEzbBJhQQzDE3jSIRk6r9gsz0oUokqIUR4u1R3dMHo0pDHM7sNOHyhulypw==", + "version": "2.4.4", + "resolved": "https://registry.npmjs.org/@types/normalize-package-data/-/normalize-package-data-2.4.4.tgz", + "integrity": "sha512-37i+OaWTh9qeK4LSHPsyRC7NahnGotNuZvjLSgcPzblpHB3rrCJxAOgI5gCdKm7coonsaX1Of0ILiTcnZjbfxA==", "dev": true }, "node_modules/@types/prettier": { - "version": "2.7.1", - "resolved": "https://registry.npmjs.org/@types/prettier/-/prettier-2.7.1.tgz", - "integrity": "sha512-ri0UmynRRvZiiUJdiz38MmIblKK+oH30MztdBVR95dv/Ubw6neWSb8u1XpRb72L4qsZOhz+L+z9JD40SJmfWow==", + "version": "2.7.3", + "resolved": "https://registry.npmjs.org/@types/prettier/-/prettier-2.7.3.tgz", + "integrity": "sha512-+68kP9yzs4LMp7VNh8gdzMSPZFL44MLGqiHWvttYJe+6qnuVr4Ek9wSBQoveqY/r+LwjCcU29kNVkidwim+kYA==", "dev": true }, "node_modules/@types/resolve": { @@ -1551,44 +1218,44 @@ } }, "node_modules/@types/semver": { - "version": "7.5.0", - "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.5.0.tgz", - "integrity": "sha512-G8hZ6XJiHnuhQKR7ZmysCeJWE08o8T0AXtk5darsCaTVsYZhhgUrq53jizaR2FvsoeCwJhlmwTjkXBY5Pn/ZHw==", + "version": "7.7.1", + "resolved": "https://registry.npmjs.org/@types/semver/-/semver-7.7.1.tgz", + "integrity": "sha512-FmgJfu+MOcQ370SD0ev7EI8TlCAfKYU+B4m5T3yXc1CiRN94g/SZPtsCkk506aUDtlMnFZvasDwHHUcZUEaYuA==", "dev": true }, "node_modules/@types/stack-utils": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@types/stack-utils/-/stack-utils-2.0.1.tgz", - "integrity": "sha512-Hl219/BT5fLAaz6NDkSuhzasy49dwQS/DSdu4MdggFB8zcXv7vflBI3xp7FEmkmdDkBUI2bPUNeMttp2knYdxw==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/@types/stack-utils/-/stack-utils-2.0.3.tgz", + "integrity": "sha512-9aEbYZ3TbYMznPdcdr3SmIrLXwC/AKZXQeCf9Pgao5CKb8CyHuEX5jzWPTkvregvhRJHcpRO6BFoGW9ycaOkYw==", "dev": true }, "node_modules/@types/yargs": { - "version": "15.0.14", - "resolved": "https://registry.npmjs.org/@types/yargs/-/yargs-15.0.14.tgz", - "integrity": "sha512-yEJzHoxf6SyQGhBhIYGXQDSCkJjB6HohDShto7m8vaKg9Yp0Yn8+71J9eakh2bnPg6BfsH9PRMhiRTZnd4eXGQ==", + "version": "15.0.20", + "resolved": "https://registry.npmjs.org/@types/yargs/-/yargs-15.0.20.tgz", + "integrity": "sha512-KIkX+/GgfFitlASYCGoSF+T4XRXhOubJLhkLVtSfsRTe9jWMmuM2g28zQ41BtPTG7TRBb2xHW+LCNVE9QR/vsg==", "dev": true, "dependencies": { "@types/yargs-parser": "*" } }, "node_modules/@types/yargs-parser": { - "version": "21.0.0", - "resolved": "https://registry.npmjs.org/@types/yargs-parser/-/yargs-parser-21.0.0.tgz", - "integrity": "sha512-iO9ZQHkZxHn4mSakYV0vFHAVDyEOIJQrV2uZ06HxEPcx+mt8swXoZHIbaaJ2crJYFfErySgktuTZ3BeLz+XmFA==", + "version": "21.0.3", + "resolved": "https://registry.npmjs.org/@types/yargs-parser/-/yargs-parser-21.0.3.tgz", + "integrity": "sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==", "dev": true }, "node_modules/@typescript-eslint/eslint-plugin": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.6.tgz", - "integrity": "sha512-sXtOgJNEuRU5RLwPUb1jxtToZbgvq3M6FPpY4QENxoOggK+UpTxUBpj6tD8+Qh2g46Pi9We87E+eHnUw8YcGsw==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.62.0.tgz", + "integrity": "sha512-TiZzBSJja/LbhNPvk6yc0JrX9XqhQ0hdh6M2svYfsHGejaKFIAGd9MQ+ERIMzLGlN/kZoYIgdxFV0PuljTKXag==", "dev": true, "dependencies": { "@eslint-community/regexpp": "^4.4.0", - "@typescript-eslint/scope-manager": "5.59.6", - "@typescript-eslint/type-utils": "5.59.6", - "@typescript-eslint/utils": "5.59.6", + "@typescript-eslint/scope-manager": "5.62.0", + "@typescript-eslint/type-utils": "5.62.0", + "@typescript-eslint/utils": "5.62.0", "debug": "^4.3.4", - "grapheme-splitter": "^1.0.4", + "graphemer": "^1.4.0", "ignore": "^5.2.0", "natural-compare-lite": "^1.4.0", "semver": "^7.3.7", @@ -1612,14 +1279,14 @@ } }, "node_modules/@typescript-eslint/parser": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-5.59.6.tgz", - "integrity": "sha512-7pCa6al03Pv1yf/dUg/s1pXz/yGMUBAw5EeWqNTFiSueKvRNonze3hma3lhdsOrQcaOXhbk5gKu2Fludiho9VA==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-5.62.0.tgz", + "integrity": "sha512-VlJEV0fOQ7BExOsHYAGrgbEiZoi8D+Bl2+f6V2RrXerRSylnp+ZBHmPvaIa8cz0Ajx7WO7Z5RqfgYg7ED1nRhA==", "dev": true, "dependencies": { - "@typescript-eslint/scope-manager": "5.59.6", - "@typescript-eslint/types": "5.59.6", - "@typescript-eslint/typescript-estree": "5.59.6", + "@typescript-eslint/scope-manager": "5.62.0", + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/typescript-estree": "5.62.0", "debug": "^4.3.4" }, "engines": { @@ -1639,13 +1306,13 @@ } }, "node_modules/@typescript-eslint/scope-manager": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-5.59.6.tgz", - "integrity": "sha512-gLbY3Le9Dxcb8KdpF0+SJr6EQ+hFGYFl6tVY8VxLPFDfUZC7BHFw+Vq7bM5lE9DwWPfx4vMWWTLGXgpc0mAYyQ==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-5.62.0.tgz", + "integrity": "sha512-VXuvVvZeQCQb5Zgf4HAxc04q5j+WrNAtNh9OwCsCgpKqESMTu3tF/jhZ3xG6T4NZwWl65Bg8KuS2uEvhSfLl0w==", "dev": true, "dependencies": { - "@typescript-eslint/types": "5.59.6", - "@typescript-eslint/visitor-keys": "5.59.6" + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/visitor-keys": "5.62.0" }, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" @@ -1656,13 +1323,13 @@ } }, "node_modules/@typescript-eslint/type-utils": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-5.59.6.tgz", - "integrity": "sha512-A4tms2Mp5yNvLDlySF+kAThV9VTBPCvGf0Rp8nl/eoDX9Okun8byTKoj3fJ52IJitjWOk0fKPNQhXEB++eNozQ==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-5.62.0.tgz", + "integrity": "sha512-xsSQreu+VnfbqQpW5vnCJdq1Z3Q0U31qiWmRhr98ONQmcp/yhiPJFPq8MXiJVLiksmOKSjIldZzkebzHuCGzew==", "dev": true, "dependencies": { - "@typescript-eslint/typescript-estree": "5.59.6", - "@typescript-eslint/utils": "5.59.6", + "@typescript-eslint/typescript-estree": "5.62.0", + "@typescript-eslint/utils": "5.62.0", "debug": "^4.3.4", "tsutils": "^3.21.0" }, @@ -1683,9 +1350,9 @@ } }, "node_modules/@typescript-eslint/types": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-5.59.6.tgz", - "integrity": "sha512-tH5lBXZI7T2MOUgOWFdVNUILsI02shyQvfzG9EJkoONWugCG77NDDa1EeDGw7oJ5IvsTAAGVV8I3Tk2PNu9QfA==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-5.62.0.tgz", + "integrity": "sha512-87NVngcbVXUahrRTqIK27gD2t5Cu1yuCXxbLcFtCzZGlfyVWWh8mLHkoxzjsB6DDNnvdL+fW8MiwPEJyGJQDgQ==", "dev": true, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" @@ -1696,13 +1363,13 @@ } }, "node_modules/@typescript-eslint/typescript-estree": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-5.59.6.tgz", - "integrity": "sha512-vW6JP3lMAs/Tq4KjdI/RiHaaJSO7IUsbkz17it/Rl9Q+WkQ77EOuOnlbaU8kKfVIOJxMhnRiBG+olE7f3M16DA==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-5.62.0.tgz", + "integrity": "sha512-CmcQ6uY7b9y694lKdRB8FEel7JbU/40iSAPomu++SjLMntB+2Leay2LO6i8VnJk58MtE9/nQSFIH6jpyRWyYzA==", "dev": true, "dependencies": { - "@typescript-eslint/types": "5.59.6", - "@typescript-eslint/visitor-keys": "5.59.6", + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/visitor-keys": "5.62.0", "debug": "^4.3.4", "globby": "^11.1.0", "is-glob": "^4.0.3", @@ -1723,17 +1390,17 @@ } }, "node_modules/@typescript-eslint/utils": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-5.59.6.tgz", - "integrity": "sha512-vzaaD6EXbTS29cVH0JjXBdzMt6VBlv+hE31XktDRMX1j3462wZCJa7VzO2AxXEXcIl8GQqZPcOPuW/Z1tZVogg==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-5.62.0.tgz", + "integrity": "sha512-n8oxjeb5aIbPFEtmQxQYOLI0i9n5ySBEY/ZEHHZqKQSFnxio1rv6dthascc9dLuwrL0RC5mPCxB7vnAVGAYWAQ==", "dev": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@types/json-schema": "^7.0.9", "@types/semver": "^7.3.12", - "@typescript-eslint/scope-manager": "5.59.6", - "@typescript-eslint/types": "5.59.6", - "@typescript-eslint/typescript-estree": "5.59.6", + "@typescript-eslint/scope-manager": "5.62.0", + "@typescript-eslint/types": "5.62.0", + "@typescript-eslint/typescript-estree": "5.62.0", "eslint-scope": "^5.1.1", "semver": "^7.3.7" }, @@ -1749,12 +1416,12 @@ } }, "node_modules/@typescript-eslint/visitor-keys": { - "version": "5.59.6", - "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-5.59.6.tgz", - "integrity": "sha512-zEfbFLzB9ETcEJ4HZEEsCR9HHeNku5/Qw1jSS5McYJv5BR+ftYXwFFAH5Al+xkGaZEqowMwl7uoJjQb1YSPF8Q==", + "version": "5.62.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-5.62.0.tgz", + "integrity": "sha512-07ny+LHRzQXepkGg6w0mFY41fVUNBrL2Roj/++7V1txKugfjm/Ci/qSND03r2RhlJhJYMcTn9AhhSSqQp0Ysyw==", "dev": true, "dependencies": { - "@typescript-eslint/types": "5.59.6", + "@typescript-eslint/types": "5.62.0", "eslint-visitor-keys": "^3.3.0" }, "engines": { @@ -1765,22 +1432,29 @@ "url": "https://opencollective.com/typescript-eslint" } }, + "node_modules/@ungap/structured-clone": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.0.tgz", + "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", + "dev": true + }, "node_modules/@webgpu/types": { - "version": "0.1.46", - "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.46.tgz", - "integrity": "sha512-2iogO6Zh0pTbKLGZuuGWEmJpF/fTABGs7G9wXxpn7s24XSJchSUIiMqIJHURi5zsMZRRTuXrV/3GLOkmOFjq5w==", + "version": "0.1.66", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.66.tgz", + "integrity": "sha512-YA2hLrwLpDsRueNDXIMqN9NTzD6bCDkuXbOSe0heS+f8YE8usA6Gbv1prj81pzVHrbaAma7zObnIC+I6/sXJgA==", "dev": true }, "node_modules/abab": { "version": "2.0.6", "resolved": "https://registry.npmjs.org/abab/-/abab-2.0.6.tgz", "integrity": "sha512-j2afSsaIENvHZN2B8GOpF566vZ5WVk5opAiMTvWgaQT8DkbOqsTfvNAvHoRGU2zzP8cPoqys+xHTRDWW8L+/BA==", + "deprecated": "Use your platform's native atob() and btoa() methods instead", "dev": true }, "node_modules/acorn": { - "version": "7.4.1", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-7.4.1.tgz", - "integrity": "sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==", + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "bin": { "acorn": "bin/acorn" @@ -1799,6 +1473,18 @@ "acorn-walk": "^7.1.1" } }, + "node_modules/acorn-globals/node_modules/acorn": { + "version": "7.4.1", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-7.4.1.tgz", + "integrity": "sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==", + "dev": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/acorn-jsx": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", @@ -1882,21 +1568,24 @@ } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.3.tgz", + "integrity": "sha512-+fksAx9eG3Ab6LDnLs3ZqZa8KVJ/jYnX+D4Qe1azX+LFGFAXqynCQLOdLpNYN/l9e7l6hMWwZbrnctqr6eSQSw==", "dev": true }, "node_modules/ansi-styles": { - "version": "3.2.1", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", - "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", "dev": true, "dependencies": { - "color-convert": "^1.9.0" + "color-convert": "^2.0.1" }, "engines": { - "node": ">=4" + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" } }, "node_modules/anymatch": { @@ -1913,13 +1602,10 @@ } }, "node_modules/argparse": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", - "integrity": "sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==", - "dev": true, - "dependencies": { - "sprintf-js": "~1.0.2" - } + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "dev": true }, "node_modules/arr-diff": { "version": "4.0.0", @@ -1993,6 +1679,14 @@ "node": ">= 4.5.0" } }, + "node_modules/audit": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/audit/-/audit-0.0.6.tgz", + "integrity": "sha512-xgv3Y3RIYE00N2/xk10VLlwFd1kjc7FRaX1vC8+CsOfDRe53a06vOSkp91BOSNijZfddYum47a1Fvju/2+JPcw==", + "engines": { + "node": ">= 0.5.0" + } + }, "node_modules/babel-jest": { "version": "26.6.3", "resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-26.6.3.tgz", @@ -2015,76 +1709,6 @@ "@babel/core": "^7.0.0" } }, - "node_modules/babel-jest/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/babel-jest/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/babel-jest/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/babel-jest/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/babel-jest/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/babel-jest/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/babel-plugin-istanbul": { "version": "6.1.1", "resolved": "https://registry.npmjs.org/babel-plugin-istanbul/-/babel-plugin-istanbul-6.1.1.tgz", @@ -2118,9 +1742,9 @@ } }, "node_modules/babel-plugin-istanbul/node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, "bin": { "semver": "bin/semver.js" @@ -2142,26 +1766,29 @@ } }, "node_modules/babel-preset-current-node-syntax": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz", - "integrity": "sha512-M7LQ0bxarkxQoN+vz5aJPsLBn77n8QgTFmo8WK0/44auK2xlCXrYcUxHFxgU7qW5Yzw/CjmLRK2uJzaCd7LvqQ==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.2.0.tgz", + "integrity": "sha512-E/VlAEzRrsLEb2+dv8yp3bo4scof3l9nR4lrld+Iy5NyVqgVYUJnDAmunkhPMisRI32Qc4iRiz425d8vM++2fg==", "dev": true, "dependencies": { "@babel/plugin-syntax-async-generators": "^7.8.4", "@babel/plugin-syntax-bigint": "^7.8.3", - "@babel/plugin-syntax-class-properties": "^7.8.3", - "@babel/plugin-syntax-import-meta": "^7.8.3", + "@babel/plugin-syntax-class-properties": "^7.12.13", + "@babel/plugin-syntax-class-static-block": "^7.14.5", + "@babel/plugin-syntax-import-attributes": "^7.24.7", + "@babel/plugin-syntax-import-meta": "^7.10.4", "@babel/plugin-syntax-json-strings": "^7.8.3", - "@babel/plugin-syntax-logical-assignment-operators": "^7.8.3", + "@babel/plugin-syntax-logical-assignment-operators": "^7.10.4", "@babel/plugin-syntax-nullish-coalescing-operator": "^7.8.3", - "@babel/plugin-syntax-numeric-separator": "^7.8.3", + "@babel/plugin-syntax-numeric-separator": "^7.10.4", "@babel/plugin-syntax-object-rest-spread": "^7.8.3", "@babel/plugin-syntax-optional-catch-binding": "^7.8.3", "@babel/plugin-syntax-optional-chaining": "^7.8.3", - "@babel/plugin-syntax-top-level-await": "^7.8.3" + "@babel/plugin-syntax-private-property-in-object": "^7.14.5", + "@babel/plugin-syntax-top-level-await": "^7.14.5" }, "peerDependencies": { - "@babel/core": "^7.0.0" + "@babel/core": "^7.0.0 || ^8.0.0-0" } }, "node_modules/babel-preset-jest": { @@ -2216,48 +1843,19 @@ "node": ">=0.10.0" } }, - "node_modules/base/node_modules/is-accessor-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-accessor-descriptor/-/is-accessor-descriptor-1.0.0.tgz", - "integrity": "sha512-m5hnHTkcVsPfqx3AKlyttIPb7J+XykHvJP2B9bZDjlhLIoEq4XoK64Vg7boZlVWYK6LUY94dYPEE7Lh0ZkZKcQ==", + "node_modules/baseline-browser-mapping": { + "version": "2.8.30", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.8.30.tgz", + "integrity": "sha512-aTUKW4ptQhS64+v2d6IkPzymEzzhw+G0bA1g3uBRV3+ntkH+svttKseW5IOR4Ed6NUVKqnY7qT3dKvzQ7io4AA==", "dev": true, - "dependencies": { - "kind-of": "^6.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/base/node_modules/is-data-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-data-descriptor/-/is-data-descriptor-1.0.0.tgz", - "integrity": "sha512-jbRXy1FmtAoCjQkVmIVYwuuqDFUbaOeDjmed1tOGPrsMhtJA4rD9tkgA0F1qJ3gRFRXcHYVkdeaP50Q5rE/jLQ==", - "dev": true, - "dependencies": { - "kind-of": "^6.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/base/node_modules/is-descriptor": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-1.0.2.tgz", - "integrity": "sha512-2eis5WqQGV7peooDyLmNEPUrps9+SXX5c9pL3xEB+4e9HnGuDa7mB7kHxHw4CbqS9k1T2hOH3miL8n8WtiYVtg==", - "dev": true, - "dependencies": { - "is-accessor-descriptor": "^1.0.0", - "is-data-descriptor": "^1.0.0", - "kind-of": "^6.0.2" - }, - "engines": { - "node": ">=0.10.0" + "bin": { + "baseline-browser-mapping": "dist/cli.js" } }, "node_modules/brace-expansion": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", - "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "dependencies": { "balanced-match": "^1.0.0", @@ -2265,12 +1863,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -2283,9 +1881,9 @@ "dev": true }, "node_modules/browserslist": { - "version": "4.21.4", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.4.tgz", - "integrity": "sha512-CBHJJdDmgjl3daYjN5Cp5kbTf1mUhZoS+beLklHIvkOWscs83YAhLlF3Wsh/lciQYAcbBJgTOD44VtG31ZM4Hw==", + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.28.0.tgz", + "integrity": "sha512-tbydkR/CxfMwelN0vwdP/pLkDwyAASZ+VfWm4EOwlB6SWhx1sYnWLqo8N5j0rAzPfzfRaxt0mM/4wPU/Su84RQ==", "dev": true, "funding": [ { @@ -2295,13 +1893,18 @@ { "type": "tidelift", "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" } ], "dependencies": { - "caniuse-lite": "^1.0.30001400", - "electron-to-chromium": "^1.4.251", - "node-releases": "^2.0.6", - "update-browserslist-db": "^1.0.9" + "baseline-browser-mapping": "^2.8.25", + "caniuse-lite": "^1.0.30001754", + "electron-to-chromium": "^1.5.249", + "node-releases": "^2.0.27", + "update-browserslist-db": "^1.1.4" }, "bin": { "browserslist": "cli.js" @@ -2357,6 +1960,19 @@ "node": ">=0.10.0" } }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "dev": true, + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/callsites": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", @@ -2376,9 +1992,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001434", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001434.tgz", - "integrity": "sha512-aOBHrLmTQw//WFa2rcF1If9fa3ypkC1wzqqiKHgfdrXTWcU8C4gKVZT77eQAPWN1APys3+uQ0Df07rKauXGEYA==", + "version": "1.0.30001756", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001756.tgz", + "integrity": "sha512-4HnCNKbMLkLdhJz3TToeVWHSnfJvPaq6vu/eRP0Ahub/07n484XHhBF5AJoSGHdVrS8tKFauUQz8Bp9P7LVx7A==", "dev": true, "funding": [ { @@ -2388,6 +2004,10 @@ { "type": "tidelift", "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" } ] }, @@ -2404,17 +2024,19 @@ } }, "node_modules/chalk": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz", - "integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==", + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", "dev": true, "dependencies": { - "ansi-styles": "^3.2.1", - "escape-string-regexp": "^1.0.5", - "supports-color": "^5.3.0" + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" }, "engines": { - "node": ">=4" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" } }, "node_modules/char-regex": { @@ -2465,6 +2087,19 @@ "node": ">=0.10.0" } }, + "node_modules/class-utils/node_modules/is-descriptor": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-0.1.7.tgz", + "integrity": "sha512-C3grZTvObeN1xud4cRWl366OMXZTj0+HGyk4hvfpx4ZHt1Pb60ANSXqCK7pdOTeUQpRzECBSTphqvD7U+l22Eg==", + "dev": true, + "dependencies": { + "is-accessor-descriptor": "^1.0.1", + "is-data-descriptor": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/cliui": { "version": "6.0.0", "resolved": "https://registry.npmjs.org/cliui/-/cliui-6.0.0.tgz", @@ -2487,9 +2122,9 @@ } }, "node_modules/collect-v8-coverage": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/collect-v8-coverage/-/collect-v8-coverage-1.0.1.tgz", - "integrity": "sha512-iBPtljfCNcTKNAto0KEtDfZ3qzjJvqE3aTGZsbhjSBlorqpXJlaWWtPO35D+ZImoC3KWejX64o+yPGxhWSTzfg==", + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/collect-v8-coverage/-/collect-v8-coverage-1.0.3.tgz", + "integrity": "sha512-1L5aqIkwPfiodaMgQunkF1zRhNqifHBmtbbbxcr6yVxxBnliw4TDOW6NxpO8DJLgJ16OT+Y4ztZqP6p/FtXnAw==", "dev": true }, "node_modules/collection-visit": { @@ -2506,18 +2141,21 @@ } }, "node_modules/color-convert": { - "version": "1.9.3", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz", - "integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", "dev": true, "dependencies": { - "color-name": "1.1.3" + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" } }, "node_modules/color-name": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz", - "integrity": "sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==", + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", "dev": true }, "node_modules/combined-stream": { @@ -2539,10 +2177,13 @@ "dev": true }, "node_modules/component-emitter": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/component-emitter/-/component-emitter-1.3.0.tgz", - "integrity": "sha512-Rd3se6QB+sO1TwqZjscQrurpEPIfO0/yYnSin6Q/rD3mOutHvUrCAhJub3r90uNb+SESBuE0QYoB90YdfatsRg==", - "dev": true + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/component-emitter/-/component-emitter-1.3.1.tgz", + "integrity": "sha512-T0+barUSQRTUQASh8bx02dl+DhF54GtIDY13Y3m9oWTklKbb3Wv974meRpeZ3lp1JpLVECWWNHC4vaG2XHXouQ==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } }, "node_modules/concat-map": { "version": "0.0.1", @@ -2566,28 +2207,17 @@ } }, "node_modules/cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dev": true, "dependencies": { - "nice-try": "^1.0.4", - "path-key": "^2.0.1", - "semver": "^5.5.0", - "shebang-command": "^1.2.0", - "which": "^1.2.9" + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" }, "engines": { - "node": ">=4.8" - } - }, - "node_modules/cross-spawn/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", - "dev": true, - "bin": { - "semver": "bin/semver" + "node": ">= 8" } }, "node_modules/cssom": { @@ -2629,12 +2259,12 @@ } }, "node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", "dev": true, "dependencies": { - "ms": "2.1.2" + "ms": "^2.1.3" }, "engines": { "node": ">=6.0" @@ -2655,9 +2285,9 @@ } }, "node_modules/decimal.js": { - "version": "10.4.2", - "resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.4.2.tgz", - "integrity": "sha512-ic1yEvwT6GuvaYwBLLY6/aFFgjZdySKTE8en/fkU3QICTmRtgtSlFn0u0BXN06InZwtfCelR7j8LRiDI/02iGA==", + "version": "10.6.0", + "resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.6.0.tgz", + "integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==", "dev": true }, "node_modules/decode-uri-component": { @@ -2676,9 +2306,9 @@ "dev": true }, "node_modules/deepmerge": { - "version": "4.2.2", - "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.2.2.tgz", - "integrity": "sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg==", + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz", + "integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==", "dev": true, "engines": { "node": ">=0.10.0" @@ -2697,44 +2327,6 @@ "node": ">=0.10.0" } }, - "node_modules/define-property/node_modules/is-accessor-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-accessor-descriptor/-/is-accessor-descriptor-1.0.0.tgz", - "integrity": "sha512-m5hnHTkcVsPfqx3AKlyttIPb7J+XykHvJP2B9bZDjlhLIoEq4XoK64Vg7boZlVWYK6LUY94dYPEE7Lh0ZkZKcQ==", - "dev": true, - "dependencies": { - "kind-of": "^6.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/define-property/node_modules/is-data-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-data-descriptor/-/is-data-descriptor-1.0.0.tgz", - "integrity": "sha512-jbRXy1FmtAoCjQkVmIVYwuuqDFUbaOeDjmed1tOGPrsMhtJA4rD9tkgA0F1qJ3gRFRXcHYVkdeaP50Q5rE/jLQ==", - "dev": true, - "dependencies": { - "kind-of": "^6.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/define-property/node_modules/is-descriptor": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-1.0.2.tgz", - "integrity": "sha512-2eis5WqQGV7peooDyLmNEPUrps9+SXX5c9pL3xEB+4e9HnGuDa7mB7kHxHw4CbqS9k1T2hOH3miL8n8WtiYVtg==", - "dev": true, - "dependencies": { - "is-accessor-descriptor": "^1.0.0", - "is-data-descriptor": "^1.0.0", - "kind-of": "^6.0.2" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/delayed-stream": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", @@ -2790,6 +2382,7 @@ "version": "2.0.1", "resolved": "https://registry.npmjs.org/domexception/-/domexception-2.0.1.tgz", "integrity": "sha512-yxJ2mFy/sibVQlu5qHjOkf9J3K6zgmCxgJ94u2EdvDOV09H+32LtRswEcUsmUWN72pVLOEnTSRaIVVzVQgS0dg==", + "deprecated": "Use your platform's native DOMException instead", "dev": true, "dependencies": { "webidl-conversions": "^5.0.0" @@ -2807,10 +2400,24 @@ "node": ">=8" } }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "dev": true, + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/electron-to-chromium": { - "version": "1.4.284", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.284.tgz", - "integrity": "sha512-M8WEXFuKXMYMVr45fo8mq0wUrrJHheiKZf6BArTKk9ZBYCKJEOU5H8cdWgDT+qCVZf7Na4lVUaZsA+h6uA9+PA==", + "version": "1.5.259", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.259.tgz", + "integrity": "sha512-I+oLXgpEJzD6Cwuwt1gYjxsDmu/S/Kd41mmLA3O+/uH2pFRO/DvOjUyGozL8j3KeLV6WyZ7ssPwELMsXCcsJAQ==", "dev": true }, "node_modules/emittery": { @@ -2832,51 +2439,98 @@ "dev": true }, "node_modules/end-of-stream": { - "version": "1.4.4", - "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz", - "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==", + "version": "1.4.5", + "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.5.tgz", + "integrity": "sha512-ooEGc6HP26xXq/N+GCGOT0JKCLDGrq2bQUZrQ7gyrJiZANJ/8YDTxTpQBXGMn+WbIQXNVpyWymm7KYVICQnyOg==", "dev": true, "dependencies": { "once": "^1.4.0" } }, "node_modules/error-ex": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.2.tgz", - "integrity": "sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g==", + "version": "1.3.4", + "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.4.tgz", + "integrity": "sha512-sqQamAnR14VgCr1A618A3sGrygcpK+HEbenA/HiEAkkUwcZIIB/tgWqHFxWgOyDh4nB4JCRimh79dR5Ywc9MDQ==", "dev": true, "dependencies": { "is-arrayish": "^0.2.1" } }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "dev": true, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "dev": true, + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "dev": true, + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/escalade": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", - "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "dev": true, "engines": { "node": ">=6" } }, "node_modules/escape-string-regexp": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", - "integrity": "sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==", + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", "dev": true, "engines": { - "node": ">=0.8.0" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/escodegen": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/escodegen/-/escodegen-2.0.0.tgz", - "integrity": "sha512-mmHKys/C8BFUGI+MAWNcSYoORYLMdPzjrknd2Vc+bUsjN5bXcr8EhrNB+UTqfL1y3I9c4fw2ihgtMPQLBRiQxw==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/escodegen/-/escodegen-2.1.0.tgz", + "integrity": "sha512-2NlIDTwUWJN0mRPQOdtQBzbUHvdGY2P1VXSyU83Q3xKxM7WHX2Ql8dKq782Q9TgQUNOLEzEYu9bzLNj1q88I5w==", "dev": true, "dependencies": { "esprima": "^4.0.1", "estraverse": "^5.2.0", - "esutils": "^2.0.2", - "optionator": "^0.8.1" + "esutils": "^2.0.2" }, "bin": { "escodegen": "bin/escodegen.js", @@ -2899,27 +2553,29 @@ } }, "node_modules/eslint": { - "version": "8.41.0", - "resolved": "https://registry.npmjs.org/eslint/-/eslint-8.41.0.tgz", - "integrity": "sha512-WQDQpzGBOP5IrXPo4Hc0814r4/v2rrIsB0rhT7jtunIalgg6gYXWhRMOejVO8yH21T/FGaxjmFjBMNqcIlmH1Q==", + "version": "8.57.1", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-8.57.1.tgz", + "integrity": "sha512-ypowyDxpVSYpkXr9WPv2PAZCtNip1Mv5KTW0SCurXv/9iOpcrH9PaqUElksqEB6pChqHGDRCFTyrZlGhnLNGiA==", + "deprecated": "This version is no longer supported. Please see https://eslint.org/version-support for other options.", "dev": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", - "@eslint-community/regexpp": "^4.4.0", - "@eslint/eslintrc": "^2.0.3", - "@eslint/js": "8.41.0", - "@humanwhocodes/config-array": "^0.11.8", + "@eslint-community/regexpp": "^4.6.1", + "@eslint/eslintrc": "^2.1.4", + "@eslint/js": "8.57.1", + "@humanwhocodes/config-array": "^0.13.0", "@humanwhocodes/module-importer": "^1.0.1", "@nodelib/fs.walk": "^1.2.8", - "ajv": "^6.10.0", + "@ungap/structured-clone": "^1.2.0", + "ajv": "^6.12.4", "chalk": "^4.0.0", "cross-spawn": "^7.0.2", "debug": "^4.3.2", "doctrine": "^3.0.0", "escape-string-regexp": "^4.0.0", - "eslint-scope": "^7.2.0", - "eslint-visitor-keys": "^3.4.1", - "espree": "^9.5.2", + "eslint-scope": "^7.2.2", + "eslint-visitor-keys": "^3.4.3", + "espree": "^9.6.1", "esquery": "^1.4.2", "esutils": "^2.0.2", "fast-deep-equal": "^3.1.3", @@ -2929,7 +2585,6 @@ "globals": "^13.19.0", "graphemer": "^1.4.0", "ignore": "^5.2.0", - "import-fresh": "^3.0.0", "imurmurhash": "^0.1.4", "is-glob": "^4.0.0", "is-path-inside": "^3.0.3", @@ -2939,9 +2594,8 @@ "lodash.merge": "^4.6.2", "minimatch": "^3.1.2", "natural-compare": "^1.4.0", - "optionator": "^0.9.1", + "optionator": "^0.9.3", "strip-ansi": "^6.0.1", - "strip-json-comments": "^3.1.0", "text-table": "^0.2.0" }, "bin": { @@ -2968,9 +2622,9 @@ } }, "node_modules/eslint-visitor-keys": { - "version": "3.4.1", - "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.1.tgz", - "integrity": "sha512-pZnmmLwYzf+kWaM/Qgrvpen51upAktaaiI01nsJD/Yr3lMOdNtq0cxkrrg16w64VtisN6okbs7Q8AfGqj4c9fA==", + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", "dev": true, "engines": { "node": "^12.22.0 || ^14.17.0 || >=16.0.0" @@ -2979,104 +2633,95 @@ "url": "https://opencollective.com/eslint" } }, - "node_modules/eslint/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/eslint/node_modules/eslint-scope": { + "version": "7.2.2", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", + "integrity": "sha512-dOt21O7lTMhDM+X9mB4GX+DZrZtCUJPL/wlcTqxyrx5IvO0IYtILdtrQGQp+8n5S0gwSVmOf9NQrjMOgfQZlIg==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "esrecurse": "^4.3.0", + "estraverse": "^5.2.0" }, "engines": { - "node": ">=8" + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" }, "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "url": "https://opencollective.com/eslint" } }, - "node_modules/eslint/node_modules/argparse": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", - "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", - "dev": true + "node_modules/eslint/node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true, + "engines": { + "node": ">=4.0" + } }, - "node_modules/eslint/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/espree": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/espree/-/espree-9.6.1.tgz", + "integrity": "sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" + "acorn": "^8.9.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^3.4.1" }, "engines": { - "node": ">=10" + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" }, "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "url": "https://opencollective.com/eslint" } }, - "node_modules/eslint/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/esprima": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", + "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", "dev": true, - "dependencies": { - "color-name": "~1.1.4" + "bin": { + "esparse": "bin/esparse.js", + "esvalidate": "bin/esvalidate.js" }, "engines": { - "node": ">=7.0.0" + "node": ">=4" } }, - "node_modules/eslint/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/eslint/node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "node_modules/esquery": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.6.0.tgz", + "integrity": "sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg==", "dev": true, "dependencies": { - "path-key": "^3.1.0", - "shebang-command": "^2.0.0", - "which": "^2.0.1" + "estraverse": "^5.1.0" }, "engines": { - "node": ">= 8" + "node": ">=0.10" } }, - "node_modules/eslint/node_modules/escape-string-regexp": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", - "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "node_modules/esquery/node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", "dev": true, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=4.0" } }, - "node_modules/eslint/node_modules/eslint-scope": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.0.tgz", - "integrity": "sha512-DYj5deGlHBfMt15J7rdtyKNq/Nqlv5KfU4iodrQ019XESsRnwXH9KAE0y3cwtUHDo2ob7CypAnCqefh6vioWRw==", + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", "dev": true, "dependencies": { - "esrecurse": "^4.3.0", "estraverse": "^5.2.0" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint" + "node": ">=4.0" } }, - "node_modules/eslint/node_modules/estraverse": { + "node_modules/esrecurse/node_modules/estraverse": { "version": "5.3.0", "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", @@ -3085,836 +2730,893 @@ "node": ">=4.0" } }, - "node_modules/eslint/node_modules/find-up": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", - "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", + "node_modules/estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", "dev": true, - "dependencies": { - "locate-path": "^6.0.0", - "path-exists": "^4.0.0" - }, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=4.0" } }, - "node_modules/eslint/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/estree-walker": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", + "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", + "dev": true + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", "dev": true, "engines": { - "node": ">=8" + "node": ">=0.10.0" } }, - "node_modules/eslint/node_modules/js-yaml": { + "node_modules/exec-sh": { + "version": "0.3.6", + "resolved": "https://registry.npmjs.org/exec-sh/-/exec-sh-0.3.6.tgz", + "integrity": "sha512-nQn+hI3yp+oD0huYhKwvYI32+JFeq+XkNcD1GAo3Y/MjxsfVGmrrzrnzjWiNY6f+pUCP440fThsFh5gZrRAU/w==", + "dev": true + }, + "node_modules/execa": { "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "resolved": "https://registry.npmjs.org/execa/-/execa-4.1.0.tgz", + "integrity": "sha512-j5W0//W7f8UxAn8hXVnwG8tLwdiUy4FJLcSupCg6maBYZDpyBvTApK7KyuI4bKj8KOh1r2YH+6ucuYtJv1bTZA==", "dev": true, "dependencies": { - "argparse": "^2.0.1" + "cross-spawn": "^7.0.0", + "get-stream": "^5.0.0", + "human-signals": "^1.1.1", + "is-stream": "^2.0.0", + "merge-stream": "^2.0.0", + "npm-run-path": "^4.0.0", + "onetime": "^5.1.0", + "signal-exit": "^3.0.2", + "strip-final-newline": "^2.0.0" }, - "bin": { - "js-yaml": "bin/js-yaml.js" + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sindresorhus/execa?sponsor=1" } }, - "node_modules/eslint/node_modules/levn": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", - "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "node_modules/exit": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/exit/-/exit-0.1.2.tgz", + "integrity": "sha512-Zk/eNKV2zbjpKzrsQ+n1G6poVbErQxJ0LBOJXaKZ1EViLzH+hrLu9cdXI4zw9dBQJslwBEpbQ2P1oS7nDxs6jQ==", "dev": true, - "dependencies": { - "prelude-ls": "^1.2.1", - "type-check": "~0.4.0" - }, "engines": { "node": ">= 0.8.0" } }, - "node_modules/eslint/node_modules/locate-path": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", - "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", + "node_modules/expand-brackets": { + "version": "2.1.4", + "resolved": "https://registry.npmjs.org/expand-brackets/-/expand-brackets-2.1.4.tgz", + "integrity": "sha512-w/ozOKR9Obk3qoWeY/WDi6MFta9AoMR+zud60mdnbniMcBxRuFJyDt2LdX/14A1UABeqk+Uk+LDfUpvoGKppZA==", "dev": true, "dependencies": { - "p-locate": "^5.0.0" + "debug": "^2.3.3", + "define-property": "^0.2.5", + "extend-shallow": "^2.0.1", + "posix-character-classes": "^0.1.0", + "regex-not": "^1.0.0", + "snapdragon": "^0.8.1", + "to-regex": "^3.0.1" }, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=0.10.0" } }, - "node_modules/eslint/node_modules/optionator": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.1.tgz", - "integrity": "sha512-74RlY5FCnhq4jRxVUPKDaRwrVNXMqsGsiW6AJw4XK8hmtm10wC0ypZBLw5IIp85NZMr91+qd1RvvENwg7jjRFw==", + "node_modules/expand-brackets/node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, "dependencies": { - "deep-is": "^0.1.3", - "fast-levenshtein": "^2.0.6", - "levn": "^0.4.1", - "prelude-ls": "^1.2.1", - "type-check": "^0.4.0", - "word-wrap": "^1.2.3" - }, - "engines": { - "node": ">= 0.8.0" + "ms": "2.0.0" } }, - "node_modules/eslint/node_modules/p-limit": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", - "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "node_modules/expand-brackets/node_modules/define-property": { + "version": "0.2.5", + "resolved": "https://registry.npmjs.org/define-property/-/define-property-0.2.5.tgz", + "integrity": "sha512-Rr7ADjQZenceVOAKop6ALkkRAmH1A4Gx9hV/7ZujPUN2rkATqFO0JZLZInbAjpZYoJ1gUx8MRMQVkYemcbMSTA==", "dev": true, "dependencies": { - "yocto-queue": "^0.1.0" + "is-descriptor": "^0.1.0" }, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=0.10.0" } }, - "node_modules/eslint/node_modules/p-locate": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", - "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", + "node_modules/expand-brackets/node_modules/extend-shallow": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/extend-shallow/-/extend-shallow-2.0.1.tgz", + "integrity": "sha512-zCnTtlxNoAiDc3gqY2aYAWFx7XWWiasuF2K8Me5WbN8otHKTUKBwjPtNpRs/rbUZm7KxWAaNj7P1a/p52GbVug==", "dev": true, "dependencies": { - "p-limit": "^3.0.2" + "is-extendable": "^0.1.0" }, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=0.10.0" } }, - "node_modules/eslint/node_modules/path-key": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", - "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "node_modules/expand-brackets/node_modules/is-descriptor": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-0.1.7.tgz", + "integrity": "sha512-C3grZTvObeN1xud4cRWl366OMXZTj0+HGyk4hvfpx4ZHt1Pb60ANSXqCK7pdOTeUQpRzECBSTphqvD7U+l22Eg==", "dev": true, + "dependencies": { + "is-accessor-descriptor": "^1.0.1", + "is-data-descriptor": "^1.0.1" + }, "engines": { - "node": ">=8" + "node": ">= 0.4" } }, - "node_modules/eslint/node_modules/prelude-ls": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", - "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "node_modules/expand-brackets/node_modules/is-extendable": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-0.1.1.tgz", + "integrity": "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==", "dev": true, "engines": { - "node": ">= 0.8.0" + "node": ">=0.10.0" } }, - "node_modules/eslint/node_modules/shebang-command": { + "node_modules/expand-brackets/node_modules/ms": { "version": "2.0.0", - "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", - "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", + "dev": true + }, + "node_modules/expect": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/expect/-/expect-26.6.2.tgz", + "integrity": "sha512-9/hlOBkQl2l/PLHJx6JjoDF6xPKcJEsUlWKb23rKE7KzeDqUZKXKNMW27KIue5JMdBV9HgmoJPcc8HtO85t9IA==", "dev": true, "dependencies": { - "shebang-regex": "^3.0.0" + "@jest/types": "^26.6.2", + "ansi-styles": "^4.0.0", + "jest-get-type": "^26.3.0", + "jest-matcher-utils": "^26.6.2", + "jest-message-util": "^26.6.2", + "jest-regex-util": "^26.0.0" }, "engines": { - "node": ">=8" - } - }, - "node_modules/eslint/node_modules/shebang-regex": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", - "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", - "dev": true, - "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/eslint/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/extend-shallow": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/extend-shallow/-/extend-shallow-3.0.2.tgz", + "integrity": "sha512-BwY5b5Ql4+qZoefgMj2NUmx+tehVTH/Kf4k1ZEtOHNFcm2wSxMRo992l6X3TIgni2eZVTZ85xMOjF31fwZAj6Q==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "assign-symbols": "^1.0.0", + "is-extendable": "^1.0.1" }, "engines": { - "node": ">=8" + "node": ">=0.10.0" } }, - "node_modules/eslint/node_modules/type-check": { - "version": "0.4.0", - "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", - "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "node_modules/extglob": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/extglob/-/extglob-2.0.4.tgz", + "integrity": "sha512-Nmb6QXkELsuBr24CJSkilo6UHHgbekK5UiZgfE6UHD3Eb27YC6oD+bhcT+tJ6cl8dmsgdQxnWlcry8ksBIBLpw==", "dev": true, "dependencies": { - "prelude-ls": "^1.2.1" + "array-unique": "^0.3.2", + "define-property": "^1.0.0", + "expand-brackets": "^2.1.4", + "extend-shallow": "^2.0.1", + "fragment-cache": "^0.2.1", + "regex-not": "^1.0.0", + "snapdragon": "^0.8.1", + "to-regex": "^3.0.1" }, "engines": { - "node": ">= 0.8.0" + "node": ">=0.10.0" } }, - "node_modules/eslint/node_modules/which": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", - "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "node_modules/extglob/node_modules/define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/define-property/-/define-property-1.0.0.tgz", + "integrity": "sha512-cZTYKFWspt9jZsMscWo8sc/5lbPC9Q0N5nBLgb+Yd915iL3udB1uFgS3B8YCx66UVHq018DAVFoee7x+gxggeA==", "dev": true, "dependencies": { - "isexe": "^2.0.0" - }, - "bin": { - "node-which": "bin/node-which" + "is-descriptor": "^1.0.0" }, "engines": { - "node": ">= 8" + "node": ">=0.10.0" } }, - "node_modules/espree": { - "version": "9.5.2", - "resolved": "https://registry.npmjs.org/espree/-/espree-9.5.2.tgz", - "integrity": "sha512-7OASN1Wma5fum5SrNhFMAMJxOUAbhyfQ8dQ//PJaJbNw0URTPWqIghHWt1MmAANKhHZIYOHruW4Kw4ruUWOdGw==", + "node_modules/extglob/node_modules/extend-shallow": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/extend-shallow/-/extend-shallow-2.0.1.tgz", + "integrity": "sha512-zCnTtlxNoAiDc3gqY2aYAWFx7XWWiasuF2K8Me5WbN8otHKTUKBwjPtNpRs/rbUZm7KxWAaNj7P1a/p52GbVug==", "dev": true, "dependencies": { - "acorn": "^8.8.0", - "acorn-jsx": "^5.3.2", - "eslint-visitor-keys": "^3.4.1" + "is-extendable": "^0.1.0" }, "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint" + "node": ">=0.10.0" } }, - "node_modules/espree/node_modules/acorn": { - "version": "8.8.2", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz", - "integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==", + "node_modules/extglob/node_modules/is-extendable": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-0.1.1.tgz", + "integrity": "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==", "dev": true, - "bin": { - "acorn": "bin/acorn" - }, "engines": { - "node": ">=0.4.0" + "node": ">=0.10.0" } }, - "node_modules/esprima": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", - "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true + }, + "node_modules/fast-glob": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", + "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", "dev": true, - "bin": { - "esparse": "bin/esparse.js", - "esvalidate": "bin/esvalidate.js" + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.8" }, "engines": { - "node": ">=4" + "node": ">=8.6.0" } }, - "node_modules/esquery": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.5.0.tgz", - "integrity": "sha512-YQLXUplAwJgCydQ78IMJywZCceoqk1oH01OERdSAJc/7U2AylwjhSCLDEtqwg811idIS/9fIU5GjG73IgjKMVg==", + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", "dev": true, "dependencies": { - "estraverse": "^5.1.0" + "is-glob": "^4.0.1" }, "engines": { - "node": ">=0.10" + "node": ">= 6" } }, - "node_modules/esquery/node_modules/estraverse": { - "version": "5.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", - "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "dev": true + }, + "node_modules/fastq": { + "version": "1.19.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.1.tgz", + "integrity": "sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ==", "dev": true, - "engines": { - "node": ">=4.0" + "dependencies": { + "reusify": "^1.0.4" } }, - "node_modules/esrecurse": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", - "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "node_modules/fb-watchman": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/fb-watchman/-/fb-watchman-2.0.2.tgz", + "integrity": "sha512-p5161BqbuCaSnB8jIbzQHOlpgsPmK5rJVDfDKO91Axs5NC1uu3HRQm6wt9cd9/+GtQQIO53JdGXXoyDpTAsgYA==", "dev": true, "dependencies": { - "estraverse": "^5.2.0" - }, - "engines": { - "node": ">=4.0" + "bser": "2.1.1" } }, - "node_modules/esrecurse/node_modules/estraverse": { - "version": "5.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", - "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "node_modules/file-entry-cache": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz", + "integrity": "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==", "dev": true, + "dependencies": { + "flat-cache": "^3.0.4" + }, "engines": { - "node": ">=4.0" + "node": "^10.12.0 || >=12.0.0" } }, - "node_modules/estraverse": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", - "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, + "dependencies": { + "to-regex-range": "^5.0.1" + }, "engines": { - "node": ">=4.0" + "node": ">=8" } }, - "node_modules/estree-walker": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-1.0.1.tgz", - "integrity": "sha512-1fMXF3YP4pZZVozF8j/ZLfvnR8NSIljt56UhbZ5PeeDmmGHpgpdwQt7ITlGvYaQukCvuBRMLEiKiYC+oeIg4cg==", - "dev": true - }, - "node_modules/esutils": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", - "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "node_modules/find-cache-dir": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/find-cache-dir/-/find-cache-dir-3.3.2.tgz", + "integrity": "sha512-wXZV5emFEjrridIgED11OoUKLxiYjAcqot/NJdAkOhlJ+vGzwhOAfcG5OX1jP+S0PcjEn8bdMJv+g2jwQ3Onig==", "dev": true, + "dependencies": { + "commondir": "^1.0.1", + "make-dir": "^3.0.2", + "pkg-dir": "^4.1.0" + }, "engines": { - "node": ">=0.10.0" + "node": ">=8" + }, + "funding": { + "url": "https://github.com/avajs/find-cache-dir?sponsor=1" } }, - "node_modules/exec-sh": { - "version": "0.3.6", - "resolved": "https://registry.npmjs.org/exec-sh/-/exec-sh-0.3.6.tgz", - "integrity": "sha512-nQn+hI3yp+oD0huYhKwvYI32+JFeq+XkNcD1GAo3Y/MjxsfVGmrrzrnzjWiNY6f+pUCP440fThsFh5gZrRAU/w==", - "dev": true - }, - "node_modules/execa": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/execa/-/execa-1.0.0.tgz", - "integrity": "sha512-adbxcyWV46qiHyvSp50TKt05tB4tK3HcmF7/nxfAdhnox83seTDbwnaqKO4sXRy7roHAIFqJP/Rw/AuEbX61LA==", + "node_modules/find-cache-dir/node_modules/make-dir": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-3.1.0.tgz", + "integrity": "sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw==", "dev": true, "dependencies": { - "cross-spawn": "^6.0.0", - "get-stream": "^4.0.0", - "is-stream": "^1.1.0", - "npm-run-path": "^2.0.0", - "p-finally": "^1.0.0", - "signal-exit": "^3.0.0", - "strip-eof": "^1.0.0" + "semver": "^6.0.0" }, "engines": { - "node": ">=6" + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/exit": { - "version": "0.1.2", - "resolved": "https://registry.npmjs.org/exit/-/exit-0.1.2.tgz", - "integrity": "sha512-Zk/eNKV2zbjpKzrsQ+n1G6poVbErQxJ0LBOJXaKZ1EViLzH+hrLu9cdXI4zw9dBQJslwBEpbQ2P1oS7nDxs6jQ==", + "node_modules/find-cache-dir/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, - "engines": { - "node": ">= 0.8.0" + "bin": { + "semver": "bin/semver.js" } }, - "node_modules/expand-brackets": { - "version": "2.1.4", - "resolved": "https://registry.npmjs.org/expand-brackets/-/expand-brackets-2.1.4.tgz", - "integrity": "sha512-w/ozOKR9Obk3qoWeY/WDi6MFta9AoMR+zud60mdnbniMcBxRuFJyDt2LdX/14A1UABeqk+Uk+LDfUpvoGKppZA==", + "node_modules/find-up": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", + "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", "dev": true, "dependencies": { - "debug": "^2.3.3", - "define-property": "^0.2.5", - "extend-shallow": "^2.0.1", - "posix-character-classes": "^0.1.0", - "regex-not": "^1.0.0", - "snapdragon": "^0.8.1", - "to-regex": "^3.0.1" + "locate-path": "^6.0.0", + "path-exists": "^4.0.0" }, "engines": { - "node": ">=0.10.0" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/expand-brackets/node_modules/debug": { - "version": "2.6.9", - "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", - "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", - "dev": true, + "node_modules/fix": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/fix/-/fix-0.0.6.tgz", + "integrity": "sha512-UQ+8m0GnIakgpY+92a9y+pYoX3Y6eaW7WNTkPolQ7r58Fjzq7NhyRLMrZ6J6U1u4y7H7APugjRmZ+i6CAn4+Dg==", "dependencies": { - "ms": "2.0.0" + "pipe": "0.0.2", + "underscore": "1.1.6", + "underscore.string": "1.1.4" + }, + "engines": { + "node": ">=0.4.8" } }, - "node_modules/expand-brackets/node_modules/define-property": { - "version": "0.2.5", - "resolved": "https://registry.npmjs.org/define-property/-/define-property-0.2.5.tgz", - "integrity": "sha512-Rr7ADjQZenceVOAKop6ALkkRAmH1A4Gx9hV/7ZujPUN2rkATqFO0JZLZInbAjpZYoJ1gUx8MRMQVkYemcbMSTA==", + "node_modules/flat-cache": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.2.0.tgz", + "integrity": "sha512-CYcENa+FtcUKLmhhqyctpclsq7QF38pKjZHsGNiSQF5r4FtoKDWabFDl3hzaEQMvT1LHEysw5twgLvpYYb4vbw==", "dev": true, "dependencies": { - "is-descriptor": "^0.1.0" + "flatted": "^3.2.9", + "keyv": "^4.5.3", + "rimraf": "^3.0.2" }, "engines": { - "node": ">=0.10.0" + "node": "^10.12.0 || >=12.0.0" } }, - "node_modules/expand-brackets/node_modules/extend-shallow": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/extend-shallow/-/extend-shallow-2.0.1.tgz", - "integrity": "sha512-zCnTtlxNoAiDc3gqY2aYAWFx7XWWiasuF2K8Me5WbN8otHKTUKBwjPtNpRs/rbUZm7KxWAaNj7P1a/p52GbVug==", + "node_modules/flatted": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", + "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "dev": true + }, + "node_modules/for-in": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/for-in/-/for-in-1.0.2.tgz", + "integrity": "sha512-7EwmXrOjyL+ChxMhmG5lnW9MPt1aIeZEwKhQzoBUdTV0N3zuwWDZYVJatDvZ2OyzPUvdIAZDsCetk3coyMfcnQ==", "dev": true, - "dependencies": { - "is-extendable": "^0.1.0" - }, "engines": { "node": ">=0.10.0" } }, - "node_modules/expand-brackets/node_modules/ms": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", - "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", - "dev": true - }, - "node_modules/expect": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/expect/-/expect-26.6.2.tgz", - "integrity": "sha512-9/hlOBkQl2l/PLHJx6JjoDF6xPKcJEsUlWKb23rKE7KzeDqUZKXKNMW27KIue5JMdBV9HgmoJPcc8HtO85t9IA==", + "node_modules/form-data": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-3.0.4.tgz", + "integrity": "sha512-f0cRzm6dkyVYV3nPoooP8XlccPQukegwhAnpoLcXy+X+A8KfpGOoXwDr9FLZd3wzgLaBGQBE3lY93Zm/i1JvIQ==", "dev": true, "dependencies": { - "@jest/types": "^26.6.2", - "ansi-styles": "^4.0.0", - "jest-get-type": "^26.3.0", - "jest-matcher-utils": "^26.6.2", - "jest-message-util": "^26.6.2", - "jest-regex-util": "^26.0.0" + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", + "mime-types": "^2.1.35" }, "engines": { - "node": ">= 10.14.2" + "node": ">= 6" } }, - "node_modules/expect/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/fragment-cache": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/fragment-cache/-/fragment-cache-0.2.1.tgz", + "integrity": "sha512-GMBAbW9antB8iZRHLoGw0b3HANt57diZYFO/HL1JGIC1MjKrdmhxvrJbupnVvpys0zsz7yBApXdQyfepKly2kA==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "map-cache": "^0.2.2" }, "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "node": ">=0.10.0" } }, - "node_modules/expect/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/fs-extra": { + "version": "10.1.0", + "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-10.1.0.tgz", + "integrity": "sha512-oRXApq54ETRj4eMiFzGnHWGy+zo5raudjuxN0b8H7s/RU2oW0Wvsx9O0ACRN/kRq9E8Vu/ReskGB5o3ji+FzHQ==", "dev": true, "dependencies": { - "color-name": "~1.1.4" + "graceful-fs": "^4.2.0", + "jsonfile": "^6.0.1", + "universalify": "^2.0.0" }, "engines": { - "node": ">=7.0.0" + "node": ">=12" } }, - "node_modules/expect/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "node_modules/fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", "dev": true }, - "node_modules/extend-shallow": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/extend-shallow/-/extend-shallow-3.0.2.tgz", - "integrity": "sha512-BwY5b5Ql4+qZoefgMj2NUmx+tehVTH/Kf4k1ZEtOHNFcm2wSxMRo992l6X3TIgni2eZVTZ85xMOjF31fwZAj6Q==", + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", "dev": true, - "dependencies": { - "assign-symbols": "^1.0.0", - "is-extendable": "^1.0.1" - }, + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], "engines": { - "node": ">=0.10.0" + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" } }, - "node_modules/extend-shallow/node_modules/is-extendable": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-1.0.1.tgz", - "integrity": "sha512-arnXMxT1hhoKo9k1LZdmlNyJdDDfy2v0fXjFlmok4+i8ul/6WlbVge9bhM74OpNPQPMGUToDtz+KXa1PneJxOA==", + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/gensync": { + "version": "1.0.0-beta.2", + "resolved": "https://registry.npmjs.org/gensync/-/gensync-1.0.0-beta.2.tgz", + "integrity": "sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==", "dev": true, - "dependencies": { - "is-plain-object": "^2.0.4" - }, "engines": { - "node": ">=0.10.0" + "node": ">=6.9.0" } }, - "node_modules/extglob": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/extglob/-/extglob-2.0.4.tgz", - "integrity": "sha512-Nmb6QXkELsuBr24CJSkilo6UHHgbekK5UiZgfE6UHD3Eb27YC6oD+bhcT+tJ6cl8dmsgdQxnWlcry8ksBIBLpw==", + "node_modules/get-caller-file": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", + "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", "dev": true, - "dependencies": { - "array-unique": "^0.3.2", - "define-property": "^1.0.0", - "expand-brackets": "^2.1.4", - "extend-shallow": "^2.0.1", - "fragment-cache": "^0.2.1", - "regex-not": "^1.0.0", - "snapdragon": "^0.8.1", - "to-regex": "^3.0.1" - }, "engines": { - "node": ">=0.10.0" + "node": "6.* || 8.* || >= 10.*" } }, - "node_modules/extglob/node_modules/define-property": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/define-property/-/define-property-1.0.0.tgz", - "integrity": "sha512-cZTYKFWspt9jZsMscWo8sc/5lbPC9Q0N5nBLgb+Yd915iL3udB1uFgS3B8YCx66UVHq018DAVFoee7x+gxggeA==", + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", "dev": true, "dependencies": { - "is-descriptor": "^1.0.0" + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" }, "engines": { - "node": ">=0.10.0" + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/extglob/node_modules/extend-shallow": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/extend-shallow/-/extend-shallow-2.0.1.tgz", - "integrity": "sha512-zCnTtlxNoAiDc3gqY2aYAWFx7XWWiasuF2K8Me5WbN8otHKTUKBwjPtNpRs/rbUZm7KxWAaNj7P1a/p52GbVug==", + "node_modules/get-package-type": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/get-package-type/-/get-package-type-0.1.0.tgz", + "integrity": "sha512-pjzuKtY64GYfWizNAJ0fr9VqttZkNiK2iS430LtIHzjBEr6bX8Am2zm4sW4Ro5wjWW5cAlRL1qAMTcXbjNAO2Q==", + "dev": true, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", "dev": true, "dependencies": { - "is-extendable": "^0.1.0" + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" }, "engines": { - "node": ">=0.10.0" + "node": ">= 0.4" } }, - "node_modules/extglob/node_modules/is-accessor-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-accessor-descriptor/-/is-accessor-descriptor-1.0.0.tgz", - "integrity": "sha512-m5hnHTkcVsPfqx3AKlyttIPb7J+XykHvJP2B9bZDjlhLIoEq4XoK64Vg7boZlVWYK6LUY94dYPEE7Lh0ZkZKcQ==", + "node_modules/get-stream": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-5.2.0.tgz", + "integrity": "sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA==", "dev": true, "dependencies": { - "kind-of": "^6.0.0" + "pump": "^3.0.0" }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/get-value": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/get-value/-/get-value-2.0.6.tgz", + "integrity": "sha512-Ln0UQDlxH1BapMu3GPtf7CuYNwRZf2gwCuPqbyG6pB8WfmFpzqcy4xtAaAMUhnNqjMKTiCPZG2oMT3YSx8U2NA==", + "dev": true, "engines": { "node": ">=0.10.0" } }, - "node_modules/extglob/node_modules/is-data-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-data-descriptor/-/is-data-descriptor-1.0.0.tgz", - "integrity": "sha512-jbRXy1FmtAoCjQkVmIVYwuuqDFUbaOeDjmed1tOGPrsMhtJA4rD9tkgA0F1qJ3gRFRXcHYVkdeaP50Q5rE/jLQ==", + "node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", "dev": true, "dependencies": { - "kind-of": "^6.0.0" + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" }, "engines": { - "node": ">=0.10.0" + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" } }, - "node_modules/extglob/node_modules/is-descriptor": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-1.0.2.tgz", - "integrity": "sha512-2eis5WqQGV7peooDyLmNEPUrps9+SXX5c9pL3xEB+4e9HnGuDa7mB7kHxHw4CbqS9k1T2hOH3miL8n8WtiYVtg==", + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", "dev": true, "dependencies": { - "is-accessor-descriptor": "^1.0.0", - "is-data-descriptor": "^1.0.0", - "kind-of": "^6.0.2" + "is-glob": "^4.0.3" }, "engines": { - "node": ">=0.10.0" + "node": ">=10.13.0" } }, - "node_modules/fast-deep-equal": { - "version": "3.1.3", - "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", - "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", - "dev": true - }, - "node_modules/fast-glob": { - "version": "3.2.12", - "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", - "integrity": "sha512-DVj4CQIYYow0BlaelwK1pHl5n5cRSJfM60UA0zK891sVInoPri2Ekj7+e1CT3/3qxXenpI+nBBmQAcJPJgaj4w==", + "node_modules/globals": { + "version": "13.24.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-13.24.0.tgz", + "integrity": "sha512-AhO5QUcj8llrbG09iWhPU2B204J1xnPeL8kQmVorSsy+Sjj1sk8gIyh6cUocGmH4L0UuhAJy+hJMRA4mgA4mFQ==", "dev": true, "dependencies": { - "@nodelib/fs.stat": "^2.0.2", - "@nodelib/fs.walk": "^1.2.3", - "glob-parent": "^5.1.2", - "merge2": "^1.3.0", - "micromatch": "^4.0.4" + "type-fest": "^0.20.2" }, "engines": { - "node": ">=8.6.0" + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/fast-glob/node_modules/glob-parent": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", - "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "node_modules/globby": { + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", + "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", "dev": true, "dependencies": { - "is-glob": "^4.0.1" + "array-union": "^2.1.0", + "dir-glob": "^3.0.1", + "fast-glob": "^3.2.9", + "ignore": "^5.2.0", + "merge2": "^1.4.1", + "slash": "^3.0.0" }, "engines": { - "node": ">= 6" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/fast-json-stable-stringify": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", - "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "dev": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", "dev": true }, - "node_modules/fast-levenshtein": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", - "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "node_modules/graphemer": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", "dev": true }, - "node_modules/fastq": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", - "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "node_modules/growly": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/growly/-/growly-1.3.0.tgz", + "integrity": "sha512-+xGQY0YyAWCnqy7Cd++hc2JqMYzlm0dG30Jd0beaA64sROr8C4nt8Yc9V5Ro3avlSUDTN0ulqP/VBKi1/lLygw==", "dev": true, - "dependencies": { - "reusify": "^1.0.4" + "optional": true + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" } }, - "node_modules/fb-watchman": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/fb-watchman/-/fb-watchman-2.0.2.tgz", - "integrity": "sha512-p5161BqbuCaSnB8jIbzQHOlpgsPmK5rJVDfDKO91Axs5NC1uu3HRQm6wt9cd9/+GtQQIO53JdGXXoyDpTAsgYA==", + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", "dev": true, - "dependencies": { - "bser": "2.1.1" + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/file-entry-cache": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz", - "integrity": "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==", + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", "dev": true, "dependencies": { - "flat-cache": "^3.0.4" + "has-symbols": "^1.0.3" }, "engines": { - "node": "^10.12.0 || >=12.0.0" + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "node_modules/has-value": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/has-value/-/has-value-1.0.0.tgz", + "integrity": "sha512-IBXk4GTsLYdQ7Rvt+GRBrFSVEkmuOUy4re0Xjd9kJSUQpnTrWR4/y9RpfexN9vkAPMFuQoeWKwqzPozRTlasGw==", "dev": true, "dependencies": { - "to-regex-range": "^5.0.1" + "get-value": "^2.0.6", + "has-values": "^1.0.0", + "isobject": "^3.0.0" }, "engines": { - "node": ">=8" + "node": ">=0.10.0" } }, - "node_modules/find-cache-dir": { - "version": "3.3.2", - "resolved": "https://registry.npmjs.org/find-cache-dir/-/find-cache-dir-3.3.2.tgz", - "integrity": "sha512-wXZV5emFEjrridIgED11OoUKLxiYjAcqot/NJdAkOhlJ+vGzwhOAfcG5OX1jP+S0PcjEn8bdMJv+g2jwQ3Onig==", + "node_modules/has-values": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/has-values/-/has-values-1.0.0.tgz", + "integrity": "sha512-ODYZC64uqzmtfGMEAX/FvZiRyWLpAC3vYnNunURUnkGVTS+mI0smVsWaPydRBsE3g+ok7h960jChO8mFcWlHaQ==", "dev": true, "dependencies": { - "commondir": "^1.0.1", - "make-dir": "^3.0.2", - "pkg-dir": "^4.1.0" + "is-number": "^3.0.0", + "kind-of": "^4.0.0" }, "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/avajs/find-cache-dir?sponsor=1" + "node": ">=0.10.0" } }, - "node_modules/find-up": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", - "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "node_modules/has-values/node_modules/is-number": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-3.0.0.tgz", + "integrity": "sha512-4cboCqIpliH+mAvFNegjZQ4kgKc3ZUhQVr3HvWbSh5q3WH2v82ct+T2Y1hdU5Gdtorx/cLifQjqCbL7bpznLTg==", "dev": true, "dependencies": { - "locate-path": "^5.0.0", - "path-exists": "^4.0.0" + "kind-of": "^3.0.2" }, "engines": { - "node": ">=8" + "node": ">=0.10.0" } }, - "node_modules/flat-cache": { - "version": "3.0.4", - "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.0.4.tgz", - "integrity": "sha512-dm9s5Pw7Jc0GvMYbshN6zchCA9RgQlzzEZX3vylR9IqFfS8XciblUXOKfW6SiuJ0e13eDYZoZV5wdrev7P3Nwg==", + "node_modules/has-values/node_modules/is-number/node_modules/kind-of": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", + "integrity": "sha512-NOW9QQXMoZGg/oqnVNoNTTIFEIid1627WCffUBJEdMxYApq7mNE7CpzucIPc+ZQg25Phej7IJSmX3hO+oblOtQ==", "dev": true, "dependencies": { - "flatted": "^3.1.0", - "rimraf": "^3.0.2" + "is-buffer": "^1.1.5" }, - "engines": { - "node": "^10.12.0 || >=12.0.0" - } - }, - "node_modules/flatted": { - "version": "3.2.7", - "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.2.7.tgz", - "integrity": "sha512-5nqDSxl8nn5BSNxyR3n4I6eDmbolI6WT+QqR547RwxQapgjQBmtktdP+HTBb/a/zLsbzERTONyUB5pefh5TtjQ==", - "dev": true - }, - "node_modules/for-in": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/for-in/-/for-in-1.0.2.tgz", - "integrity": "sha512-7EwmXrOjyL+ChxMhmG5lnW9MPt1aIeZEwKhQzoBUdTV0N3zuwWDZYVJatDvZ2OyzPUvdIAZDsCetk3coyMfcnQ==", - "dev": true, "engines": { "node": ">=0.10.0" } }, - "node_modules/form-data": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-3.0.1.tgz", - "integrity": "sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg==", + "node_modules/has-values/node_modules/kind-of": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-4.0.0.tgz", + "integrity": "sha512-24XsCxmEbRwEDbz/qz3stgin8TTzZ1ESR56OMCN0ujYg+vRutNSiOj9bHH9u85DKgXguraugV5sFuvbD4FW/hw==", "dev": true, "dependencies": { - "asynckit": "^0.4.0", - "combined-stream": "^1.0.8", - "mime-types": "^2.1.12" + "is-buffer": "^1.1.5" }, "engines": { - "node": ">= 6" + "node": ">=0.10.0" } }, - "node_modules/fragment-cache": { - "version": "0.2.1", - "resolved": "https://registry.npmjs.org/fragment-cache/-/fragment-cache-0.2.1.tgz", - "integrity": "sha512-GMBAbW9antB8iZRHLoGw0b3HANt57diZYFO/HL1JGIC1MjKrdmhxvrJbupnVvpys0zsz7yBApXdQyfepKly2kA==", + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", "dev": true, "dependencies": { - "map-cache": "^0.2.2" + "function-bind": "^1.1.2" }, "engines": { - "node": ">=0.10.0" + "node": ">= 0.4" } }, - "node_modules/fs.realpath": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", - "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", + "node_modules/hosted-git-info": { + "version": "2.8.9", + "resolved": "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz", + "integrity": "sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==", "dev": true }, - "node_modules/fsevents": { - "version": "2.3.3", - "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", - "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "node_modules/html-encoding-sniffer": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-2.0.1.tgz", + "integrity": "sha512-D5JbOMBIR/TVZkubHT+OyT2705QvogUW4IBn6nHd756OwieSF9aDYFj4dv6HHEVGYbHaLETa3WggZYWWMyy3ZQ==", "dev": true, - "hasInstallScript": true, - "license": "MIT", - "optional": true, - "os": [ - "darwin" - ], + "dependencies": { + "whatwg-encoding": "^1.0.5" + }, "engines": { - "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + "node": ">=10" } }, - "node_modules/function-bind": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", - "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", + "node_modules/html-escaper": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/html-escaper/-/html-escaper-2.0.2.tgz", + "integrity": "sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==", "dev": true }, - "node_modules/gensync": { - "version": "1.0.0-beta.2", - "resolved": "https://registry.npmjs.org/gensync/-/gensync-1.0.0-beta.2.tgz", - "integrity": "sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==", + "node_modules/http-proxy-agent": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-4.0.1.tgz", + "integrity": "sha512-k0zdNgqWTGA6aeIRVpvfVob4fL52dTfaehylg0Y4UvSySvOq/Y+BOyPrgpUrA7HylqvU8vIZGsRuXmspskV0Tg==", "dev": true, + "dependencies": { + "@tootallnate/once": "1", + "agent-base": "6", + "debug": "4" + }, "engines": { - "node": ">=6.9.0" + "node": ">= 6" } }, - "node_modules/get-caller-file": { - "version": "2.0.5", - "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", - "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", + "node_modules/https-proxy-agent": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-5.0.1.tgz", + "integrity": "sha512-dFcAjpTQFgoLMzC2VwU+C/CbS7uRL0lWmxDITmqm7C+7F0Odmj6s9l6alZc6AELXhrnggM2CeWSXHGOdX2YtwA==", "dev": true, + "dependencies": { + "agent-base": "6", + "debug": "4" + }, "engines": { - "node": "6.* || 8.* || >= 10.*" + "node": ">= 6" } }, - "node_modules/get-package-type": { - "version": "0.1.0", - "resolved": "https://registry.npmjs.org/get-package-type/-/get-package-type-0.1.0.tgz", - "integrity": "sha512-pjzuKtY64GYfWizNAJ0fr9VqttZkNiK2iS430LtIHzjBEr6bX8Am2zm4sW4Ro5wjWW5cAlRL1qAMTcXbjNAO2Q==", + "node_modules/human-signals": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-1.1.1.tgz", + "integrity": "sha512-SEQu7vl8KjNL2eoGBLF3+wAjpsNfA9XMlXAYj/3EdaNfAlxKthD1xjEQfGOUhllCGGJVNY34bRr6lPINhNjyZw==", "dev": true, "engines": { - "node": ">=8.0.0" + "node": ">=8.12.0" } }, - "node_modules/get-stream": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-4.1.0.tgz", - "integrity": "sha512-GMat4EJ5161kIy2HevLlr4luNjBgvmj413KaQA7jt4V8B4RDsfpHk7WQ9GVqfYyyx8OS/L66Kox+rJRNklLK7w==", + "node_modules/iconv-lite": { + "version": "0.4.24", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz", + "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==", "dev": true, "dependencies": { - "pump": "^3.0.0" + "safer-buffer": ">= 2.1.2 < 3" }, "engines": { - "node": ">=6" + "node": ">=0.10.0" } }, - "node_modules/get-value": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/get-value/-/get-value-2.0.6.tgz", - "integrity": "sha512-Ln0UQDlxH1BapMu3GPtf7CuYNwRZf2gwCuPqbyG6pB8WfmFpzqcy4xtAaAMUhnNqjMKTiCPZG2oMT3YSx8U2NA==", + "node_modules/ignore": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", "dev": true, "engines": { - "node": ">=0.10.0" + "node": ">= 4" } }, - "node_modules/glob": { - "version": "7.2.3", - "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", - "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "node_modules/import-fresh": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", "dev": true, "dependencies": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.1.1", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" }, "engines": { - "node": "*" + "node": ">=6" }, "funding": { - "url": "https://github.com/sponsors/isaacs" + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/glob-parent": { - "version": "6.0.2", - "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", - "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "node_modules/import-local": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/import-local/-/import-local-3.2.0.tgz", + "integrity": "sha512-2SPlun1JUPWoM6t3F0dw0FkCF/jWY8kttcY4f599GLTSjh2OCuuhdTkJQsEcZzBqbXZGKMK2OqW1oZsjtf/gQA==", "dev": true, "dependencies": { - "is-glob": "^4.0.3" + "pkg-dir": "^4.2.0", + "resolve-cwd": "^3.0.0" }, - "engines": { - "node": ">=10.13.0" - } - }, - "node_modules/globals": { - "version": "13.20.0", - "resolved": "https://registry.npmjs.org/globals/-/globals-13.20.0.tgz", - "integrity": "sha512-Qg5QtVkCy/kv3FUSlu4ukeZDVf9ee0iXLAUYX13gbR17bnejFTzr4iS9bY7kwCf1NztRNm1t91fjOiyx4CSwPQ==", - "dev": true, - "dependencies": { - "type-fest": "^0.20.2" + "bin": { + "import-local-fixture": "fixtures/cli.js" }, "engines": { "node": ">=8" @@ -3923,256 +3625,246 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/globals/node_modules/type-fest": { - "version": "0.20.2", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", - "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", "dev": true, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">=0.8.19" } }, - "node_modules/globby": { - "version": "11.1.0", - "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", - "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", + "node_modules/inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.", "dev": true, "dependencies": { - "array-union": "^2.1.0", - "dir-glob": "^3.0.1", - "fast-glob": "^3.2.9", - "ignore": "^5.2.0", - "merge2": "^1.4.1", - "slash": "^3.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "once": "^1.3.0", + "wrappy": "1" } }, - "node_modules/graceful-fs": { - "version": "4.2.10", - "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.10.tgz", - "integrity": "sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==", + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", "dev": true }, - "node_modules/grapheme-splitter": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/grapheme-splitter/-/grapheme-splitter-1.0.4.tgz", - "integrity": "sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==", - "dev": true + "node_modules/is-accessor-descriptor": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-accessor-descriptor/-/is-accessor-descriptor-1.0.1.tgz", + "integrity": "sha512-YBUanLI8Yoihw923YeFUS5fs0fF2f5TSFTNiYAAzhhDscDa3lEqYuz1pDOEP5KvX94I9ey3vsqjJcLVFVU+3QA==", + "dev": true, + "dependencies": { + "hasown": "^2.0.0" + }, + "engines": { + "node": ">= 0.10" + } }, - "node_modules/graphemer": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", - "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", + "node_modules/is-arrayish": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.2.1.tgz", + "integrity": "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==", "dev": true }, - "node_modules/growly": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/growly/-/growly-1.3.0.tgz", - "integrity": "sha512-+xGQY0YyAWCnqy7Cd++hc2JqMYzlm0dG30Jd0beaA64sROr8C4nt8Yc9V5Ro3avlSUDTN0ulqP/VBKi1/lLygw==", - "dev": true, - "optional": true + "node_modules/is-buffer": { + "version": "1.1.6", + "resolved": "https://registry.npmjs.org/is-buffer/-/is-buffer-1.1.6.tgz", + "integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==", + "dev": true }, - "node_modules/has": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", - "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", + "node_modules/is-builtin-module": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/is-builtin-module/-/is-builtin-module-3.2.1.tgz", + "integrity": "sha512-BSLE3HnV2syZ0FK0iMA/yUGplUeMmNz4AW5fnTunbCIqZi4vG3WjJT9FHMy5D69xmAYBHXQhJdALdpwVxV501A==", "dev": true, "dependencies": { - "function-bind": "^1.1.1" + "builtin-modules": "^3.3.0" }, "engines": { - "node": ">= 0.4.0" + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/has-flag": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", - "integrity": "sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==", + "node_modules/is-ci": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/is-ci/-/is-ci-2.0.0.tgz", + "integrity": "sha512-YfJT7rkpQB0updsdHLGWrvhBJfcfzNNawYDNIyQXJz0IViGf75O8EBPKSdvw2rF+LGCsX4FZ8tcr3b19LcZq4w==", "dev": true, - "engines": { - "node": ">=4" + "dependencies": { + "ci-info": "^2.0.0" + }, + "bin": { + "is-ci": "bin.js" } }, - "node_modules/has-value": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-value/-/has-value-1.0.0.tgz", - "integrity": "sha512-IBXk4GTsLYdQ7Rvt+GRBrFSVEkmuOUy4re0Xjd9kJSUQpnTrWR4/y9RpfexN9vkAPMFuQoeWKwqzPozRTlasGw==", + "node_modules/is-core-module": { + "version": "2.16.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", + "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", "dev": true, "dependencies": { - "get-value": "^2.0.6", - "has-values": "^1.0.0", - "isobject": "^3.0.0" + "hasown": "^2.0.2" }, "engines": { - "node": ">=0.10.0" + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/has-values": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-values/-/has-values-1.0.0.tgz", - "integrity": "sha512-ODYZC64uqzmtfGMEAX/FvZiRyWLpAC3vYnNunURUnkGVTS+mI0smVsWaPydRBsE3g+ok7h960jChO8mFcWlHaQ==", + "node_modules/is-data-descriptor": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-data-descriptor/-/is-data-descriptor-1.0.1.tgz", + "integrity": "sha512-bc4NlCDiCr28U4aEsQ3Qs2491gVq4V8G7MQyws968ImqjKuYtTJXrl7Vq7jsN7Ly/C3xj5KWFrY7sHNeDkAzXw==", "dev": true, "dependencies": { - "is-number": "^3.0.0", - "kind-of": "^4.0.0" + "hasown": "^2.0.0" }, "engines": { - "node": ">=0.10.0" + "node": ">= 0.4" } }, - "node_modules/has-values/node_modules/is-number": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/is-number/-/is-number-3.0.0.tgz", - "integrity": "sha512-4cboCqIpliH+mAvFNegjZQ4kgKc3ZUhQVr3HvWbSh5q3WH2v82ct+T2Y1hdU5Gdtorx/cLifQjqCbL7bpznLTg==", + "node_modules/is-descriptor": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-1.0.3.tgz", + "integrity": "sha512-JCNNGbwWZEVaSPtS45mdtrneRWJFp07LLmykxeFV5F6oBvNF8vHSfJuJgoT472pSfk+Mf8VnlrspaFBHWM8JAw==", "dev": true, "dependencies": { - "kind-of": "^3.0.2" + "is-accessor-descriptor": "^1.0.1", + "is-data-descriptor": "^1.0.1" }, "engines": { - "node": ">=0.10.0" + "node": ">= 0.4" } }, - "node_modules/has-values/node_modules/is-number/node_modules/kind-of": { - "version": "3.2.2", - "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", - "integrity": "sha512-NOW9QQXMoZGg/oqnVNoNTTIFEIid1627WCffUBJEdMxYApq7mNE7CpzucIPc+ZQg25Phej7IJSmX3hO+oblOtQ==", + "node_modules/is-docker": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", + "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", "dev": true, - "dependencies": { - "is-buffer": "^1.1.5" + "optional": true, + "bin": { + "is-docker": "cli.js" }, "engines": { - "node": ">=0.10.0" + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/has-values/node_modules/kind-of": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-4.0.0.tgz", - "integrity": "sha512-24XsCxmEbRwEDbz/qz3stgin8TTzZ1ESR56OMCN0ujYg+vRutNSiOj9bHH9u85DKgXguraugV5sFuvbD4FW/hw==", + "node_modules/is-extendable": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-1.0.1.tgz", + "integrity": "sha512-arnXMxT1hhoKo9k1LZdmlNyJdDDfy2v0fXjFlmok4+i8ul/6WlbVge9bhM74OpNPQPMGUToDtz+KXa1PneJxOA==", "dev": true, "dependencies": { - "is-buffer": "^1.1.5" + "is-plain-object": "^2.0.4" }, "engines": { "node": ">=0.10.0" } }, - "node_modules/hosted-git-info": { - "version": "2.8.9", - "resolved": "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz", - "integrity": "sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==", - "dev": true + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } }, - "node_modules/html-encoding-sniffer": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-2.0.1.tgz", - "integrity": "sha512-D5JbOMBIR/TVZkubHT+OyT2705QvogUW4IBn6nHd756OwieSF9aDYFj4dv6HHEVGYbHaLETa3WggZYWWMyy3ZQ==", + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", "dev": true, - "dependencies": { - "whatwg-encoding": "^1.0.5" - }, "engines": { - "node": ">=10" + "node": ">=8" } }, - "node_modules/html-escaper": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/html-escaper/-/html-escaper-2.0.2.tgz", - "integrity": "sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==", - "dev": true + "node_modules/is-generator-fn": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-generator-fn/-/is-generator-fn-2.1.0.tgz", + "integrity": "sha512-cTIB4yPYL/Grw0EaSzASzg6bBy9gqCofvWN8okThAYIxKJZC+udlRAmGbM0XLeniEJSs8uEgHPGuHSe1XsOLSQ==", + "dev": true, + "engines": { + "node": ">=6" + } }, - "node_modules/http-proxy-agent": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-4.0.1.tgz", - "integrity": "sha512-k0zdNgqWTGA6aeIRVpvfVob4fL52dTfaehylg0Y4UvSySvOq/Y+BOyPrgpUrA7HylqvU8vIZGsRuXmspskV0Tg==", + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", "dev": true, "dependencies": { - "@tootallnate/once": "1", - "agent-base": "6", - "debug": "4" + "is-extglob": "^2.1.1" }, "engines": { - "node": ">= 6" + "node": ">=0.10.0" } }, - "node_modules/https-proxy-agent": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-5.0.1.tgz", - "integrity": "sha512-dFcAjpTQFgoLMzC2VwU+C/CbS7uRL0lWmxDITmqm7C+7F0Odmj6s9l6alZc6AELXhrnggM2CeWSXHGOdX2YtwA==", + "node_modules/is-module": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-module/-/is-module-1.0.0.tgz", + "integrity": "sha512-51ypPSPCoTEIN9dy5Oy+h4pShgJmPCygKfyRCISBI+JoWT/2oJvK8QPxmwv7b/p239jXrm9M1mlQbyKJ5A152g==", + "dev": true + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", "dev": true, - "dependencies": { - "agent-base": "6", - "debug": "4" - }, "engines": { - "node": ">= 6" + "node": ">=0.12.0" } }, - "node_modules/human-signals": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-1.1.1.tgz", - "integrity": "sha512-SEQu7vl8KjNL2eoGBLF3+wAjpsNfA9XMlXAYj/3EdaNfAlxKthD1xjEQfGOUhllCGGJVNY34bRr6lPINhNjyZw==", + "node_modules/is-path-inside": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-3.0.3.tgz", + "integrity": "sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ==", "dev": true, "engines": { - "node": ">=8.12.0" + "node": ">=8" } }, - "node_modules/iconv-lite": { - "version": "0.4.24", - "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz", - "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==", + "node_modules/is-plain-object": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/is-plain-object/-/is-plain-object-2.0.4.tgz", + "integrity": "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og==", "dev": true, "dependencies": { - "safer-buffer": ">= 2.1.2 < 3" + "isobject": "^3.0.1" }, "engines": { "node": ">=0.10.0" } }, - "node_modules/ignore": { - "version": "5.2.4", - "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz", - "integrity": "sha512-MAb38BcSbH0eHNBxn7ql2NH/kX33OkB3lZ1BNdh7ENeRChHTYsTvWrMubiIAMNS2llXEEgZ1MUOBtXChP3kaFQ==", - "dev": true, - "engines": { - "node": ">= 4" - } + "node_modules/is-potential-custom-element-name": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-potential-custom-element-name/-/is-potential-custom-element-name-1.0.1.tgz", + "integrity": "sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==", + "dev": true }, - "node_modules/import-fresh": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", - "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", + "node_modules/is-reference": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/is-reference/-/is-reference-1.2.1.tgz", + "integrity": "sha512-U82MsXXiFIrjCK4otLT+o2NA2Cd2g5MLoOVXUZjIOhLurrRxpEXzI8O0KZHr3IjLvlAH1kTPYSuqer5T9ZVBKQ==", "dev": true, "dependencies": { - "parent-module": "^1.0.0", - "resolve-from": "^4.0.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "@types/estree": "*" } }, - "node_modules/import-local": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/import-local/-/import-local-3.1.0.tgz", - "integrity": "sha512-ASB07uLtnDs1o6EHjKpX34BKYDSqnFerfTOJL2HvMqF70LnxpjkzDB8J44oT9pu4AMPkQwf8jl6szgvNd2tRIg==", + "node_modules/is-stream": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", + "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", "dev": true, - "dependencies": { - "pkg-dir": "^4.2.0", - "resolve-cwd": "^3.0.0" - }, - "bin": { - "import-local-fixture": "fixtures/cli.js" - }, "engines": { "node": ">=8" }, @@ -4180,2089 +3872,675 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/imurmurhash": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", - "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "node_modules/is-typedarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-typedarray/-/is-typedarray-1.0.0.tgz", + "integrity": "sha512-cyA56iCMHAh5CdzjJIa4aohJyeO1YbwLi3Jc35MmRU6poroFjIGZzUzupGiRPOjgHg9TLu43xbpwXk523fMxKA==", + "dev": true + }, + "node_modules/is-windows": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/is-windows/-/is-windows-1.0.2.tgz", + "integrity": "sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA==", "dev": true, "engines": { - "node": ">=0.8.19" + "node": ">=0.10.0" } }, - "node_modules/inflight": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", - "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "node_modules/is-wsl": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", + "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", "dev": true, + "optional": true, "dependencies": { - "once": "^1.3.0", - "wrappy": "1" + "is-docker": "^2.0.0" + }, + "engines": { + "node": ">=8" } }, - "node_modules/inherits": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", - "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "node_modules/isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==", "dev": true }, - "node_modules/is-accessor-descriptor": { - "version": "0.1.6", - "resolved": "https://registry.npmjs.org/is-accessor-descriptor/-/is-accessor-descriptor-0.1.6.tgz", - "integrity": "sha512-e1BM1qnDbMRG3ll2U9dSK0UMHuWOs3pY3AtcFsmvwPtKL3MML/Q86i+GilLfvqEs4GW+ExB91tQ3Ig9noDIZ+A==", + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true + }, + "node_modules/isobject": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/isobject/-/isobject-3.0.1.tgz", + "integrity": "sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==", "dev": true, - "dependencies": { - "kind-of": "^3.0.2" - }, "engines": { "node": ">=0.10.0" } }, - "node_modules/is-accessor-descriptor/node_modules/kind-of": { + "node_modules/istanbul-lib-coverage": { "version": "3.2.2", - "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", - "integrity": "sha512-NOW9QQXMoZGg/oqnVNoNTTIFEIid1627WCffUBJEdMxYApq7mNE7CpzucIPc+ZQg25Phej7IJSmX3hO+oblOtQ==", + "resolved": "https://registry.npmjs.org/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.2.tgz", + "integrity": "sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/istanbul-lib-instrument": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/istanbul-lib-instrument/-/istanbul-lib-instrument-4.0.3.tgz", + "integrity": "sha512-BXgQl9kf4WTCPCCpmFGoJkz/+uhvm7h7PFKUYxh7qarQd3ER33vHG//qaE8eN25l07YqZPpHXU9I09l/RD5aGQ==", "dev": true, "dependencies": { - "is-buffer": "^1.1.5" + "@babel/core": "^7.7.5", + "@istanbuljs/schema": "^0.1.2", + "istanbul-lib-coverage": "^3.0.0", + "semver": "^6.3.0" }, "engines": { - "node": ">=0.10.0" + "node": ">=8" } }, - "node_modules/is-arrayish": { - "version": "0.2.1", - "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.2.1.tgz", - "integrity": "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==", - "dev": true - }, - "node_modules/is-buffer": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/is-buffer/-/is-buffer-1.1.6.tgz", - "integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==", - "dev": true - }, - "node_modules/is-builtin-module": { - "version": "3.2.1", - "resolved": "https://registry.npmjs.org/is-builtin-module/-/is-builtin-module-3.2.1.tgz", - "integrity": "sha512-BSLE3HnV2syZ0FK0iMA/yUGplUeMmNz4AW5fnTunbCIqZi4vG3WjJT9FHMy5D69xmAYBHXQhJdALdpwVxV501A==", - "dev": true, - "dependencies": { - "builtin-modules": "^3.3.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/is-ci": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/is-ci/-/is-ci-2.0.0.tgz", - "integrity": "sha512-YfJT7rkpQB0updsdHLGWrvhBJfcfzNNawYDNIyQXJz0IViGf75O8EBPKSdvw2rF+LGCsX4FZ8tcr3b19LcZq4w==", - "dev": true, - "dependencies": { - "ci-info": "^2.0.0" - }, - "bin": { - "is-ci": "bin.js" - } - }, - "node_modules/is-core-module": { - "version": "2.11.0", - "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.11.0.tgz", - "integrity": "sha512-RRjxlvLDkD1YJwDbroBHMb+cukurkDWNyHx7D3oNB5x9rb5ogcksMC5wHCadcXoo67gVr/+3GFySh3134zi6rw==", - "dev": true, - "dependencies": { - "has": "^1.0.3" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-data-descriptor": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/is-data-descriptor/-/is-data-descriptor-0.1.4.tgz", - "integrity": "sha512-+w9D5ulSoBNlmw9OHn3U2v51SyoCd0he+bB3xMl62oijhrspxowjU+AIcDY0N3iEJbUEkB15IlMASQsxYigvXg==", - "dev": true, - "dependencies": { - "kind-of": "^3.0.2" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-data-descriptor/node_modules/kind-of": { - "version": "3.2.2", - "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", - "integrity": "sha512-NOW9QQXMoZGg/oqnVNoNTTIFEIid1627WCffUBJEdMxYApq7mNE7CpzucIPc+ZQg25Phej7IJSmX3hO+oblOtQ==", - "dev": true, - "dependencies": { - "is-buffer": "^1.1.5" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-descriptor": { - "version": "0.1.6", - "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-0.1.6.tgz", - "integrity": "sha512-avDYr0SB3DwO9zsMov0gKCESFYqCnE4hq/4z3TdUlukEy5t9C0YRq7HLrsN52NAcqXKaepeCD0n+B0arnVG3Hg==", - "dev": true, - "dependencies": { - "is-accessor-descriptor": "^0.1.6", - "is-data-descriptor": "^0.1.4", - "kind-of": "^5.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-descriptor/node_modules/kind-of": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-5.1.0.tgz", - "integrity": "sha512-NGEErnH6F2vUuXDh+OlbcKW7/wOcfdRHaZ7VWtqCztfHri/++YKmP51OdWeGPuqCOba6kk2OTe5d02VmTB80Pw==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-docker": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", - "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", - "dev": true, - "optional": true, - "bin": { - "is-docker": "cli.js" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/is-extendable": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-0.1.1.tgz", - "integrity": "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-extglob": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", - "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-fullwidth-code-point": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", - "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/is-generator-fn": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/is-generator-fn/-/is-generator-fn-2.1.0.tgz", - "integrity": "sha512-cTIB4yPYL/Grw0EaSzASzg6bBy9gqCofvWN8okThAYIxKJZC+udlRAmGbM0XLeniEJSs8uEgHPGuHSe1XsOLSQ==", - "dev": true, - "engines": { - "node": ">=6" - } - }, - "node_modules/is-glob": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", - "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", - "dev": true, - "dependencies": { - "is-extglob": "^2.1.1" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-module": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-module/-/is-module-1.0.0.tgz", - "integrity": "sha512-51ypPSPCoTEIN9dy5Oy+h4pShgJmPCygKfyRCISBI+JoWT/2oJvK8QPxmwv7b/p239jXrm9M1mlQbyKJ5A152g==", - "dev": true - }, - "node_modules/is-number": { - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", - "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", - "dev": true, - "engines": { - "node": ">=0.12.0" - } - }, - "node_modules/is-path-inside": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-3.0.3.tgz", - "integrity": "sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/is-plain-object": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/is-plain-object/-/is-plain-object-2.0.4.tgz", - "integrity": "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og==", - "dev": true, - "dependencies": { - "isobject": "^3.0.1" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-potential-custom-element-name": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/is-potential-custom-element-name/-/is-potential-custom-element-name-1.0.1.tgz", - "integrity": "sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==", - "dev": true - }, - "node_modules/is-reference": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/is-reference/-/is-reference-1.2.1.tgz", - "integrity": "sha512-U82MsXXiFIrjCK4otLT+o2NA2Cd2g5MLoOVXUZjIOhLurrRxpEXzI8O0KZHr3IjLvlAH1kTPYSuqer5T9ZVBKQ==", - "dev": true, - "dependencies": { - "@types/estree": "*" - } - }, - "node_modules/is-stream": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-1.1.0.tgz", - "integrity": "sha512-uQPm8kcs47jx38atAcWTVxyltQYoPT68y9aWYdV6yWXSyW8mzSat0TL6CiWdZeCdF3KrAvpVtnHbTv4RN+rqdQ==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-typedarray": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-typedarray/-/is-typedarray-1.0.0.tgz", - "integrity": "sha512-cyA56iCMHAh5CdzjJIa4aohJyeO1YbwLi3Jc35MmRU6poroFjIGZzUzupGiRPOjgHg9TLu43xbpwXk523fMxKA==", - "dev": true - }, - "node_modules/is-windows": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/is-windows/-/is-windows-1.0.2.tgz", - "integrity": "sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-wsl": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", - "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", - "dev": true, - "optional": true, - "dependencies": { - "is-docker": "^2.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/isarray": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", - "integrity": "sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==", - "dev": true - }, - "node_modules/isexe": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", - "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", - "dev": true - }, - "node_modules/isobject": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/isobject/-/isobject-3.0.1.tgz", - "integrity": "sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/istanbul-lib-coverage": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.0.tgz", - "integrity": "sha512-eOeJ5BHCmHYvQK7xt9GkdHuzuCGS1Y6g9Gvnx3Ym33fz/HpLRYxiS0wHNr+m/MBC8B647Xt608vCDEvhl9c6Mw==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/istanbul-lib-instrument": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/istanbul-lib-instrument/-/istanbul-lib-instrument-4.0.3.tgz", - "integrity": "sha512-BXgQl9kf4WTCPCCpmFGoJkz/+uhvm7h7PFKUYxh7qarQd3ER33vHG//qaE8eN25l07YqZPpHXU9I09l/RD5aGQ==", - "dev": true, - "dependencies": { - "@babel/core": "^7.7.5", - "@istanbuljs/schema": "^0.1.2", - "istanbul-lib-coverage": "^3.0.0", - "semver": "^6.3.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/istanbul-lib-instrument/node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", - "dev": true, - "bin": { - "semver": "bin/semver.js" - } - }, - "node_modules/istanbul-lib-report": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz", - "integrity": "sha512-wcdi+uAKzfiGT2abPpKZ0hSU1rGQjUQnLvtY5MpQ7QCTahD3VODhcu4wcfY1YtkGaDD5yuydOLINXsfbus9ROw==", - "dev": true, - "dependencies": { - "istanbul-lib-coverage": "^3.0.0", - "make-dir": "^3.0.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/istanbul-lib-report/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/istanbul-lib-report/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/istanbul-lib-source-maps": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/istanbul-lib-source-maps/-/istanbul-lib-source-maps-4.0.1.tgz", - "integrity": "sha512-n3s8EwkdFIJCG3BPKBYvskgXGoy88ARzvegkitk60NxRdwltLOTaH7CUiMRXvwYorl0Q712iEjcWB+fK/MrWVw==", - "dev": true, - "dependencies": { - "debug": "^4.1.1", - "istanbul-lib-coverage": "^3.0.0", - "source-map": "^0.6.1" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/istanbul-reports": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/istanbul-reports/-/istanbul-reports-3.1.5.tgz", - "integrity": "sha512-nUsEMa9pBt/NOHqbcbeJEgqIlY/K7rVWUX6Lql2orY5e9roQOthbR3vtY4zzf2orPELg80fnxxk9zUyPlgwD1w==", - "dev": true, - "dependencies": { - "html-escaper": "^2.0.0", - "istanbul-lib-report": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest": { - "version": "26.6.3", - "resolved": "https://registry.npmjs.org/jest/-/jest-26.6.3.tgz", - "integrity": "sha512-lGS5PXGAzR4RF7V5+XObhqz2KZIDUA1yD0DG6pBVmy10eh0ZIXQImRuzocsI/N2XZ1GrLFwTS27In2i2jlpq1Q==", - "dev": true, - "dependencies": { - "@jest/core": "^26.6.3", - "import-local": "^3.0.2", - "jest-cli": "^26.6.3" - }, - "bin": { - "jest": "bin/jest.js" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-changed-files": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-changed-files/-/jest-changed-files-26.6.2.tgz", - "integrity": "sha512-fDS7szLcY9sCtIip8Fjry9oGf3I2ht/QT21bAHm5Dmf0mD4X3ReNUf17y+bO6fR8WgbIZTlbyG1ak/53cbRzKQ==", - "dev": true, - "dependencies": { - "@jest/types": "^26.6.2", - "execa": "^4.0.0", - "throat": "^5.0.0" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-changed-files/node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", - "dev": true, - "dependencies": { - "path-key": "^3.1.0", - "shebang-command": "^2.0.0", - "which": "^2.0.1" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/jest-changed-files/node_modules/execa": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/execa/-/execa-4.1.0.tgz", - "integrity": "sha512-j5W0//W7f8UxAn8hXVnwG8tLwdiUy4FJLcSupCg6maBYZDpyBvTApK7KyuI4bKj8KOh1r2YH+6ucuYtJv1bTZA==", - "dev": true, - "dependencies": { - "cross-spawn": "^7.0.0", - "get-stream": "^5.0.0", - "human-signals": "^1.1.1", - "is-stream": "^2.0.0", - "merge-stream": "^2.0.0", - "npm-run-path": "^4.0.0", - "onetime": "^5.1.0", - "signal-exit": "^3.0.2", - "strip-final-newline": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sindresorhus/execa?sponsor=1" - } - }, - "node_modules/jest-changed-files/node_modules/get-stream": { - "version": "5.2.0", - "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-5.2.0.tgz", - "integrity": "sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA==", - "dev": true, - "dependencies": { - "pump": "^3.0.0" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/jest-changed-files/node_modules/is-stream": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", - "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", - "dev": true, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/jest-changed-files/node_modules/npm-run-path": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", - "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", - "dev": true, - "dependencies": { - "path-key": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-changed-files/node_modules/path-key": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", - "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-changed-files/node_modules/shebang-command": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", - "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", - "dev": true, - "dependencies": { - "shebang-regex": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-changed-files/node_modules/shebang-regex": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", - "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-changed-files/node_modules/which": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", - "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "dev": true, - "dependencies": { - "isexe": "^2.0.0" - }, - "bin": { - "node-which": "bin/node-which" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/jest-config": { - "version": "26.6.3", - "resolved": "https://registry.npmjs.org/jest-config/-/jest-config-26.6.3.tgz", - "integrity": "sha512-t5qdIj/bCj2j7NFVHb2nFB4aUdfucDn3JRKgrZnplb8nieAirAzRSHP8uDEd+qV6ygzg9Pz4YG7UTJf94LPSyg==", - "dev": true, - "dependencies": { - "@babel/core": "^7.1.0", - "@jest/test-sequencer": "^26.6.3", - "@jest/types": "^26.6.2", - "babel-jest": "^26.6.3", - "chalk": "^4.0.0", - "deepmerge": "^4.2.2", - "glob": "^7.1.1", - "graceful-fs": "^4.2.4", - "jest-environment-jsdom": "^26.6.2", - "jest-environment-node": "^26.6.2", - "jest-get-type": "^26.3.0", - "jest-jasmine2": "^26.6.3", - "jest-regex-util": "^26.0.0", - "jest-resolve": "^26.6.2", - "jest-util": "^26.6.2", - "jest-validate": "^26.6.2", - "micromatch": "^4.0.2", - "pretty-format": "^26.6.2" - }, - "engines": { - "node": ">= 10.14.2" - }, - "peerDependencies": { - "ts-node": ">=9.0.0" - }, - "peerDependenciesMeta": { - "ts-node": { - "optional": true - } - } - }, - "node_modules/jest-config/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-config/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-config/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-config/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-config/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-config/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-diff": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-diff/-/jest-diff-26.6.2.tgz", - "integrity": "sha512-6m+9Z3Gv9wN0WFVasqjCL/06+EFCMTqDEUl/b87HYK2rAPTyfz4ZIuSlPhY51PIQRWx5TaxeF1qmXKe9gfN3sA==", - "dev": true, - "dependencies": { - "chalk": "^4.0.0", - "diff-sequences": "^26.6.2", - "jest-get-type": "^26.3.0", - "pretty-format": "^26.6.2" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-diff/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-diff/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-diff/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-diff/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-diff/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-diff/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-docblock": { - "version": "26.0.0", - "resolved": "https://registry.npmjs.org/jest-docblock/-/jest-docblock-26.0.0.tgz", - "integrity": "sha512-RDZ4Iz3QbtRWycd8bUEPxQsTlYazfYn/h5R65Fc6gOfwozFhoImx+affzky/FFBuqISPTqjXomoIGJVKBWoo0w==", - "dev": true, - "dependencies": { - "detect-newline": "^3.0.0" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-each": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-each/-/jest-each-26.6.2.tgz", - "integrity": "sha512-Mer/f0KaATbjl8MCJ+0GEpNdqmnVmDYqCTJYTvoo7rqmRiDllmp2AYN+06F93nXcY3ur9ShIjS+CO/uD+BbH4A==", - "dev": true, - "dependencies": { - "@jest/types": "^26.6.2", - "chalk": "^4.0.0", - "jest-get-type": "^26.3.0", - "jest-util": "^26.6.2", - "pretty-format": "^26.6.2" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-each/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-each/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-each/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-each/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-each/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-each/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-environment-jsdom": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-environment-jsdom/-/jest-environment-jsdom-26.6.2.tgz", - "integrity": "sha512-jgPqCruTlt3Kwqg5/WVFyHIOJHsiAvhcp2qiR2QQstuG9yWox5+iHpU3ZrcBxW14T4fe5Z68jAfLRh7joCSP2Q==", - "dev": true, - "dependencies": { - "@jest/environment": "^26.6.2", - "@jest/fake-timers": "^26.6.2", - "@jest/types": "^26.6.2", - "@types/node": "*", - "jest-mock": "^26.6.2", - "jest-util": "^26.6.2", - "jsdom": "^16.4.0" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-environment-node": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-environment-node/-/jest-environment-node-26.6.2.tgz", - "integrity": "sha512-zhtMio3Exty18dy8ee8eJ9kjnRyZC1N4C1Nt/VShN1apyXc8rWGtJ9lI7vqiWcyyXS4BVSEn9lxAM2D+07/Tag==", - "dev": true, - "dependencies": { - "@jest/environment": "^26.6.2", - "@jest/fake-timers": "^26.6.2", - "@jest/types": "^26.6.2", - "@types/node": "*", - "jest-mock": "^26.6.2", - "jest-util": "^26.6.2" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-get-type": { - "version": "26.3.0", - "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-26.3.0.tgz", - "integrity": "sha512-TpfaviN1R2pQWkIihlfEanwOXK0zcxrKEE4MlU6Tn7keoXdN6/3gK/xl0yEh8DOunn5pOVGKf8hB4R9gVh04ig==", - "dev": true, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-haste-map": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-haste-map/-/jest-haste-map-26.6.2.tgz", - "integrity": "sha512-easWIJXIw71B2RdR8kgqpjQrbMRWQBgiBwXYEhtGUTaX+doCjBheluShdDMeR8IMfJiTqH4+zfhtg29apJf/8w==", - "dev": true, - "dependencies": { - "@jest/types": "^26.6.2", - "@types/graceful-fs": "^4.1.2", - "@types/node": "*", - "anymatch": "^3.0.3", - "fb-watchman": "^2.0.0", - "graceful-fs": "^4.2.4", - "jest-regex-util": "^26.0.0", - "jest-serializer": "^26.6.2", - "jest-util": "^26.6.2", - "jest-worker": "^26.6.2", - "micromatch": "^4.0.2", - "sane": "^4.0.3", - "walker": "^1.0.7" - }, - "engines": { - "node": ">= 10.14.2" - }, - "optionalDependencies": { - "fsevents": "^2.1.2" - } - }, - "node_modules/jest-jasmine2": { - "version": "26.6.3", - "resolved": "https://registry.npmjs.org/jest-jasmine2/-/jest-jasmine2-26.6.3.tgz", - "integrity": "sha512-kPKUrQtc8aYwBV7CqBg5pu+tmYXlvFlSFYn18ev4gPFtrRzB15N2gW/Roew3187q2w2eHuu0MU9TJz6w0/nPEg==", - "dev": true, - "dependencies": { - "@babel/traverse": "^7.1.0", - "@jest/environment": "^26.6.2", - "@jest/source-map": "^26.6.2", - "@jest/test-result": "^26.6.2", - "@jest/types": "^26.6.2", - "@types/node": "*", - "chalk": "^4.0.0", - "co": "^4.6.0", - "expect": "^26.6.2", - "is-generator-fn": "^2.0.0", - "jest-each": "^26.6.2", - "jest-matcher-utils": "^26.6.2", - "jest-message-util": "^26.6.2", - "jest-runtime": "^26.6.3", - "jest-snapshot": "^26.6.2", - "jest-util": "^26.6.2", - "pretty-format": "^26.6.2", - "throat": "^5.0.0" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-jasmine2/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-jasmine2/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-jasmine2/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-jasmine2/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-jasmine2/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-jasmine2/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-leak-detector": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-leak-detector/-/jest-leak-detector-26.6.2.tgz", - "integrity": "sha512-i4xlXpsVSMeKvg2cEKdfhh0H39qlJlP5Ex1yQxwF9ubahboQYMgTtz5oML35AVA3B4Eu+YsmwaiKVev9KCvLxg==", - "dev": true, - "dependencies": { - "jest-get-type": "^26.3.0", - "pretty-format": "^26.6.2" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-matcher-utils": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-matcher-utils/-/jest-matcher-utils-26.6.2.tgz", - "integrity": "sha512-llnc8vQgYcNqDrqRDXWwMr9i7rS5XFiCwvh6DTP7Jqa2mqpcCBBlpCbn+trkG0KNhPu/h8rzyBkriOtBstvWhw==", - "dev": true, - "dependencies": { - "chalk": "^4.0.0", - "jest-diff": "^26.6.2", - "jest-get-type": "^26.3.0", - "pretty-format": "^26.6.2" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-matcher-utils/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-matcher-utils/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-matcher-utils/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-matcher-utils/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-matcher-utils/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-matcher-utils/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-message-util": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-message-util/-/jest-message-util-26.6.2.tgz", - "integrity": "sha512-rGiLePzQ3AzwUshu2+Rn+UMFk0pHN58sOG+IaJbk5Jxuqo3NYO1U2/MIR4S1sKgsoYSXSzdtSa0TgrmtUwEbmA==", - "dev": true, - "dependencies": { - "@babel/code-frame": "^7.0.0", - "@jest/types": "^26.6.2", - "@types/stack-utils": "^2.0.0", - "chalk": "^4.0.0", - "graceful-fs": "^4.2.4", - "micromatch": "^4.0.2", - "pretty-format": "^26.6.2", - "slash": "^3.0.0", - "stack-utils": "^2.0.2" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-message-util/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-message-util/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-message-util/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-message-util/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-message-util/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-message-util/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-mock": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-mock/-/jest-mock-26.6.2.tgz", - "integrity": "sha512-YyFjePHHp1LzpzYcmgqkJ0nm0gg/lJx2aZFzFy1S6eUqNjXsOqTK10zNRff2dNfssgokjkG65OlWNcIlgd3zew==", - "dev": true, - "dependencies": { - "@jest/types": "^26.6.2", - "@types/node": "*" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-pnp-resolver": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/jest-pnp-resolver/-/jest-pnp-resolver-1.2.3.tgz", - "integrity": "sha512-+3NpwQEnRoIBtx4fyhblQDPgJI0H1IEIkX7ShLUjPGA7TtUTvI1oiKi3SR4oBR0hQhQR80l4WAe5RrXBwWMA8w==", - "dev": true, - "engines": { - "node": ">=6" - }, - "peerDependencies": { - "jest-resolve": "*" - }, - "peerDependenciesMeta": { - "jest-resolve": { - "optional": true - } - } - }, - "node_modules/jest-regex-util": { - "version": "26.0.0", - "resolved": "https://registry.npmjs.org/jest-regex-util/-/jest-regex-util-26.0.0.tgz", - "integrity": "sha512-Gv3ZIs/nA48/Zvjrl34bf+oD76JHiGDUxNOVgUjh3j890sblXryjY4rss71fPtD/njchl6PSE2hIhvyWa1eT0A==", - "dev": true, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-resolve": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-resolve/-/jest-resolve-26.6.2.tgz", - "integrity": "sha512-sOxsZOq25mT1wRsfHcbtkInS+Ek7Q8jCHUB0ZUTP0tc/c41QHriU/NunqMfCUWsL4H3MHpvQD4QR9kSYhS7UvQ==", - "dev": true, - "dependencies": { - "@jest/types": "^26.6.2", - "chalk": "^4.0.0", - "graceful-fs": "^4.2.4", - "jest-pnp-resolver": "^1.2.2", - "jest-util": "^26.6.2", - "read-pkg-up": "^7.0.1", - "resolve": "^1.18.1", - "slash": "^3.0.0" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-resolve-dependencies": { - "version": "26.6.3", - "resolved": "https://registry.npmjs.org/jest-resolve-dependencies/-/jest-resolve-dependencies-26.6.3.tgz", - "integrity": "sha512-pVwUjJkxbhe4RY8QEWzN3vns2kqyuldKpxlxJlzEYfKSvY6/bMvxoFrYYzUO1Gx28yKWN37qyV7rIoIp2h8fTg==", - "dev": true, - "dependencies": { - "@jest/types": "^26.6.2", - "jest-regex-util": "^26.0.0", - "jest-snapshot": "^26.6.2" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-resolve/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-resolve/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-resolve/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-resolve/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-resolve/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-resolve/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-runner": { - "version": "26.6.3", - "resolved": "https://registry.npmjs.org/jest-runner/-/jest-runner-26.6.3.tgz", - "integrity": "sha512-atgKpRHnaA2OvByG/HpGA4g6CSPS/1LK0jK3gATJAoptC1ojltpmVlYC3TYgdmGp+GLuhzpH30Gvs36szSL2JQ==", - "dev": true, - "dependencies": { - "@jest/console": "^26.6.2", - "@jest/environment": "^26.6.2", - "@jest/test-result": "^26.6.2", - "@jest/types": "^26.6.2", - "@types/node": "*", - "chalk": "^4.0.0", - "emittery": "^0.7.1", - "exit": "^0.1.2", - "graceful-fs": "^4.2.4", - "jest-config": "^26.6.3", - "jest-docblock": "^26.0.0", - "jest-haste-map": "^26.6.2", - "jest-leak-detector": "^26.6.2", - "jest-message-util": "^26.6.2", - "jest-resolve": "^26.6.2", - "jest-runtime": "^26.6.3", - "jest-util": "^26.6.2", - "jest-worker": "^26.6.2", - "source-map-support": "^0.5.6", - "throat": "^5.0.0" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-runner/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "node_modules/istanbul-lib-instrument/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "bin": { + "semver": "bin/semver.js" } }, - "node_modules/jest-runner/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/istanbul-lib-report": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/istanbul-lib-report/-/istanbul-lib-report-3.0.1.tgz", + "integrity": "sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", + "istanbul-lib-coverage": "^3.0.0", + "make-dir": "^4.0.0", "supports-color": "^7.1.0" }, "engines": { "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" } }, - "node_modules/jest-runner/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/istanbul-lib-source-maps": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/istanbul-lib-source-maps/-/istanbul-lib-source-maps-4.0.1.tgz", + "integrity": "sha512-n3s8EwkdFIJCG3BPKBYvskgXGoy88ARzvegkitk60NxRdwltLOTaH7CUiMRXvwYorl0Q712iEjcWB+fK/MrWVw==", "dev": true, "dependencies": { - "color-name": "~1.1.4" + "debug": "^4.1.1", + "istanbul-lib-coverage": "^3.0.0", + "source-map": "^0.6.1" }, "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-runner/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-runner/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" + "node": ">=10" } }, - "node_modules/jest-runner/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/istanbul-reports": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/istanbul-reports/-/istanbul-reports-3.2.0.tgz", + "integrity": "sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "html-escaper": "^2.0.0", + "istanbul-lib-report": "^3.0.0" }, "engines": { "node": ">=8" } }, - "node_modules/jest-runtime": { + "node_modules/jest": { "version": "26.6.3", - "resolved": "https://registry.npmjs.org/jest-runtime/-/jest-runtime-26.6.3.tgz", - "integrity": "sha512-lrzyR3N8sacTAMeonbqpnSka1dHNux2uk0qqDXVkMv2c/A3wYnvQ4EXuI013Y6+gSKSCxdaczvf4HF0mVXHRdw==", + "resolved": "https://registry.npmjs.org/jest/-/jest-26.6.3.tgz", + "integrity": "sha512-lGS5PXGAzR4RF7V5+XObhqz2KZIDUA1yD0DG6pBVmy10eh0ZIXQImRuzocsI/N2XZ1GrLFwTS27In2i2jlpq1Q==", "dev": true, "dependencies": { - "@jest/console": "^26.6.2", - "@jest/environment": "^26.6.2", - "@jest/fake-timers": "^26.6.2", - "@jest/globals": "^26.6.2", - "@jest/source-map": "^26.6.2", - "@jest/test-result": "^26.6.2", - "@jest/transform": "^26.6.2", - "@jest/types": "^26.6.2", - "@types/yargs": "^15.0.0", - "chalk": "^4.0.0", - "cjs-module-lexer": "^0.6.0", - "collect-v8-coverage": "^1.0.0", - "exit": "^0.1.2", - "glob": "^7.1.3", - "graceful-fs": "^4.2.4", - "jest-config": "^26.6.3", - "jest-haste-map": "^26.6.2", - "jest-message-util": "^26.6.2", - "jest-mock": "^26.6.2", - "jest-regex-util": "^26.0.0", - "jest-resolve": "^26.6.2", - "jest-snapshot": "^26.6.2", - "jest-util": "^26.6.2", - "jest-validate": "^26.6.2", - "slash": "^3.0.0", - "strip-bom": "^4.0.0", - "yargs": "^15.4.1" + "@jest/core": "^26.6.3", + "import-local": "^3.0.2", + "jest-cli": "^26.6.3" }, "bin": { - "jest-runtime": "bin/jest-runtime.js" - }, - "engines": { - "node": ">= 10.14.2" - } - }, - "node_modules/jest-runtime/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-runtime/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-runtime/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-runtime/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-runtime/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-runtime/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-serializer": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-serializer/-/jest-serializer-26.6.2.tgz", - "integrity": "sha512-S5wqyz0DXnNJPd/xfIzZ5Xnp1HrJWBczg8mMfMpN78OJ5eDxXyf+Ygld9wX1DnUWbIbhM1YDY95NjR4CBXkb2g==", - "dev": true, - "dependencies": { - "@types/node": "*", - "graceful-fs": "^4.2.4" + "jest": "bin/jest.js" }, "engines": { "node": ">= 10.14.2" } }, - "node_modules/jest-snapshot": { + "node_modules/jest-changed-files": { "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-snapshot/-/jest-snapshot-26.6.2.tgz", - "integrity": "sha512-OLhxz05EzUtsAmOMzuupt1lHYXCNib0ECyuZ/PZOx9TrZcC8vL0x+DUG3TL+GLX3yHG45e6YGjIm0XwDc3q3og==", + "resolved": "https://registry.npmjs.org/jest-changed-files/-/jest-changed-files-26.6.2.tgz", + "integrity": "sha512-fDS7szLcY9sCtIip8Fjry9oGf3I2ht/QT21bAHm5Dmf0mD4X3ReNUf17y+bO6fR8WgbIZTlbyG1ak/53cbRzKQ==", "dev": true, "dependencies": { - "@babel/types": "^7.0.0", "@jest/types": "^26.6.2", - "@types/babel__traverse": "^7.0.4", - "@types/prettier": "^2.0.0", - "chalk": "^4.0.0", - "expect": "^26.6.2", - "graceful-fs": "^4.2.4", - "jest-diff": "^26.6.2", - "jest-get-type": "^26.3.0", - "jest-haste-map": "^26.6.2", - "jest-matcher-utils": "^26.6.2", - "jest-message-util": "^26.6.2", - "jest-resolve": "^26.6.2", - "natural-compare": "^1.4.0", - "pretty-format": "^26.6.2", - "semver": "^7.3.2" + "execa": "^4.0.0", + "throat": "^5.0.0" }, "engines": { "node": ">= 10.14.2" - } - }, - "node_modules/jest-snapshot/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/jest-snapshot/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", - "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/jest-snapshot/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-snapshot/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-snapshot/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-snapshot/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/jest-util": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-util/-/jest-util-26.6.2.tgz", - "integrity": "sha512-MDW0fKfsn0OI7MS7Euz6h8HNDXVQ0gaM9uW6RjfDmd1DAFcaxX9OqIakHIqhbnmF08Cf2DLDG+ulq8YQQ0Lp0Q==", + } + }, + "node_modules/jest-cli": { + "version": "26.6.3", + "resolved": "https://registry.npmjs.org/jest-cli/-/jest-cli-26.6.3.tgz", + "integrity": "sha512-GF9noBSa9t08pSyl3CY4frMrqp+aQXFGFkf5hEPbh/pIUFYWMK6ZLTfbmadxJVcJrdRoChlWQsA2VkJcDFK8hg==", "dev": true, "dependencies": { + "@jest/core": "^26.6.3", + "@jest/test-result": "^26.6.2", "@jest/types": "^26.6.2", - "@types/node": "*", "chalk": "^4.0.0", + "exit": "^0.1.2", "graceful-fs": "^4.2.4", + "import-local": "^3.0.2", "is-ci": "^2.0.0", - "micromatch": "^4.0.2" + "jest-config": "^26.6.3", + "jest-util": "^26.6.2", + "jest-validate": "^26.6.2", + "prompts": "^2.0.1", + "yargs": "^15.4.1" + }, + "bin": { + "jest": "bin/jest.js" }, "engines": { "node": ">= 10.14.2" } }, - "node_modules/jest-util/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/jest-config": { + "version": "26.6.3", + "resolved": "https://registry.npmjs.org/jest-config/-/jest-config-26.6.3.tgz", + "integrity": "sha512-t5qdIj/bCj2j7NFVHb2nFB4aUdfucDn3JRKgrZnplb8nieAirAzRSHP8uDEd+qV6ygzg9Pz4YG7UTJf94LPSyg==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "@babel/core": "^7.1.0", + "@jest/test-sequencer": "^26.6.3", + "@jest/types": "^26.6.2", + "babel-jest": "^26.6.3", + "chalk": "^4.0.0", + "deepmerge": "^4.2.2", + "glob": "^7.1.1", + "graceful-fs": "^4.2.4", + "jest-environment-jsdom": "^26.6.2", + "jest-environment-node": "^26.6.2", + "jest-get-type": "^26.3.0", + "jest-jasmine2": "^26.6.3", + "jest-regex-util": "^26.0.0", + "jest-resolve": "^26.6.2", + "jest-util": "^26.6.2", + "jest-validate": "^26.6.2", + "micromatch": "^4.0.2", + "pretty-format": "^26.6.2" }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "peerDependencies": { + "ts-node": ">=9.0.0" + }, + "peerDependenciesMeta": { + "ts-node": { + "optional": true + } } }, - "node_modules/jest-util/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/jest-diff": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-diff/-/jest-diff-26.6.2.tgz", + "integrity": "sha512-6m+9Z3Gv9wN0WFVasqjCL/06+EFCMTqDEUl/b87HYK2rAPTyfz4ZIuSlPhY51PIQRWx5TaxeF1qmXKe9gfN3sA==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" + "chalk": "^4.0.0", + "diff-sequences": "^26.6.2", + "jest-get-type": "^26.3.0", + "pretty-format": "^26.6.2" }, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "node": ">= 10.14.2" } }, - "node_modules/jest-util/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/jest-docblock": { + "version": "26.0.0", + "resolved": "https://registry.npmjs.org/jest-docblock/-/jest-docblock-26.0.0.tgz", + "integrity": "sha512-RDZ4Iz3QbtRWycd8bUEPxQsTlYazfYn/h5R65Fc6gOfwozFhoImx+affzky/FFBuqISPTqjXomoIGJVKBWoo0w==", "dev": true, "dependencies": { - "color-name": "~1.1.4" + "detect-newline": "^3.0.0" }, "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/jest-util/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-util/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest-util/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/jest-each": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-each/-/jest-each-26.6.2.tgz", + "integrity": "sha512-Mer/f0KaATbjl8MCJ+0GEpNdqmnVmDYqCTJYTvoo7rqmRiDllmp2AYN+06F93nXcY3ur9ShIjS+CO/uD+BbH4A==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "@jest/types": "^26.6.2", + "chalk": "^4.0.0", + "jest-get-type": "^26.3.0", + "jest-util": "^26.6.2", + "pretty-format": "^26.6.2" }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest-validate": { + "node_modules/jest-environment-jsdom": { "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-validate/-/jest-validate-26.6.2.tgz", - "integrity": "sha512-NEYZ9Aeyj0i5rQqbq+tpIOom0YS1u2MVu6+euBsvpgIme+FOfRmoC4R5p0JiAUpaFvFy24xgrpMknarR/93XjQ==", + "resolved": "https://registry.npmjs.org/jest-environment-jsdom/-/jest-environment-jsdom-26.6.2.tgz", + "integrity": "sha512-jgPqCruTlt3Kwqg5/WVFyHIOJHsiAvhcp2qiR2QQstuG9yWox5+iHpU3ZrcBxW14T4fe5Z68jAfLRh7joCSP2Q==", "dev": true, "dependencies": { + "@jest/environment": "^26.6.2", + "@jest/fake-timers": "^26.6.2", "@jest/types": "^26.6.2", - "camelcase": "^6.0.0", - "chalk": "^4.0.0", - "jest-get-type": "^26.3.0", - "leven": "^3.1.0", - "pretty-format": "^26.6.2" + "@types/node": "*", + "jest-mock": "^26.6.2", + "jest-util": "^26.6.2", + "jsdom": "^16.4.0" }, "engines": { "node": ">= 10.14.2" } }, - "node_modules/jest-validate/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/jest-environment-node": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-environment-node/-/jest-environment-node-26.6.2.tgz", + "integrity": "sha512-zhtMio3Exty18dy8ee8eJ9kjnRyZC1N4C1Nt/VShN1apyXc8rWGtJ9lI7vqiWcyyXS4BVSEn9lxAM2D+07/Tag==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "@jest/environment": "^26.6.2", + "@jest/fake-timers": "^26.6.2", + "@jest/types": "^26.6.2", + "@types/node": "*", + "jest-mock": "^26.6.2", + "jest-util": "^26.6.2" }, "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "node": ">= 10.14.2" } }, - "node_modules/jest-validate/node_modules/camelcase": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-6.3.0.tgz", - "integrity": "sha512-Gmy6FhYlCY7uOElZUSbxo2UCDH8owEk996gkbrpsgGtrJLM3J7jGxl9Ic7Qwwj4ivOE5AWZWRMecDdF7hqGjFA==", + "node_modules/jest-get-type": { + "version": "26.3.0", + "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-26.3.0.tgz", + "integrity": "sha512-TpfaviN1R2pQWkIihlfEanwOXK0zcxrKEE4MlU6Tn7keoXdN6/3gK/xl0yEh8DOunn5pOVGKf8hB4R9gVh04ig==", "dev": true, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "node": ">= 10.14.2" } }, - "node_modules/jest-validate/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/jest-haste-map": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-haste-map/-/jest-haste-map-26.6.2.tgz", + "integrity": "sha512-easWIJXIw71B2RdR8kgqpjQrbMRWQBgiBwXYEhtGUTaX+doCjBheluShdDMeR8IMfJiTqH4+zfhtg29apJf/8w==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" + "@jest/types": "^26.6.2", + "@types/graceful-fs": "^4.1.2", + "@types/node": "*", + "anymatch": "^3.0.3", + "fb-watchman": "^2.0.0", + "graceful-fs": "^4.2.4", + "jest-regex-util": "^26.0.0", + "jest-serializer": "^26.6.2", + "jest-util": "^26.6.2", + "jest-worker": "^26.6.2", + "micromatch": "^4.0.2", + "sane": "^4.0.3", + "walker": "^1.0.7" }, "engines": { - "node": ">=10" + "node": ">= 10.14.2" }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "optionalDependencies": { + "fsevents": "^2.1.2" } }, - "node_modules/jest-validate/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/jest-jasmine2": { + "version": "26.6.3", + "resolved": "https://registry.npmjs.org/jest-jasmine2/-/jest-jasmine2-26.6.3.tgz", + "integrity": "sha512-kPKUrQtc8aYwBV7CqBg5pu+tmYXlvFlSFYn18ev4gPFtrRzB15N2gW/Roew3187q2w2eHuu0MU9TJz6w0/nPEg==", "dev": true, "dependencies": { - "color-name": "~1.1.4" + "@babel/traverse": "^7.1.0", + "@jest/environment": "^26.6.2", + "@jest/source-map": "^26.6.2", + "@jest/test-result": "^26.6.2", + "@jest/types": "^26.6.2", + "@types/node": "*", + "chalk": "^4.0.0", + "co": "^4.6.0", + "expect": "^26.6.2", + "is-generator-fn": "^2.0.0", + "jest-each": "^26.6.2", + "jest-matcher-utils": "^26.6.2", + "jest-message-util": "^26.6.2", + "jest-runtime": "^26.6.3", + "jest-snapshot": "^26.6.2", + "jest-util": "^26.6.2", + "pretty-format": "^26.6.2", + "throat": "^5.0.0" }, "engines": { - "node": ">=7.0.0" + "node": ">= 10.14.2" } }, - "node_modules/jest-validate/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-validate/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/jest-leak-detector": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-leak-detector/-/jest-leak-detector-26.6.2.tgz", + "integrity": "sha512-i4xlXpsVSMeKvg2cEKdfhh0H39qlJlP5Ex1yQxwF9ubahboQYMgTtz5oML35AVA3B4Eu+YsmwaiKVev9KCvLxg==", "dev": true, + "dependencies": { + "jest-get-type": "^26.3.0", + "pretty-format": "^26.6.2" + }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest-validate/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/jest-matcher-utils": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-matcher-utils/-/jest-matcher-utils-26.6.2.tgz", + "integrity": "sha512-llnc8vQgYcNqDrqRDXWwMr9i7rS5XFiCwvh6DTP7Jqa2mqpcCBBlpCbn+trkG0KNhPu/h8rzyBkriOtBstvWhw==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "chalk": "^4.0.0", + "jest-diff": "^26.6.2", + "jest-get-type": "^26.3.0", + "pretty-format": "^26.6.2" }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest-watcher": { + "node_modules/jest-message-util": { "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-watcher/-/jest-watcher-26.6.2.tgz", - "integrity": "sha512-WKJob0P/Em2csiVthsI68p6aGKTIcsfjH9Gsx1f0A3Italz43e3ho0geSAVsmj09RWOELP1AZ/DXyJgOgDKxXQ==", + "resolved": "https://registry.npmjs.org/jest-message-util/-/jest-message-util-26.6.2.tgz", + "integrity": "sha512-rGiLePzQ3AzwUshu2+Rn+UMFk0pHN58sOG+IaJbk5Jxuqo3NYO1U2/MIR4S1sKgsoYSXSzdtSa0TgrmtUwEbmA==", "dev": true, "dependencies": { - "@jest/test-result": "^26.6.2", + "@babel/code-frame": "^7.0.0", "@jest/types": "^26.6.2", - "@types/node": "*", - "ansi-escapes": "^4.2.1", + "@types/stack-utils": "^2.0.0", "chalk": "^4.0.0", - "jest-util": "^26.6.2", - "string-length": "^4.0.1" + "graceful-fs": "^4.2.4", + "micromatch": "^4.0.2", + "pretty-format": "^26.6.2", + "slash": "^3.0.0", + "stack-utils": "^2.0.2" }, "engines": { "node": ">= 10.14.2" } }, - "node_modules/jest-watcher/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/jest-mock": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-mock/-/jest-mock-26.6.2.tgz", + "integrity": "sha512-YyFjePHHp1LzpzYcmgqkJ0nm0gg/lJx2aZFzFy1S6eUqNjXsOqTK10zNRff2dNfssgokjkG65OlWNcIlgd3zew==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "@jest/types": "^26.6.2", + "@types/node": "*" }, "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "node": ">= 10.14.2" } }, - "node_modules/jest-watcher/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/jest-pnp-resolver": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/jest-pnp-resolver/-/jest-pnp-resolver-1.2.3.tgz", + "integrity": "sha512-+3NpwQEnRoIBtx4fyhblQDPgJI0H1IEIkX7ShLUjPGA7TtUTvI1oiKi3SR4oBR0hQhQR80l4WAe5RrXBwWMA8w==", "dev": true, - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, "engines": { - "node": ">=10" + "node": ">=6" }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "peerDependencies": { + "jest-resolve": "*" + }, + "peerDependenciesMeta": { + "jest-resolve": { + "optional": true + } } }, - "node_modules/jest-watcher/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/jest-regex-util": { + "version": "26.0.0", + "resolved": "https://registry.npmjs.org/jest-regex-util/-/jest-regex-util-26.0.0.tgz", + "integrity": "sha512-Gv3ZIs/nA48/Zvjrl34bf+oD76JHiGDUxNOVgUjh3j890sblXryjY4rss71fPtD/njchl6PSE2hIhvyWa1eT0A==", "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, "engines": { - "node": ">=7.0.0" + "node": ">= 10.14.2" } }, - "node_modules/jest-watcher/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest-watcher/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/jest-resolve": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-resolve/-/jest-resolve-26.6.2.tgz", + "integrity": "sha512-sOxsZOq25mT1wRsfHcbtkInS+Ek7Q8jCHUB0ZUTP0tc/c41QHriU/NunqMfCUWsL4H3MHpvQD4QR9kSYhS7UvQ==", "dev": true, + "dependencies": { + "@jest/types": "^26.6.2", + "chalk": "^4.0.0", + "graceful-fs": "^4.2.4", + "jest-pnp-resolver": "^1.2.2", + "jest-util": "^26.6.2", + "read-pkg-up": "^7.0.1", + "resolve": "^1.18.1", + "slash": "^3.0.0" + }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest-watcher/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/jest-resolve-dependencies": { + "version": "26.6.3", + "resolved": "https://registry.npmjs.org/jest-resolve-dependencies/-/jest-resolve-dependencies-26.6.3.tgz", + "integrity": "sha512-pVwUjJkxbhe4RY8QEWzN3vns2kqyuldKpxlxJlzEYfKSvY6/bMvxoFrYYzUO1Gx28yKWN37qyV7rIoIp2h8fTg==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "@jest/types": "^26.6.2", + "jest-regex-util": "^26.0.0", + "jest-snapshot": "^26.6.2" }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest-worker": { - "version": "26.6.2", - "resolved": "https://registry.npmjs.org/jest-worker/-/jest-worker-26.6.2.tgz", - "integrity": "sha512-KWYVV1c4i+jbMpaBC+U++4Va0cp8OisU185o73T1vo99hqi7w8tSJfUXYswwqqrjzwxa6KpRK54WhPvwf5w6PQ==", + "node_modules/jest-runner": { + "version": "26.6.3", + "resolved": "https://registry.npmjs.org/jest-runner/-/jest-runner-26.6.3.tgz", + "integrity": "sha512-atgKpRHnaA2OvByG/HpGA4g6CSPS/1LK0jK3gATJAoptC1ojltpmVlYC3TYgdmGp+GLuhzpH30Gvs36szSL2JQ==", "dev": true, "dependencies": { + "@jest/console": "^26.6.2", + "@jest/environment": "^26.6.2", + "@jest/test-result": "^26.6.2", + "@jest/types": "^26.6.2", "@types/node": "*", - "merge-stream": "^2.0.0", - "supports-color": "^7.0.0" + "chalk": "^4.0.0", + "emittery": "^0.7.1", + "exit": "^0.1.2", + "graceful-fs": "^4.2.4", + "jest-config": "^26.6.3", + "jest-docblock": "^26.0.0", + "jest-haste-map": "^26.6.2", + "jest-leak-detector": "^26.6.2", + "jest-message-util": "^26.6.2", + "jest-resolve": "^26.6.2", + "jest-runtime": "^26.6.3", + "jest-util": "^26.6.2", + "jest-worker": "^26.6.2", + "source-map-support": "^0.5.6", + "throat": "^5.0.0" }, "engines": { - "node": ">= 10.13.0" + "node": ">= 10.14.2" } }, - "node_modules/jest-worker/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/jest-runtime": { + "version": "26.6.3", + "resolved": "https://registry.npmjs.org/jest-runtime/-/jest-runtime-26.6.3.tgz", + "integrity": "sha512-lrzyR3N8sacTAMeonbqpnSka1dHNux2uk0qqDXVkMv2c/A3wYnvQ4EXuI013Y6+gSKSCxdaczvf4HF0mVXHRdw==", "dev": true, + "dependencies": { + "@jest/console": "^26.6.2", + "@jest/environment": "^26.6.2", + "@jest/fake-timers": "^26.6.2", + "@jest/globals": "^26.6.2", + "@jest/source-map": "^26.6.2", + "@jest/test-result": "^26.6.2", + "@jest/transform": "^26.6.2", + "@jest/types": "^26.6.2", + "@types/yargs": "^15.0.0", + "chalk": "^4.0.0", + "cjs-module-lexer": "^0.6.0", + "collect-v8-coverage": "^1.0.0", + "exit": "^0.1.2", + "glob": "^7.1.3", + "graceful-fs": "^4.2.4", + "jest-config": "^26.6.3", + "jest-haste-map": "^26.6.2", + "jest-message-util": "^26.6.2", + "jest-mock": "^26.6.2", + "jest-regex-util": "^26.0.0", + "jest-resolve": "^26.6.2", + "jest-snapshot": "^26.6.2", + "jest-util": "^26.6.2", + "jest-validate": "^26.6.2", + "slash": "^3.0.0", + "strip-bom": "^4.0.0", + "yargs": "^15.4.1" + }, + "bin": { + "jest-runtime": "bin/jest-runtime.js" + }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest-worker/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/jest-serializer": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-serializer/-/jest-serializer-26.6.2.tgz", + "integrity": "sha512-S5wqyz0DXnNJPd/xfIzZ5Xnp1HrJWBczg8mMfMpN78OJ5eDxXyf+Ygld9wX1DnUWbIbhM1YDY95NjR4CBXkb2g==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "@types/node": "*", + "graceful-fs": "^4.2.4" }, "engines": { - "node": ">=8" + "node": ">= 10.14.2" } }, - "node_modules/jest/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/jest-snapshot": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-snapshot/-/jest-snapshot-26.6.2.tgz", + "integrity": "sha512-OLhxz05EzUtsAmOMzuupt1lHYXCNib0ECyuZ/PZOx9TrZcC8vL0x+DUG3TL+GLX3yHG45e6YGjIm0XwDc3q3og==", "dev": true, "dependencies": { - "color-convert": "^2.0.1" + "@babel/types": "^7.0.0", + "@jest/types": "^26.6.2", + "@types/babel__traverse": "^7.0.4", + "@types/prettier": "^2.0.0", + "chalk": "^4.0.0", + "expect": "^26.6.2", + "graceful-fs": "^4.2.4", + "jest-diff": "^26.6.2", + "jest-get-type": "^26.3.0", + "jest-haste-map": "^26.6.2", + "jest-matcher-utils": "^26.6.2", + "jest-message-util": "^26.6.2", + "jest-resolve": "^26.6.2", + "natural-compare": "^1.4.0", + "pretty-format": "^26.6.2", + "semver": "^7.3.2" }, "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "node": ">= 10.14.2" } }, - "node_modules/jest/node_modules/chalk": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", - "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "node_modules/jest-util": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-util/-/jest-util-26.6.2.tgz", + "integrity": "sha512-MDW0fKfsn0OI7MS7Euz6h8HNDXVQ0gaM9uW6RjfDmd1DAFcaxX9OqIakHIqhbnmF08Cf2DLDG+ulq8YQQ0Lp0Q==", "dev": true, "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" + "@jest/types": "^26.6.2", + "@types/node": "*", + "chalk": "^4.0.0", + "graceful-fs": "^4.2.4", + "is-ci": "^2.0.0", + "micromatch": "^4.0.2" }, "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" + "node": ">= 10.14.2" } }, - "node_modules/jest/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "node_modules/jest-validate": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-validate/-/jest-validate-26.6.2.tgz", + "integrity": "sha512-NEYZ9Aeyj0i5rQqbq+tpIOom0YS1u2MVu6+euBsvpgIme+FOfRmoC4R5p0JiAUpaFvFy24xgrpMknarR/93XjQ==", "dev": true, "dependencies": { - "color-name": "~1.1.4" + "@jest/types": "^26.6.2", + "camelcase": "^6.0.0", + "chalk": "^4.0.0", + "jest-get-type": "^26.3.0", + "leven": "^3.1.0", + "pretty-format": "^26.6.2" }, "engines": { - "node": ">=7.0.0" + "node": ">= 10.14.2" } }, - "node_modules/jest/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, - "node_modules/jest/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/jest-validate/node_modules/camelcase": { + "version": "6.3.0", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-6.3.0.tgz", + "integrity": "sha512-Gmy6FhYlCY7uOElZUSbxo2UCDH8owEk996gkbrpsgGtrJLM3J7jGxl9Ic7Qwwj4ivOE5AWZWRMecDdF7hqGjFA==", "dev": true, "engines": { - "node": ">=8" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/jest/node_modules/jest-cli": { - "version": "26.6.3", - "resolved": "https://registry.npmjs.org/jest-cli/-/jest-cli-26.6.3.tgz", - "integrity": "sha512-GF9noBSa9t08pSyl3CY4frMrqp+aQXFGFkf5hEPbh/pIUFYWMK6ZLTfbmadxJVcJrdRoChlWQsA2VkJcDFK8hg==", + "node_modules/jest-watcher": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-watcher/-/jest-watcher-26.6.2.tgz", + "integrity": "sha512-WKJob0P/Em2csiVthsI68p6aGKTIcsfjH9Gsx1f0A3Italz43e3ho0geSAVsmj09RWOELP1AZ/DXyJgOgDKxXQ==", "dev": true, "dependencies": { - "@jest/core": "^26.6.3", "@jest/test-result": "^26.6.2", "@jest/types": "^26.6.2", + "@types/node": "*", + "ansi-escapes": "^4.2.1", "chalk": "^4.0.0", - "exit": "^0.1.2", - "graceful-fs": "^4.2.4", - "import-local": "^3.0.2", - "is-ci": "^2.0.0", - "jest-config": "^26.6.3", "jest-util": "^26.6.2", - "jest-validate": "^26.6.2", - "prompts": "^2.0.1", - "yargs": "^15.4.1" - }, - "bin": { - "jest": "bin/jest.js" + "string-length": "^4.0.1" }, "engines": { "node": ">= 10.14.2" } }, - "node_modules/jest/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "node_modules/jest-worker": { + "version": "26.6.2", + "resolved": "https://registry.npmjs.org/jest-worker/-/jest-worker-26.6.2.tgz", + "integrity": "sha512-KWYVV1c4i+jbMpaBC+U++4Va0cp8OisU185o73T1vo99hqi7w8tSJfUXYswwqqrjzwxa6KpRK54WhPvwf5w6PQ==", "dev": true, "dependencies": { - "has-flag": "^4.0.0" + "@types/node": "*", + "merge-stream": "^2.0.0", + "supports-color": "^7.0.0" }, "engines": { - "node": ">=8" + "node": ">= 10.13.0" } }, "node_modules/js-tokens": { @@ -6272,13 +4550,12 @@ "dev": true }, "node_modules/js-yaml": { - "version": "3.14.1", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.1.tgz", - "integrity": "sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "dependencies": { - "argparse": "^1.0.7", - "esprima": "^4.0.0" + "argparse": "^2.0.1" }, "bin": { "js-yaml": "bin/js-yaml.js" @@ -6330,30 +4607,24 @@ } } }, - "node_modules/jsdom/node_modules/acorn": { - "version": "8.8.1", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.1.tgz", - "integrity": "sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==", - "dev": true, - "bin": { - "acorn": "bin/acorn" - }, - "engines": { - "node": ">=0.4.0" - } - }, "node_modules/jsesc": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz", - "integrity": "sha512-OYu7XEzjkCQ3C5Ps3QIZsQfNpqoJyZZA99wd9aWd05NCtC5pWOkShK2mkL6HXQR6/Cy2lbNdPlZBpuQHXE63gA==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", + "integrity": "sha512-/sM3dO2FOzXjKQhJuo0Q173wf2KOo8t4I8vHy6lF9poUp7bKT0/NHE8fPX23PwfhnykfqnC2xRxOnVw5XuGIaA==", "dev": true, "bin": { "jsesc": "bin/jsesc" }, "engines": { - "node": ">=4" + "node": ">=6" } }, + "node_modules/json-buffer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", + "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", + "dev": true + }, "node_modules/json-parse-even-better-errors": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", @@ -6385,11 +4656,32 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.3.1.tgz", + "integrity": "sha512-HUgH65KyejrUFPvHFPbqOY0rsFip3Bo5wb4ngvdi1EpCYWUQDC5V+Y7mZws+DLkr4M//zQJoanu1SP+87Dv1oQ==", "dev": true }, + "node_modules/jsonfile": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-6.2.0.tgz", + "integrity": "sha512-FGuPw30AdOIUTRMC2OMRtQV+jkVj2cfPqSeWXv1NEAJ1qZ5zb1X6z1mFhbfOB/iy3ssJCD+3KuZ8r8C3uVFlAg==", + "dev": true, + "dependencies": { + "universalify": "^2.0.0" + }, + "optionalDependencies": { + "graceful-fs": "^4.1.6" + } + }, + "node_modules/keyv": { + "version": "4.5.4", + "resolved": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz", + "integrity": "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==", + "dev": true, + "dependencies": { + "json-buffer": "3.0.1" + } + }, "node_modules/kind-of": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-6.0.3.tgz", @@ -6418,13 +4710,13 @@ } }, "node_modules/levn": { - "version": "0.3.0", - "resolved": "https://registry.npmjs.org/levn/-/levn-0.3.0.tgz", - "integrity": "sha512-0OO4y2iOHix2W6ujICbKIaEQXvFQHue65vUG3pb5EUomzPI90z9hsA1VsO/dbIIpC53J8gxM9Q4Oho0jrCM/yA==", + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", "dev": true, "dependencies": { - "prelude-ls": "~1.1.2", - "type-check": "~0.3.2" + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" }, "engines": { "node": ">= 0.8.0" @@ -6437,15 +4729,18 @@ "dev": true }, "node_modules/locate-path": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", - "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", "dev": true, "dependencies": { - "p-locate": "^4.1.0" + "p-locate": "^5.0.0" }, "engines": { - "node": ">=8" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/lodash": { @@ -6461,15 +4756,12 @@ "dev": true }, "node_modules/lru-cache": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", - "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-5.1.1.tgz", + "integrity": "sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w==", "dev": true, "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=10" + "yallist": "^3.0.2" } }, "node_modules/lunr": { @@ -6488,29 +4780,20 @@ } }, "node_modules/make-dir": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-3.1.0.tgz", - "integrity": "sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw==", + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-4.0.0.tgz", + "integrity": "sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==", "dev": true, "dependencies": { - "semver": "^6.0.0" + "semver": "^7.5.3" }, "engines": { - "node": ">=8" + "node": ">=10" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/make-dir/node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", - "dev": true, - "bin": { - "semver": "bin/semver.js" - } - }, "node_modules/makeerror": { "version": "1.0.12", "resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz", @@ -6553,6 +4836,15 @@ "node": ">= 12" } }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "dev": true, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/merge-stream": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", @@ -6569,12 +4861,12 @@ } }, "node_modules/micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "dependencies": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" }, "engines": { @@ -6624,9 +4916,9 @@ } }, "node_modules/minimist": { - "version": "1.2.7", - "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.7.tgz", - "integrity": "sha512-bzfL1YUZsP41gmu/qjrEk0Q6i2ix/cVeAhbCbqH9u3zYutS1cLg00qhrD0M2MVdCcx4Sc0UpP2eBWo9rotpq6g==", + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", + "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", "dev": true, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -6645,22 +4937,10 @@ "node": ">=0.10.0" } }, - "node_modules/mixin-deep/node_modules/is-extendable": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-1.0.1.tgz", - "integrity": "sha512-arnXMxT1hhoKo9k1LZdmlNyJdDDfy2v0fXjFlmok4+i8ul/6WlbVge9bhM74OpNPQPMGUToDtz+KXa1PneJxOA==", - "dev": true, - "dependencies": { - "is-plain-object": "^2.0.4" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true }, "node_modules/nanomatch": { @@ -6724,26 +5004,10 @@ "which": "^2.0.2" } }, - "node_modules/node-notifier/node_modules/which": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", - "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "dev": true, - "optional": true, - "dependencies": { - "isexe": "^2.0.0" - }, - "bin": { - "node-which": "bin/node-which" - }, - "engines": { - "node": ">= 8" - } - }, "node_modules/node-releases": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.6.tgz", - "integrity": "sha512-PiVXnNuFm5+iYkLBNeq5211hvO38y63T0i2KKh2KnUs3RpzJ+JtODFjkD8yjLwnDkTYF1eKXheUwdssR+NRZdg==", + "version": "2.0.27", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.27.tgz", + "integrity": "sha512-nmh3lCkYZ3grZvqcCH+fjmQ7X+H0OeZgP40OierEaAptX4XofMh5kwNbWh7lBduUzCcV/8kZ+NDLCwm2iorIlA==", "dev": true }, "node_modules/normalize-package-data": { @@ -6759,9 +5023,9 @@ } }, "node_modules/normalize-package-data/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -6777,21 +5041,21 @@ } }, "node_modules/npm-run-path": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-2.0.2.tgz", - "integrity": "sha512-lJxZYlT4DW/bRUtFh1MQIWqmLwQfAxnqWG4HhEdjMlkrJYnJn0Jrr2u3mgxqaWsdiBc76TYkTG/mhrnYTuzfHw==", + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", + "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", "dev": true, "dependencies": { - "path-key": "^2.0.0" + "path-key": "^3.0.0" }, "engines": { - "node": ">=4" + "node": ">=8" } }, "node_modules/nwsapi": { - "version": "2.2.2", - "resolved": "https://registry.npmjs.org/nwsapi/-/nwsapi-2.2.2.tgz", - "integrity": "sha512-90yv+6538zuvUMnN+zCr8LuV6bPFdq50304114vJYJ8RDyK8D5O9Phpbd6SZWgI7PwzmmfN1upeOJlvybDSgCw==", + "version": "2.2.22", + "resolved": "https://registry.npmjs.org/nwsapi/-/nwsapi-2.2.22.tgz", + "integrity": "sha512-ujSMe1OWVn55euT1ihwCI1ZcAaAU3nxUiDwfDQldc51ZXaB9m2AyOn6/jh1BLe2t/G8xd6uKG1UBF2aZJeg2SQ==", "dev": true }, "node_modules/object-copy": { @@ -6820,6 +5084,19 @@ "node": ">=0.10.0" } }, + "node_modules/object-copy/node_modules/is-descriptor": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-0.1.7.tgz", + "integrity": "sha512-C3grZTvObeN1xud4cRWl366OMXZTj0+HGyk4hvfpx4ZHt1Pb60ANSXqCK7pdOTeUQpRzECBSTphqvD7U+l22Eg==", + "dev": true, + "dependencies": { + "is-accessor-descriptor": "^1.0.1", + "is-data-descriptor": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/object-copy/node_modules/kind-of": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", @@ -6881,17 +5158,17 @@ } }, "node_modules/optionator": { - "version": "0.8.3", - "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.8.3.tgz", - "integrity": "sha512-+IW9pACdk3XWmmTXG8m3upGUJst5XRGzxMRjXzAuJ1XnIFNvfhjjIuYkDvysnPQ7qzqVzLt78BCruntqRhWQbA==", + "version": "0.9.4", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", + "integrity": "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==", "dev": true, "dependencies": { - "deep-is": "~0.1.3", - "fast-levenshtein": "~2.0.6", - "levn": "~0.3.0", - "prelude-ls": "~1.1.2", - "type-check": "~0.3.2", - "word-wrap": "~1.2.3" + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.5" }, "engines": { "node": ">= 0.8.0" @@ -6919,30 +5196,33 @@ } }, "node_modules/p-limit": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", - "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", "dev": true, "dependencies": { - "p-try": "^2.0.0" + "yocto-queue": "^0.1.0" }, "engines": { - "node": ">=6" + "node": ">=10" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/p-locate": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", - "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", + "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", "dev": true, "dependencies": { - "p-limit": "^2.2.0" + "p-limit": "^3.0.2" }, "engines": { - "node": ">=8" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/p-try": { @@ -7018,12 +5298,12 @@ } }, "node_modules/path-key": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/path-key/-/path-key-2.0.1.tgz", - "integrity": "sha512-fEHGKCSmUSDPv4uoj8AlD+joPlq3peND+HRYyxFz4KPw4z926S/b8rIuFs2FYJg3BwsxJf6A9/3eIdLaYC+9Dw==", + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", "dev": true, "engines": { - "node": ">=4" + "node": ">=8" } }, "node_modules/path-parse": { @@ -7042,9 +5322,9 @@ } }, "node_modules/picocolors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", - "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "dev": true }, "node_modules/picomatch": { @@ -7059,10 +5339,18 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/pipe": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/pipe/-/pipe-0.0.2.tgz", + "integrity": "sha512-67s0/X7rv2PX1sl64FQqC0qQuSpd1tv8Wh6c+U1lprj6Q7NxDYulCxZTbVbDvc/HSpZLYh7Oo821xReXSCZikQ==", + "engines": { + "node": ">=0.4.8" + } + }, "node_modules/pirates": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.5.tgz", - "integrity": "sha512-8V9+HQPupnaXMA23c5hvl69zXvTwTzyAYasnkb0Tts4XvO4CliqONMOnvlq26rkhLC3nWDFBJf73LU1e1VZLaQ==", + "version": "4.0.7", + "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.7.tgz", + "integrity": "sha512-TfySrs/5nm8fQJDcBDuUng3VOUKsd7S+zqvbOTiGXHfxX4wK31ard+hoNuvkicM/2YFzlpDgABOevKSsB4G/FA==", "dev": true, "engines": { "node": ">= 6" @@ -7080,6 +5368,58 @@ "node": ">=8" } }, + "node_modules/pkg-dir/node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dev": true, + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/pkg-dir/node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dev": true, + "dependencies": { + "p-locate": "^4.1.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/pkg-dir/node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "dev": true, + "dependencies": { + "p-try": "^2.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/pkg-dir/node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dev": true, + "dependencies": { + "p-limit": "^2.2.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/posix-character-classes": { "version": "0.1.1", "resolved": "https://registry.npmjs.org/posix-character-classes/-/posix-character-classes-0.1.1.tgz", @@ -7090,9 +5430,9 @@ } }, "node_modules/prelude-ls": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.1.2.tgz", - "integrity": "sha512-ESF23V4SKG6lVSGZgYNpbsiaAkdab6ZgOxe52p7+Kid3W3u3bxR4Vfd/o21dmN7jSt0IwgZ4v5MUd26FEtXE9w==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", "dev": true, "engines": { "node": ">= 0.8.0" @@ -7113,39 +5453,6 @@ "node": ">= 10" } }, - "node_modules/pretty-format/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/pretty-format/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/pretty-format/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, "node_modules/prompts": { "version": "2.4.2", "resolved": "https://registry.npmjs.org/prompts/-/prompts-2.4.2.tgz", @@ -7160,15 +5467,21 @@ } }, "node_modules/psl": { - "version": "1.9.0", - "resolved": "https://registry.npmjs.org/psl/-/psl-1.9.0.tgz", - "integrity": "sha512-E/ZsdU4HLs/68gYzgGTkMicWTLPdAftJLfJFlLUAAKZGkStNU72sZjT66SnMDVOfOWY/YAoiD7Jxa9iHvngcag==", - "dev": true + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/psl/-/psl-1.15.0.tgz", + "integrity": "sha512-JZd3gMVBAVQkSs6HdNZo9Sdo0LNcQeMNP3CozBJb3JYC/QUYZTnKxP+f8oWRX4rHP5EurWxqAHTSwUCjlNKa1w==", + "dev": true, + "dependencies": { + "punycode": "^2.3.1" + }, + "funding": { + "url": "https://github.com/sponsors/lupomontero" + } }, "node_modules/pump": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz", - "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.3.tgz", + "integrity": "sha512-todwxLMY7/heScKmntwQG8CXVkWUOdYxIvY2s0VWAAMh/nd8SoYiRaKjlr7+iCs984f2P8zvrfWcDDYVb73NfA==", "dev": true, "dependencies": { "end-of-stream": "^1.1.0", @@ -7176,9 +5489,9 @@ } }, "node_modules/punycode": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.1.1.tgz", - "integrity": "sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A==", + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", "dev": true, "engines": { "node": ">=6" @@ -7248,6 +5561,67 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/read-pkg-up/node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dev": true, + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/read-pkg-up/node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dev": true, + "dependencies": { + "p-locate": "^4.1.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/read-pkg-up/node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "dev": true, + "dependencies": { + "p-try": "^2.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/read-pkg-up/node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dev": true, + "dependencies": { + "p-limit": "^2.2.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/read-pkg-up/node_modules/type-fest": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.8.1.tgz", + "integrity": "sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA==", + "dev": true, + "engines": { + "node": ">=8" + } + }, "node_modules/read-pkg/node_modules/type-fest": { "version": "0.6.0", "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.6.0.tgz", @@ -7316,18 +5690,21 @@ "dev": true }, "node_modules/resolve": { - "version": "1.22.1", - "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.1.tgz", - "integrity": "sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw==", + "version": "1.22.11", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.11.tgz", + "integrity": "sha512-RfqAvLnMl313r7c9oclB1HhUEAezcpLjz95wFH4LVuhk9JF/r22qmVP9AMmOU4vMX7Q8pN8jwNg/CSpdFnMjTQ==", "dev": true, "dependencies": { - "is-core-module": "^2.9.0", + "is-core-module": "^2.16.1", "path-parse": "^1.0.7", "supports-preserve-symlinks-flag": "^1.0.0" }, "bin": { "resolve": "bin/resolve" }, + "engines": { + "node": ">= 0.4" + }, "funding": { "url": "https://github.com/sponsors/ljharb" } @@ -7379,9 +5756,9 @@ } }, "node_modules/reusify": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", - "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.1.0.tgz", + "integrity": "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==", "dev": true, "engines": { "iojs": ">=1.0.0", @@ -7392,6 +5769,7 @@ "version": "3.0.2", "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "deprecated": "Rimraf versions prior to v4 are no longer supported", "dev": true, "dependencies": { "glob": "^7.1.3" @@ -7408,7 +5786,6 @@ "resolved": "https://registry.npmjs.org/rollup/-/rollup-2.79.2.tgz", "integrity": "sha512-fS6iqSPZDs3dr/y7Od6y5nha8dW1YnbgtsyotCVvoFGKbERG++CVRFv1meyGDE1SNItQA8BrnCw7ScdAhRJ3XQ==", "dev": true, - "license": "MIT", "bin": { "rollup": "dist/bin/rollup" }, @@ -7455,47 +5832,6 @@ "node": ">= 8.0.0" } }, - "node_modules/rollup-plugin-typescript2/node_modules/estree-walker": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", - "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", - "dev": true - }, - "node_modules/rollup-plugin-typescript2/node_modules/fs-extra": { - "version": "10.1.0", - "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-10.1.0.tgz", - "integrity": "sha512-oRXApq54ETRj4eMiFzGnHWGy+zo5raudjuxN0b8H7s/RU2oW0Wvsx9O0ACRN/kRq9E8Vu/ReskGB5o3ji+FzHQ==", - "dev": true, - "dependencies": { - "graceful-fs": "^4.2.0", - "jsonfile": "^6.0.1", - "universalify": "^2.0.0" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/rollup-plugin-typescript2/node_modules/jsonfile": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-6.1.0.tgz", - "integrity": "sha512-5dgndWOriYSm5cnYaJNhalLNDKOqFwyDB/rr1E9ZsGciGvKPs8R2xYGCacuf3z6K1YKDz182fd+fY3cn3pMqXQ==", - "dev": true, - "dependencies": { - "universalify": "^2.0.0" - }, - "optionalDependencies": { - "graceful-fs": "^4.1.6" - } - }, - "node_modules/rollup-plugin-typescript2/node_modules/universalify": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.0.tgz", - "integrity": "sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==", - "dev": true, - "engines": { - "node": ">= 10.0.0" - } - }, "node_modules/rsvp": { "version": "4.8.5", "resolved": "https://registry.npmjs.org/rsvp/-/rsvp-4.8.5.tgz", @@ -7610,6 +5946,40 @@ "node": ">=0.10.0" } }, + "node_modules/sane/node_modules/cross-spawn": { + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", + "dev": true, + "dependencies": { + "nice-try": "^1.0.4", + "path-key": "^2.0.1", + "semver": "^5.5.0", + "shebang-command": "^1.2.0", + "which": "^1.2.9" + }, + "engines": { + "node": ">=4.8" + } + }, + "node_modules/sane/node_modules/execa": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/execa/-/execa-1.0.0.tgz", + "integrity": "sha512-adbxcyWV46qiHyvSp50TKt05tB4tK3HcmF7/nxfAdhnox83seTDbwnaqKO4sXRy7roHAIFqJP/Rw/AuEbX61LA==", + "dev": true, + "dependencies": { + "cross-spawn": "^6.0.0", + "get-stream": "^4.0.0", + "is-stream": "^1.1.0", + "npm-run-path": "^2.0.0", + "p-finally": "^1.0.0", + "signal-exit": "^3.0.0", + "strip-eof": "^1.0.0" + }, + "engines": { + "node": ">=6" + } + }, "node_modules/sane/node_modules/fill-range": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-4.0.0.tgz", @@ -7637,6 +6007,27 @@ "node": ">=0.10.0" } }, + "node_modules/sane/node_modules/get-stream": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-4.1.0.tgz", + "integrity": "sha512-GMat4EJ5161kIy2HevLlr4luNjBgvmj413KaQA7jt4V8B4RDsfpHk7WQ9GVqfYyyx8OS/L66Kox+rJRNklLK7w==", + "dev": true, + "dependencies": { + "pump": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/sane/node_modules/is-extendable": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-0.1.1.tgz", + "integrity": "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/sane/node_modules/is-number": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-3.0.0.tgz", @@ -7661,6 +6052,15 @@ "node": ">=0.10.0" } }, + "node_modules/sane/node_modules/is-stream": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-1.1.0.tgz", + "integrity": "sha512-uQPm8kcs47jx38atAcWTVxyltQYoPT68y9aWYdV6yWXSyW8mzSat0TL6CiWdZeCdF3KrAvpVtnHbTv4RN+rqdQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/sane/node_modules/micromatch": { "version": "3.1.10", "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-3.1.10.tgz", @@ -7697,6 +6097,57 @@ "node": ">=0.10.0" } }, + "node_modules/sane/node_modules/npm-run-path": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-2.0.2.tgz", + "integrity": "sha512-lJxZYlT4DW/bRUtFh1MQIWqmLwQfAxnqWG4HhEdjMlkrJYnJn0Jrr2u3mgxqaWsdiBc76TYkTG/mhrnYTuzfHw==", + "dev": true, + "dependencies": { + "path-key": "^2.0.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/sane/node_modules/path-key": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-2.0.1.tgz", + "integrity": "sha512-fEHGKCSmUSDPv4uoj8AlD+joPlq3peND+HRYyxFz4KPw4z926S/b8rIuFs2FYJg3BwsxJf6A9/3eIdLaYC+9Dw==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/sane/node_modules/semver": { + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", + "dev": true, + "bin": { + "semver": "bin/semver" + } + }, + "node_modules/sane/node_modules/shebang-command": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-1.2.0.tgz", + "integrity": "sha512-EV3L1+UQWGor21OmnvojK36mhg+TyIKDh3iFBKBohr5xeXIhNBcx8oWdgkTEEQ+BEFFYdLRuqMfd5L84N1V5Vg==", + "dev": true, + "dependencies": { + "shebang-regex": "^1.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/sane/node_modules/shebang-regex": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-1.0.0.tgz", + "integrity": "sha512-wpoSFAxys6b2a2wHZ1XpDSgD7N9iVjg29Ph9uV/uaP9Ex/KXlkTZTeddxDPSYQpgvzKLGJke2UU0AzoGCjNIvQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/sane/node_modules/to-regex-range": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-2.1.1.tgz", @@ -7710,6 +6161,18 @@ "node": ">=0.10.0" } }, + "node_modules/sane/node_modules/which": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/which/-/which-1.3.1.tgz", + "integrity": "sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ==", + "dev": true, + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "which": "bin/which" + } + }, "node_modules/saxes": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/saxes/-/saxes-5.0.1.tgz", @@ -7723,13 +6186,10 @@ } }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", "dev": true, - "dependencies": { - "lru-cache": "^6.0.0" - }, "bin": { "semver": "bin/semver.js" }, @@ -7770,25 +6230,34 @@ "node": ">=0.10.0" } }, + "node_modules/set-value/node_modules/is-extendable": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-0.1.1.tgz", + "integrity": "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/shebang-command": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-1.2.0.tgz", - "integrity": "sha512-EV3L1+UQWGor21OmnvojK36mhg+TyIKDh3iFBKBohr5xeXIhNBcx8oWdgkTEEQ+BEFFYdLRuqMfd5L84N1V5Vg==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", "dev": true, "dependencies": { - "shebang-regex": "^1.0.0" + "shebang-regex": "^3.0.0" }, "engines": { - "node": ">=0.10.0" + "node": ">=8" } }, "node_modules/shebang-regex": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-1.0.0.tgz", - "integrity": "sha512-wpoSFAxys6b2a2wHZ1XpDSgD7N9iVjg29Ph9uV/uaP9Ex/KXlkTZTeddxDPSYQpgvzKLGJke2UU0AzoGCjNIvQ==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", "dev": true, "engines": { - "node": ">=0.10.0" + "node": ">=8" } }, "node_modules/shellwords": { @@ -7799,9 +6268,9 @@ "optional": true }, "node_modules/shiki": { - "version": "0.14.2", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.2.tgz", - "integrity": "sha512-ltSZlSLOuSY0M0Y75KA+ieRaZ0Trf5Wl3gutE7jzLuIcWxLp5i/uEnLoQWNvgKXQ5OMpGkJnVMRLAuzjc0LJ2A==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -7876,44 +6345,6 @@ "node": ">=0.10.0" } }, - "node_modules/snapdragon-node/node_modules/is-accessor-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-accessor-descriptor/-/is-accessor-descriptor-1.0.0.tgz", - "integrity": "sha512-m5hnHTkcVsPfqx3AKlyttIPb7J+XykHvJP2B9bZDjlhLIoEq4XoK64Vg7boZlVWYK6LUY94dYPEE7Lh0ZkZKcQ==", - "dev": true, - "dependencies": { - "kind-of": "^6.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/snapdragon-node/node_modules/is-data-descriptor": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/is-data-descriptor/-/is-data-descriptor-1.0.0.tgz", - "integrity": "sha512-jbRXy1FmtAoCjQkVmIVYwuuqDFUbaOeDjmed1tOGPrsMhtJA4rD9tkgA0F1qJ3gRFRXcHYVkdeaP50Q5rE/jLQ==", - "dev": true, - "dependencies": { - "kind-of": "^6.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/snapdragon-node/node_modules/is-descriptor": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-1.0.2.tgz", - "integrity": "sha512-2eis5WqQGV7peooDyLmNEPUrps9+SXX5c9pL3xEB+4e9HnGuDa7mB7kHxHw4CbqS9k1T2hOH3miL8n8WtiYVtg==", - "dev": true, - "dependencies": { - "is-accessor-descriptor": "^1.0.0", - "is-data-descriptor": "^1.0.0", - "kind-of": "^6.0.2" - }, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/snapdragon-util": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/snapdragon-util/-/snapdragon-util-3.0.1.tgz", @@ -7971,6 +6402,28 @@ "node": ">=0.10.0" } }, + "node_modules/snapdragon/node_modules/is-descriptor": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-0.1.7.tgz", + "integrity": "sha512-C3grZTvObeN1xud4cRWl366OMXZTj0+HGyk4hvfpx4ZHt1Pb60ANSXqCK7pdOTeUQpRzECBSTphqvD7U+l22Eg==", + "dev": true, + "dependencies": { + "is-accessor-descriptor": "^1.0.1", + "is-data-descriptor": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/snapdragon/node_modules/is-extendable": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-0.1.1.tgz", + "integrity": "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/snapdragon/node_modules/ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -8034,9 +6487,9 @@ "dev": true }, "node_modules/spdx-correct": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/spdx-correct/-/spdx-correct-3.1.1.tgz", - "integrity": "sha512-cOYcUWwhCuHCXi49RhFRCyJEK3iPj1Ziz9DpViV3tbZOwXD49QzIN3MpOLJNxh2qwq2lJJZaKMVw9qNi4jTC0w==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/spdx-correct/-/spdx-correct-3.2.0.tgz", + "integrity": "sha512-kN9dJbvnySHULIluDHy32WHRUu3Og7B9sbY7tsFLctQkIqnMh3hErYgdMjTYuqmcXX+lK5T1lnUt3G7zNswmZA==", "dev": true, "dependencies": { "spdx-expression-parse": "^3.0.0", @@ -8044,9 +6497,9 @@ } }, "node_modules/spdx-exceptions": { - "version": "2.3.0", - "resolved": "https://registry.npmjs.org/spdx-exceptions/-/spdx-exceptions-2.3.0.tgz", - "integrity": "sha512-/tTrYOC7PPI1nUAgx34hUpqXuyJG+DTHJTnIULG4rDygi4xu/tfgmq1e1cIRwRzwZgo4NLySi+ricLkZkw4i5A==", + "version": "2.5.0", + "resolved": "https://registry.npmjs.org/spdx-exceptions/-/spdx-exceptions-2.5.0.tgz", + "integrity": "sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==", "dev": true }, "node_modules/spdx-expression-parse": { @@ -8060,9 +6513,9 @@ } }, "node_modules/spdx-license-ids": { - "version": "3.0.12", - "resolved": "https://registry.npmjs.org/spdx-license-ids/-/spdx-license-ids-3.0.12.tgz", - "integrity": "sha512-rr+VVSXtRhO4OHbXUiAF7xW3Bo9DuuF6C5jH+q/x15j2jniycgKbxU09Hr0WqlSLUs4i4ltHGXqTe7VHclYWyA==", + "version": "3.0.22", + "resolved": "https://registry.npmjs.org/spdx-license-ids/-/spdx-license-ids-3.0.22.tgz", + "integrity": "sha512-4PRT4nh1EImPbt2jASOKHX7PB7I+e4IWNLvkKFDxNhJlfjbYlleYQh285Z/3mPTHSAK/AvdMmw5BNNuYH8ShgQ==", "dev": true }, "node_modules/split-string": { @@ -8129,6 +6582,19 @@ "node": ">=0.10.0" } }, + "node_modules/static-extend/node_modules/is-descriptor": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/is-descriptor/-/is-descriptor-0.1.7.tgz", + "integrity": "sha512-C3grZTvObeN1xud4cRWl366OMXZTj0+HGyk4hvfpx4ZHt1Pb60ANSXqCK7pdOTeUQpRzECBSTphqvD7U+l22Eg==", + "dev": true, + "dependencies": { + "is-accessor-descriptor": "^1.0.1", + "is-data-descriptor": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/string-length": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/string-length/-/string-length-4.0.2.tgz", @@ -8208,15 +6674,15 @@ } }, "node_modules/supports-color": { - "version": "5.5.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz", - "integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==", + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", "dev": true, "dependencies": { - "has-flag": "^3.0.0" + "has-flag": "^4.0.0" }, "engines": { - "node": ">=4" + "node": ">=8" } }, "node_modules/supports-hyperlinks": { @@ -8232,27 +6698,6 @@ "node": ">=8" } }, - "node_modules/supports-hyperlinks/node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", - "dev": true, - "engines": { - "node": ">=8" - } - }, - "node_modules/supports-hyperlinks/node_modules/supports-color": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", - "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", - "dev": true, - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/supports-preserve-symlinks-flag": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", @@ -8319,15 +6764,6 @@ "integrity": "sha512-3f0uOEAQwIqGuWW2MVzYg8fV/QNnc/IpuJNG837rLuczAaLVHslWHZQj4IGiEl5Hs3kkbhwL9Ab7Hrsmuj+Smw==", "dev": true }, - "node_modules/to-fast-properties": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/to-fast-properties/-/to-fast-properties-2.0.0.tgz", - "integrity": "sha512-/OaKK0xYrs3DmxRYqL/yDc+FxFUVYhDlXMhRmv3z915w2HF1tnN1omB354j8VUGO/hbRzyD6Y3sA7v7GS/ceog==", - "dev": true, - "engines": { - "node": ">=4" - } - }, "node_modules/to-object-path": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/to-object-path/-/to-object-path-0.3.0.tgz", @@ -8380,9 +6816,9 @@ } }, "node_modules/tough-cookie": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-4.1.2.tgz", - "integrity": "sha512-G9fqXWoYFZgTc2z8Q5zaHy/vJMjm+WV0AkAeHxVCQiEB1b+dGvWzFW6QV07cY5jQ5gRkeid2qIkzkxUnmoQZUQ==", + "version": "4.1.4", + "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-4.1.4.tgz", + "integrity": "sha512-Loo5UUvLD9ScZ6jh8beX1T6sO1w2/MpCRpEP7V280GKMVUQ0Jzar2U3UJPsrdbziLEMMhu3Ujnq//rhiFuIeag==", "dev": true, "dependencies": { "psl": "^1.1.33", @@ -8394,6 +6830,15 @@ "node": ">=6" } }, + "node_modules/tough-cookie/node_modules/universalify": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.2.0.tgz", + "integrity": "sha512-CJ1QgKmNg3CwvAv/kOFmtnEN05f0D/cn9QntgNOQlQF9dgvVTHj3t+8JPdjqawCHk7V/KA+fbUqzZ9XWhcqPUg==", + "dev": true, + "engines": { + "node": ">= 4.0.0" + } + }, "node_modules/tr46": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/tr46/-/tr46-2.1.0.tgz", @@ -8407,9 +6852,9 @@ } }, "node_modules/tslib": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.5.2.tgz", - "integrity": "sha512-5svOrSA2w3iGFDs1HibEVBGbDrAY82bFQ3HZ3ixB+88nsbsWQoKqDRb5UBYAUPEzbBn6dAp5gRNXglySbx1MlA==", + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", "dev": true }, "node_modules/tsutils": { @@ -8434,12 +6879,12 @@ "dev": true }, "node_modules/type-check": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.3.2.tgz", - "integrity": "sha512-ZCmOJdvOWDBYJlzAoFkC+Q0+bUyEOS1ltgp1MGU03fqHG+dbi9tBFU2Rd9QKiDZFAYrhPh2JUf7rZRIuHRKtOg==", + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", "dev": true, "dependencies": { - "prelude-ls": "~1.1.2" + "prelude-ls": "^1.2.1" }, "engines": { "node": ">= 0.8.0" @@ -8455,12 +6900,15 @@ } }, "node_modules/type-fest": { - "version": "0.8.1", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.8.1.tgz", - "integrity": "sha512-4dbzIzqvjtgiM5rw1k5rEHtBANKmdudhGyBEajN01fEyhaAIhsoKNy6y7+IN93IfpFtwY9iqi7kD+xwKhQsNJA==", + "version": "0.20.2", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", + "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", "dev": true, "engines": { - "node": ">=8" + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/typedarray-to-buffer": { @@ -8473,9 +6921,9 @@ } }, "node_modules/typedoc": { - "version": "0.24.7", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.24.7.tgz", - "integrity": "sha512-zzfKDFIZADA+XRIp2rMzLe9xZ6pt12yQOhCr7cD7/PBTjhPmMyMvGrkZ2lPNJitg3Hj1SeiYFNzCsSDrlpxpKw==", + "version": "0.24.8", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.24.8.tgz", + "integrity": "sha512-ahJ6Cpcvxwaxfu4KtjA8qZNqS43wYt6JL27wYiIgl1vd38WW/KWX11YuAeZhuz9v+ttrutSsgK+XO1CjL1kA3w==", "dev": true, "dependencies": { "lunr": "^2.3.9", @@ -8490,7 +6938,7 @@ "node": ">= 14.14" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x" } }, "node_modules/typedoc-plugin-missing-exports": { @@ -8503,18 +6951,18 @@ } }, "node_modules/typedoc/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "dependencies": { "balanced-match": "^1.0.0" } }, "node_modules/typedoc/node_modules/minimatch": { - "version": "9.0.1", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.1.tgz", - "integrity": "sha512-0jWhJpD/MdhPXwPuiRkCbfYfSKp2qnn2eOc279qI7f+osl/l+prKSrvhg157zSYvx/1nmgn2NqdT6k2Z7zSH9w==", + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" @@ -8539,6 +6987,31 @@ "node": ">=4.2.0" } }, + "node_modules/underscore": { + "version": "1.1.6", + "resolved": "https://registry.npmjs.org/underscore/-/underscore-1.1.6.tgz", + "integrity": "sha512-aqSzrO92Cjmeo8G7F49+ZHWBo3IJpjpsUZZaqfOHJGN61flbpLxQw/sP91p4kf/2+nkFrG6AG2WHlJh6RCf+/g==", + "engines": { + "node": "*" + } + }, + "node_modules/underscore.string": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/underscore.string/-/underscore.string-1.1.4.tgz", + "integrity": "sha512-WsF8NWzIbTvxUaSOpSLq+AiO0tzweXdWQZ4w9Op8S/1BT9Fh7hCS7bfrF17vZu9kJg3pcqO+8WXfQSr1ah0f2g==", + "dependencies": { + "underscore": "1.1.6" + }, + "engines": { + "node": "*" + } + }, + "node_modules/undici-types": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", + "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", + "dev": true + }, "node_modules/union-value": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/union-value/-/union-value-1.0.1.tgz", @@ -8554,13 +7027,22 @@ "node": ">=0.10.0" } }, + "node_modules/union-value/node_modules/is-extendable": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/is-extendable/-/is-extendable-0.1.1.tgz", + "integrity": "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/universalify": { - "version": "0.2.0", - "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.2.0.tgz", - "integrity": "sha512-CJ1QgKmNg3CwvAv/kOFmtnEN05f0D/cn9QntgNOQlQF9dgvVTHj3t+8JPdjqawCHk7V/KA+fbUqzZ9XWhcqPUg==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.1.tgz", + "integrity": "sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==", "dev": true, "engines": { - "node": ">= 4.0.0" + "node": ">= 10.0.0" } }, "node_modules/unset-value": { @@ -8612,9 +7094,9 @@ } }, "node_modules/update-browserslist-db": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.10.tgz", - "integrity": "sha512-OztqDenkfFkbSG+tRxBeAnCVPckDBcvibKd35yDONx6OU8N7sqgwc7rCbkJ/WcYtVRZ4ba68d6byhC21GFh7sQ==", + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.1.4.tgz", + "integrity": "sha512-q0SPT4xyU84saUX+tomz1WLkxUbuaJnR1xWt17M7fJtEJigJeWUNGUqrauFXsHnqev9y9JTRGwk13tFBuKby4A==", "dev": true, "funding": [ { @@ -8624,14 +7106,18 @@ { "type": "tidelift", "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" } ], "dependencies": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" + "escalade": "^3.2.0", + "picocolors": "^1.1.1" }, "bin": { - "browserslist-lint": "cli.js" + "update-browserslist-db": "cli.js" }, "peerDependencies": { "browserslist": ">= 4.21.0" @@ -8697,12 +7183,12 @@ } }, "node_modules/v8-to-istanbul/node_modules/source-map": { - "version": "0.7.4", - "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.4.tgz", - "integrity": "sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA==", + "version": "0.7.6", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.6.tgz", + "integrity": "sha512-i5uvt8C3ikiWeNZSVZNWcfZPItFQOsYTUAOkcUPGd8DqDy1uOUikjt5dG+uRlwyvR108Fb9DOd4GvXfT0N2/uQ==", "dev": true, "engines": { - "node": ">= 8" + "node": ">= 12" } }, "node_modules/validate-npm-package-license": { @@ -8797,27 +7283,30 @@ } }, "node_modules/which": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/which/-/which-1.3.1.tgz", - "integrity": "sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", "dev": true, "dependencies": { "isexe": "^2.0.0" }, "bin": { - "which": "bin/which" + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" } }, "node_modules/which-module": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/which-module/-/which-module-2.0.0.tgz", - "integrity": "sha512-B+enWhmw6cjfVC7kS8Pj9pCrKSc5txArRyaYGe088shv/FGWH+0Rjx/xPgtsWfsUtS27FkP697E4DDhgrgoc0Q==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/which-module/-/which-module-2.0.1.tgz", + "integrity": "sha512-iBdZ57RDvnOR9AGBhML2vFZf7h8vmBjhoaZqODJBFWHVtKkDmKuHai3cx5PgVMrX5YDNp27AofYbAwctSS+vhQ==", "dev": true }, "node_modules/word-wrap": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.3.tgz", - "integrity": "sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==", + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", + "integrity": "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==", "dev": true, "engines": { "node": ">=0.10.0" @@ -8837,39 +7326,6 @@ "node": ">=8" } }, - "node_modules/wrap-ansi/node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/wrap-ansi/node_modules/color-convert": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", - "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/wrap-ansi/node_modules/color-name": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true - }, "node_modules/wrappy": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", @@ -8889,9 +7345,9 @@ } }, "node_modules/ws": { - "version": "7.5.9", - "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz", - "integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==", + "version": "7.5.10", + "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.10.tgz", + "integrity": "sha512-+dbF1tHwZpXcbOJdVOkzLDxZP1ailvSxM6ZweXTegylPny803bFhA+vqBYw4s31NSAk4S2Qz+AKXK9a4wkdjcQ==", "dev": true, "engines": { "node": ">=8.3.0" @@ -8928,9 +7384,9 @@ "dev": true }, "node_modules/yallist": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", + "integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==", "dev": true }, "node_modules/yargs": { @@ -8968,6 +7424,58 @@ "node": ">=6" } }, + "node_modules/yargs/node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dev": true, + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dev": true, + "dependencies": { + "p-locate": "^4.1.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "dev": true, + "dependencies": { + "p-try": "^2.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/yargs/node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dev": true, + "dependencies": { + "p-limit": "^2.2.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", diff --git a/web/package.json b/web/package.json index 7893fce407da..e793eed586bb 100644 --- a/web/package.json +++ b/web/package.json @@ -3,7 +3,7 @@ "description": "TVM WASM/WebGPU runtime for JS/TS", "license": "Apache-2.0", "homepage": "https://github.com/apache/tvm/tree/main/web", - "version": "0.22.0-dev0", + "version": "0.23.0-dev1", "files": [ "lib" ], diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index 61ad021c7fef..439f91c88160 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -17,7 +17,7 @@ * under the License. */ -export interface NDArrayCacheEntry { +export interface TensorCacheEntry { name: string; shape: Array; dtype: string; @@ -26,11 +26,11 @@ export interface NDArrayCacheEntry { nbytes: number; } -export interface NDArrayShardEntry { +export interface TensorShardEntry { dataPath: string; format: "raw-shard"; nbytes: number; - records: Array; + records: Array; } /** @@ -357,13 +357,13 @@ export class ArtifactIndexedDBCache implements ArtifactCacheTemplate { /** * Function to check if NDarray is in Cache or not * - * @param ndarrayCacheUrl The cache url which links to the NDArray + * @param tensorCacheUrl The cache url which links to the Tensor * @param cacheScope The scope identifier of the cache * @param cacheType The type of the cache: "cache" or "indexedDB" - * @returns the result if the cache has NDArray + * @returns the result if the cache has Tensor */ -export async function hasNDArrayInCache( - ndarrayCacheUrl: string, +export async function hasTensorInCache( + tensorCacheUrl: string, cacheScope = "tvmjs", cacheType = "cache" ): Promise { @@ -376,25 +376,25 @@ export async function hasNDArrayInCache( console.error("Unsupported cacheType: " + cacheType + ", using default ArtifactCache."); artifactCache = new ArtifactCache(cacheScope); } - const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; const hasJsonUrlInCache = await artifactCache.hasAllKeys([jsonUrl]); if (!hasJsonUrlInCache) { return false; } let list = await artifactCache.fetchWithCache(jsonUrl, "json"); - list = list["records"] as Array; - return await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)); + list = list["records"] as Array; + return await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, tensorCacheUrl).href)); } /** - * Given cacheUrl, search up items to delete based on cacheUrl/ndarray-cache.json + * Given cacheUrl, search up items to delete based on cacheUrl/tensor-cache.json * * @param cacheUrl The cacheUrl for the items * @param cacheScope The scope identifier of the cache * @param cacheType The type of the cache: "cache" or "indexedDB" */ -export async function deleteNDArrayCache( +export async function deleteTensorCache( cacheUrl: string, cacheScope = "tvmjs", cacheType = "cache" @@ -408,9 +408,9 @@ export async function deleteNDArrayCache( console.error("Unsupported cacheType: " + cacheType + ", using default ArtifactCache."); artifactCache = new ArtifactCache(cacheScope); } - const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href; + const jsonUrl = new URL("tensor-cache.json", cacheUrl).href; const list = await artifactCache.fetchWithCache(jsonUrl, "json"); - const arrayentry = list["records"] as Array; + const arrayentry = list["records"] as Array; const processShard = async (i: number) => { const dataUrl = new URL(arrayentry[i].dataPath, cacheUrl).href; await artifactCache.deleteInCache(dataUrl); diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index 41d848a22886..1f91779692ef 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -41,7 +41,7 @@ export const enum SizeOf { TVMFFIAny = 8 * 2, DLDataType = I32, DLDevice = I32 + I32, - ObjectHeader = 8 * 2, + ObjectHeader = 8 * 3, } //---------------The new TVM FFI--------------- @@ -51,6 +51,22 @@ export const enum SizeOf { * We are keeping the same style as C API here. */ export const enum TypeIndex { + /* + * \brief The root type of all FFI objects. + * + * We include it so TypeIndex captures all possible runtime values. + * `kTVMFFIAny` code will never appear in Any::type_index. + * However, it may appear in field annotations during reflection. + */ + kTVMFFIAny = -1, + // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) + // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, + // which is not owned by TVMFFIAny. It is required that the following + // invariant holds: + // - `Any::type_index` is never `kTVMFFIRawStr` + // - `AnyView::type_index` can be `kTVMFFIRawStr` + // + /*! \brief None/nullptr value */ kTVMFFINone = 0, /*! \brief POD int value */ kTVMFFIInt = 1, @@ -66,7 +82,7 @@ export const enum TypeIndex { kTVMFFIDevice = 6, /*! \brief DLTensor* */ kTVMFFIDLTensorPtr = 7, - /*! \brief const char**/ + /*! \brief const char* */ kTVMFFIRawStr = 8, /*! \brief TVMFFIByteArray* */ kTVMFFIByteArrayPtr = 9, @@ -95,20 +111,39 @@ export const enum TypeIndex { kTVMFFIError = 67, /*! \brief Function object. */ kTVMFFIFunction = 68, - /*! \brief Array object. */ - kTVMFFIArray = 69, - /*! \brief Map object. */ - kTVMFFIMap = 70, /*! * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } */ - kTVMFFIShape = 71, + kTVMFFIShape = 69, /*! - * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... } + * \brief Tensor object, layout = { TVMFFIObject, DLTensor, ... } */ - kTVMFFINDArray = 72, - /*! \brief Runtime module object. */ + kTVMFFITensor = 70, + /*! \brief Array object. */ + kTVMFFIArray = 71, + //---------------------------------------------------------------- + // more complex objects + //---------------------------------------------------------------- + /*! \brief Map object. */ + kTVMFFIMap = 72, + /*! \brief Runtime dynamic loaded module object. */ kTVMFFIModule = 73, + /*! + * \brief Opaque python object. + * + * This is a special type index to indicate we are storing an opaque PyObject. + * Such object may interact with callback functions that are registered to support + * python-related operations. + * + * We only translate the objects that we do not recognize into this type index. + * + * \sa TVMFFIObjectCreateOpaque + */ + kTVMFFIOpaquePyObject = 74, + kTVMFFIStaticObjectEnd, + // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) + /*! \brief Start of type indices that are allocated at runtime. */ + kTVMFFIDynObjectBegin = 128 } // -- TVM Wasm Auxiliary C API -- @@ -142,9 +177,9 @@ export type FTVMFFIWasmFunctionCreate = ( export type FTVMFFIWasmFunctionDeleter = (self: Pointer) => void; /** - * int TVMFFIObjectFree(TVMFFIObjectHandle obj); + * int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); */ -export type FTVMFFIObjectFree = (obj: Pointer) => number; +export type FTVMFFIObjectDecRef = (obj: Pointer) => number; /** * int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); diff --git a/web/src/index.ts b/web/src/index.ts index d4fc9b9187e6..868a26623ae0 100644 --- a/web/src/index.ts +++ b/web/src/index.ts @@ -19,7 +19,7 @@ export { Scalar, DLDevice, DLDataType, - PackedFunc, Module, NDArray, + PackedFunc, Module, Tensor, TVMArray, TVMObject, VirtualMachine, InitProgressCallback, InitProgressReport, Instance, instantiate @@ -28,8 +28,8 @@ export { ArtifactCacheTemplate, ArtifactCache, ArtifactIndexedDBCache, - hasNDArrayInCache, - deleteNDArrayCache + hasTensorInCache, + deleteTensorCache } from "./artifact_cache"; export { Disposable, LibraryProvider } from "./types"; export { RPCServer } from "./rpc_server"; diff --git a/web/src/memory.ts b/web/src/memory.ts index 94ecb4e15afa..c57f83854df0 100644 --- a/web/src/memory.ts +++ b/web/src/memory.ts @@ -175,7 +175,8 @@ export class Memory { * @returns The object type index. */ loadObjectTypeIndex(objectHandle: Pointer): number { - return this.loadI32(objectHandle); + // The object layout is [ref_counter (i64), type_index (i32), ...]. + return this.loadI32(objectHandle + SizeOf.I64); } /** * Load the type key from the type info pointer. diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index 1e3af6f6438e..3adab93be103 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -81,8 +81,8 @@ export class RPCServer { state: RPCServerState = RPCServerState.InitHeader; logger: (msg: string) => void; getImports: () => Record; - private ndarrayCacheUrl: string; - private ndarrayCacheDevice: string; + private tensorCacheUrl: string; + private tensorCacheDevice: string; private initProgressCallback?: runtime.InitProgressCallback; private asyncOnServerLoad?: (inst: runtime.Instance) => Promise; private pendingSend: Promise = Promise.resolve(); @@ -102,8 +102,8 @@ export class RPCServer { key: string, getImports: () => Record, logger: (msg: string) => void = console.log, - ndarrayCacheUrl = "", - ndarrayCacheDevice = "cpu", + tensorCacheUrl = "", + tensorCacheDevice = "cpu", initProgressCallback: runtime.InitProgressCallback | undefined = undefined, asyncOnServerLoad: ((inst: runtime.Instance) => Promise) | undefined = undefined, ) { @@ -112,8 +112,8 @@ export class RPCServer { this.name = "WebSocketRPCServer[" + this.key + "]: "; this.getImports = getImports; this.logger = logger; - this.ndarrayCacheUrl = ndarrayCacheUrl; - this.ndarrayCacheDevice = ndarrayCacheDevice; + this.tensorCacheUrl = tensorCacheUrl; + this.tensorCacheDevice = tensorCacheDevice; this.initProgressCallback = initProgressCallback; this.asyncOnServerLoad = asyncOnServerLoad; this.checkLittleEndian(); @@ -145,7 +145,7 @@ export class RPCServer { this.log("Automatic reconnecting.."); new RPCServer( this.url, this.key, this.getImports, this.logger, - this.ndarrayCacheUrl, this.ndarrayCacheDevice, + this.tensorCacheUrl, this.tensorCacheDevice, this.initProgressCallback, this.asyncOnServerLoad); } else { this.log("Closing the server, final state=" + this.state); @@ -262,7 +262,7 @@ export class RPCServer { const asyncInitServer = async (): Promise => { assert(args[1] instanceof Uint8Array); const inst = await runtime.instantiate( - args[1].buffer, + args[1].buffer as ArrayBuffer, this.getImports(), this.logger ); @@ -287,12 +287,12 @@ export class RPCServer { this.inst.registerInitProgressCallback(this.initProgressCallback); } - if (this.ndarrayCacheUrl.length != 0) { - if (this.ndarrayCacheDevice === "cpu") { - await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.cpu()); + if (this.tensorCacheUrl.length != 0) { + if (this.tensorCacheDevice === "cpu") { + await this.inst.fetchTensorCache(this.tensorCacheUrl, this.inst.cpu()); } else { - assert(this.ndarrayCacheDevice === "webgpu"); - await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.webgpu()); + assert(this.tensorCacheDevice === "webgpu"); + await this.inst.fetchTensorCache(this.tensorCacheUrl, this.inst.webgpu()); } } diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 071b2eed68e4..41bc43b54c5f 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -31,7 +31,7 @@ import { ArtifactCache, ArtifactCacheTemplate, ArtifactIndexedDBCache, - NDArrayShardEntry, + TensorShardEntry, } from "./artifact_cache"; import * as compact from "./compact"; import * as ctypes from "./ctypes"; @@ -156,24 +156,24 @@ class RuntimeContext implements Disposable { functionListGlobalNamesFunctor: PackedFunc; moduleGetFunction: PackedFunc; moduleImport: PackedFunc; - ndarrayEmpty: PackedFunc; - ndarrayCopyFromTo: PackedFunc; - ndarrayCopyFromJSBytes: PackedFunc; - ndarrayCopyToJSBytes: PackedFunc; + tensorEmpty: PackedFunc; + tensorCopyFromTo: PackedFunc; + tensorCopyFromJSBytes: PackedFunc; + tensorCopyToJSBytes: PackedFunc; arrayGetItem: PackedFunc; arrayGetSize: PackedFunc; arrayMake: PackedFunc; arrayConcat: PackedFunc; getSysLib: PackedFunc; - arrayCacheGet: PackedFunc; - arrayCacheUpdate: PackedFunc; - arrayCacheRemove: PackedFunc; - arrayCacheClear: PackedFunc; + tensorCacheGet: PackedFunc; + tensorCacheUpdate: PackedFunc; + tensorCacheRemove: PackedFunc; + tensorCacheClear: PackedFunc; arrayDecodeStorage: PackedFunc; paramModuleFromCache: PackedFunc; paramModuleFromCacheByName: PackedFunc; makeShapeTuple: PackedFunc; - ndarrayCreateView: PackedFunc; + tensorCreateView: PackedFunc; sampleTopPFromLogits: PackedFunc; sampleTopPFromProb: PackedFunc; applyRepetitionPenalty: PackedFunc; @@ -191,24 +191,24 @@ class RuntimeContext implements Disposable { ); this.moduleGetFunction = getGlobalFunc("ffi.ModuleGetFunction"); this.moduleImport = getGlobalFunc("ffi.ModuleImportModule"); - this.ndarrayEmpty = getGlobalFunc("runtime.TVMArrayAllocWithScope"); - this.ndarrayCopyFromTo = getGlobalFunc("runtime.TVMArrayCopyFromTo"); - this.ndarrayCopyFromJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyFromBytes"); - this.ndarrayCopyToJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyToBytes"); + this.tensorEmpty = getGlobalFunc("runtime.TVMTensorAllocWithScope"); + this.tensorCopyFromTo = getGlobalFunc("runtime.TVMTensorCopyFromTo"); + this.tensorCopyFromJSBytes = getGlobalFunc("tvmjs.runtime.TensorCopyFromBytes"); + this.tensorCopyToJSBytes = getGlobalFunc("tvmjs.runtime.TensorCopyToBytes"); this.arrayGetItem = getGlobalFunc("ffi.ArrayGetItem"); this.arrayGetSize = getGlobalFunc("ffi.ArraySize"); this.arrayMake = getGlobalFunc("ffi.Array"); this.arrayConcat = getGlobalFunc("tvmjs.runtime.ArrayConcat"); this.getSysLib = getGlobalFunc("ffi.SystemLib"); - this.arrayCacheGet = getGlobalFunc("vm.builtin.ndarray_cache.get"); - this.arrayCacheRemove = getGlobalFunc("vm.builtin.ndarray_cache.remove"); - this.arrayCacheUpdate = getGlobalFunc("vm.builtin.ndarray_cache.update"); - this.arrayCacheClear = getGlobalFunc("vm.builtin.ndarray_cache.clear"); + this.tensorCacheGet = getGlobalFunc("vm.builtin.tensor_cache.get"); + this.tensorCacheRemove = getGlobalFunc("vm.builtin.tensor_cache.remove"); + this.tensorCacheUpdate = getGlobalFunc("vm.builtin.tensor_cache.update"); + this.tensorCacheClear = getGlobalFunc("vm.builtin.tensor_cache.clear"); this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name"); this.makeShapeTuple = getGlobalFunc("ffi.Shape"); - this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); + this.tensorCreateView = getGlobalFunc("runtime.TVMTensorCreateView"); this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob"); this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); @@ -219,20 +219,19 @@ class RuntimeContext implements Disposable { dispose(): void { // call array cache clear to clear all cached items - this.arrayCacheClear.dispose(); + this.tensorCacheClear.dispose(); this.arrayGetItem.dispose(); this.arrayGetSize.dispose(); this.arrayMake.dispose(); this.arrayConcat.dispose(); - this.arrayCacheGet.dispose(); - this.arrayCacheRemove.dispose(); - this.arrayCacheUpdate.dispose(); - this.arrayCacheClear.dispose(); + this.tensorCacheGet.dispose(); + this.tensorCacheRemove.dispose(); + this.tensorCacheUpdate.dispose(); this.arrayDecodeStorage.dispose(); this.paramModuleFromCache.dispose(); this.paramModuleFromCacheByName.dispose(); this.makeShapeTuple.dispose(); - this.ndarrayCreateView.dispose(); + this.tensorCreateView.dispose(); this.sampleTopPFromLogits.dispose(); this.applyRepetitionPenalty.dispose(); this.applyPresenceAndFrequencyPenalty.dispose(); @@ -339,7 +338,7 @@ const DeviceStrToEnum: Record = { }; /** - * Represent a runtime context where a NDArray can reside. + * Represent a runtime context where a Tensor can reside. */ export class DLDevice { /** The device type code of the device. */ @@ -399,7 +398,7 @@ const DLDataTypeCodeToStr: Record = { }; /** - * Runtime data type of NDArray. + * Runtime data type of Tensor. */ export class DLDataType { /** The type code */ @@ -450,7 +449,7 @@ export class TVMObject implements Disposable { dispose(): void { if (this.handle != 0) { this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(this.handle) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(this.handle) ); this.handle = 0; } @@ -497,10 +496,10 @@ class PackedFuncCell extends TVMObject { } /** - * n-dimnesional array. + * Tensor( n-dimnesional array). */ -export class NDArray extends TVMObject { +export class Tensor extends TVMObject { /** Number of dimensions. */ ndim: number; /** Data type of the array. */ @@ -572,12 +571,12 @@ export class NDArray extends TVMObject { * @param dtype The data type of the new array. * @returns The new sliced ndarray. */ - view(shape: Array, dtype?: string): NDArray { + view(shape: Array, dtype?: string): Tensor { const shapeArray = shape.map((value) => new Scalar(value, "int")); if (dtype === undefined) { dtype = this.dtype; } - return this.ctx.ndarrayCreateView( + return this.ctx.tensorCreateView( this, this.ctx.makeShapeTuple(...shapeArray), this.dtype, @@ -591,24 +590,24 @@ export class NDArray extends TVMObject { */ getDataPtr(): Pointer { if (this.handle === 0) { - throw Error("NDArray has already been disposed"); + throw Error("Tensor has already been disposed"); } return this.dataPtr; } /** - * Copy data from another NDArray or javascript array. + * Copy data from another Tensor or javascript array. * The number of elements must match. * * @param data The source data array. * @returns this */ copyFrom( - data: NDArray | Array | Float32Array | Float64Array | + data: Tensor | Array | Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array | Uint8ClampedArray ): this { - if (data instanceof NDArray) { - this.ctx.ndarrayCopyFromTo(data, this); + if (data instanceof Tensor) { + this.ctx.tensorCopyFromTo(data, this); return this; } else { const size = this.shape.reduce((a, b) => { @@ -660,23 +659,23 @@ export class NDArray extends TVMObject { if (nbytes != data.length) { throw new Error("Expect the data's length equals nbytes=" + nbytes); } - this.ctx.ndarrayCopyFromJSBytes(this, data); + this.ctx.tensorCopyFromJSBytes(this, data); return this; } /** - * Return a copied Uint8Array of the raw bytes in the NDArray. + * Return a copied Uint8Array of the raw bytes in the Tensor. * @returns The result array. */ toRawBytes(): Uint8Array { if (this.device.deviceType != DeviceStrToEnum.cpu) { throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); } - return this.ctx.ndarrayCopyToJSBytes(this) as Uint8Array; + return this.ctx.tensorCopyToJSBytes(this) as Uint8Array; } /** - * Return a TypedArray copy of the NDArray, the specific type depends on - * the dtype of the NDArray. + * Return a TypedArray copy of the Tensor, the specific type depends on + * the dtype of the Tensor. * @returns The result array. */ toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { @@ -834,7 +833,7 @@ export type InitProgressCallback = (report: InitProgressReport) => void; /** * TVM runtime instance. * - * All objects(NDArray, Module, PackedFunc) returned by TVM runtim function call + * All objects(Tensor, Module, PackedFunc) returned by TVM runtim function call * and PackedFunc instance are tracked through a scope mechanism that will get * auto-released when we call EndScope. * @@ -1179,7 +1178,7 @@ export class Instance implements Disposable { } //----------------------------------------------- - // Native NDArray Cache Support + // Native Tensor Cache Support //----------------------------------------------- /** * Register a call back for fetch progress. @@ -1213,53 +1212,53 @@ export class Instance implements Disposable { } /** - * Get NDArray from cache. + * Get Tensor from cache. * @param name The name of array. * @returns The result. */ - ndarrayCacheGet(name: string): NDArray | undefined { - return this.ctx.arrayCacheGet(name); + tensorCacheGet(name: string): Tensor | undefined { + return this.ctx.tensorCacheGet(name); } /** - * Get NDArray from cache. + * Get Tensor from cache. * @param name The name of array. * @returns The result. */ - ndarrayCacheRemove(name: string): NDArray | undefined { - return this.ctx.arrayCacheRemove(name); + tensorCacheRemove(name: string): Tensor | undefined { + return this.ctx.tensorCacheRemove(name); } /** - * Update the ndarray cache. + * Update the tensor cache. * @param name The name of the array. * @param arr The content. */ - ndarrayCacheUpdate(name: string, arr: NDArray, override = false) { - this.ctx.arrayCacheUpdate(name, arr, this.scalar(override ? 1 : 0, "int32")); + tensorCacheUpdate(name: string, arr: Tensor, override = false) { + this.ctx.tensorCacheUpdate(name, arr, this.scalar(override ? 1 : 0, "int32")); } /** - * Update the ndarray cache. + * Update the tensor cache. * @param name The name of the array. * @param arr The content. */ - ndarrayCacheClear() { - this.ctx.arrayCacheClear(); + tensorCacheClear() { + this.ctx.tensorCacheClear(); } /** - * Given cacheUrl, search up items to fetch based on cacheUrl/ndarray-cache.json + * Given cacheUrl, search up items to fetch based on cacheUrl/tensor-cache.json * - * @param ndarrayCacheUrl The cache url. + * @param tensorCacheUrl The cache url. * @param device The device to be fetched to. * @param cacheScope The scope identifier of the cache * @param cacheType The type of the cache: "cache" or "indexedDB" * @param signal An optional AbortSignal to abort the fetch * @returns The meta data */ - async fetchNDArrayCache( - ndarrayCacheUrl: string, + async fetchTensorCache( + tensorCacheUrl: string, device: DLDevice, cacheScope = "tvmjs", cacheType = "cache", @@ -1274,28 +1273,28 @@ export class Instance implements Disposable { console.error("Unsupported cacheType: " + cacheType + ", using default ArtifactCache."); artifactCache = new ArtifactCache(cacheScope); } - const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; const list = await artifactCache.fetchWithCache(jsonUrl, "json"); - await this.fetchNDArrayCacheInternal( - ndarrayCacheUrl, - list["records"] as Array, device, artifactCache, + await this.fetchTensorCacheInternal( + tensorCacheUrl, + list["records"] as Array, device, artifactCache, signal); this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record) }; } /** - * Fetch list of NDArray into the NDArrayCache. + * Fetch list of Tensor into the TensorCache. * - * @param ndarrayCacheUrl The cache url. + * @param tensorCacheUrl The cache url. * @param list The list of array data. * @param device The device to store the data to. * @param artifactCache The artifact cache * @param signal An optional AbortSignal to abort the fetch */ - private async fetchNDArrayCacheInternal( - ndarrayCacheUrl: string, - list: Array, + private async fetchTensorCacheInternal( + tensorCacheUrl: string, + list: Array, device: DLDevice, artifactCache: ArtifactCacheTemplate, signal?: AbortSignal, @@ -1310,7 +1309,7 @@ export class Instance implements Disposable { let fetchedShards = 0; let timeElapsed = 0; - const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)); + const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, tensorCacheUrl).href)); // `loading`: we have finished downloading (or already cacheOnly) and are loading onto WebGPU const reportCallback = (iter: number, loading = false) => { @@ -1351,7 +1350,7 @@ export class Instance implements Disposable { // Download params [start, end) from `list` for (let i = start; i < end; i++) { const shard = list[i]; - const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; + const dataUrl = new URL(shard.dataPath, tensorCacheUrl).href; try { await artifactCache.addToCache(dataUrl, "arraybuffer", signal); } catch (err) { @@ -1360,7 +1359,7 @@ export class Instance implements Disposable { } timeElapsed = Math.ceil((perf.now() - tstart) / 1000); fetchedBytes += shard.nbytes; - reportCallback(fetchedShards++, /*loading=*/false); + reportCallback(++fetchedShards, /*loading=*/false); } } // We launch 4 parallel for loops to limit the max concurrency to 4 download @@ -1374,10 +1373,14 @@ export class Instance implements Disposable { ]); } + // Reset for the loading phase to avoid double counting with download phase + fetchedBytes = 0; + fetchedShards = 0; + // Then iteratively, load the shard from cache for (let i = 0; i < list.length; ++i) { const shard = list[i]; - const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; + const dataUrl = new URL(shard.dataPath, tensorCacheUrl).href; let buffer; try { buffer = await artifactCache.fetchWithCache(dataUrl, "arraybuffer"); @@ -1399,7 +1402,7 @@ export class Instance implements Disposable { this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); // then async stream into GPU if needed if (device.deviceType === DeviceStrToEnum.cpu) { - this.ndarrayCacheUpdate(rec.name, cpu_arr, false); + this.tensorCacheUpdate(rec.name, cpu_arr, false); cpu_arr.dispose(); } else { // allocate a gpu arr and async copy to it. @@ -1410,7 +1413,7 @@ export class Instance implements Disposable { }); gpu_arr.copyFrom(cpu_arr); await device.sync(); - this.ndarrayCacheUpdate(rec.name, gpu_arr, false); + this.tensorCacheUpdate(rec.name, gpu_arr, false); cpu_arr.dispose(); gpu_arr.dispose(); } @@ -1422,7 +1425,9 @@ export class Instance implements Disposable { throw err; } } - reportCallback(i + 1, /*loading=*/true); + fetchedBytes += shard.nbytes; + timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + reportCallback(++fetchedShards, /*loading=*/true); } } @@ -1463,7 +1468,7 @@ export class Instance implements Disposable { } /** - * Create an empty {@link NDArray} with given shape and dtype. + * Create an empty {@link Tensor} with given shape and dtype. * * @param shape The shape of the array. * @param dtype The data type of the array. @@ -1474,13 +1479,13 @@ export class Instance implements Disposable { shape: Array | number, dtype: string | DLDataType = "float32", dev: DLDevice = this.device("cpu", 0) - ): NDArray { + ): Tensor { shape = typeof shape === "number" ? [shape] : shape; - return this.ctx.ndarrayEmpty(this.makeShapeTuple(shape), dtype, dev, null); + return this.ctx.tensorEmpty(this.makeShapeTuple(shape), dtype, dev, null); } /** - * Create am uniform {@link NDArray} with given shape. + * Create am uniform {@link Tensor} with given shape. * * @param shape The shape of the array. * @param low The low value. @@ -1493,7 +1498,7 @@ export class Instance implements Disposable { low: number, high: number, dev: DLDevice - ): NDArray { + ): Tensor { const ret = this.empty(shape, "float32", dev); const size = shape.reduce((a, b) => { return a * b; @@ -1521,7 +1526,7 @@ export class Instance implements Disposable { * @param top_p The top_p * @returns The sampled index. */ - sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): number { + sampleTopPFromLogits(logits: Tensor, temperature: number, top_p: number): number { return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, this.rng.randomFloat()); } @@ -1532,7 +1537,7 @@ export class Instance implements Disposable { * @param top_p The top_p * @returns The sampled index. */ - sampleTopPFromProb(prob: NDArray, top_p: number): number { + sampleTopPFromProb(prob: Tensor, top_p: number): number { return this.ctx.sampleTopPFromProb(prob, top_p, this.rng.randomFloat()); } @@ -1542,7 +1547,7 @@ export class Instance implements Disposable { * @param token_ids The appeared token ids. * @param penalty The penalty factor. */ - applyRepetitionPenalty(logits: NDArray, token_ids: NDArray, penalty: number) { + applyRepetitionPenalty(logits: Tensor, token_ids: Tensor, penalty: number) { return this.ctx.applyRepetitionPenalty(logits, token_ids, penalty); } @@ -1556,9 +1561,9 @@ export class Instance implements Disposable { * @param frequency_penalty The penalty factor. */ applyPresenceAndFrequencyPenalty( - logits: NDArray, - token_ids: NDArray, - token_freqs: NDArray, + logits: Tensor, + token_ids: Tensor, + token_freqs: Tensor, presence_penalty: number, frequency_penalty: number ) { @@ -1572,7 +1577,7 @@ export class Instance implements Disposable { * @param logits The input logits before softmax w/ temperature. * @param temperature The temperature factor. */ - applySoftmaxWithTemperature(logits: NDArray, temperature: number) { + applySoftmaxWithTemperature(logits: Tensor, temperature: number) { return this.ctx.applySoftmaxWithTemperature(logits, temperature); } @@ -1587,11 +1592,11 @@ export class Instance implements Disposable { /** * Show image in canvas. * - * @param dataRGBA Image array in height x width uint32 NDArray RGBA format on GPU. + * @param dataRGBA Image array in height x width uint32 Tensor RGBA format on GPU. */ - showImage(dataRGBA: NDArray) { + showImage(dataRGBA: Tensor) { if (dataRGBA.shape.length != 2) { - throw Error("Require a height x width uint32 NDArray in RGBA" + + throw Error("Require a height x width uint32 Tensor in RGBA" + "get shape=" + dataRGBA.shape.toString() + " instead." ); } @@ -1600,7 +1605,7 @@ export class Instance implements Disposable { "get " + DeviceEnumToStr[dataRGBA.device.deviceType] + " instead."); } if (dataRGBA.dtype != "uint32") { - throw Error("Require a height x width uint32 NDArray in RGBA, " + + throw Error("Require a height x width uint32 Tensor in RGBA, " + "get " + dataRGBA.dtype + " instead."); } this.lib.webGPUContext?.drawImageFromBuffer( @@ -1644,11 +1649,11 @@ export class Instance implements Disposable { } /** - * Join a sequence of NDArrays that represent embeddings. - * @param inputs A list of embeddings in NDArrays, each array i has shape (m_i, hidden_size). - * @returns An NDArray of shape (\sum_{i} {m}, hidden_size) + * Join a sequence of Tensors that represent embeddings. + * @param inputs A list of embeddings in Tensors, each array i has shape (m_i, hidden_size). + * @returns An Tensor of shape (\sum_{i} {m}, hidden_size) */ - concatEmbeddings(embeddings: Array): NDArray { + concatEmbeddings(embeddings: Array): Tensor { // 1. Check shape validity const hidden_size = embeddings[0].shape[1]; embeddings.forEach((input) => { @@ -1664,7 +1669,7 @@ export class Instance implements Disposable { "not found, but called concatEmbeddings." ); } - return this.ctx.concatEmbeddings(...embeddings) as NDArray; + return this.ctx.concatEmbeddings(...embeddings) as Tensor; } /** @@ -2033,9 +2038,9 @@ export class Instance implements Disposable { stack.storeI32(argZeroPaddingOffset, 0); // clear off the extra zero padding after ptr storage stack.storeI32(argValueOffset + SizeOf.I32, 0); - if (val instanceof NDArray) { + if (val instanceof Tensor) { if (!val.isView) { - stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINDArray); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFITensor); stack.storePtr(argValueOffset, val.getHandle()); } else { stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDLTensorPtr); @@ -2225,15 +2230,15 @@ export class Instance implements Disposable { case TypeIndex.kTVMFFIOpaquePtr: { return this.memory.loadPointer(valuePtr); } - case TypeIndex.kTVMFFINDArray: { + case TypeIndex.kTVMFFITensor: { return this.ctx.attachToCurrentScope( - new NDArray(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false) + new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false) ); } case TypeIndex.kTVMFFIDLTensorPtr: { assert(callbackArg); // no need to attach as we are only looking at view - return new NDArray(this.memory.loadPointer(valuePtr), this.lib, this.ctx, true); + return new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, true); } case TypeIndex.kTVMFFIFunction: { return this.ctx.attachToCurrentScope( @@ -2253,7 +2258,7 @@ export class Instance implements Disposable { const strObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) ); return result; } @@ -2264,7 +2269,7 @@ export class Instance implements Disposable { const strObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) ); return result; } @@ -2275,7 +2280,7 @@ export class Instance implements Disposable { const bytesObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsBytes(bytesObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(bytesObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(bytesObjPtr) ); return result; } diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 27d68d887c32..3c905f3800ef 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -476,7 +476,7 @@ export class WebGPUContext { this.device.queue.writeBuffer( this.gpuBufferFromPtr(toPtr), toOffset, - rawBytes, + rawBytes as GPUAllowSharedBufferSource, 0, nbytes ); @@ -861,7 +861,7 @@ export class WebGPUContext { this.device.queue.writeBuffer( this.gpuBufferFromPtr(to), toOffset, - rawBytes, + rawBytes as GPUAllowSharedBufferSource, 0, nbytes ); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js index 3c6980cc1f06..83ac61156430 100644 --- a/web/tests/node/test_packed_func.js +++ b/web/tests/node/test_packed_func.js @@ -158,7 +158,7 @@ test("ExceptionPassing", () => { tvm.endScope(); }); -test("NDArrayCbArg", () => { +test("TensorCbArg", () => { tvm.beginScope(); let use_count = tvm.getGlobalFunc("testing.object_use_count"); let record = []; diff --git a/web/tests/node/test_ndarray.js b/web/tests/node/test_tensor.js similarity index 100% rename from web/tests/node/test_ndarray.js rename to web/tests/node/test_tensor.js diff --git a/web/tests/python/relax_rpc_test.py b/web/tests/python/relax_rpc_test.py index e55ad1935122..c21b98564d78 100644 --- a/web/tests/python/relax_rpc_test.py +++ b/web/tests/python/relax_rpc_test.py @@ -74,8 +74,8 @@ def check(remote): vm = relax.VirtualMachine(remote.system_lib(), device=dev) adata = np.random.uniform(size=n).astype(dtype) bdata = np.random.uniform(size=n).astype(dtype) - a = tvm.nd.array(adata, dev) - b = tvm.nd.array(bdata, dev) + a = tvm.runtime.tensor(adata, dev) + b = tvm.runtime.tensor(bdata, dev) vm.set_input("main", a, b) vm.invoke_stateful("main") c = vm.get_outputs("main") diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index 8925da00a489..f1e1c828885f 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -64,14 +64,14 @@ def check(remote, size): # basic function checks. dev = remote.webgpu(0) adata = np.random.uniform(size=size).astype(A.dtype) - a = tvm.nd.array(adata, dev) - b = tvm.nd.array(np.zeros(size, dtype=A.dtype), dev) + a = tvm.runtime.tensor(adata, dev) + b = tvm.runtime.tensor(np.zeros(size, dtype=A.dtype), dev) np.testing.assert_equal(a.numpy(), adata) f1 = remote.system_lib() addone = f1.get_function("main") addone(a, b) - np.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(b.numpy(), np.log(np.abs(a.numpy()) + 1), atol=1e-5, rtol=1e-5) print("Test pass..") check(remote, 71821 * 32)